feat: squash merge all changes from develop-0.1.75

Squash of 124 commits from the legacy develop branch (develop-0.1.75)
onto a clean v0.1.75 upstream base, to simplify future upstream merges.

Key changes included:
- Refactor scope-level rate limiting to model-level rate limiting
- Antigravity gateway service improvements (smart retry, error policy)
- Digest session store (flat cache replacing Trie-based store)
- Client disconnect detection during streaming
- Gemini messages compatibility service enhancements
- Scheduler shuffle for thundering herd prevention
- Session hash generation improvements
- Frontend customizations (WeChat service, HomeView, etc.)
- Ops monitoring scope cleanup
This commit is contained in:
erio
2026-02-09 12:32:35 +08:00
parent e4d74ae11d
commit 084e0adb34
63 changed files with 6396 additions and 2133 deletions

View File

@@ -17,6 +17,7 @@ jobs:
go-version-file: backend/go.mod
check-latest: false
cache: true
cache-dependency-path: backend/go.sum
- name: Verify Go version
run: |
go version | grep -q 'go1.25.7'
@@ -36,6 +37,7 @@ jobs:
go-version-file: backend/go.mod
check-latest: false
cache: true
cache-dependency-path: backend/go.sum
- name: Verify Go version
run: |
go version | grep -q 'go1.25.7'

1
.gitignore vendored
View File

@@ -78,6 +78,7 @@ Desktop.ini
# ===================
tmp/
temp/
logs/
*.tmp
*.temp
*.log

723
AGENTS.md Normal file
View File

@@ -0,0 +1,723 @@
# Sub2API 开发说明
## 版本管理策略
### 版本号规则
我们在官方版本号后面添加自己的小版本号:
- 官方版本:`v0.1.68`
- 我们的版本:`v0.1.68.1``v0.1.68.2`(递增)
### 分支策略
| 分支 | 说明 |
|------|------|
| `main` | 我们的主分支,包含所有定制功能 |
| `release/custom-X.Y.Z` | 基于官方 `vX.Y.Z` 的发布分支 |
| `upstream/main` | 上游官方仓库 |
---
## 发布流程(基于新官方版本)
当官方发布新版本(如 `v0.1.69`)时:
### 1. 同步上游并创建发布分支
```bash
# 获取上游最新代码
git fetch upstream --tags
# 基于官方标签创建新的发布分支
git checkout v0.1.69 -b release/custom-0.1.69
# 合并我们的 main 分支(包含所有定制功能)
git merge main --no-edit
# 解决可能的冲突后继续
```
### 2. 更新版本号并打标签
```bash
# 更新版本号文件
echo "0.1.69.1" > backend/cmd/server/VERSION
git add backend/cmd/server/VERSION
git commit -m "chore: bump version to 0.1.69.1"
# 打上我们自己的标签
git tag v0.1.69.1
# 推送分支和标签
git push origin release/custom-0.1.69
git push origin v0.1.69.1
```
### 3. 更新 main 分支
```bash
# 将发布分支合并回 main保持 main 包含最新定制功能
git checkout main
git merge release/custom-0.1.69
git push origin main
```
---
## 热修复发布(在现有版本上修复)
当需要在当前版本上发布修复时:
```bash
# 在当前发布分支上修复
git checkout release/custom-0.1.68
# ... 进行修复 ...
git commit -m "fix: 修复描述"
# 递增小版本号
echo "0.1.68.2" > backend/cmd/server/VERSION
git add backend/cmd/server/VERSION
git commit -m "chore: bump version to 0.1.68.2"
# 打标签并推送
git tag v0.1.68.2
git push origin release/custom-0.1.68
git push origin v0.1.68.2
# 同步修复到 main
git checkout main
git cherry-pick <fix-commit-hash>
git push origin main
```
---
## 服务器部署流程
### 前置条件
- 本地已配置 SSH 别名 `clicodeplus` 连接到服务器
- 服务器部署目录:`/root/sub2api`(正式)、`/root/sub2api-beta`(测试)
- 服务器使用 Docker Compose 部署
### 部署环境说明
| 环境 | 目录 | 端口 | 数据库 | 容器名 |
|------|------|------|--------|--------|
| 正式 | `/root/sub2api` | 8080 | `sub2api` | `sub2api` |
| Beta | `/root/sub2api-beta` | 8084 | `beta` | `sub2api-beta` |
### 外部数据库
正式和 Beta 环境**共用外部 PostgreSQL 数据库**(非容器内数据库),配置在 `.env` 文件中:
- `DATABASE_HOST`:外部数据库地址
- `DATABASE_SSLMODE`SSL 模式(通常为 `require`
- `POSTGRES_USER` / `POSTGRES_DB`:用户名和数据库名
#### 数据库操作命令
通过 SSH 在服务器上执行数据库操作:
```bash
# 正式环境 - 查询迁移记录
ssh clicodeplus "source /root/sub2api/deploy/.env && PGPASSWORD=\"\$POSTGRES_PASSWORD\" psql -h \$DATABASE_HOST -U \$POSTGRES_USER -d \$POSTGRES_DB -c 'SELECT * FROM schema_migrations ORDER BY applied_at DESC LIMIT 5;'"
# Beta 环境 - 查询迁移记录
ssh clicodeplus "source /root/sub2api-beta/deploy/.env && PGPASSWORD=\"\$POSTGRES_PASSWORD\" psql -h \$DATABASE_HOST -U \$POSTGRES_USER -d \$POSTGRES_DB -c 'SELECT * FROM schema_migrations ORDER BY applied_at DESC LIMIT 5;'"
# Beta 环境 - 清除指定迁移记录(重新执行迁移)
ssh clicodeplus "source /root/sub2api-beta/deploy/.env && PGPASSWORD=\"\$POSTGRES_PASSWORD\" psql -h \$DATABASE_HOST -U \$POSTGRES_USER -d \$POSTGRES_DB -c \"DELETE FROM schema_migrations WHERE filename LIKE '%049%';\""
# Beta 环境 - 更新账号数据
ssh clicodeplus "source /root/sub2api-beta/deploy/.env && PGPASSWORD=\"\$POSTGRES_PASSWORD\" psql -h \$DATABASE_HOST -U \$POSTGRES_USER -d \$POSTGRES_DB -c \"UPDATE accounts SET credentials = credentials - 'model_mapping' WHERE platform = 'antigravity';\""
```
> **注意**:使用 `source .env` 加载环境变量,避免在命令行中暴露密码。
### 部署步骤
**重要:每次部署都必须递增版本号!**
#### 0. 递增版本号(本地操作)
每次部署前,先在本地递增小版本号:
```bash
# 查看当前版本号
cat backend/cmd/server/VERSION
# 假设当前是 0.1.69.1
# 递增版本号
echo "0.1.69.2" > backend/cmd/server/VERSION
git add backend/cmd/server/VERSION
git commit -m "chore: bump version to 0.1.69.2"
git push origin release/custom-0.1.69
```
#### 1. 服务器拉取代码
```bash
ssh clicodeplus "cd /root/sub2api && git fetch fork && git checkout -B release/custom-0.1.69 fork/release/custom-0.1.69"
```
#### 2. 服务器构建镜像
```bash
ssh clicodeplus "cd /root/sub2api && docker build --no-cache -t sub2api:latest -f Dockerfile ."
```
#### 3. 更新镜像标签并重启服务
```bash
ssh clicodeplus "docker tag sub2api:latest weishaw/sub2api:latest"
ssh clicodeplus "cd /root/sub2api/deploy && docker compose up -d --force-recreate sub2api"
```
#### 4. 验证部署
```bash
# 查看启动日志
ssh clicodeplus "docker logs sub2api --tail 20"
# 确认版本号(必须与步骤 0 中设置的版本号一致)
ssh clicodeplus "cat /root/sub2api/backend/cmd/server/VERSION"
# 检查容器状态
ssh clicodeplus "docker ps | grep sub2api"
```
---
## Beta 并行部署(不影响现网)
目标:在同一台服务器上并行启动一个 beta 实例(例如端口 `8084`**严禁改动/重启**现网实例(默认目录 `/root/sub2api`)。
### 设计原则
- **新目录**beta 使用独立目录,例如 `/root/sub2api-beta`
- **敏感信息只放 `.env`**beta 的数据库密码、JWT_SECRET 等只写入 `/root/sub2api-beta/deploy/.env`,不要提交到 git。
- **独立 Compose Project**:通过 `docker compose -p sub2api-beta ...` 启动,确保 network/volume 隔离。
- **独立端口**:通过 `.env``SERVER_PORT` 映射宿主机端口(例如 `8084:8080`)。
### 前置检查
```bash
# 1) 确保 8084 未被占用
ssh clicodeplus "ss -ltnp | grep :8084 || echo '8084 is free'"
# 2) 确认现网容器还在(只读检查)
ssh clicodeplus "docker ps --format 'table {{.Names}}\t{{.Image}}\t{{.Ports}}' | sed -n '1,200p'"
```
### 首次部署步骤
```bash
# 0) 进入服务器
ssh clicodeplus
# 1) 克隆代码到新目录(示例使用你的 fork
cd /root
git clone https://github.com/touwaeriol/sub2api.git sub2api-beta
cd /root/sub2api-beta
git checkout release/custom-0.1.71
# 2) 准备 beta 的 .env敏感信息只写这里
cd /root/sub2api-beta/deploy
# 推荐:从现网 .env 复制,保证除 DB 名/用户/端口外完全一致
cp -f /root/sub2api/deploy/.env ./.env
# 仅修改以下三项(其他保持不变)
perl -pi -e 's/^SERVER_PORT=.*/SERVER_PORT=8084/' ./.env
perl -pi -e 's/^POSTGRES_USER=.*/POSTGRES_USER=beta/' ./.env
perl -pi -e 's/^POSTGRES_DB=.*/POSTGRES_DB=beta/' ./.env
# 3) 写 compose override避免与现网容器名冲突镜像使用本地构建的 sub2api:beta
cat > docker-compose.override.yml <<'YAML'
services:
sub2api:
image: sub2api:beta
container_name: sub2api-beta
redis:
container_name: sub2api-beta-redis
YAML
# 4) 构建 beta 镜像(基于当前代码)
cd /root/sub2api-beta
docker build -t sub2api:beta -f Dockerfile .
# 5) 启动 beta独立 project确保不影响现网
cd /root/sub2api-beta/deploy
docker compose -p sub2api-beta --env-file .env -f docker-compose.yml -f docker-compose.override.yml up -d
# 6) 验证 beta
curl -fsS http://127.0.0.1:8084/health
docker logs sub2api-beta --tail 50
```
### 数据库配置约定beta
- 数据库地址/SSL/密码:与现网一致(从现网 `.env` 复制即可)。
- 仅修改:
- `POSTGRES_USER=beta`
- `POSTGRES_DB=beta`
注意:需要数据库侧已存在 `beta` 用户与 `beta` 数据库,并授予权限;否则容器会启动失败并不断重启。
### 更新 beta拉代码 + 仅重建 beta 容器)
```bash
ssh clicodeplus "set -e; cd /root/sub2api-beta && git fetch --all --tags && git checkout -f release/custom-0.1.71 && git reset --hard origin/release/custom-0.1.71"
ssh clicodeplus "cd /root/sub2api-beta && docker build -t sub2api:beta -f Dockerfile ."
ssh clicodeplus "cd /root/sub2api-beta/deploy && docker compose -p sub2api-beta --env-file .env -f docker-compose.yml -f docker-compose.override.yml up -d --no-deps --force-recreate sub2api"
ssh clicodeplus "curl -fsS http://127.0.0.1:8084/health"
```
### 停止/回滚 beta只影响 beta
```bash
ssh clicodeplus "cd /root/sub2api-beta/deploy && docker compose -p sub2api-beta -f docker-compose.yml -f docker-compose.override.yml down"
```
---
## 服务器首次部署
### 1. 克隆代码并配置远程仓库
```bash
ssh clicodeplus
cd /root
git clone https://github.com/Wei-Shaw/sub2api.git
cd sub2api
# 添加 fork 仓库
git remote add fork https://github.com/touwaeriol/sub2api.git
```
### 2. 切换到定制分支并配置环境
```bash
git fetch fork
git checkout -B release/custom-0.1.69 fork/release/custom-0.1.69
cd deploy
cp .env.example .env
vim .env # 配置 DATABASE_URL, REDIS_URL, JWT_SECRET 等
```
### 3. 构建并启动
```bash
cd /root/sub2api
docker build -t sub2api:latest -f Dockerfile .
docker tag sub2api:latest weishaw/sub2api:latest
cd deploy && docker compose up -d
```
### 6. 启动服务
```bash
# 进入 deploy 目录
cd deploy
# 启动所有服务PostgreSQL、Redis、sub2api
docker compose up -d
# 查看服务状态
docker compose ps
```
### 7. 验证部署
```bash
# 查看应用日志
docker logs sub2api --tail 50
# 检查健康状态
curl http://localhost:8080/health
# 确认版本号
cat /root/sub2api/backend/cmd/server/VERSION
```
### 8. 常用运维命令
```bash
# 查看实时日志
docker logs -f sub2api
# 重启服务
docker compose restart sub2api
# 停止所有服务
docker compose down
# 停止并删除数据卷(慎用!会删除数据库数据)
docker compose down -v
# 查看资源使用情况
docker stats sub2api
```
---
## 定制功能说明
当前定制分支包含以下功能(相对于官方版本):
### UI/UX 定制
| 功能 | 说明 |
|------|------|
| 首页优化 | 面向用户的价值主张设计 |
| 移除 GitHub 链接 | 用户菜单中不显示 GitHub 导航 |
| 微信客服按钮 | 首页悬浮微信客服入口 |
| 限流时间精确显示 | 账号限流时间显示精确到秒 |
### Antigravity 平台增强
| 功能 | 说明 |
|------|------|
| Scope 级别限流 | 按配额域claude/gemini_text/gemini_image独立限流避免整个账号被锁定 |
| 模型级别限流 | 按具体模型(如 claude-opus-4-5独立限流更精细的限流控制 |
| 限流预检查 | 调度时预检查账号/模型限流状态,避免选中已限流账号 |
| 秒级冷却时间 | 支持 429 响应的秒级精确冷却时间 |
| 身份注入优化 | 模型身份信息注入 + 静默边界防止身份泄露 |
| thoughtSignature 修复 | Gemini 3 函数调用 400 错误修复 |
| max_tokens 自动修正 | 自动修正 max_tokens <= budget_tokens 导致的 400 错误 |
### 调度算法优化
| 功能 | 说明 |
|------|------|
| 分层过滤选择 | 调度算法从全排序改为分层过滤,提升性能 |
| LRU 随机选择 | 相同 LRU 时间时随机选择,避免账号集中 |
| 限流等待阈值配置化 | 可配置的限流等待阈值 |
### 运维增强
| 功能 | 说明 |
|------|------|
| Scope 限流统计 | 运维界面展示 Antigravity 账号 scope 级别限流统计 |
| 账号限流状态显示 | 账号列表显示 scope 和模型级别限流状态 |
| 清除限流按钮增强 | 有 scope/模型限流时也显示清除限流按钮 |
### 其他修复
| 功能 | 说明 |
|------|------|
| .gitattributes | 确保迁移文件使用 LF 换行符(解决 Windows 下 SQL 摘要不一致) |
| 部署配置优化 | DATABASE_HOST 和 DATABASE_SSLMODE 可通过 .env 配置 |
---
## 注意事项
1. **前端必须打包进镜像**:使用 `docker build` 在服务器上构建Dockerfile 会自动编译前端并 embed 到后端二进制中
2. **镜像标签**docker-compose.yml 使用 `weishaw/sub2api:latest`,本地构建后需要 `docker tag` 覆盖
3. **Windows 换行符问题**:已通过 `.gitattributes` 解决,确保 `*.sql` 文件始终使用 LF
4. **版本号管理**:每次发布必须更新 `backend/cmd/server/VERSION` 并打标签
5. **合并冲突**:合并上游新版本时,重点关注以下文件可能的冲突:
- `backend/internal/service/antigravity_gateway_service.go`
- `backend/internal/service/gateway_service.go`
- `backend/internal/pkg/antigravity/request_transformer.go`
---
## Go 代码规范
### 1. 函数设计
#### 单一职责原则
- **函数行数**:单个函数常规不应超过 **30 行**,超过时应拆分为子函数。若某段逻辑确实不可拆分(如复杂的状态机、协议解析等),可以例外,但需添加注释说明原因
- **嵌套层级**:避免超过 3 层嵌套,使用 early return 减少嵌套
```go
// ❌ 不推荐:深层嵌套
func process(data []Item) {
for _, item := range data {
if item.Valid {
if item.Type == "A" {
if item.Status == "active" {
// 业务逻辑...
}
}
}
}
}
// ✅ 推荐early return
func process(data []Item) {
for _, item := range data {
if !item.Valid {
continue
}
if item.Type != "A" {
continue
}
if item.Status != "active" {
continue
}
// 业务逻辑...
}
}
```
#### 复杂逻辑提取
将复杂的条件判断或处理逻辑提取为独立函数:
```go
// ❌ 不推荐:内联复杂逻辑
if resp.StatusCode == 429 || resp.StatusCode == 503 {
// 80+ 行处理逻辑...
}
// ✅ 推荐:提取为独立函数
result := handleRateLimitResponse(resp, params)
switch result.action {
case actionRetry:
continue
case actionBreak:
return result.resp, nil
}
```
### 2. 重复代码消除
#### 配置获取模式
将重复的配置获取逻辑提取为方法:
```go
// ❌ 不推荐:重复代码
logBody := s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBody
maxBytes := 2048
if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes > 0 {
maxBytes = s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
}
// ✅ 推荐:提取为方法
func (s *Service) getLogConfig() (logBody bool, maxBytes int) {
maxBytes = 2048
if s.settingService == nil || s.settingService.cfg == nil {
return false, maxBytes
}
cfg := s.settingService.cfg.Gateway
if cfg.LogUpstreamErrorBodyMaxBytes > 0 {
maxBytes = cfg.LogUpstreamErrorBodyMaxBytes
}
return cfg.LogUpstreamErrorBody, maxBytes
}
```
### 3. 常量管理
#### 避免魔法数字
所有硬编码的数值都应定义为常量:
```go
// ❌ 不推荐
if retryDelay >= 10*time.Second {
resetAt := time.Now().Add(30 * time.Second)
}
// ✅ 推荐
const (
rateLimitThreshold = 10 * time.Second
defaultRateLimitDuration = 30 * time.Second
)
if retryDelay >= rateLimitThreshold {
resetAt := time.Now().Add(defaultRateLimitDuration)
}
```
#### 注释引用常量名
在注释中引用常量名而非硬编码值:
```go
// ❌ 不推荐
// < 10s: 等待后重试
// ✅ 推荐
// < rateLimitThreshold: 等待后重试
```
### 4. 错误处理
#### 使用结构化日志
优先使用 `slog` 进行结构化日志记录:
```go
// ❌ 不推荐
log.Printf("%s status=%d model_rate_limit_failed model=%s error=%v", prefix, statusCode, modelName, err)
// ✅ 推荐
slog.Error("failed to set model rate limit",
"prefix", prefix,
"status_code", statusCode,
"model", modelName,
"error", err,
)
```
### 5. 测试规范
#### Mock 函数签名同步
修改函数签名时,必须同步更新所有测试中的 mock 函数:
```go
// 如果修改了 handleError 签名
handleError func(..., groupID int64, sessionHash string) *Result
// 必须同步更新测试中的 mock
handleError: func(..., groupID int64, sessionHash string) *Result {
return nil
},
```
#### 测试构建标签
统一使用测试构建标签:
```go
//go:build unit
package service
```
### 6. 时间格式解析
#### 使用标准库
优先使用 `time.ParseDuration`,支持所有 Go duration 格式:
```go
// ❌ 不推荐:手动限制格式
if !strings.HasSuffix(delay, "s") || strings.Contains(delay, "m") {
continue
}
// ✅ 推荐:使用标准库
dur, err := time.ParseDuration(delay) // 支持 "0.5s", "4m50s", "1h30m" 等
```
### 7. 接口设计
#### 接口隔离原则
定义最小化接口,只包含必需的方法:
```go
// ❌ 不推荐:使用过于宽泛的接口
type AccountRepository interface {
// 20+ 个方法...
}
// ✅ 推荐:定义最小化接口
type ModelRateLimiter interface {
SetModelRateLimit(ctx context.Context, id int64, modelKey string, resetAt time.Time) error
}
```
### 8. 并发安全
#### 共享数据保护
访问可能被并发修改的数据时,确保线程安全:
```go
// 如果 Account.Extra 可能被并发修改
// 需要使用互斥锁或原子操作保护读取
func (a *Account) GetRateLimitRemainingTime(model string) time.Duration {
a.mu.RLock()
defer a.mu.RUnlock()
// 读取 Extra 字段...
}
```
### 9. 命名规范
#### 一致的命名风格
- 常量使用 camelCase`rateLimitThreshold`
- 类型使用 PascalCase`AntigravityQuotaScope`
- 同一概念使用统一命名:`Threshold``Limit`,不要混用
```go
// ❌ 不推荐:命名不一致
antigravitySmartRetryMinWait // 使用 Min
antigravityRateLimitThreshold // 使用 Threshold
// ✅ 推荐:统一风格
antigravityMinRetryWait
antigravityRateLimitThreshold
```
### 10. 代码审查清单
在提交代码前,检查以下项目:
- [ ] 函数是否超过 30 行?(不可拆分的逻辑除外,需注释说明)
- [ ] 嵌套是否超过 3 层?
- [ ] 是否有重复代码可以提取?
- [ ] 是否使用了魔法数字?
- [ ] Mock 函数签名是否与实际函数一致?
- [ ] 测试是否覆盖了新增逻辑?
- [ ] 日志是否包含足够的上下文信息?
- [ ] 是否考虑了并发安全?
---
## CI 检查与发布门禁
### GitHub Actions 检查项
本项目有 4 个 CI 任务,**任何代码推送或发布前都必须全部通过**
| Workflow | Job | 说明 | 本地验证命令 |
|----------|-----|------|-------------|
| CI | `test` | 单元测试 + 集成测试 | `cd backend && make test-unit && make test-integration` |
| CI | `golangci-lint` | Go 代码静态检查golangci-lint v2.7 | `cd backend && golangci-lint run --timeout=5m` |
| Security Scan | `backend-security` | govulncheck + gosec 安全扫描 | `cd backend && govulncheck ./... && gosec -severity high -confidence high ./...` |
| Security Scan | `frontend-security` | pnpm audit 前端依赖安全检查 | `cd frontend && pnpm audit --prod --audit-level=high` |
### 向上游提交 PR
PR 目标是上游官方仓库,**只包含通用功能改动**bug fix、新功能、性能优化等
**以下文件禁止出现在 PR 中**(属于我们 fork 的定制化内容):
- `CLAUDE.md``AGENTS.md` — 我们的开发文档
- `backend/cmd/server/VERSION` — 我们的版本号文件
- UI 定制改动GitHub 链接移除、微信客服按钮、首页定制等)
- 部署配置(`deploy/` 目录下的定制修改)
**PR 流程**
1.`develop` 创建功能分支,只包含要提交给上游的改动
2. 推送分支后,**等待 4 个 CI job 全部通过**
3. 确认通过后再创建 PR
4. 使用 `gh run list --repo touwaeriol/sub2api --branch <branch>` 检查状态
### 自有分支推送develop / main
推送到我们自己的 `develop``main` 分支时,包含所有改动(定制化 + 通用功能)。
**推送流程**
1. 本地运行 `cd backend && make test-unit` 确保单元测试通过
2. 本地运行 `cd backend && gofmt -l ./...` 确保格式正确
3. 推送后确认 CI 和 Security Scan 两个 workflow 的 4 个 job 全部绿色 ✅
4. 任何 job 失败必须立即修复,**禁止在 CI 未通过的状态下继续后续操作**
### 发布版本
1. 确保 `main` 分支最新提交的 4 个 CI job 全部通过
2. 递增 `backend/cmd/server/VERSION`,提交并推送
3. 打 tag 推送后,确认 tag 触发的 3 个 workflowCI、Security Scan、Release全部通过
4. **Release workflow 失败时禁止部署** — 必须先修复问题,删除旧 tag重新打 tag
5. 使用 `gh run list --repo touwaeriol/sub2api --limit 10` 确认状态
### 常见 CI 失败原因及修复
- **gofmt**struct 字段对齐不一致 → 运行 `gofmt -w <file>` 修复
- **golangci-lint**:未使用的变量/导入 → 删除或使用 `_` 忽略
- **test 失败**mock 函数签名不一致 → 同步更新 mock
- **gosec**:安全漏洞 → 根据提示修复或添加例外

723
CLAUDE.md Normal file
View File

@@ -0,0 +1,723 @@
# Sub2API 开发说明
## 版本管理策略
### 版本号规则
我们在官方版本号后面添加自己的小版本号:
- 官方版本:`v0.1.68`
- 我们的版本:`v0.1.68.1``v0.1.68.2`(递增)
### 分支策略
| 分支 | 说明 |
|------|------|
| `main` | 我们的主分支,包含所有定制功能 |
| `release/custom-X.Y.Z` | 基于官方 `vX.Y.Z` 的发布分支 |
| `upstream/main` | 上游官方仓库 |
---
## 发布流程(基于新官方版本)
当官方发布新版本(如 `v0.1.69`)时:
### 1. 同步上游并创建发布分支
```bash
# 获取上游最新代码
git fetch upstream --tags
# 基于官方标签创建新的发布分支
git checkout v0.1.69 -b release/custom-0.1.69
# 合并我们的 main 分支(包含所有定制功能)
git merge main --no-edit
# 解决可能的冲突后继续
```
### 2. 更新版本号并打标签
```bash
# 更新版本号文件
echo "0.1.69.1" > backend/cmd/server/VERSION
git add backend/cmd/server/VERSION
git commit -m "chore: bump version to 0.1.69.1"
# 打上我们自己的标签
git tag v0.1.69.1
# 推送分支和标签
git push origin release/custom-0.1.69
git push origin v0.1.69.1
```
### 3. 更新 main 分支
```bash
# 将发布分支合并回 main保持 main 包含最新定制功能
git checkout main
git merge release/custom-0.1.69
git push origin main
```
---
## 热修复发布(在现有版本上修复)
当需要在当前版本上发布修复时:
```bash
# 在当前发布分支上修复
git checkout release/custom-0.1.68
# ... 进行修复 ...
git commit -m "fix: 修复描述"
# 递增小版本号
echo "0.1.68.2" > backend/cmd/server/VERSION
git add backend/cmd/server/VERSION
git commit -m "chore: bump version to 0.1.68.2"
# 打标签并推送
git tag v0.1.68.2
git push origin release/custom-0.1.68
git push origin v0.1.68.2
# 同步修复到 main
git checkout main
git cherry-pick <fix-commit-hash>
git push origin main
```
---
## 服务器部署流程
### 前置条件
- 本地已配置 SSH 别名 `clicodeplus` 连接到服务器
- 服务器部署目录:`/root/sub2api`(正式)、`/root/sub2api-beta`(测试)
- 服务器使用 Docker Compose 部署
### 部署环境说明
| 环境 | 目录 | 端口 | 数据库 | 容器名 |
|------|------|------|--------|--------|
| 正式 | `/root/sub2api` | 8080 | `sub2api` | `sub2api` |
| Beta | `/root/sub2api-beta` | 8084 | `beta` | `sub2api-beta` |
### 外部数据库
正式和 Beta 环境**共用外部 PostgreSQL 数据库**(非容器内数据库),配置在 `.env` 文件中:
- `DATABASE_HOST`:外部数据库地址
- `DATABASE_SSLMODE`SSL 模式(通常为 `require`
- `POSTGRES_USER` / `POSTGRES_DB`:用户名和数据库名
#### 数据库操作命令
通过 SSH 在服务器上执行数据库操作:
```bash
# 正式环境 - 查询迁移记录
ssh clicodeplus "source /root/sub2api/deploy/.env && PGPASSWORD=\"\$POSTGRES_PASSWORD\" psql -h \$DATABASE_HOST -U \$POSTGRES_USER -d \$POSTGRES_DB -c 'SELECT * FROM schema_migrations ORDER BY applied_at DESC LIMIT 5;'"
# Beta 环境 - 查询迁移记录
ssh clicodeplus "source /root/sub2api-beta/deploy/.env && PGPASSWORD=\"\$POSTGRES_PASSWORD\" psql -h \$DATABASE_HOST -U \$POSTGRES_USER -d \$POSTGRES_DB -c 'SELECT * FROM schema_migrations ORDER BY applied_at DESC LIMIT 5;'"
# Beta 环境 - 清除指定迁移记录(重新执行迁移)
ssh clicodeplus "source /root/sub2api-beta/deploy/.env && PGPASSWORD=\"\$POSTGRES_PASSWORD\" psql -h \$DATABASE_HOST -U \$POSTGRES_USER -d \$POSTGRES_DB -c \"DELETE FROM schema_migrations WHERE filename LIKE '%049%';\""
# Beta 环境 - 更新账号数据
ssh clicodeplus "source /root/sub2api-beta/deploy/.env && PGPASSWORD=\"\$POSTGRES_PASSWORD\" psql -h \$DATABASE_HOST -U \$POSTGRES_USER -d \$POSTGRES_DB -c \"UPDATE accounts SET credentials = credentials - 'model_mapping' WHERE platform = 'antigravity';\""
```
> **注意**:使用 `source .env` 加载环境变量,避免在命令行中暴露密码。
### 部署步骤
**重要:每次部署都必须递增版本号!**
#### 0. 递增版本号(本地操作)
每次部署前,先在本地递增小版本号:
```bash
# 查看当前版本号
cat backend/cmd/server/VERSION
# 假设当前是 0.1.69.1
# 递增版本号
echo "0.1.69.2" > backend/cmd/server/VERSION
git add backend/cmd/server/VERSION
git commit -m "chore: bump version to 0.1.69.2"
git push origin release/custom-0.1.69
```
#### 1. 服务器拉取代码
```bash
ssh clicodeplus "cd /root/sub2api && git fetch fork && git checkout -B release/custom-0.1.69 fork/release/custom-0.1.69"
```
#### 2. 服务器构建镜像
```bash
ssh clicodeplus "cd /root/sub2api && docker build --no-cache -t sub2api:latest -f Dockerfile ."
```
#### 3. 更新镜像标签并重启服务
```bash
ssh clicodeplus "docker tag sub2api:latest weishaw/sub2api:latest"
ssh clicodeplus "cd /root/sub2api/deploy && docker compose up -d --force-recreate sub2api"
```
#### 4. 验证部署
```bash
# 查看启动日志
ssh clicodeplus "docker logs sub2api --tail 20"
# 确认版本号(必须与步骤 0 中设置的版本号一致)
ssh clicodeplus "cat /root/sub2api/backend/cmd/server/VERSION"
# 检查容器状态
ssh clicodeplus "docker ps | grep sub2api"
```
---
## Beta 并行部署(不影响现网)
目标:在同一台服务器上并行启动一个 beta 实例(例如端口 `8084`**严禁改动/重启**现网实例(默认目录 `/root/sub2api`)。
### 设计原则
- **新目录**beta 使用独立目录,例如 `/root/sub2api-beta`
- **敏感信息只放 `.env`**beta 的数据库密码、JWT_SECRET 等只写入 `/root/sub2api-beta/deploy/.env`,不要提交到 git。
- **独立 Compose Project**:通过 `docker compose -p sub2api-beta ...` 启动,确保 network/volume 隔离。
- **独立端口**:通过 `.env``SERVER_PORT` 映射宿主机端口(例如 `8084:8080`)。
### 前置检查
```bash
# 1) 确保 8084 未被占用
ssh clicodeplus "ss -ltnp | grep :8084 || echo '8084 is free'"
# 2) 确认现网容器还在(只读检查)
ssh clicodeplus "docker ps --format 'table {{.Names}}\t{{.Image}}\t{{.Ports}}' | sed -n '1,200p'"
```
### 首次部署步骤
```bash
# 0) 进入服务器
ssh clicodeplus
# 1) 克隆代码到新目录(示例使用你的 fork
cd /root
git clone https://github.com/touwaeriol/sub2api.git sub2api-beta
cd /root/sub2api-beta
git checkout release/custom-0.1.71
# 2) 准备 beta 的 .env敏感信息只写这里
cd /root/sub2api-beta/deploy
# 推荐:从现网 .env 复制,保证除 DB 名/用户/端口外完全一致
cp -f /root/sub2api/deploy/.env ./.env
# 仅修改以下三项(其他保持不变)
perl -pi -e 's/^SERVER_PORT=.*/SERVER_PORT=8084/' ./.env
perl -pi -e 's/^POSTGRES_USER=.*/POSTGRES_USER=beta/' ./.env
perl -pi -e 's/^POSTGRES_DB=.*/POSTGRES_DB=beta/' ./.env
# 3) 写 compose override避免与现网容器名冲突镜像使用本地构建的 sub2api:beta
cat > docker-compose.override.yml <<'YAML'
services:
sub2api:
image: sub2api:beta
container_name: sub2api-beta
redis:
container_name: sub2api-beta-redis
YAML
# 4) 构建 beta 镜像(基于当前代码)
cd /root/sub2api-beta
docker build -t sub2api:beta -f Dockerfile .
# 5) 启动 beta独立 project确保不影响现网
cd /root/sub2api-beta/deploy
docker compose -p sub2api-beta --env-file .env -f docker-compose.yml -f docker-compose.override.yml up -d
# 6) 验证 beta
curl -fsS http://127.0.0.1:8084/health
docker logs sub2api-beta --tail 50
```
### 数据库配置约定beta
- 数据库地址/SSL/密码:与现网一致(从现网 `.env` 复制即可)。
- 仅修改:
- `POSTGRES_USER=beta`
- `POSTGRES_DB=beta`
注意:需要数据库侧已存在 `beta` 用户与 `beta` 数据库,并授予权限;否则容器会启动失败并不断重启。
### 更新 beta拉代码 + 仅重建 beta 容器)
```bash
ssh clicodeplus "set -e; cd /root/sub2api-beta && git fetch --all --tags && git checkout -f release/custom-0.1.71 && git reset --hard origin/release/custom-0.1.71"
ssh clicodeplus "cd /root/sub2api-beta && docker build -t sub2api:beta -f Dockerfile ."
ssh clicodeplus "cd /root/sub2api-beta/deploy && docker compose -p sub2api-beta --env-file .env -f docker-compose.yml -f docker-compose.override.yml up -d --no-deps --force-recreate sub2api"
ssh clicodeplus "curl -fsS http://127.0.0.1:8084/health"
```
### 停止/回滚 beta只影响 beta
```bash
ssh clicodeplus "cd /root/sub2api-beta/deploy && docker compose -p sub2api-beta -f docker-compose.yml -f docker-compose.override.yml down"
```
---
## 服务器首次部署
### 1. 克隆代码并配置远程仓库
```bash
ssh clicodeplus
cd /root
git clone https://github.com/Wei-Shaw/sub2api.git
cd sub2api
# 添加 fork 仓库
git remote add fork https://github.com/touwaeriol/sub2api.git
```
### 2. 切换到定制分支并配置环境
```bash
git fetch fork
git checkout -B release/custom-0.1.69 fork/release/custom-0.1.69
cd deploy
cp .env.example .env
vim .env # 配置 DATABASE_URL, REDIS_URL, JWT_SECRET 等
```
### 3. 构建并启动
```bash
cd /root/sub2api
docker build -t sub2api:latest -f Dockerfile .
docker tag sub2api:latest weishaw/sub2api:latest
cd deploy && docker compose up -d
```
### 6. 启动服务
```bash
# 进入 deploy 目录
cd deploy
# 启动所有服务PostgreSQL、Redis、sub2api
docker compose up -d
# 查看服务状态
docker compose ps
```
### 7. 验证部署
```bash
# 查看应用日志
docker logs sub2api --tail 50
# 检查健康状态
curl http://localhost:8080/health
# 确认版本号
cat /root/sub2api/backend/cmd/server/VERSION
```
### 8. 常用运维命令
```bash
# 查看实时日志
docker logs -f sub2api
# 重启服务
docker compose restart sub2api
# 停止所有服务
docker compose down
# 停止并删除数据卷(慎用!会删除数据库数据)
docker compose down -v
# 查看资源使用情况
docker stats sub2api
```
---
## 定制功能说明
当前定制分支包含以下功能(相对于官方版本):
### UI/UX 定制
| 功能 | 说明 |
|------|------|
| 首页优化 | 面向用户的价值主张设计 |
| 移除 GitHub 链接 | 用户菜单中不显示 GitHub 导航 |
| 微信客服按钮 | 首页悬浮微信客服入口 |
| 限流时间精确显示 | 账号限流时间显示精确到秒 |
### Antigravity 平台增强
| 功能 | 说明 |
|------|------|
| Scope 级别限流 | 按配额域claude/gemini_text/gemini_image独立限流避免整个账号被锁定 |
| 模型级别限流 | 按具体模型(如 claude-opus-4-5独立限流更精细的限流控制 |
| 限流预检查 | 调度时预检查账号/模型限流状态,避免选中已限流账号 |
| 秒级冷却时间 | 支持 429 响应的秒级精确冷却时间 |
| 身份注入优化 | 模型身份信息注入 + 静默边界防止身份泄露 |
| thoughtSignature 修复 | Gemini 3 函数调用 400 错误修复 |
| max_tokens 自动修正 | 自动修正 max_tokens <= budget_tokens 导致的 400 错误 |
### 调度算法优化
| 功能 | 说明 |
|------|------|
| 分层过滤选择 | 调度算法从全排序改为分层过滤,提升性能 |
| LRU 随机选择 | 相同 LRU 时间时随机选择,避免账号集中 |
| 限流等待阈值配置化 | 可配置的限流等待阈值 |
### 运维增强
| 功能 | 说明 |
|------|------|
| Scope 限流统计 | 运维界面展示 Antigravity 账号 scope 级别限流统计 |
| 账号限流状态显示 | 账号列表显示 scope 和模型级别限流状态 |
| 清除限流按钮增强 | 有 scope/模型限流时也显示清除限流按钮 |
### 其他修复
| 功能 | 说明 |
|------|------|
| .gitattributes | 确保迁移文件使用 LF 换行符(解决 Windows 下 SQL 摘要不一致) |
| 部署配置优化 | DATABASE_HOST 和 DATABASE_SSLMODE 可通过 .env 配置 |
---
## 注意事项
1. **前端必须打包进镜像**:使用 `docker build` 在服务器上构建Dockerfile 会自动编译前端并 embed 到后端二进制中
2. **镜像标签**docker-compose.yml 使用 `weishaw/sub2api:latest`,本地构建后需要 `docker tag` 覆盖
3. **Windows 换行符问题**:已通过 `.gitattributes` 解决,确保 `*.sql` 文件始终使用 LF
4. **版本号管理**:每次发布必须更新 `backend/cmd/server/VERSION` 并打标签
5. **合并冲突**:合并上游新版本时,重点关注以下文件可能的冲突:
- `backend/internal/service/antigravity_gateway_service.go`
- `backend/internal/service/gateway_service.go`
- `backend/internal/pkg/antigravity/request_transformer.go`
---
## Go 代码规范
### 1. 函数设计
#### 单一职责原则
- **函数行数**:单个函数常规不应超过 **30 行**,超过时应拆分为子函数。若某段逻辑确实不可拆分(如复杂的状态机、协议解析等),可以例外,但需添加注释说明原因
- **嵌套层级**:避免超过 3 层嵌套,使用 early return 减少嵌套
```go
// ❌ 不推荐:深层嵌套
func process(data []Item) {
for _, item := range data {
if item.Valid {
if item.Type == "A" {
if item.Status == "active" {
// 业务逻辑...
}
}
}
}
}
// ✅ 推荐early return
func process(data []Item) {
for _, item := range data {
if !item.Valid {
continue
}
if item.Type != "A" {
continue
}
if item.Status != "active" {
continue
}
// 业务逻辑...
}
}
```
#### 复杂逻辑提取
将复杂的条件判断或处理逻辑提取为独立函数:
```go
// ❌ 不推荐:内联复杂逻辑
if resp.StatusCode == 429 || resp.StatusCode == 503 {
// 80+ 行处理逻辑...
}
// ✅ 推荐:提取为独立函数
result := handleRateLimitResponse(resp, params)
switch result.action {
case actionRetry:
continue
case actionBreak:
return result.resp, nil
}
```
### 2. 重复代码消除
#### 配置获取模式
将重复的配置获取逻辑提取为方法:
```go
// ❌ 不推荐:重复代码
logBody := s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBody
maxBytes := 2048
if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes > 0 {
maxBytes = s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
}
// ✅ 推荐:提取为方法
func (s *Service) getLogConfig() (logBody bool, maxBytes int) {
maxBytes = 2048
if s.settingService == nil || s.settingService.cfg == nil {
return false, maxBytes
}
cfg := s.settingService.cfg.Gateway
if cfg.LogUpstreamErrorBodyMaxBytes > 0 {
maxBytes = cfg.LogUpstreamErrorBodyMaxBytes
}
return cfg.LogUpstreamErrorBody, maxBytes
}
```
### 3. 常量管理
#### 避免魔法数字
所有硬编码的数值都应定义为常量:
```go
// ❌ 不推荐
if retryDelay >= 10*time.Second {
resetAt := time.Now().Add(30 * time.Second)
}
// ✅ 推荐
const (
rateLimitThreshold = 10 * time.Second
defaultRateLimitDuration = 30 * time.Second
)
if retryDelay >= rateLimitThreshold {
resetAt := time.Now().Add(defaultRateLimitDuration)
}
```
#### 注释引用常量名
在注释中引用常量名而非硬编码值:
```go
// ❌ 不推荐
// < 10s: 等待后重试
// ✅ 推荐
// < rateLimitThreshold: 等待后重试
```
### 4. 错误处理
#### 使用结构化日志
优先使用 `slog` 进行结构化日志记录:
```go
// ❌ 不推荐
log.Printf("%s status=%d model_rate_limit_failed model=%s error=%v", prefix, statusCode, modelName, err)
// ✅ 推荐
slog.Error("failed to set model rate limit",
"prefix", prefix,
"status_code", statusCode,
"model", modelName,
"error", err,
)
```
### 5. 测试规范
#### Mock 函数签名同步
修改函数签名时,必须同步更新所有测试中的 mock 函数:
```go
// 如果修改了 handleError 签名
handleError func(..., groupID int64, sessionHash string) *Result
// 必须同步更新测试中的 mock
handleError: func(..., groupID int64, sessionHash string) *Result {
return nil
},
```
#### 测试构建标签
统一使用测试构建标签:
```go
//go:build unit
package service
```
### 6. 时间格式解析
#### 使用标准库
优先使用 `time.ParseDuration`,支持所有 Go duration 格式:
```go
// ❌ 不推荐:手动限制格式
if !strings.HasSuffix(delay, "s") || strings.Contains(delay, "m") {
continue
}
// ✅ 推荐:使用标准库
dur, err := time.ParseDuration(delay) // 支持 "0.5s", "4m50s", "1h30m" 等
```
### 7. 接口设计
#### 接口隔离原则
定义最小化接口,只包含必需的方法:
```go
// ❌ 不推荐:使用过于宽泛的接口
type AccountRepository interface {
// 20+ 个方法...
}
// ✅ 推荐:定义最小化接口
type ModelRateLimiter interface {
SetModelRateLimit(ctx context.Context, id int64, modelKey string, resetAt time.Time) error
}
```
### 8. 并发安全
#### 共享数据保护
访问可能被并发修改的数据时,确保线程安全:
```go
// 如果 Account.Extra 可能被并发修改
// 需要使用互斥锁或原子操作保护读取
func (a *Account) GetRateLimitRemainingTime(model string) time.Duration {
a.mu.RLock()
defer a.mu.RUnlock()
// 读取 Extra 字段...
}
```
### 9. 命名规范
#### 一致的命名风格
- 常量使用 camelCase`rateLimitThreshold`
- 类型使用 PascalCase`AntigravityQuotaScope`
- 同一概念使用统一命名:`Threshold``Limit`,不要混用
```go
// ❌ 不推荐:命名不一致
antigravitySmartRetryMinWait // 使用 Min
antigravityRateLimitThreshold // 使用 Threshold
// ✅ 推荐:统一风格
antigravityMinRetryWait
antigravityRateLimitThreshold
```
### 10. 代码审查清单
在提交代码前,检查以下项目:
- [ ] 函数是否超过 30 行?(不可拆分的逻辑除外,需注释说明)
- [ ] 嵌套是否超过 3 层?
- [ ] 是否有重复代码可以提取?
- [ ] 是否使用了魔法数字?
- [ ] Mock 函数签名是否与实际函数一致?
- [ ] 测试是否覆盖了新增逻辑?
- [ ] 日志是否包含足够的上下文信息?
- [ ] 是否考虑了并发安全?
---
## CI 检查与发布门禁
### GitHub Actions 检查项
本项目有 4 个 CI 任务,**任何代码推送或发布前都必须全部通过**
| Workflow | Job | 说明 | 本地验证命令 |
|----------|-----|------|-------------|
| CI | `test` | 单元测试 + 集成测试 | `cd backend && make test-unit && make test-integration` |
| CI | `golangci-lint` | Go 代码静态检查golangci-lint v2.7 | `cd backend && golangci-lint run --timeout=5m` |
| Security Scan | `backend-security` | govulncheck + gosec 安全扫描 | `cd backend && govulncheck ./... && gosec -severity high -confidence high ./...` |
| Security Scan | `frontend-security` | pnpm audit 前端依赖安全检查 | `cd frontend && pnpm audit --prod --audit-level=high` |
### 向上游提交 PR
PR 目标是上游官方仓库,**只包含通用功能改动**bug fix、新功能、性能优化等
**以下文件禁止出现在 PR 中**(属于我们 fork 的定制化内容):
- `CLAUDE.md``AGENTS.md` — 我们的开发文档
- `backend/cmd/server/VERSION` — 我们的版本号文件
- UI 定制改动GitHub 链接移除、微信客服按钮、首页定制等)
- 部署配置(`deploy/` 目录下的定制修改)
**PR 流程**
1.`develop` 创建功能分支,只包含要提交给上游的改动
2. 推送分支后,**等待 4 个 CI job 全部通过**
3. 确认通过后再创建 PR
4. 使用 `gh run list --repo touwaeriol/sub2api --branch <branch>` 检查状态
### 自有分支推送develop / main
推送到我们自己的 `develop``main` 分支时,包含所有改动(定制化 + 通用功能)。
**推送流程**
1. 本地运行 `cd backend && make test-unit` 确保单元测试通过
2. 本地运行 `cd backend && gofmt -l ./...` 确保格式正确
3. 推送后确认 CI 和 Security Scan 两个 workflow 的 4 个 job 全部绿色 ✅
4. 任何 job 失败必须立即修复,**禁止在 CI 未通过的状态下继续后续操作**
### 发布版本
1. 确保 `main` 分支最新提交的 4 个 CI job 全部通过
2. 递增 `backend/cmd/server/VERSION`,提交并推送
3. 打 tag 推送后,确认 tag 触发的 3 个 workflowCI、Security Scan、Release全部通过
4. **Release workflow 失败时禁止部署** — 必须先修复问题,删除旧 tag重新打 tag
5. 使用 `gh run list --repo touwaeriol/sub2api --limit 10` 确认状态
### 常见 CI 失败原因及修复
- **gofmt**struct 字段对齐不一致 → 运行 `gofmt -w <file>` 修复
- **golangci-lint**:未使用的变量/导入 → 删除或使用 `_` 忽略
- **test 失败**mock 函数签名不一致 → 同步更新 mock
- **gosec**:安全漏洞 → 根据提示修复或添加例外

View File

@@ -1 +1 @@
0.1.74.7
0.1.75.7

View File

@@ -154,7 +154,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
identityService := service.NewIdentityService(identityCache)
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService)
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache)
digestSessionStore := service.NewDigestSessionStore()
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, digestSessionStore)
openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)

View File

@@ -103,6 +103,7 @@ require (
github.com/ncruces/go-strftime v1.0.0 // indirect
github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/opencontainers/image-spec v1.1.1 // indirect
github.com/patrickmn/go-cache v2.1.0+incompatible // indirect
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect

View File

@@ -207,6 +207,8 @@ github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M=
github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc=
github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ=
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=

View File

@@ -2,11 +2,6 @@ package dto
import "time"
type ScopeRateLimitInfo struct {
ResetAt time.Time `json:"reset_at"`
RemainingSec int64 `json:"remaining_sec"`
}
type User struct {
ID int64 `json:"id"`
Email string `json:"email"`
@@ -126,9 +121,6 @@ type Account struct {
RateLimitResetAt *time.Time `json:"rate_limit_reset_at"`
OverloadUntil *time.Time `json:"overload_until"`
// Antigravity scope 级限流状态(从 extra 提取)
ScopeRateLimits map[string]ScopeRateLimitInfo `json:"scope_rate_limits,omitempty"`
TempUnschedulableUntil *time.Time `json:"temp_unschedulable_until"`
TempUnschedulableReason string `json:"temp_unschedulable_reason"`

View File

@@ -13,6 +13,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/domain"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
@@ -114,7 +115,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
setOpsRequestContext(c, "", false, body)
parsedReq, err := service.ParseGatewayRequest(body)
parsedReq, err := service.ParseGatewayRequest(body, domain.PlatformAnthropic)
if err != nil {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
@@ -203,6 +204,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
// 计算粘性会话hash
parsedReq.SessionContext = &service.SessionContext{
ClientIP: ip.GetClientIP(c),
UserAgent: c.GetHeader("User-Agent"),
APIKeyID: apiKey.ID,
}
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
// 获取平台:优先使用强制平台(/antigravity 路由,中间件已设置 request.Context否则使用分组平台
@@ -335,7 +341,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if errors.As(err, &failoverErr) {
failedAccountIDs[account.ID] = struct{}{}
lastFailoverErr = failoverErr
if failoverErr.ForceCacheBilling {
if needForceCacheBilling(hasBoundSession, failoverErr) {
forceCacheBilling = true
}
if switchCount >= maxAccountSwitches {
@@ -344,6 +350,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
switchCount++
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
if account.Platform == service.PlatformAntigravity {
if !sleepFailoverDelay(c.Request.Context(), switchCount) {
return
}
}
continue
}
// 错误响应已在Forward中处理这里只记录日志
@@ -530,7 +541,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if errors.As(err, &failoverErr) {
failedAccountIDs[account.ID] = struct{}{}
lastFailoverErr = failoverErr
if failoverErr.ForceCacheBilling {
if needForceCacheBilling(hasBoundSession, failoverErr) {
forceCacheBilling = true
}
if switchCount >= maxAccountSwitches {
@@ -539,6 +550,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
switchCount++
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
if account.Platform == service.PlatformAntigravity {
if !sleepFailoverDelay(c.Request.Context(), switchCount) {
return
}
}
continue
}
// 错误响应已在Forward中处理这里只记录日志
@@ -801,6 +817,27 @@ func (h *GatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotT
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
}
// needForceCacheBilling 判断 failover 时是否需要强制缓存计费
// 粘性会话切换账号、或上游明确标记时,将 input_tokens 转为 cache_read 计费
func needForceCacheBilling(hasBoundSession bool, failoverErr *service.UpstreamFailoverError) bool {
return hasBoundSession || (failoverErr != nil && failoverErr.ForceCacheBilling)
}
// sleepFailoverDelay 账号切换线性递增延时第1次0s、第2次1s、第3次2s…
// 返回 false 表示 context 已取消。
func sleepFailoverDelay(ctx context.Context, switchCount int) bool {
delay := time.Duration(switchCount-1) * time.Second
if delay <= 0 {
return true
}
select {
case <-ctx.Done():
return false
case <-time.After(delay):
return true
}
}
func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, platform string, streamStarted bool) {
statusCode := failoverErr.StatusCode
responseBody := failoverErr.ResponseBody
@@ -934,7 +971,7 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
setOpsRequestContext(c, "", false, body)
parsedReq, err := service.ParseGatewayRequest(body)
parsedReq, err := service.ParseGatewayRequest(body, domain.PlatformAnthropic)
if err != nil {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
@@ -962,6 +999,11 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
}
// 计算粘性会话 hash
parsedReq.SessionContext = &service.SessionContext{
ClientIP: ip.GetClientIP(c),
UserAgent: c.GetHeader("User-Agent"),
APIKeyID: apiKey.ID,
}
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
// 选择支持该模型的账号

View File

@@ -14,6 +14,7 @@ import (
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/domain"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/gemini"
@@ -30,13 +31,6 @@ import (
// 匹配格式: /Users/xxx/.gemini/tmp/[64位十六进制哈希]
var geminiCLITmpDirRegex = regexp.MustCompile(`/\.gemini/tmp/([A-Fa-f0-9]{64})`)
func isGeminiCLIRequest(c *gin.Context, body []byte) bool {
if strings.TrimSpace(c.GetHeader("x-gemini-api-privileged-user-id")) != "" {
return true
}
return geminiCLITmpDirRegex.Match(body)
}
// GeminiV1BetaListModels proxies:
// GET /v1beta/models
func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) {
@@ -239,7 +233,14 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
sessionHash := extractGeminiCLISessionHash(c, body)
if sessionHash == "" {
// Fallback: 使用通用的会话哈希生成逻辑(适用于其他客户端)
parsedReq, _ := service.ParseGatewayRequest(body)
parsedReq, _ := service.ParseGatewayRequest(body, domain.PlatformGemini)
if parsedReq != nil {
parsedReq.SessionContext = &service.SessionContext{
ClientIP: ip.GetClientIP(c),
UserAgent: c.GetHeader("User-Agent"),
APIKeyID: apiKey.ID,
}
}
sessionHash = h.gatewayService.GenerateSessionHash(parsedReq)
}
sessionKey := sessionHash
@@ -258,6 +259,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
var geminiDigestChain string
var geminiPrefixHash string
var geminiSessionUUID string
var matchedDigestChain string
useDigestFallback := sessionBoundAccountID == 0
if useDigestFallback {
@@ -284,13 +286,14 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
)
// 查找会话
foundUUID, foundAccountID, found := h.gatewayService.FindGeminiSession(
foundUUID, foundAccountID, foundMatchedChain, found := h.gatewayService.FindGeminiSession(
c.Request.Context(),
derefGroupID(apiKey.GroupID),
geminiPrefixHash,
geminiDigestChain,
)
if found {
matchedDigestChain = foundMatchedChain
sessionBoundAccountID = foundAccountID
geminiSessionUUID = foundUUID
log.Printf("[Gemini] Digest fallback matched: uuid=%s, accountID=%d, chain=%s",
@@ -316,7 +319,6 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
// 判断是否真的绑定了粘性会话:有 sessionKey 且已经绑定到某个账号
hasBoundSession := sessionKey != "" && sessionBoundAccountID > 0
isCLI := isGeminiCLIRequest(c, body)
cleanedForUnknownBinding := false
maxAccountSwitches := h.maxAccountSwitchesGemini
@@ -344,10 +346,10 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
log.Printf("[Gemini] Sticky session account switched: %d -> %d, cleaning thoughtSignature", sessionBoundAccountID, account.ID)
body = service.CleanGeminiNativeThoughtSignatures(body)
sessionBoundAccountID = account.ID
} else if sessionKey != "" && sessionBoundAccountID == 0 && isCLI && !cleanedForUnknownBinding && bytes.Contains(body, []byte(`"thoughtSignature"`)) {
// 无缓存绑定但请求里已有 thoughtSignature常见于缓存丢失/TTL 过期后,CLI 继续携带旧签名。
} else if sessionKey != "" && sessionBoundAccountID == 0 && !cleanedForUnknownBinding && bytes.Contains(body, []byte(`"thoughtSignature"`)) {
// 无缓存绑定但请求里已有 thoughtSignature常见于缓存丢失/TTL 过期后,客户端继续携带旧签名。
// 为避免第一次转发就 400这里做一次确定性清理让新账号重新生成签名链路。
log.Printf("[Gemini] Sticky session binding missing for CLI request, cleaning thoughtSignature proactively")
log.Printf("[Gemini] Sticky session binding missing, cleaning thoughtSignature proactively")
body = service.CleanGeminiNativeThoughtSignatures(body)
cleanedForUnknownBinding = true
sessionBoundAccountID = account.ID
@@ -422,7 +424,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
failedAccountIDs[account.ID] = struct{}{}
if failoverErr.ForceCacheBilling {
if needForceCacheBilling(hasBoundSession, failoverErr) {
forceCacheBilling = true
}
if switchCount >= maxAccountSwitches {
@@ -433,6 +435,11 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
lastFailoverErr = failoverErr
switchCount++
log.Printf("Gemini account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
if account.Platform == service.PlatformAntigravity {
if !sleepFailoverDelay(c.Request.Context(), switchCount) {
return
}
}
continue
}
// ForwardNative already wrote the response
@@ -453,6 +460,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
geminiDigestChain,
geminiSessionUUID,
account.ID,
matchedDigestChain,
); err != nil {
log.Printf("[Gemini] Failed to save digest session: %v", err)
}

View File

@@ -798,53 +798,6 @@ func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetA
return nil
}
func (r *accountRepository) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope service.AntigravityQuotaScope, resetAt time.Time) error {
now := time.Now().UTC()
payload := map[string]string{
"rate_limited_at": now.Format(time.RFC3339),
"rate_limit_reset_at": resetAt.UTC().Format(time.RFC3339),
}
raw, err := json.Marshal(payload)
if err != nil {
return err
}
scopeKey := string(scope)
client := clientFromContext(ctx, r.client)
result, err := client.ExecContext(
ctx,
`UPDATE accounts SET
extra = jsonb_set(
jsonb_set(COALESCE(extra, '{}'::jsonb), '{antigravity_quota_scopes}'::text[], COALESCE(extra->'antigravity_quota_scopes', '{}'::jsonb), true),
ARRAY['antigravity_quota_scopes', $1]::text[],
$2::jsonb,
true
),
updated_at = NOW(),
last_used_at = NOW()
WHERE id = $3 AND deleted_at IS NULL`,
scopeKey,
raw,
id,
)
if err != nil {
return err
}
affected, err := result.RowsAffected()
if err != nil {
return err
}
if affected == 0 {
return service.ErrAccountNotFound
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue quota scope failed: account=%d err=%v", id, err)
}
return nil
}
func (r *accountRepository) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
if scope == "" {
return nil

View File

@@ -11,63 +11,6 @@ import (
const stickySessionPrefix = "sticky_session:"
// Gemini Trie Lua 脚本
const (
// geminiTrieFindScript 查找最长前缀匹配的 Lua 脚本
// KEYS[1] = trie key
// ARGV[1] = digestChain (如 "u:a-m:b-u:c-m:d")
// ARGV[2] = TTL seconds (用于刷新)
// 返回: 最长匹配的 value (uuid:accountID) 或 nil
// 查找成功时自动刷新 TTL防止活跃会话意外过期
geminiTrieFindScript = `
local chain = ARGV[1]
local ttl = tonumber(ARGV[2])
local lastMatch = nil
local path = ""
for part in string.gmatch(chain, "[^-]+") do
path = path == "" and part or path .. "-" .. part
local val = redis.call('HGET', KEYS[1], path)
if val and val ~= "" then
lastMatch = val
end
end
if lastMatch then
redis.call('EXPIRE', KEYS[1], ttl)
end
return lastMatch
`
// geminiTrieSaveScript 保存会话到 Trie 的 Lua 脚本
// KEYS[1] = trie key
// ARGV[1] = digestChain
// ARGV[2] = value (uuid:accountID)
// ARGV[3] = TTL seconds
geminiTrieSaveScript = `
local chain = ARGV[1]
local value = ARGV[2]
local ttl = tonumber(ARGV[3])
local path = ""
for part in string.gmatch(chain, "[^-]+") do
path = path == "" and part or path .. "-" .. part
end
redis.call('HSET', KEYS[1], path, value)
redis.call('EXPIRE', KEYS[1], ttl)
return "OK"
`
)
// 模型负载统计相关常量
const (
modelLoadKeyPrefix = "ag:model_load:" // 模型调用次数 key 前缀
modelLastUsedKeyPrefix = "ag:model_last_used:" // 模型最后调度时间 key 前缀
modelLoadTTL = 24 * time.Hour // 调用次数 TTL24 小时无调用后清零)
modelLastUsedTTL = 24 * time.Hour // 最后调度时间 TTL
)
type gatewayCache struct {
rdb *redis.Client
}
@@ -108,171 +51,3 @@ func (c *gatewayCache) DeleteSessionAccountID(ctx context.Context, groupID int64
key := buildSessionKey(groupID, sessionHash)
return c.rdb.Del(ctx, key).Err()
}
// ============ Antigravity 模型负载统计方法 ============
// modelLoadKey 构建模型调用次数 key
// 格式: ag:model_load:{accountID}:{model}
func modelLoadKey(accountID int64, model string) string {
return fmt.Sprintf("%s%d:%s", modelLoadKeyPrefix, accountID, model)
}
// modelLastUsedKey 构建模型最后调度时间 key
// 格式: ag:model_last_used:{accountID}:{model}
func modelLastUsedKey(accountID int64, model string) string {
return fmt.Sprintf("%s%d:%s", modelLastUsedKeyPrefix, accountID, model)
}
// IncrModelCallCount 增加模型调用次数并更新最后调度时间
// 返回更新后的调用次数
func (c *gatewayCache) IncrModelCallCount(ctx context.Context, accountID int64, model string) (int64, error) {
loadKey := modelLoadKey(accountID, model)
lastUsedKey := modelLastUsedKey(accountID, model)
pipe := c.rdb.Pipeline()
incrCmd := pipe.Incr(ctx, loadKey)
pipe.Expire(ctx, loadKey, modelLoadTTL) // 每次调用刷新 TTL
pipe.Set(ctx, lastUsedKey, time.Now().Unix(), modelLastUsedTTL)
if _, err := pipe.Exec(ctx); err != nil {
return 0, err
}
return incrCmd.Val(), nil
}
// GetModelLoadBatch 批量获取账号的模型负载信息
func (c *gatewayCache) GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*service.ModelLoadInfo, error) {
if len(accountIDs) == 0 {
return make(map[int64]*service.ModelLoadInfo), nil
}
loadCmds, lastUsedCmds := c.pipelineModelLoadGet(ctx, accountIDs, model)
return c.parseModelLoadResults(accountIDs, loadCmds, lastUsedCmds), nil
}
// pipelineModelLoadGet 批量获取模型负载的 Pipeline 操作
func (c *gatewayCache) pipelineModelLoadGet(
ctx context.Context,
accountIDs []int64,
model string,
) (map[int64]*redis.StringCmd, map[int64]*redis.StringCmd) {
pipe := c.rdb.Pipeline()
loadCmds := make(map[int64]*redis.StringCmd, len(accountIDs))
lastUsedCmds := make(map[int64]*redis.StringCmd, len(accountIDs))
for _, id := range accountIDs {
loadCmds[id] = pipe.Get(ctx, modelLoadKey(id, model))
lastUsedCmds[id] = pipe.Get(ctx, modelLastUsedKey(id, model))
}
_, _ = pipe.Exec(ctx) // 忽略错误key 不存在是正常的
return loadCmds, lastUsedCmds
}
// parseModelLoadResults 解析 Pipeline 结果
func (c *gatewayCache) parseModelLoadResults(
accountIDs []int64,
loadCmds map[int64]*redis.StringCmd,
lastUsedCmds map[int64]*redis.StringCmd,
) map[int64]*service.ModelLoadInfo {
result := make(map[int64]*service.ModelLoadInfo, len(accountIDs))
for _, id := range accountIDs {
result[id] = &service.ModelLoadInfo{
CallCount: getInt64OrZero(loadCmds[id]),
LastUsedAt: getTimeOrZero(lastUsedCmds[id]),
}
}
return result
}
// getInt64OrZero 从 StringCmd 获取 int64 值,失败返回 0
func getInt64OrZero(cmd *redis.StringCmd) int64 {
val, _ := cmd.Int64()
return val
}
// getTimeOrZero 从 StringCmd 获取 time.Time失败返回零值
func getTimeOrZero(cmd *redis.StringCmd) time.Time {
val, err := cmd.Int64()
if err != nil {
return time.Time{}
}
return time.Unix(val, 0)
}
// ============ Gemini 会话 Fallback 方法 (Trie 实现) ============
// FindGeminiSession 查找 Gemini 会话(使用 Trie + Lua 脚本实现 O(L) 查询)
// 返回最长匹配的会话信息,匹配成功时自动刷新 TTL
func (c *gatewayCache) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
if digestChain == "" {
return "", 0, false
}
trieKey := service.BuildGeminiTrieKey(groupID, prefixHash)
ttlSeconds := int(service.GeminiSessionTTL().Seconds())
// 使用 Lua 脚本在 Redis 端执行 Trie 查找O(L) 次 HGET1 次网络往返
// 查找成功时自动刷新 TTL防止活跃会话意外过期
result, err := c.rdb.Eval(ctx, geminiTrieFindScript, []string{trieKey}, digestChain, ttlSeconds).Result()
if err != nil || result == nil {
return "", 0, false
}
value, ok := result.(string)
if !ok || value == "" {
return "", 0, false
}
uuid, accountID, ok = service.ParseGeminiSessionValue(value)
return uuid, accountID, ok
}
// SaveGeminiSession 保存 Gemini 会话(使用 Trie + Lua 脚本)
func (c *gatewayCache) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
if digestChain == "" {
return nil
}
trieKey := service.BuildGeminiTrieKey(groupID, prefixHash)
value := service.FormatGeminiSessionValue(uuid, accountID)
ttlSeconds := int(service.GeminiSessionTTL().Seconds())
return c.rdb.Eval(ctx, geminiTrieSaveScript, []string{trieKey}, digestChain, value, ttlSeconds).Err()
}
// ============ Anthropic 会话 Fallback 方法 (复用 Trie 实现) ============
// FindAnthropicSession 查找 Anthropic 会话(复用 Gemini Trie Lua 脚本)
func (c *gatewayCache) FindAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
if digestChain == "" {
return "", 0, false
}
trieKey := service.BuildAnthropicTrieKey(groupID, prefixHash)
ttlSeconds := int(service.AnthropicSessionTTL().Seconds())
result, err := c.rdb.Eval(ctx, geminiTrieFindScript, []string{trieKey}, digestChain, ttlSeconds).Result()
if err != nil || result == nil {
return "", 0, false
}
value, ok := result.(string)
if !ok || value == "" {
return "", 0, false
}
uuid, accountID, ok = service.ParseGeminiSessionValue(value)
return uuid, accountID, ok
}
// SaveAnthropicSession 保存 Anthropic 会话(复用 Gemini Trie Lua 脚本)
func (c *gatewayCache) SaveAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
if digestChain == "" {
return nil
}
trieKey := service.BuildAnthropicTrieKey(groupID, prefixHash)
value := service.FormatGeminiSessionValue(uuid, accountID)
ttlSeconds := int(service.AnthropicSessionTTL().Seconds())
return c.rdb.Eval(ctx, geminiTrieSaveScript, []string{trieKey}, digestChain, value, ttlSeconds).Err()
}

View File

@@ -104,157 +104,6 @@ func (s *GatewayCacheSuite) TestGetSessionAccountID_CorruptedValue() {
require.False(s.T(), errors.Is(err, redis.Nil), "expected parsing error, not redis.Nil")
}
// ============ Gemini Trie 会话测试 ============
func (s *GatewayCacheSuite) TestGeminiSessionTrie_SaveAndFind() {
groupID := int64(1)
prefixHash := "testprefix"
digestChain := "u:hash1-m:hash2-u:hash3"
uuid := "test-uuid-123"
accountID := int64(42)
// 保存会话
err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, digestChain, uuid, accountID)
require.NoError(s.T(), err, "SaveGeminiSession")
// 精确匹配查找
foundUUID, foundAccountID, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, digestChain)
require.True(s.T(), found, "should find exact match")
require.Equal(s.T(), uuid, foundUUID)
require.Equal(s.T(), accountID, foundAccountID)
}
func (s *GatewayCacheSuite) TestGeminiSessionTrie_PrefixMatch() {
groupID := int64(1)
prefixHash := "prefixmatch"
shortChain := "u:a-m:b"
longChain := "u:a-m:b-u:c-m:d"
uuid := "uuid-prefix"
accountID := int64(100)
// 保存短链
err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, shortChain, uuid, accountID)
require.NoError(s.T(), err)
// 用长链查找,应该匹配到短链(前缀匹配)
foundUUID, foundAccountID, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, longChain)
require.True(s.T(), found, "should find prefix match")
require.Equal(s.T(), uuid, foundUUID)
require.Equal(s.T(), accountID, foundAccountID)
}
func (s *GatewayCacheSuite) TestGeminiSessionTrie_LongestPrefixMatch() {
groupID := int64(1)
prefixHash := "longestmatch"
// 保存多个不同长度的链
err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, "u:a", "uuid-short", 1)
require.NoError(s.T(), err)
err = s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, "u:a-m:b", "uuid-medium", 2)
require.NoError(s.T(), err)
err = s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, "u:a-m:b-u:c", "uuid-long", 3)
require.NoError(s.T(), err)
// 查找更长的链,应该匹配到最长的前缀
foundUUID, foundAccountID, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, "u:a-m:b-u:c-m:d-u:e")
require.True(s.T(), found, "should find longest prefix match")
require.Equal(s.T(), "uuid-long", foundUUID)
require.Equal(s.T(), int64(3), foundAccountID)
// 查找中等长度的链
foundUUID, foundAccountID, found = s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, "u:a-m:b-u:x")
require.True(s.T(), found)
require.Equal(s.T(), "uuid-medium", foundUUID)
require.Equal(s.T(), int64(2), foundAccountID)
}
func (s *GatewayCacheSuite) TestGeminiSessionTrie_NoMatch() {
groupID := int64(1)
prefixHash := "nomatch"
digestChain := "u:a-m:b"
// 保存一个会话
err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, digestChain, "uuid", 1)
require.NoError(s.T(), err)
// 用不同的链查找,应该找不到
_, _, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, "u:x-m:y")
require.False(s.T(), found, "should not find non-matching chain")
}
func (s *GatewayCacheSuite) TestGeminiSessionTrie_DifferentPrefixHash() {
groupID := int64(1)
digestChain := "u:a-m:b"
// 保存到 prefixHash1
err := s.cache.SaveGeminiSession(s.ctx, groupID, "prefix1", digestChain, "uuid1", 1)
require.NoError(s.T(), err)
// 用 prefixHash2 查找,应该找不到(不同用户/客户端隔离)
_, _, found := s.cache.FindGeminiSession(s.ctx, groupID, "prefix2", digestChain)
require.False(s.T(), found, "different prefixHash should be isolated")
}
func (s *GatewayCacheSuite) TestGeminiSessionTrie_DifferentGroupID() {
prefixHash := "sameprefix"
digestChain := "u:a-m:b"
// 保存到 groupID 1
err := s.cache.SaveGeminiSession(s.ctx, 1, prefixHash, digestChain, "uuid1", 1)
require.NoError(s.T(), err)
// 用 groupID 2 查找,应该找不到(分组隔离)
_, _, found := s.cache.FindGeminiSession(s.ctx, 2, prefixHash, digestChain)
require.False(s.T(), found, "different groupID should be isolated")
}
func (s *GatewayCacheSuite) TestGeminiSessionTrie_EmptyDigestChain() {
groupID := int64(1)
prefixHash := "emptytest"
// 空链不应该保存
err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, "", "uuid", 1)
require.NoError(s.T(), err, "empty chain should not error")
// 空链查找应该返回 false
_, _, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, "")
require.False(s.T(), found, "empty chain should not match")
}
func (s *GatewayCacheSuite) TestGeminiSessionTrie_MultipleSessions() {
groupID := int64(1)
prefixHash := "multisession"
// 保存多个不同会话(模拟 1000 个并发会话的场景)
sessions := []struct {
chain string
uuid string
accountID int64
}{
{"u:session1", "uuid-1", 1},
{"u:session2-m:reply2", "uuid-2", 2},
{"u:session3-m:reply3-u:msg3", "uuid-3", 3},
}
for _, sess := range sessions {
err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, sess.chain, sess.uuid, sess.accountID)
require.NoError(s.T(), err)
}
// 验证每个会话都能正确查找
for _, sess := range sessions {
foundUUID, foundAccountID, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, sess.chain)
require.True(s.T(), found, "should find session: %s", sess.chain)
require.Equal(s.T(), sess.uuid, foundUUID)
require.Equal(s.T(), sess.accountID, foundAccountID)
}
// 验证继续对话的场景
foundUUID, foundAccountID, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, "u:session2-m:reply2-u:newmsg")
require.True(s.T(), found)
require.Equal(s.T(), "uuid-2", foundUUID)
require.Equal(s.T(), int64(2), foundAccountID)
}
func TestGatewayCacheSuite(t *testing.T) {
suite.Run(t, new(GatewayCacheSuite))

View File

@@ -1,234 +0,0 @@
//go:build integration
package repository
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
// ============ Gateway Cache 模型负载统计集成测试 ============
type GatewayCacheModelLoadSuite struct {
suite.Suite
}
func TestGatewayCacheModelLoadSuite(t *testing.T) {
suite.Run(t, new(GatewayCacheModelLoadSuite))
}
func (s *GatewayCacheModelLoadSuite) TestIncrModelCallCount_Basic() {
t := s.T()
rdb := testRedis(t)
cache := &gatewayCache{rdb: rdb}
ctx := context.Background()
accountID := int64(123)
model := "claude-sonnet-4-20250514"
// 首次调用应返回 1
count1, err := cache.IncrModelCallCount(ctx, accountID, model)
require.NoError(t, err)
require.Equal(t, int64(1), count1)
// 第二次调用应返回 2
count2, err := cache.IncrModelCallCount(ctx, accountID, model)
require.NoError(t, err)
require.Equal(t, int64(2), count2)
// 第三次调用应返回 3
count3, err := cache.IncrModelCallCount(ctx, accountID, model)
require.NoError(t, err)
require.Equal(t, int64(3), count3)
}
func (s *GatewayCacheModelLoadSuite) TestIncrModelCallCount_DifferentModels() {
t := s.T()
rdb := testRedis(t)
cache := &gatewayCache{rdb: rdb}
ctx := context.Background()
accountID := int64(456)
model1 := "claude-sonnet-4-20250514"
model2 := "claude-opus-4-5-20251101"
// 不同模型应该独立计数
count1, err := cache.IncrModelCallCount(ctx, accountID, model1)
require.NoError(t, err)
require.Equal(t, int64(1), count1)
count2, err := cache.IncrModelCallCount(ctx, accountID, model2)
require.NoError(t, err)
require.Equal(t, int64(1), count2)
count1Again, err := cache.IncrModelCallCount(ctx, accountID, model1)
require.NoError(t, err)
require.Equal(t, int64(2), count1Again)
}
func (s *GatewayCacheModelLoadSuite) TestIncrModelCallCount_DifferentAccounts() {
t := s.T()
rdb := testRedis(t)
cache := &gatewayCache{rdb: rdb}
ctx := context.Background()
account1 := int64(111)
account2 := int64(222)
model := "gemini-2.5-pro"
// 不同账号应该独立计数
count1, err := cache.IncrModelCallCount(ctx, account1, model)
require.NoError(t, err)
require.Equal(t, int64(1), count1)
count2, err := cache.IncrModelCallCount(ctx, account2, model)
require.NoError(t, err)
require.Equal(t, int64(1), count2)
}
func (s *GatewayCacheModelLoadSuite) TestGetModelLoadBatch_Empty() {
t := s.T()
rdb := testRedis(t)
cache := &gatewayCache{rdb: rdb}
ctx := context.Background()
result, err := cache.GetModelLoadBatch(ctx, []int64{}, "any-model")
require.NoError(t, err)
require.NotNil(t, result)
require.Empty(t, result)
}
func (s *GatewayCacheModelLoadSuite) TestGetModelLoadBatch_NonExistent() {
t := s.T()
rdb := testRedis(t)
cache := &gatewayCache{rdb: rdb}
ctx := context.Background()
// 查询不存在的账号应返回零值
result, err := cache.GetModelLoadBatch(ctx, []int64{9999, 9998}, "claude-sonnet-4-20250514")
require.NoError(t, err)
require.Len(t, result, 2)
require.Equal(t, int64(0), result[9999].CallCount)
require.True(t, result[9999].LastUsedAt.IsZero())
require.Equal(t, int64(0), result[9998].CallCount)
require.True(t, result[9998].LastUsedAt.IsZero())
}
func (s *GatewayCacheModelLoadSuite) TestGetModelLoadBatch_AfterIncrement() {
t := s.T()
rdb := testRedis(t)
cache := &gatewayCache{rdb: rdb}
ctx := context.Background()
accountID := int64(789)
model := "claude-sonnet-4-20250514"
// 先增加调用次数
beforeIncr := time.Now()
_, err := cache.IncrModelCallCount(ctx, accountID, model)
require.NoError(t, err)
_, err = cache.IncrModelCallCount(ctx, accountID, model)
require.NoError(t, err)
_, err = cache.IncrModelCallCount(ctx, accountID, model)
require.NoError(t, err)
afterIncr := time.Now()
// 获取负载信息
result, err := cache.GetModelLoadBatch(ctx, []int64{accountID}, model)
require.NoError(t, err)
require.Len(t, result, 1)
loadInfo := result[accountID]
require.NotNil(t, loadInfo)
require.Equal(t, int64(3), loadInfo.CallCount)
require.False(t, loadInfo.LastUsedAt.IsZero())
// LastUsedAt 应该在 beforeIncr 和 afterIncr 之间
require.True(t, loadInfo.LastUsedAt.After(beforeIncr.Add(-time.Second)) || loadInfo.LastUsedAt.Equal(beforeIncr))
require.True(t, loadInfo.LastUsedAt.Before(afterIncr.Add(time.Second)) || loadInfo.LastUsedAt.Equal(afterIncr))
}
func (s *GatewayCacheModelLoadSuite) TestGetModelLoadBatch_MultipleAccounts() {
t := s.T()
rdb := testRedis(t)
cache := &gatewayCache{rdb: rdb}
ctx := context.Background()
model := "claude-opus-4-5-20251101"
account1 := int64(1001)
account2 := int64(1002)
account3 := int64(1003) // 不调用
// account1 调用 2 次
_, err := cache.IncrModelCallCount(ctx, account1, model)
require.NoError(t, err)
_, err = cache.IncrModelCallCount(ctx, account1, model)
require.NoError(t, err)
// account2 调用 5 次
for i := 0; i < 5; i++ {
_, err = cache.IncrModelCallCount(ctx, account2, model)
require.NoError(t, err)
}
// 批量获取
result, err := cache.GetModelLoadBatch(ctx, []int64{account1, account2, account3}, model)
require.NoError(t, err)
require.Len(t, result, 3)
require.Equal(t, int64(2), result[account1].CallCount)
require.False(t, result[account1].LastUsedAt.IsZero())
require.Equal(t, int64(5), result[account2].CallCount)
require.False(t, result[account2].LastUsedAt.IsZero())
require.Equal(t, int64(0), result[account3].CallCount)
require.True(t, result[account3].LastUsedAt.IsZero())
}
func (s *GatewayCacheModelLoadSuite) TestGetModelLoadBatch_ModelIsolation() {
t := s.T()
rdb := testRedis(t)
cache := &gatewayCache{rdb: rdb}
ctx := context.Background()
accountID := int64(2001)
model1 := "claude-sonnet-4-20250514"
model2 := "gemini-2.5-pro"
// 对 model1 调用 3 次
for i := 0; i < 3; i++ {
_, err := cache.IncrModelCallCount(ctx, accountID, model1)
require.NoError(t, err)
}
// 获取 model1 的负载
result1, err := cache.GetModelLoadBatch(ctx, []int64{accountID}, model1)
require.NoError(t, err)
require.Equal(t, int64(3), result1[accountID].CallCount)
// 获取 model2 的负载(应该为 0
result2, err := cache.GetModelLoadBatch(ctx, []int64{accountID}, model2)
require.NoError(t, err)
require.Equal(t, int64(0), result2[accountID].CallCount)
}
// ============ 辅助函数测试 ============
func (s *GatewayCacheModelLoadSuite) TestModelLoadKey_Format() {
t := s.T()
key := modelLoadKey(123, "claude-sonnet-4")
require.Equal(t, "ag:model_load:123:claude-sonnet-4", key)
}
func (s *GatewayCacheModelLoadSuite) TestModelLastUsedKey_Format() {
t := s.T()
key := modelLastUsedKey(456, "gemini-2.5-pro")
require.Equal(t, "ag:model_last_used:456:gemini-2.5-pro", key)
}

View File

@@ -1004,10 +1004,6 @@ func (s *stubAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt
return errors.New("not implemented")
}
func (s *stubAccountRepo) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope service.AntigravityQuotaScope, resetAt time.Time) error {
return errors.New("not implemented")
}
func (s *stubAccountRepo) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
return errors.New("not implemented")
}

View File

@@ -50,7 +50,6 @@ type AccountRepository interface {
ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error)
SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error
SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error
SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error
SetOverloaded(ctx context.Context, id int64, until time.Time) error
SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error

View File

@@ -143,10 +143,6 @@ func (s *accountRepoStub) SetRateLimited(ctx context.Context, id int64, resetAt
panic("unexpected SetRateLimited call")
}
func (s *accountRepoStub) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error {
panic("unexpected SetAntigravityQuotaScopeLimit call")
}
func (s *accountRepoStub) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
panic("unexpected SetModelRateLimit call")
}

View File

@@ -2,7 +2,6 @@ package service
import (
"encoding/json"
"strconv"
"strings"
"time"
)
@@ -12,9 +11,6 @@ const (
// anthropicSessionTTLSeconds Anthropic 会话缓存 TTL5 分钟)
anthropicSessionTTLSeconds = 300
// anthropicTrieKeyPrefix Anthropic Trie 会话 key 前缀
anthropicTrieKeyPrefix = "anthropic:trie:"
// anthropicDigestSessionKeyPrefix Anthropic 摘要 fallback 会话 key 前缀
anthropicDigestSessionKeyPrefix = "anthropic:digest:"
)
@@ -68,12 +64,6 @@ func rolePrefix(role string) string {
}
}
// BuildAnthropicTrieKey 构建 Anthropic Trie Redis key
// 格式: anthropic:trie:{groupID}:{prefixHash}
func BuildAnthropicTrieKey(groupID int64, prefixHash string) string {
return anthropicTrieKeyPrefix + strconv.FormatInt(groupID, 10) + ":" + prefixHash
}
// GenerateAnthropicDigestSessionKey 生成 Anthropic 摘要 fallback 的 sessionKey
// 组合 prefixHash 前 8 位 + uuid 前 8 位,确保不同会话产生不同的 sessionKey
func GenerateAnthropicDigestSessionKey(prefixHash, uuid string) string {

View File

@@ -236,43 +236,6 @@ func TestBuildAnthropicDigestChain_Deterministic(t *testing.T) {
}
}
func TestBuildAnthropicTrieKey(t *testing.T) {
tests := []struct {
name string
groupID int64
prefixHash string
want string
}{
{
name: "normal",
groupID: 123,
prefixHash: "abcdef12",
want: "anthropic:trie:123:abcdef12",
},
{
name: "zero group",
groupID: 0,
prefixHash: "xyz",
want: "anthropic:trie:0:xyz",
},
{
name: "empty prefix",
groupID: 1,
prefixHash: "",
want: "anthropic:trie:1:",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := BuildAnthropicTrieKey(tt.groupID, tt.prefixHash)
if got != tt.want {
t.Errorf("BuildAnthropicTrieKey(%d, %q) = %q, want %q", tt.groupID, tt.prefixHash, got, tt.want)
}
})
}
}
func TestGenerateAnthropicDigestSessionKey(t *testing.T) {
tests := []struct {
name string

File diff suppressed because it is too large Load Diff

View File

@@ -4,17 +4,42 @@ import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
// antigravityFailingWriter 模拟客户端断开连接的 gin.ResponseWriter
type antigravityFailingWriter struct {
gin.ResponseWriter
failAfter int // 允许成功写入的次数,之后所有写入返回错误
writes int
}
func (w *antigravityFailingWriter) Write(p []byte) (int, error) {
if w.writes >= w.failAfter {
return 0, errors.New("write failed: client disconnected")
}
w.writes++
return w.ResponseWriter.Write(p)
}
// newAntigravityTestService 创建用于流式测试的 AntigravityGatewayService
func newAntigravityTestService(cfg *config.Config) *AntigravityGatewayService {
return &AntigravityGatewayService{
settingService: &SettingService{cfg: cfg},
}
}
func TestStripSignatureSensitiveBlocksFromClaudeRequest(t *testing.T) {
req := &antigravity.ClaudeRequest{
Model: "claude-sonnet-4-5",
@@ -337,8 +362,8 @@ func TestAntigravityGatewayService_Forward_StickySessionForceCacheBilling(t *tes
require.True(t, failoverErr.ForceCacheBilling, "ForceCacheBilling should be true for sticky session switch")
}
// TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling
// 验证:ForwardGemini 粘性会话切换时UpstreamFailoverError.ForceCacheBilling 应为 true
// TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling verifies
// that ForwardGemini sets ForceCacheBilling=true for sticky session switch.
func TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling(t *testing.T) {
gin.SetMode(gin.TestMode)
writer := httptest.NewRecorder()
@@ -391,3 +416,438 @@ func TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling(
require.Equal(t, http.StatusServiceUnavailable, failoverErr.StatusCode)
require.True(t, failoverErr.ForceCacheBilling, "ForceCacheBilling should be true for sticky session switch")
}
// --- 流式 happy path 测试 ---
// TestStreamUpstreamResponse_NormalComplete
// 验证正常流式转发完成时数据正确透传、usage 正确收集、clientDisconnect=false
func TestStreamUpstreamResponse_NormalComplete(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newAntigravityTestService(&config.Config{
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
})
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
pr, pw := io.Pipe()
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
go func() {
defer func() { _ = pw.Close() }()
fmt.Fprintln(pw, `event: message_start`)
fmt.Fprintln(pw, `data: {"type":"message_start","message":{"usage":{"input_tokens":10}}}`)
fmt.Fprintln(pw, "")
fmt.Fprintln(pw, `event: content_block_delta`)
fmt.Fprintln(pw, `data: {"type":"content_block_delta","delta":{"text":"hello"}}`)
fmt.Fprintln(pw, "")
fmt.Fprintln(pw, `event: message_delta`)
fmt.Fprintln(pw, `data: {"type":"message_delta","usage":{"output_tokens":5}}`)
fmt.Fprintln(pw, "")
}()
result := svc.streamUpstreamResponse(c, resp, time.Now())
_ = pr.Close()
require.NotNil(t, result)
require.False(t, result.clientDisconnect, "normal completion should not set clientDisconnect")
require.NotNil(t, result.usage)
require.Equal(t, 5, result.usage.OutputTokens, "should collect output_tokens from message_delta")
require.NotNil(t, result.firstTokenMs, "should record first token time")
// 验证数据被透传到客户端
body := rec.Body.String()
require.Contains(t, body, "event: message_start")
require.Contains(t, body, "content_block_delta")
require.Contains(t, body, "message_delta")
}
// TestHandleGeminiStreamingResponse_NormalComplete
// 验证:正常 Gemini 流式转发数据正确透传、usage 正确收集
func TestHandleGeminiStreamingResponse_NormalComplete(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newAntigravityTestService(&config.Config{
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
})
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
pr, pw := io.Pipe()
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
go func() {
defer func() { _ = pw.Close() }()
// 第一个 chunk部分内容
fmt.Fprintln(pw, `data: {"candidates":[{"content":{"parts":[{"text":"Hello"}]}}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":3}}`)
fmt.Fprintln(pw, "")
// 第二个 chunk最终内容+完整 usage
fmt.Fprintln(pw, `data: {"candidates":[{"content":{"parts":[{"text":" world"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":8,"cachedContentTokenCount":2}}`)
fmt.Fprintln(pw, "")
}()
result, err := svc.handleGeminiStreamingResponse(c, resp, time.Now())
_ = pr.Close()
require.NoError(t, err)
require.NotNil(t, result)
require.False(t, result.clientDisconnect, "normal completion should not set clientDisconnect")
require.NotNil(t, result.usage)
// Gemini usage: promptTokenCount=10, candidatesTokenCount=8, cachedContentTokenCount=2
// → InputTokens=10-2=8, OutputTokens=8, CacheReadInputTokens=2
require.Equal(t, 8, result.usage.InputTokens)
require.Equal(t, 8, result.usage.OutputTokens)
require.Equal(t, 2, result.usage.CacheReadInputTokens)
require.NotNil(t, result.firstTokenMs, "should record first token time")
// 验证数据被透传到客户端
body := rec.Body.String()
require.Contains(t, body, "Hello")
require.Contains(t, body, "world")
// 不应包含错误事件
require.NotContains(t, body, "event: error")
}
// TestHandleClaudeStreamingResponse_NormalComplete
// 验证:正常 Claude 流式转发Gemini→Claude 转换),数据正确转换并输出
func TestHandleClaudeStreamingResponse_NormalComplete(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newAntigravityTestService(&config.Config{
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
})
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
pr, pw := io.Pipe()
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
go func() {
defer func() { _ = pw.Close() }()
// v1internal 包装格式Gemini 数据嵌套在 "response" 字段下
// ProcessLine 先尝试反序列化为 V1InternalResponse裸格式会导致 Response.UsageMetadata 为空
fmt.Fprintln(pw, `data: {"response":{"candidates":[{"content":{"parts":[{"text":"Hi there"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":5,"candidatesTokenCount":3}}}`)
fmt.Fprintln(pw, "")
}()
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5")
_ = pr.Close()
require.NoError(t, err)
require.NotNil(t, result)
require.False(t, result.clientDisconnect, "normal completion should not set clientDisconnect")
require.NotNil(t, result.usage)
// Gemini→Claude 转换的 usagepromptTokenCount=5→InputTokens=5, candidatesTokenCount=3→OutputTokens=3
require.Equal(t, 5, result.usage.InputTokens)
require.Equal(t, 3, result.usage.OutputTokens)
require.NotNil(t, result.firstTokenMs, "should record first token time")
// 验证输出是 Claude SSE 格式processor 会转换)
body := rec.Body.String()
require.Contains(t, body, "event: message_start", "should contain Claude message_start event")
require.Contains(t, body, "event: message_stop", "should contain Claude message_stop event")
// 不应包含错误事件
require.NotContains(t, body, "event: error")
}
// --- 流式客户端断开检测测试 ---
// TestStreamUpstreamResponse_ClientDisconnectDrainsUsage
// 验证客户端写入失败后streamUpstreamResponse 继续读取上游以收集 usage
func TestStreamUpstreamResponse_ClientDisconnectDrainsUsage(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newAntigravityTestService(&config.Config{
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
})
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
c.Writer = &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
pr, pw := io.Pipe()
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
go func() {
defer func() { _ = pw.Close() }()
fmt.Fprintln(pw, `event: message_start`)
fmt.Fprintln(pw, `data: {"type":"message_start","message":{"usage":{"input_tokens":10}}}`)
fmt.Fprintln(pw, "")
fmt.Fprintln(pw, `event: message_delta`)
fmt.Fprintln(pw, `data: {"type":"message_delta","usage":{"output_tokens":20}}`)
fmt.Fprintln(pw, "")
}()
result := svc.streamUpstreamResponse(c, resp, time.Now())
_ = pr.Close()
require.NotNil(t, result)
require.True(t, result.clientDisconnect)
require.NotNil(t, result.usage)
require.Equal(t, 20, result.usage.OutputTokens)
}
// TestStreamUpstreamResponse_ContextCanceled
// 验证context 取消时返回 usage 且标记 clientDisconnect
func TestStreamUpstreamResponse_ContextCanceled(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newAntigravityTestService(&config.Config{
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
})
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
ctx, cancel := context.WithCancel(context.Background())
cancel()
c.Request = httptest.NewRequest(http.MethodPost, "/", nil).WithContext(ctx)
resp := &http.Response{StatusCode: http.StatusOK, Body: cancelReadCloser{}, Header: http.Header{}}
result := svc.streamUpstreamResponse(c, resp, time.Now())
require.NotNil(t, result)
require.True(t, result.clientDisconnect)
require.NotContains(t, rec.Body.String(), "event: error")
}
// TestStreamUpstreamResponse_Timeout
// 验证:上游超时时返回已收集的 usage
func TestStreamUpstreamResponse_Timeout(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newAntigravityTestService(&config.Config{
Gateway: config.GatewayConfig{StreamDataIntervalTimeout: 1, MaxLineSize: defaultMaxLineSize},
})
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
pr, pw := io.Pipe()
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
result := svc.streamUpstreamResponse(c, resp, time.Now())
_ = pw.Close()
_ = pr.Close()
require.NotNil(t, result)
require.False(t, result.clientDisconnect)
}
// TestStreamUpstreamResponse_TimeoutAfterClientDisconnect
// 验证:客户端断开后上游超时,返回 usage 并标记 clientDisconnect
func TestStreamUpstreamResponse_TimeoutAfterClientDisconnect(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newAntigravityTestService(&config.Config{
Gateway: config.GatewayConfig{StreamDataIntervalTimeout: 1, MaxLineSize: defaultMaxLineSize},
})
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
c.Writer = &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
pr, pw := io.Pipe()
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
go func() {
fmt.Fprintln(pw, `data: {"type":"message_start","message":{"usage":{"input_tokens":5}}}`)
fmt.Fprintln(pw, "")
// 不关闭 pw → 等待超时
}()
result := svc.streamUpstreamResponse(c, resp, time.Now())
_ = pw.Close()
_ = pr.Close()
require.NotNil(t, result)
require.True(t, result.clientDisconnect)
}
// TestHandleGeminiStreamingResponse_ClientDisconnect
// 验证Gemini 流式转发中客户端断开后继续 drain 上游
func TestHandleGeminiStreamingResponse_ClientDisconnect(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newAntigravityTestService(&config.Config{
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
})
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
c.Writer = &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
pr, pw := io.Pipe()
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
go func() {
defer func() { _ = pw.Close() }()
fmt.Fprintln(pw, `data: {"candidates":[{"content":{"parts":[{"text":"hi"}]}}],"usageMetadata":{"promptTokenCount":5,"candidatesTokenCount":10}}`)
fmt.Fprintln(pw, "")
}()
result, err := svc.handleGeminiStreamingResponse(c, resp, time.Now())
_ = pr.Close()
require.NoError(t, err)
require.NotNil(t, result)
require.True(t, result.clientDisconnect)
require.NotContains(t, rec.Body.String(), "write_failed")
}
// TestHandleGeminiStreamingResponse_ContextCanceled
// 验证context 取消时不注入错误事件
func TestHandleGeminiStreamingResponse_ContextCanceled(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newAntigravityTestService(&config.Config{
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
})
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
ctx, cancel := context.WithCancel(context.Background())
cancel()
c.Request = httptest.NewRequest(http.MethodPost, "/", nil).WithContext(ctx)
resp := &http.Response{StatusCode: http.StatusOK, Body: cancelReadCloser{}, Header: http.Header{}}
result, err := svc.handleGeminiStreamingResponse(c, resp, time.Now())
require.NoError(t, err)
require.NotNil(t, result)
require.True(t, result.clientDisconnect)
require.NotContains(t, rec.Body.String(), "event: error")
}
// TestHandleClaudeStreamingResponse_ClientDisconnect
// 验证Claude 流式转发中客户端断开后继续 drain 上游
func TestHandleClaudeStreamingResponse_ClientDisconnect(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newAntigravityTestService(&config.Config{
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
})
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
c.Writer = &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
pr, pw := io.Pipe()
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
go func() {
defer func() { _ = pw.Close() }()
// v1internal 包装格式
fmt.Fprintln(pw, `data: {"response":{"candidates":[{"content":{"parts":[{"text":"hello"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":8,"candidatesTokenCount":15}}}`)
fmt.Fprintln(pw, "")
}()
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5")
_ = pr.Close()
require.NoError(t, err)
require.NotNil(t, result)
require.True(t, result.clientDisconnect)
}
// TestHandleClaudeStreamingResponse_ContextCanceled
// 验证context 取消时不注入错误事件
func TestHandleClaudeStreamingResponse_ContextCanceled(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newAntigravityTestService(&config.Config{
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
})
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
ctx, cancel := context.WithCancel(context.Background())
cancel()
c.Request = httptest.NewRequest(http.MethodPost, "/", nil).WithContext(ctx)
resp := &http.Response{StatusCode: http.StatusOK, Body: cancelReadCloser{}, Header: http.Header{}}
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5")
require.NoError(t, err)
require.NotNil(t, result)
require.True(t, result.clientDisconnect)
require.NotContains(t, rec.Body.String(), "event: error")
}
// TestExtractSSEUsage 验证 extractSSEUsage 从 SSE data 行正确提取 usage
func TestExtractSSEUsage(t *testing.T) {
svc := &AntigravityGatewayService{}
tests := []struct {
name string
line string
expected ClaudeUsage
}{
{
name: "message_delta with output_tokens",
line: `data: {"type":"message_delta","usage":{"output_tokens":42}}`,
expected: ClaudeUsage{OutputTokens: 42},
},
{
name: "non-data line ignored",
line: `event: message_start`,
expected: ClaudeUsage{},
},
{
name: "top-level usage with all fields",
line: `data: {"usage":{"input_tokens":10,"output_tokens":20,"cache_read_input_tokens":5,"cache_creation_input_tokens":3}}`,
expected: ClaudeUsage{InputTokens: 10, OutputTokens: 20, CacheReadInputTokens: 5, CacheCreationInputTokens: 3},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
usage := &ClaudeUsage{}
svc.extractSSEUsage(tt.line, usage)
require.Equal(t, tt.expected, *usage)
})
}
}
// TestAntigravityClientWriter 验证 antigravityClientWriter 的断开检测
func TestAntigravityClientWriter(t *testing.T) {
t.Run("normal write succeeds", func(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
flusher, _ := c.Writer.(http.Flusher)
cw := newAntigravityClientWriter(c.Writer, flusher, "test")
ok := cw.Write([]byte("hello"))
require.True(t, ok)
require.False(t, cw.Disconnected())
require.Contains(t, rec.Body.String(), "hello")
})
t.Run("write failure marks disconnected", func(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
fw := &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
flusher, _ := c.Writer.(http.Flusher)
cw := newAntigravityClientWriter(fw, flusher, "test")
ok := cw.Write([]byte("hello"))
require.False(t, ok)
require.True(t, cw.Disconnected())
})
t.Run("subsequent writes are no-op", func(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
fw := &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
flusher, _ := c.Writer.(http.Flusher)
cw := newAntigravityClientWriter(fw, flusher, "test")
cw.Write([]byte("first"))
ok := cw.Fprintf("second %d", 2)
require.False(t, ok)
require.True(t, cw.Disconnected())
})
}

View File

@@ -2,63 +2,23 @@ package service
import (
"context"
"slices"
"strings"
"time"
)
const antigravityQuotaScopesKey = "antigravity_quota_scopes"
// AntigravityQuotaScope 表示 Antigravity 的配额域
type AntigravityQuotaScope string
const (
AntigravityQuotaScopeClaude AntigravityQuotaScope = "claude"
AntigravityQuotaScopeGeminiText AntigravityQuotaScope = "gemini_text"
AntigravityQuotaScopeGeminiImage AntigravityQuotaScope = "gemini_image"
)
// IsScopeSupported 检查给定的 scope 是否在分组支持的 scope 列表中
func IsScopeSupported(supportedScopes []string, scope AntigravityQuotaScope) bool {
if len(supportedScopes) == 0 {
// 未配置时默认全部支持
return true
}
supported := slices.Contains(supportedScopes, string(scope))
return supported
}
// ResolveAntigravityQuotaScope 根据模型名称解析配额域(导出版本)
func ResolveAntigravityQuotaScope(requestedModel string) (AntigravityQuotaScope, bool) {
return resolveAntigravityQuotaScope(requestedModel)
}
// resolveAntigravityQuotaScope 根据模型名称解析配额域
func resolveAntigravityQuotaScope(requestedModel string) (AntigravityQuotaScope, bool) {
model := normalizeAntigravityModelName(requestedModel)
if model == "" {
return "", false
}
switch {
case strings.HasPrefix(model, "claude-"):
return AntigravityQuotaScopeClaude, true
case strings.HasPrefix(model, "gemini-"):
if isImageGenerationModel(model) {
return AntigravityQuotaScopeGeminiImage, true
}
return AntigravityQuotaScopeGeminiText, true
default:
return "", false
}
}
func normalizeAntigravityModelName(model string) string {
normalized := strings.ToLower(strings.TrimSpace(model))
normalized = strings.TrimPrefix(normalized, "models/")
return normalized
}
// IsSchedulableForModel 结合 Antigravity 配额域限流判断是否可调度。
// resolveAntigravityModelKey 根据请求的模型名解析限流 key
// 返回空字符串表示无法解析
func resolveAntigravityModelKey(requestedModel string) string {
return normalizeAntigravityModelName(requestedModel)
}
// IsSchedulableForModel 结合模型级限流判断是否可调度。
// 保持旧签名以兼容既有调用方;默认使用 context.Background()。
func (a *Account) IsSchedulableForModel(requestedModel string) bool {
return a.IsSchedulableForModelWithContext(context.Background(), requestedModel)
@@ -74,107 +34,20 @@ func (a *Account) IsSchedulableForModelWithContext(ctx context.Context, requeste
if a.isModelRateLimitedWithContext(ctx, requestedModel) {
return false
}
if a.Platform != PlatformAntigravity {
return true
}
scope, ok := resolveAntigravityQuotaScope(requestedModel)
if !ok {
return true
}
resetAt := a.antigravityQuotaScopeResetAt(scope)
if resetAt == nil {
return true
}
now := time.Now()
return !now.Before(*resetAt)
return true
}
func (a *Account) antigravityQuotaScopeResetAt(scope AntigravityQuotaScope) *time.Time {
if a == nil || a.Extra == nil || scope == "" {
return nil
}
rawScopes, ok := a.Extra[antigravityQuotaScopesKey].(map[string]any)
if !ok {
return nil
}
rawScope, ok := rawScopes[string(scope)].(map[string]any)
if !ok {
return nil
}
resetAtRaw, ok := rawScope["rate_limit_reset_at"].(string)
if !ok || strings.TrimSpace(resetAtRaw) == "" {
return nil
}
resetAt, err := time.Parse(time.RFC3339, resetAtRaw)
if err != nil {
return nil
}
return &resetAt
}
var antigravityAllScopes = []AntigravityQuotaScope{
AntigravityQuotaScopeClaude,
AntigravityQuotaScopeGeminiText,
AntigravityQuotaScopeGeminiImage,
}
func (a *Account) GetAntigravityScopeRateLimits() map[string]int64 {
if a == nil || a.Platform != PlatformAntigravity {
return nil
}
now := time.Now()
result := make(map[string]int64)
for _, scope := range antigravityAllScopes {
resetAt := a.antigravityQuotaScopeResetAt(scope)
if resetAt != nil && now.Before(*resetAt) {
remainingSec := int64(time.Until(*resetAt).Seconds())
if remainingSec > 0 {
result[string(scope)] = remainingSec
}
}
}
if len(result) == 0 {
return nil
}
return result
}
// GetQuotaScopeRateLimitRemainingTime 获取模型域限流剩余时间
// 返回 0 表示未限流或已过期
func (a *Account) GetQuotaScopeRateLimitRemainingTime(requestedModel string) time.Duration {
if a == nil || a.Platform != PlatformAntigravity {
return 0
}
scope, ok := resolveAntigravityQuotaScope(requestedModel)
if !ok {
return 0
}
resetAt := a.antigravityQuotaScopeResetAt(scope)
if resetAt == nil {
return 0
}
if remaining := time.Until(*resetAt); remaining > 0 {
return remaining
}
return 0
}
// GetRateLimitRemainingTime 获取限流剩余时间(模型限流和模型域限流取最大值)
// GetRateLimitRemainingTime 获取限流剩余时间(模型级限流)
// 返回 0 表示未限流或已过期
func (a *Account) GetRateLimitRemainingTime(requestedModel string) time.Duration {
return a.GetRateLimitRemainingTimeWithContext(context.Background(), requestedModel)
}
// GetRateLimitRemainingTimeWithContext 获取限流剩余时间(模型限流和模型域限流取最大值
// GetRateLimitRemainingTimeWithContext 获取限流剩余时间(模型限流)
// 返回 0 表示未限流或已过期
func (a *Account) GetRateLimitRemainingTimeWithContext(ctx context.Context, requestedModel string) time.Duration {
if a == nil {
return 0
}
modelRemaining := a.GetModelRateLimitRemainingTimeWithContext(ctx, requestedModel)
scopeRemaining := a.GetQuotaScopeRateLimitRemainingTime(requestedModel)
if modelRemaining > scopeRemaining {
return modelRemaining
}
return scopeRemaining
return a.GetModelRateLimitRemainingTimeWithContext(ctx, requestedModel)
}

View File

@@ -59,12 +59,6 @@ func (s *stubAntigravityUpstream) DoWithTLS(req *http.Request, proxyURL string,
return s.Do(req, proxyURL, accountID, accountConcurrency)
}
type scopeLimitCall struct {
accountID int64
scope AntigravityQuotaScope
resetAt time.Time
}
type rateLimitCall struct {
accountID int64
resetAt time.Time
@@ -78,16 +72,10 @@ type modelRateLimitCall struct {
type stubAntigravityAccountRepo struct {
AccountRepository
scopeCalls []scopeLimitCall
rateCalls []rateLimitCall
modelRateLimitCalls []modelRateLimitCall
}
func (s *stubAntigravityAccountRepo) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error {
s.scopeCalls = append(s.scopeCalls, scopeLimitCall{accountID: id, scope: scope, resetAt: resetAt})
return nil
}
func (s *stubAntigravityAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
s.rateCalls = append(s.rateCalls, rateLimitCall{accountID: id, resetAt: resetAt})
return nil
@@ -131,10 +119,9 @@ func TestAntigravityRetryLoop_URLFallback_UsesLatestSuccess(t *testing.T) {
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
quotaScope: AntigravityQuotaScopeClaude,
httpUpstream: upstream,
requestedModel: "claude-sonnet-4-5",
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleErrorCalled = true
return nil
},
@@ -155,23 +142,6 @@ func TestAntigravityRetryLoop_URLFallback_UsesLatestSuccess(t *testing.T) {
require.Equal(t, base2, available[0])
}
func TestAntigravityHandleUpstreamError_UsesScopeLimit(t *testing.T) {
// 分区限流始终开启,不再支持通过环境变量关闭
repo := &stubAntigravityAccountRepo{}
svc := &AntigravityGatewayService{accountRepo: repo}
account := &Account{ID: 9, Name: "acc-9", Platform: PlatformAntigravity}
body := buildGeminiRateLimitBody("3s")
svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, AntigravityQuotaScopeClaude, 0, "", false)
require.Len(t, repo.scopeCalls, 1)
require.Empty(t, repo.rateCalls)
call := repo.scopeCalls[0]
require.Equal(t, account.ID, call.accountID)
require.Equal(t, AntigravityQuotaScopeClaude, call.scope)
require.WithinDuration(t, time.Now().Add(3*time.Second), call.resetAt, 2*time.Second)
}
// TestHandleUpstreamError_429_ModelRateLimit 测试 429 模型限流场景
func TestHandleUpstreamError_429_ModelRateLimit(t *testing.T) {
repo := &stubAntigravityAccountRepo{}
@@ -189,7 +159,7 @@ func TestHandleUpstreamError_429_ModelRateLimit(t *testing.T) {
}
}`)
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, AntigravityQuotaScopeClaude, 0, "", false)
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, "claude-sonnet-4-5", 0, "", false)
// 应该触发模型限流
require.NotNil(t, result)
@@ -200,22 +170,22 @@ func TestHandleUpstreamError_429_ModelRateLimit(t *testing.T) {
require.Equal(t, "claude-sonnet-4-5", repo.modelRateLimitCalls[0].modelKey)
}
// TestHandleUpstreamError_429_NonModelRateLimit 测试 429 非模型限流场景(走 scope 限流
// TestHandleUpstreamError_429_NonModelRateLimit 测试 429 非模型限流场景(走模型级限流兜底
func TestHandleUpstreamError_429_NonModelRateLimit(t *testing.T) {
repo := &stubAntigravityAccountRepo{}
svc := &AntigravityGatewayService{accountRepo: repo}
account := &Account{ID: 2, Name: "acc-2", Platform: PlatformAntigravity}
// 429 + 普通限流响应(无 RATE_LIMIT_EXCEEDED reasonscope 限流
// 429 + 普通限流响应(无 RATE_LIMIT_EXCEEDED reason走模型级限流兜底
body := buildGeminiRateLimitBody("5s")
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, AntigravityQuotaScopeClaude, 0, "", false)
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, "claude-sonnet-4-5", 0, "", false)
// 不应该触发模型限流,应该走 scope 限流
// handleModelRateLimit 不会处理(因为没有 RATE_LIMIT_EXCEEDED
// 但 429 兜底逻辑会使用 requestedModel 设置模型级限流
require.Nil(t, result)
require.Empty(t, repo.modelRateLimitCalls)
require.Len(t, repo.scopeCalls, 1)
require.Equal(t, AntigravityQuotaScopeClaude, repo.scopeCalls[0].scope)
require.Len(t, repo.modelRateLimitCalls, 1)
require.Equal(t, "claude-sonnet-4-5", repo.modelRateLimitCalls[0].modelKey)
}
// TestHandleUpstreamError_503_ModelRateLimit 测试 503 模型限流场景
@@ -235,7 +205,7 @@ func TestHandleUpstreamError_503_ModelRateLimit(t *testing.T) {
}
}`)
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, AntigravityQuotaScopeGeminiText, 0, "", false)
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, "gemini-3-pro-high", 0, "", false)
// 应该触发模型限流
require.NotNil(t, result)
@@ -263,12 +233,11 @@ func TestHandleUpstreamError_503_NonModelRateLimit(t *testing.T) {
}
}`)
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, AntigravityQuotaScopeGeminiText, 0, "", false)
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, "gemini-3-pro-high", 0, "", false)
// 503 非模型限流不应该做任何处理
require.Nil(t, result)
require.Empty(t, repo.modelRateLimitCalls, "503 non-model rate limit should not trigger model rate limit")
require.Empty(t, repo.scopeCalls, "503 non-model rate limit should not trigger scope rate limit")
require.Empty(t, repo.rateCalls, "503 non-model rate limit should not trigger account rate limit")
}
@@ -281,12 +250,11 @@ func TestHandleUpstreamError_503_EmptyBody(t *testing.T) {
// 503 + 空响应体 → 不做任何处理
body := []byte(`{}`)
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, AntigravityQuotaScopeGeminiText, 0, "", false)
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, "gemini-3-pro-high", 0, "", false)
// 503 空响应不应该做任何处理
require.Nil(t, result)
require.Empty(t, repo.modelRateLimitCalls)
require.Empty(t, repo.scopeCalls)
require.Empty(t, repo.rateCalls)
}
@@ -307,15 +275,7 @@ func TestAccountIsSchedulableForModel_AntigravityRateLimits(t *testing.T) {
require.False(t, account.IsSchedulableForModel("gemini-3-flash"))
account.RateLimitResetAt = nil
account.Extra = map[string]any{
antigravityQuotaScopesKey: map[string]any{
"claude": map[string]any{
"rate_limit_reset_at": future.Format(time.RFC3339),
},
},
}
require.False(t, account.IsSchedulableForModel("claude-sonnet-4-5"))
require.True(t, account.IsSchedulableForModel("claude-sonnet-4-5"))
require.True(t, account.IsSchedulableForModel("gemini-3-flash"))
}
@@ -635,6 +595,7 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) {
}`,
expectedShouldRetry: false,
expectedShouldRateLimit: true,
minWait: 7 * time.Second,
modelName: "gemini-pro",
},
{
@@ -652,6 +613,7 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) {
}`,
expectedShouldRetry: false,
expectedShouldRateLimit: true,
minWait: 39 * time.Second,
modelName: "gemini-3-pro-high",
},
{
@@ -669,6 +631,7 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) {
}`,
expectedShouldRetry: false,
expectedShouldRateLimit: true,
minWait: 30 * time.Second,
modelName: "gemini-2.5-flash",
},
{
@@ -686,6 +649,7 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) {
}`,
expectedShouldRetry: false,
expectedShouldRateLimit: true,
minWait: 30 * time.Second,
modelName: "claude-sonnet-4-5",
},
}
@@ -704,6 +668,11 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) {
t.Errorf("wait = %v, want >= %v", wait, tt.minWait)
}
}
if shouldRateLimit && tt.minWait > 0 {
if wait < tt.minWait {
t.Errorf("rate limit wait = %v, want >= %v", wait, tt.minWait)
}
}
if (shouldRetry || shouldRateLimit) && model != tt.modelName {
t.Errorf("modelName = %q, want %q", model, tt.modelName)
}
@@ -832,7 +801,7 @@ func TestAntigravityRetryLoop_PreCheck_SwitchesWhenRateLimited(t *testing.T) {
requestedModel: "claude-sonnet-4-5",
httpUpstream: upstream,
isStickySession: true,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
})
@@ -875,7 +844,7 @@ func TestAntigravityRetryLoop_PreCheck_SwitchesWhenRemainingLong(t *testing.T) {
requestedModel: "claude-sonnet-4-5",
httpUpstream: upstream,
isStickySession: true,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
})

View File

@@ -75,7 +75,7 @@ func TestHandleSmartRetry_URLLevelRateLimit(t *testing.T) {
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
@@ -127,7 +127,7 @@ func TestHandleSmartRetry_LongDelay_ReturnsSwitchError(t *testing.T) {
body: []byte(`{"input":"test"}`),
accountRepo: repo,
isStickySession: true,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
@@ -194,7 +194,7 @@ func TestHandleSmartRetry_ShortDelay_SmartRetrySuccess(t *testing.T) {
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
@@ -269,7 +269,7 @@ func TestHandleSmartRetry_ShortDelay_SmartRetryFailed_ReturnsSwitchError(t *test
httpUpstream: upstream,
accountRepo: repo,
isStickySession: false,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
@@ -331,7 +331,7 @@ func TestHandleSmartRetry_503_ModelCapacityExhausted_ReturnsSwitchError(t *testi
body: []byte(`{"input":"test"}`),
accountRepo: repo,
isStickySession: true,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
@@ -387,7 +387,7 @@ func TestHandleSmartRetry_NonAntigravityAccount_ContinuesDefaultLogic(t *testing
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
@@ -436,7 +436,7 @@ func TestHandleSmartRetry_NonModelRateLimit_ContinuesDefaultLogic(t *testing.T)
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
@@ -487,7 +487,7 @@ func TestHandleSmartRetry_ExactlyAtThreshold_ReturnsSwitchError(t *testing.T) {
action: "generateContent",
body: []byte(`{"input":"test"}`),
accountRepo: repo,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
@@ -548,7 +548,7 @@ func TestAntigravityRetryLoop_HandleSmartRetry_SwitchError_Propagates(t *testing
httpUpstream: upstream,
accountRepo: repo,
isStickySession: true,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
})
@@ -604,7 +604,7 @@ func TestHandleSmartRetry_NetworkError_ExhaustsRetry(t *testing.T) {
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
accountRepo: repo,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
@@ -662,7 +662,7 @@ func TestHandleSmartRetry_NoRetryDelay_UsesDefaultRateLimit(t *testing.T) {
body: []byte(`{"input":"test"}`),
accountRepo: repo,
isStickySession: true,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
@@ -754,7 +754,7 @@ func TestHandleSmartRetry_ShortDelay_StickySession_FailedRetry_ClearsSession(t *
isStickySession: true,
groupID: 42,
sessionHash: "sticky-hash-abc",
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
@@ -842,7 +842,7 @@ func TestHandleSmartRetry_ShortDelay_NonStickySession_FailedRetry_NoDeleteSessio
isStickySession: false,
groupID: 42,
sessionHash: "", // 非粘性会话sessionHash 为空
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
@@ -918,7 +918,7 @@ func TestHandleSmartRetry_ShortDelay_StickySession_FailedRetry_NilCache_NoPanic(
isStickySession: true,
groupID: 42,
sessionHash: "sticky-hash-nil-cache",
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
@@ -983,7 +983,7 @@ func TestHandleSmartRetry_ShortDelay_StickySession_SuccessRetry_NoDeleteSession(
isStickySession: true,
groupID: 42,
sessionHash: "sticky-hash-success",
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
@@ -1043,7 +1043,7 @@ func TestHandleSmartRetry_LongDelay_StickySession_NoDeleteInHandleSmartRetry(t *
isStickySession: true,
groupID: 42,
sessionHash: "sticky-hash-long-delay",
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
@@ -1108,7 +1108,7 @@ func TestHandleSmartRetry_ShortDelay_NetworkError_StickySession_ClearsSession(t
isStickySession: true,
groupID: 99,
sessionHash: "sticky-net-error",
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
@@ -1188,7 +1188,7 @@ func TestHandleSmartRetry_ShortDelay_503_StickySession_FailedRetry_ClearsSession
isStickySession: true,
groupID: 77,
sessionHash: "sticky-503-short",
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
@@ -1278,7 +1278,7 @@ func TestAntigravityRetryLoop_SmartRetryFailed_StickySession_SwitchErrorPropagat
isStickySession: true,
groupID: 55,
sessionHash: "sticky-loop-test",
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
})
@@ -1296,4 +1296,4 @@ func TestAntigravityRetryLoop_SmartRetryFailed_StickySession_SwitchErrorPropagat
require.Len(t, cache.deleteCalls, 1, "should clear sticky session in handleSmartRetry")
require.Equal(t, int64(55), cache.deleteCalls[0].groupID)
require.Equal(t, "sticky-loop-test", cache.deleteCalls[0].sessionHash)
}
}

View File

@@ -0,0 +1,69 @@
package service
import (
"strconv"
"strings"
"time"
gocache "github.com/patrickmn/go-cache"
)
// digestSessionTTL 摘要会话默认 TTL
const digestSessionTTL = 5 * time.Minute
// sessionEntry flat cache 条目
type sessionEntry struct {
uuid string
accountID int64
}
// DigestSessionStore 内存摘要会话存储flat cache 实现)
// key: "{groupID}:{prefixHash}|{digestChain}" → *sessionEntry
type DigestSessionStore struct {
cache *gocache.Cache
}
// NewDigestSessionStore 创建内存摘要会话存储
func NewDigestSessionStore() *DigestSessionStore {
return &DigestSessionStore{
cache: gocache.New(digestSessionTTL, time.Minute),
}
}
// Save 保存摘要会话。oldDigestChain 为 Find 返回的 matchedChain用于删旧 key。
func (s *DigestSessionStore) Save(groupID int64, prefixHash, digestChain, uuid string, accountID int64, oldDigestChain string) {
if digestChain == "" {
return
}
ns := buildNS(groupID, prefixHash)
s.cache.Set(ns+digestChain, &sessionEntry{uuid: uuid, accountID: accountID}, gocache.DefaultExpiration)
if oldDigestChain != "" && oldDigestChain != digestChain {
s.cache.Delete(ns + oldDigestChain)
}
}
// Find 查找摘要会话,从完整 chain 逐段截断,返回最长匹配及对应 matchedChain。
func (s *DigestSessionStore) Find(groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, matchedChain string, found bool) {
if digestChain == "" {
return "", 0, "", false
}
ns := buildNS(groupID, prefixHash)
chain := digestChain
for {
if val, ok := s.cache.Get(ns + chain); ok {
if e, ok := val.(*sessionEntry); ok {
return e.uuid, e.accountID, chain, true
}
}
i := strings.LastIndex(chain, "-")
if i < 0 {
return "", 0, "", false
}
chain = chain[:i]
}
}
// buildNS 构建 namespace 前缀
func buildNS(groupID int64, prefixHash string) string {
return strconv.FormatInt(groupID, 10) + ":" + prefixHash + "|"
}

View File

@@ -0,0 +1,312 @@
//go:build unit
package service
import (
"fmt"
"sync"
"testing"
"time"
gocache "github.com/patrickmn/go-cache"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestDigestSessionStore_SaveAndFind(t *testing.T) {
store := NewDigestSessionStore()
store.Save(1, "prefix", "s:a1-u:b2-m:c3", "uuid-1", 100, "")
uuid, accountID, _, found := store.Find(1, "prefix", "s:a1-u:b2-m:c3")
require.True(t, found)
assert.Equal(t, "uuid-1", uuid)
assert.Equal(t, int64(100), accountID)
}
func TestDigestSessionStore_PrefixMatch(t *testing.T) {
store := NewDigestSessionStore()
// 保存短链
store.Save(1, "prefix", "u:a-m:b", "uuid-short", 10, "")
// 用长链查找,应前缀匹配到短链
uuid, accountID, matchedChain, found := store.Find(1, "prefix", "u:a-m:b-u:c-m:d")
require.True(t, found)
assert.Equal(t, "uuid-short", uuid)
assert.Equal(t, int64(10), accountID)
assert.Equal(t, "u:a-m:b", matchedChain)
}
func TestDigestSessionStore_LongestPrefixMatch(t *testing.T) {
store := NewDigestSessionStore()
store.Save(1, "prefix", "u:a", "uuid-1", 1, "")
store.Save(1, "prefix", "u:a-m:b", "uuid-2", 2, "")
store.Save(1, "prefix", "u:a-m:b-u:c", "uuid-3", 3, "")
// 应匹配最深的 "u:a-m:b-u:c"(从完整 chain 逐段截断,先命中最长的)
uuid, accountID, _, found := store.Find(1, "prefix", "u:a-m:b-u:c-m:d-u:e")
require.True(t, found)
assert.Equal(t, "uuid-3", uuid)
assert.Equal(t, int64(3), accountID)
// 查找中等长度,应匹配到 "u:a-m:b"
uuid, accountID, _, found = store.Find(1, "prefix", "u:a-m:b-u:x")
require.True(t, found)
assert.Equal(t, "uuid-2", uuid)
assert.Equal(t, int64(2), accountID)
}
func TestDigestSessionStore_SaveDeletesOldChain(t *testing.T) {
store := NewDigestSessionStore()
// 第一轮:保存 "u:a-m:b"
store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "")
// 第二轮:同一 uuid 保存更长的链,传入旧 chain
store.Save(1, "prefix", "u:a-m:b-u:c-m:d", "uuid-1", 100, "u:a-m:b")
// 旧链 "u:a-m:b" 应已被删除
_, _, _, found := store.Find(1, "prefix", "u:a-m:b")
assert.False(t, found, "old chain should be deleted")
// 新链应能找到
uuid, accountID, _, found := store.Find(1, "prefix", "u:a-m:b-u:c-m:d")
require.True(t, found)
assert.Equal(t, "uuid-1", uuid)
assert.Equal(t, int64(100), accountID)
}
func TestDigestSessionStore_DifferentSessionsNoInterference(t *testing.T) {
store := NewDigestSessionStore()
// 相同系统提示词,不同用户提示词
store.Save(1, "prefix", "s:sys-u:user1", "uuid-1", 100, "")
store.Save(1, "prefix", "s:sys-u:user2", "uuid-2", 200, "")
uuid, accountID, _, found := store.Find(1, "prefix", "s:sys-u:user1-m:reply1")
require.True(t, found)
assert.Equal(t, "uuid-1", uuid)
assert.Equal(t, int64(100), accountID)
uuid, accountID, _, found = store.Find(1, "prefix", "s:sys-u:user2-m:reply2")
require.True(t, found)
assert.Equal(t, "uuid-2", uuid)
assert.Equal(t, int64(200), accountID)
}
func TestDigestSessionStore_NoMatch(t *testing.T) {
store := NewDigestSessionStore()
store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "")
// 完全不同的 chain
_, _, _, found := store.Find(1, "prefix", "u:x-m:y")
assert.False(t, found)
}
func TestDigestSessionStore_DifferentPrefixHash(t *testing.T) {
store := NewDigestSessionStore()
store.Save(1, "prefix1", "u:a-m:b", "uuid-1", 100, "")
// 不同 prefixHash 应隔离
_, _, _, found := store.Find(1, "prefix2", "u:a-m:b")
assert.False(t, found)
}
func TestDigestSessionStore_DifferentGroupID(t *testing.T) {
store := NewDigestSessionStore()
store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "")
// 不同 groupID 应隔离
_, _, _, found := store.Find(2, "prefix", "u:a-m:b")
assert.False(t, found)
}
func TestDigestSessionStore_EmptyDigestChain(t *testing.T) {
store := NewDigestSessionStore()
// 空链不应保存
store.Save(1, "prefix", "", "uuid-1", 100, "")
_, _, _, found := store.Find(1, "prefix", "")
assert.False(t, found)
}
func TestDigestSessionStore_TTLExpiration(t *testing.T) {
store := &DigestSessionStore{
cache: gocache.New(100*time.Millisecond, 50*time.Millisecond),
}
store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "")
// 立即应该能找到
_, _, _, found := store.Find(1, "prefix", "u:a-m:b")
require.True(t, found)
// 等待过期 + 清理周期
time.Sleep(300 * time.Millisecond)
// 过期后应找不到
_, _, _, found = store.Find(1, "prefix", "u:a-m:b")
assert.False(t, found)
}
func TestDigestSessionStore_ConcurrentSafety(t *testing.T) {
store := NewDigestSessionStore()
var wg sync.WaitGroup
const goroutines = 50
const operations = 100
wg.Add(goroutines)
for g := 0; g < goroutines; g++ {
go func(id int) {
defer wg.Done()
prefix := fmt.Sprintf("prefix-%d", id%5)
for i := 0; i < operations; i++ {
chain := fmt.Sprintf("u:%d-m:%d", id, i)
uuid := fmt.Sprintf("uuid-%d-%d", id, i)
store.Save(1, prefix, chain, uuid, int64(id), "")
store.Find(1, prefix, chain)
}
}(g)
}
wg.Wait()
}
func TestDigestSessionStore_MultipleSessions(t *testing.T) {
store := NewDigestSessionStore()
sessions := []struct {
chain string
uuid string
accountID int64
}{
{"u:session1", "uuid-1", 1},
{"u:session2-m:reply2", "uuid-2", 2},
{"u:session3-m:reply3-u:msg3", "uuid-3", 3},
}
for _, sess := range sessions {
store.Save(1, "prefix", sess.chain, sess.uuid, sess.accountID, "")
}
// 验证每个会话都能正确查找
for _, sess := range sessions {
uuid, accountID, _, found := store.Find(1, "prefix", sess.chain)
require.True(t, found, "should find session: %s", sess.chain)
assert.Equal(t, sess.uuid, uuid)
assert.Equal(t, sess.accountID, accountID)
}
// 验证继续对话的场景
uuid, accountID, _, found := store.Find(1, "prefix", "u:session2-m:reply2-u:newmsg")
require.True(t, found)
assert.Equal(t, "uuid-2", uuid)
assert.Equal(t, int64(2), accountID)
}
func TestDigestSessionStore_Performance1000Sessions(t *testing.T) {
store := NewDigestSessionStore()
// 插入 1000 个会话
for i := 0; i < 1000; i++ {
chain := fmt.Sprintf("s:sys-u:user%d-m:reply%d", i, i)
store.Save(1, "prefix", chain, fmt.Sprintf("uuid-%d", i), int64(i), "")
}
// 查找性能测试
start := time.Now()
const lookups = 10000
for i := 0; i < lookups; i++ {
idx := i % 1000
chain := fmt.Sprintf("s:sys-u:user%d-m:reply%d-u:newmsg", idx, idx)
_, _, _, found := store.Find(1, "prefix", chain)
assert.True(t, found)
}
elapsed := time.Since(start)
t.Logf("%d lookups in %v (%.0f ns/op)", lookups, elapsed, float64(elapsed.Nanoseconds())/lookups)
}
func TestDigestSessionStore_FindReturnsMatchedChain(t *testing.T) {
store := NewDigestSessionStore()
store.Save(1, "prefix", "u:a-m:b-u:c", "uuid-1", 100, "")
// 精确匹配
_, _, matchedChain, found := store.Find(1, "prefix", "u:a-m:b-u:c")
require.True(t, found)
assert.Equal(t, "u:a-m:b-u:c", matchedChain)
// 前缀匹配(截断后命中)
_, _, matchedChain, found = store.Find(1, "prefix", "u:a-m:b-u:c-m:d-u:e")
require.True(t, found)
assert.Equal(t, "u:a-m:b-u:c", matchedChain)
}
func TestDigestSessionStore_CacheItemCountStable(t *testing.T) {
store := NewDigestSessionStore()
// 模拟 100 个独立会话,每个进行 10 轮对话
// 正确传递 oldDigestChain 时,每个会话始终只保留 1 个 key
for conv := 0; conv < 100; conv++ {
var prevMatchedChain string
for round := 0; round < 10; round++ {
chain := fmt.Sprintf("s:sys-u:user%d", conv)
for r := 0; r < round; r++ {
chain += fmt.Sprintf("-m:a%d-u:q%d", r, r+1)
}
uuid := fmt.Sprintf("uuid-conv%d", conv)
_, _, matched, _ := store.Find(1, "prefix", chain)
store.Save(1, "prefix", chain, uuid, int64(conv), matched)
prevMatchedChain = matched
_ = prevMatchedChain
}
}
// 100 个会话 × 1 key/会话 = 应该 ≤ 100 个 key
// 允许少量并发残留,但绝不能接近 100×10=1000
itemCount := store.cache.ItemCount()
assert.LessOrEqual(t, itemCount, 100, "cache should have at most 100 items (1 per conversation), got %d", itemCount)
t.Logf("Cache item count after 100 conversations × 10 rounds: %d", itemCount)
}
func TestDigestSessionStore_TTLPreventsUnboundedGrowth(t *testing.T) {
// 使用极短 TTL 验证大量写入后 cache 能被清理
store := &DigestSessionStore{
cache: gocache.New(100*time.Millisecond, 50*time.Millisecond),
}
// 插入 500 个不同的 key无 oldDigestChain模拟最坏场景全是新会话首轮
for i := 0; i < 500; i++ {
chain := fmt.Sprintf("u:user%d", i)
store.Save(1, "prefix", chain, fmt.Sprintf("uuid-%d", i), int64(i), "")
}
assert.Equal(t, 500, store.cache.ItemCount())
// 等待 TTL + 清理周期
time.Sleep(300 * time.Millisecond)
assert.Equal(t, 0, store.cache.ItemCount(), "all items should be expired and cleaned up")
}
func TestDigestSessionStore_SaveSameChainNoDelete(t *testing.T) {
store := NewDigestSessionStore()
// 保存 chain
store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "")
// 用户重发相同消息oldDigestChain == digestChain不应删掉刚设置的 key
store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "u:a-m:b")
// 仍然能找到
uuid, accountID, _, found := store.Find(1, "prefix", "u:a-m:b")
require.True(t, found)
assert.Equal(t, "uuid-1", uuid)
assert.Equal(t, int64(100), accountID)
}

View File

@@ -0,0 +1,366 @@
//go:build unit
package service
import (
"context"
"io"
"net/http"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/stretchr/testify/require"
)
// ---------------------------------------------------------------------------
// Mocks (scoped to this file by naming convention)
// ---------------------------------------------------------------------------
// epFixedUpstream returns a fixed response for every request.
type epFixedUpstream struct {
statusCode int
body string
calls int
}
func (u *epFixedUpstream) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
u.calls++
return &http.Response{
StatusCode: u.statusCode,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(u.body)),
}, nil
}
func (u *epFixedUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) {
return u.Do(req, proxyURL, accountID, accountConcurrency)
}
// epAccountRepo records SetTempUnschedulable / SetError calls.
type epAccountRepo struct {
mockAccountRepoForGemini
tempCalls int
setErrCalls int
}
func (r *epAccountRepo) SetTempUnschedulable(_ context.Context, _ int64, _ time.Time, _ string) error {
r.tempCalls++
return nil
}
func (r *epAccountRepo) SetError(_ context.Context, _ int64, _ string) error {
r.setErrCalls++
return nil
}
// ---------------------------------------------------------------------------
// Helpers
// ---------------------------------------------------------------------------
func saveAndSetBaseURLs(t *testing.T) {
t.Helper()
oldBaseURLs := append([]string(nil), antigravity.BaseURLs...)
oldAvail := antigravity.DefaultURLAvailability
antigravity.BaseURLs = []string{"https://ep-test.example"}
antigravity.DefaultURLAvailability = antigravity.NewURLAvailability(time.Minute)
t.Cleanup(func() {
antigravity.BaseURLs = oldBaseURLs
antigravity.DefaultURLAvailability = oldAvail
})
}
func newRetryParams(account *Account, upstream HTTPUpstream, handleError func(context.Context, string, *Account, int, http.Header, []byte, string, int64, string, bool) *handleModelRateLimitResult) antigravityRetryLoopParams {
return antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[ep-test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
requestedModel: "claude-sonnet-4-5",
handleError: handleError,
}
}
// ---------------------------------------------------------------------------
// TestRetryLoop_ErrorPolicy_CustomErrorCodes
// ---------------------------------------------------------------------------
func TestRetryLoop_ErrorPolicy_CustomErrorCodes(t *testing.T) {
tests := []struct {
name string
upstreamStatus int
upstreamBody string
customCodes []any
expectHandleError int
expectUpstream int
expectStatusCode int
}{
{
name: "429_in_custom_codes_matched",
upstreamStatus: 429,
upstreamBody: `{"error":"rate limited"}`,
customCodes: []any{float64(429)},
expectHandleError: 1,
expectUpstream: 1,
expectStatusCode: 429,
},
{
name: "429_not_in_custom_codes_skipped",
upstreamStatus: 429,
upstreamBody: `{"error":"rate limited"}`,
customCodes: []any{float64(500)},
expectHandleError: 0,
expectUpstream: 1,
expectStatusCode: 429,
},
{
name: "500_in_custom_codes_matched",
upstreamStatus: 500,
upstreamBody: `{"error":"internal"}`,
customCodes: []any{float64(500)},
expectHandleError: 1,
expectUpstream: 1,
expectStatusCode: 500,
},
{
name: "500_not_in_custom_codes_skipped",
upstreamStatus: 500,
upstreamBody: `{"error":"internal"}`,
customCodes: []any{float64(429)},
expectHandleError: 0,
expectUpstream: 1,
expectStatusCode: 500,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
saveAndSetBaseURLs(t)
upstream := &epFixedUpstream{statusCode: tt.upstreamStatus, body: tt.upstreamBody}
repo := &epAccountRepo{}
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
account := &Account{
ID: 100,
Type: AccountTypeAPIKey,
Platform: PlatformAntigravity,
Schedulable: true,
Status: StatusActive,
Concurrency: 1,
Credentials: map[string]any{
"custom_error_codes_enabled": true,
"custom_error_codes": tt.customCodes,
},
}
svc := &AntigravityGatewayService{rateLimitService: rlSvc}
var handleErrorCount int
p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult {
handleErrorCount++
return nil
})
result, err := svc.antigravityRetryLoop(p)
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.resp)
defer func() { _ = result.resp.Body.Close() }()
require.Equal(t, tt.expectStatusCode, result.resp.StatusCode)
require.Equal(t, tt.expectHandleError, handleErrorCount, "handleError call count")
require.Equal(t, tt.expectUpstream, upstream.calls, "upstream call count")
})
}
}
// ---------------------------------------------------------------------------
// TestRetryLoop_ErrorPolicy_TempUnschedulable
// ---------------------------------------------------------------------------
func TestRetryLoop_ErrorPolicy_TempUnschedulable(t *testing.T) {
tempRulesAccount := func(rules []any) *Account {
return &Account{
ID: 200,
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Schedulable: true,
Status: StatusActive,
Concurrency: 1,
Credentials: map[string]any{
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": rules,
},
}
}
overloadedRule := map[string]any{
"error_code": float64(503),
"keywords": []any{"overloaded"},
"duration_minutes": float64(10),
}
rateLimitRule := map[string]any{
"error_code": float64(429),
"keywords": []any{"rate limited keyword"},
"duration_minutes": float64(5),
}
t.Run("503_overloaded_matches_rule", func(t *testing.T) {
saveAndSetBaseURLs(t)
upstream := &epFixedUpstream{statusCode: 503, body: `overloaded`}
repo := &epAccountRepo{}
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
svc := &AntigravityGatewayService{rateLimitService: rlSvc}
account := tempRulesAccount([]any{overloadedRule})
p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult {
t.Error("handleError should not be called for temp unschedulable")
return nil
})
result, err := svc.antigravityRetryLoop(p)
require.Nil(t, result)
var switchErr *AntigravityAccountSwitchError
require.ErrorAs(t, err, &switchErr)
require.Equal(t, account.ID, switchErr.OriginalAccountID)
require.Equal(t, 1, upstream.calls, "should not retry")
})
t.Run("429_rate_limited_keyword_matches_rule", func(t *testing.T) {
saveAndSetBaseURLs(t)
upstream := &epFixedUpstream{statusCode: 429, body: `rate limited keyword`}
repo := &epAccountRepo{}
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
svc := &AntigravityGatewayService{rateLimitService: rlSvc}
account := tempRulesAccount([]any{rateLimitRule})
p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult {
t.Error("handleError should not be called for temp unschedulable")
return nil
})
result, err := svc.antigravityRetryLoop(p)
require.Nil(t, result)
var switchErr *AntigravityAccountSwitchError
require.ErrorAs(t, err, &switchErr)
require.Equal(t, account.ID, switchErr.OriginalAccountID)
require.Equal(t, 1, upstream.calls, "should not retry")
})
t.Run("503_body_no_match_continues_default_retry", func(t *testing.T) {
saveAndSetBaseURLs(t)
upstream := &epFixedUpstream{statusCode: 503, body: `random`}
repo := &epAccountRepo{}
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
svc := &AntigravityGatewayService{rateLimitService: rlSvc}
account := tempRulesAccount([]any{overloadedRule})
// Use a short-lived context: the backoff sleep (~1s) will be
// interrupted, proving the code entered the default retry path
// instead of breaking early via error policy.
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult {
return nil
})
p.ctx = ctx
result, err := svc.antigravityRetryLoop(p)
// Context cancellation during backoff proves default retry was entered
require.Nil(t, result)
require.ErrorIs(t, err, context.DeadlineExceeded)
require.GreaterOrEqual(t, upstream.calls, 1, "should have called upstream at least once")
})
}
// ---------------------------------------------------------------------------
// TestRetryLoop_ErrorPolicy_NilRateLimitService
// ---------------------------------------------------------------------------
func TestRetryLoop_ErrorPolicy_NilRateLimitService(t *testing.T) {
saveAndSetBaseURLs(t)
upstream := &epFixedUpstream{statusCode: 429, body: `{"error":"rate limited"}`}
// rateLimitService is nil — must not panic
svc := &AntigravityGatewayService{rateLimitService: nil}
account := &Account{
ID: 300,
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Schedulable: true,
Status: StatusActive,
Concurrency: 1,
}
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult {
return nil
})
p.ctx = ctx
// Should not panic; enters the default retry path (eventually times out)
result, err := svc.antigravityRetryLoop(p)
require.Nil(t, result)
require.ErrorIs(t, err, context.DeadlineExceeded)
require.GreaterOrEqual(t, upstream.calls, 1)
}
// ---------------------------------------------------------------------------
// TestRetryLoop_ErrorPolicy_NoPolicy_OriginalBehavior
// ---------------------------------------------------------------------------
func TestRetryLoop_ErrorPolicy_NoPolicy_OriginalBehavior(t *testing.T) {
saveAndSetBaseURLs(t)
upstream := &epFixedUpstream{statusCode: 429, body: `{"error":"rate limited"}`}
repo := &epAccountRepo{}
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
svc := &AntigravityGatewayService{rateLimitService: rlSvc}
// Plain OAuth account with no error policy configured
account := &Account{
ID: 400,
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Schedulable: true,
Status: StatusActive,
Concurrency: 1,
}
var handleErrorCount int
p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult {
handleErrorCount++
return nil
})
result, err := svc.antigravityRetryLoop(p)
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.resp)
defer func() { _ = result.resp.Body.Close() }()
require.Equal(t, http.StatusTooManyRequests, result.resp.StatusCode)
require.Equal(t, antigravityMaxRetries, upstream.calls, "should exhaust all retries")
require.Equal(t, 1, handleErrorCount, "handleError should be called once after retries exhausted")
}

View File

@@ -0,0 +1,289 @@
//go:build unit
package service
import (
"context"
"net/http"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
// ---------------------------------------------------------------------------
// TestCheckErrorPolicy — 6 table-driven cases for the pure logic function
// ---------------------------------------------------------------------------
func TestCheckErrorPolicy(t *testing.T) {
tests := []struct {
name string
account *Account
statusCode int
body []byte
expected ErrorPolicyResult
}{
{
name: "no_policy_oauth_returns_none",
account: &Account{
ID: 1,
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
// no custom error codes, no temp rules
},
statusCode: 500,
body: []byte(`"error"`),
expected: ErrorPolicyNone,
},
{
name: "custom_error_codes_hit_returns_matched",
account: &Account{
ID: 2,
Type: AccountTypeAPIKey,
Platform: PlatformAntigravity,
Credentials: map[string]any{
"custom_error_codes_enabled": true,
"custom_error_codes": []any{float64(429), float64(500)},
},
},
statusCode: 500,
body: []byte(`"error"`),
expected: ErrorPolicyMatched,
},
{
name: "custom_error_codes_miss_returns_skipped",
account: &Account{
ID: 3,
Type: AccountTypeAPIKey,
Platform: PlatformAntigravity,
Credentials: map[string]any{
"custom_error_codes_enabled": true,
"custom_error_codes": []any{float64(429), float64(500)},
},
},
statusCode: 503,
body: []byte(`"error"`),
expected: ErrorPolicySkipped,
},
{
name: "temp_unschedulable_hit_returns_temp_unscheduled",
account: &Account{
ID: 4,
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Credentials: map[string]any{
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": float64(503),
"keywords": []any{"overloaded"},
"duration_minutes": float64(10),
"description": "overloaded rule",
},
},
},
},
statusCode: 503,
body: []byte(`overloaded service`),
expected: ErrorPolicyTempUnscheduled,
},
{
name: "temp_unschedulable_body_miss_returns_none",
account: &Account{
ID: 5,
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Credentials: map[string]any{
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": float64(503),
"keywords": []any{"overloaded"},
"duration_minutes": float64(10),
"description": "overloaded rule",
},
},
},
},
statusCode: 503,
body: []byte(`random msg`),
expected: ErrorPolicyNone,
},
{
name: "custom_error_codes_override_temp_unschedulable",
account: &Account{
ID: 6,
Type: AccountTypeAPIKey,
Platform: PlatformAntigravity,
Credentials: map[string]any{
"custom_error_codes_enabled": true,
"custom_error_codes": []any{float64(503)},
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": float64(503),
"keywords": []any{"overloaded"},
"duration_minutes": float64(10),
"description": "overloaded rule",
},
},
},
},
statusCode: 503,
body: []byte(`overloaded`),
expected: ErrorPolicyMatched, // custom codes take precedence
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
repo := &errorPolicyRepoStub{}
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
result := svc.CheckErrorPolicy(context.Background(), tt.account, tt.statusCode, tt.body)
require.Equal(t, tt.expected, result, "unexpected ErrorPolicyResult")
})
}
}
// ---------------------------------------------------------------------------
// TestApplyErrorPolicy — 4 table-driven cases for the wrapper method
// ---------------------------------------------------------------------------
func TestApplyErrorPolicy(t *testing.T) {
tests := []struct {
name string
account *Account
statusCode int
body []byte
expectedHandled bool
expectedSwitchErr bool // expect *AntigravityAccountSwitchError
handleErrorCalls int
}{
{
name: "none_not_handled",
account: &Account{
ID: 10,
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
},
statusCode: 500,
body: []byte(`"error"`),
expectedHandled: false,
handleErrorCalls: 0,
},
{
name: "skipped_handled_no_handleError",
account: &Account{
ID: 11,
Type: AccountTypeAPIKey,
Platform: PlatformAntigravity,
Credentials: map[string]any{
"custom_error_codes_enabled": true,
"custom_error_codes": []any{float64(429)},
},
},
statusCode: 500, // not in custom codes
body: []byte(`"error"`),
expectedHandled: true,
handleErrorCalls: 0,
},
{
name: "matched_handled_calls_handleError",
account: &Account{
ID: 12,
Type: AccountTypeAPIKey,
Platform: PlatformAntigravity,
Credentials: map[string]any{
"custom_error_codes_enabled": true,
"custom_error_codes": []any{float64(500)},
},
},
statusCode: 500,
body: []byte(`"error"`),
expectedHandled: true,
handleErrorCalls: 1,
},
{
name: "temp_unscheduled_returns_switch_error",
account: &Account{
ID: 13,
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Credentials: map[string]any{
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": float64(503),
"keywords": []any{"overloaded"},
"duration_minutes": float64(10),
},
},
},
},
statusCode: 503,
body: []byte(`overloaded`),
expectedHandled: true,
expectedSwitchErr: true,
handleErrorCalls: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
repo := &errorPolicyRepoStub{}
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
svc := &AntigravityGatewayService{
rateLimitService: rlSvc,
}
var handleErrorCount int
p := antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: tt.account,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleErrorCount++
return nil
},
isStickySession: true,
}
handled, retErr := svc.applyErrorPolicy(p, tt.statusCode, http.Header{}, tt.body)
require.Equal(t, tt.expectedHandled, handled, "handled mismatch")
require.Equal(t, tt.handleErrorCalls, handleErrorCount, "handleError call count mismatch")
if tt.expectedSwitchErr {
var switchErr *AntigravityAccountSwitchError
require.ErrorAs(t, retErr, &switchErr)
require.Equal(t, tt.account.ID, switchErr.OriginalAccountID)
} else {
require.NoError(t, retErr)
}
})
}
}
// ---------------------------------------------------------------------------
// errorPolicyRepoStub — minimal AccountRepository stub for error policy tests
// ---------------------------------------------------------------------------
type errorPolicyRepoStub struct {
mockAccountRepoForGemini
tempCalls int
setErrCalls int
lastErrorMsg string
}
func (r *errorPolicyRepoStub) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
r.tempCalls++
return nil
}
func (r *errorPolicyRepoStub) SetError(ctx context.Context, id int64, errorMsg string) error {
r.setErrCalls++
r.lastErrorMsg = errorMsg
return nil
}

View File

@@ -142,9 +142,6 @@ func (m *mockAccountRepoForPlatform) ListSchedulableByGroupIDAndPlatforms(ctx co
func (m *mockAccountRepoForPlatform) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
return nil
}
func (m *mockAccountRepoForPlatform) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error {
return nil
}
func (m *mockAccountRepoForPlatform) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
return nil
}
@@ -216,29 +213,6 @@ func (m *mockGatewayCacheForPlatform) DeleteSessionAccountID(ctx context.Context
return nil
}
func (m *mockGatewayCacheForPlatform) IncrModelCallCount(ctx context.Context, accountID int64, model string) (int64, error) {
return 0, nil
}
func (m *mockGatewayCacheForPlatform) GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*ModelLoadInfo, error) {
return nil, nil
}
func (m *mockGatewayCacheForPlatform) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
return "", 0, false
}
func (m *mockGatewayCacheForPlatform) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
return nil
}
func (m *mockGatewayCacheForPlatform) FindAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
return "", 0, false
}
func (m *mockGatewayCacheForPlatform) SaveAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
return nil
}
type mockGroupRepoForGateway struct {
groups map[int64]*Group

View File

@@ -6,9 +6,19 @@ import (
"fmt"
"math"
"github.com/Wei-Shaw/sub2api/internal/domain"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
)
// SessionContext 粘性会话上下文,用于区分不同来源的请求。
// 仅在 GenerateSessionHash 第 3 级 fallback消息内容 hash时混入
// 避免不同用户发送相同消息产生相同 hash 导致账号集中。
type SessionContext struct {
ClientIP string
UserAgent string
APIKeyID int64
}
// ParsedRequest 保存网关请求的预解析结果
//
// 性能优化说明:
@@ -22,20 +32,22 @@ import (
// 2. 将解析结果 ParsedRequest 传递给 Service 层
// 3. 避免重复 json.Unmarshal减少 CPU 和内存开销
type ParsedRequest struct {
Body []byte // 原始请求体(保留用于转发)
Model string // 请求的模型名称
Stream bool // 是否为流式请求
MetadataUserID string // metadata.user_id用于会话亲和
System any // system 字段内容
Messages []any // messages 数组
HasSystem bool // 是否包含 system 字段(包含 null 也视为显式传入)
ThinkingEnabled bool // 是否开启 thinking部分平台会影响最终模型名
MaxTokens int // max_tokens 值(用于探测请求拦截)
Body []byte // 原始请求体(保留用于转发)
Model string // 请求的模型名称
Stream bool // 是否为流式请求
MetadataUserID string // metadata.user_id用于会话亲和
System any // system 字段内容
Messages []any // messages 数组
HasSystem bool // 是否包含 system 字段(包含 null 也视为显式传入)
ThinkingEnabled bool // 是否开启 thinking部分平台会影响最终模型名
MaxTokens int // max_tokens 值(用于探测请求拦截)
SessionContext *SessionContext // 可选请求上下文区分因子nil 时行为不变)
}
// ParseGatewayRequest 解析网关请求体并返回结构化结果
// 性能优化:一次解析提取所有需要的字段,避免重复 Unmarshal
func ParseGatewayRequest(body []byte) (*ParsedRequest, error) {
// ParseGatewayRequest 解析网关请求体并返回结构化结果
// protocol 指定请求协议格式domain.PlatformAnthropic / domain.PlatformGemini
// 不同协议使用不同的 system/messages 字段名。
func ParseGatewayRequest(body []byte, protocol string) (*ParsedRequest, error) {
var req map[string]any
if err := json.Unmarshal(body, &req); err != nil {
return nil, err
@@ -64,14 +76,29 @@ func ParseGatewayRequest(body []byte) (*ParsedRequest, error) {
parsed.MetadataUserID = userID
}
}
// system 字段只要存在就视为显式提供(即使为 null
// 以避免客户端传 null 时被默认 system 误注入。
if system, ok := req["system"]; ok {
parsed.HasSystem = true
parsed.System = system
}
if messages, ok := req["messages"].([]any); ok {
parsed.Messages = messages
switch protocol {
case domain.PlatformGemini:
// Gemini 原生格式: systemInstruction.parts / contents
if sysInst, ok := req["systemInstruction"].(map[string]any); ok {
if parts, ok := sysInst["parts"].([]any); ok {
parsed.System = parts
}
}
if contents, ok := req["contents"].([]any); ok {
parsed.Messages = contents
}
default:
// Anthropic / OpenAI 格式: system / messages
// system 字段只要存在就视为显式提供(即使为 null
// 以避免客户端传 null 时被默认 system 误注入。
if system, ok := req["system"]; ok {
parsed.HasSystem = true
parsed.System = system
}
if messages, ok := req["messages"].([]any); ok {
parsed.Messages = messages
}
}
// thinking: {type: "enabled"}

View File

@@ -4,12 +4,13 @@ import (
"encoding/json"
"testing"
"github.com/Wei-Shaw/sub2api/internal/domain"
"github.com/stretchr/testify/require"
)
func TestParseGatewayRequest(t *testing.T) {
body := []byte(`{"model":"claude-3-7-sonnet","stream":true,"metadata":{"user_id":"session_123e4567-e89b-12d3-a456-426614174000"},"system":[{"type":"text","text":"hello","cache_control":{"type":"ephemeral"}}],"messages":[{"content":"hi"}]}`)
parsed, err := ParseGatewayRequest(body)
parsed, err := ParseGatewayRequest(body, "")
require.NoError(t, err)
require.Equal(t, "claude-3-7-sonnet", parsed.Model)
require.True(t, parsed.Stream)
@@ -22,7 +23,7 @@ func TestParseGatewayRequest(t *testing.T) {
func TestParseGatewayRequest_ThinkingEnabled(t *testing.T) {
body := []byte(`{"model":"claude-sonnet-4-5","thinking":{"type":"enabled"},"messages":[{"content":"hi"}]}`)
parsed, err := ParseGatewayRequest(body)
parsed, err := ParseGatewayRequest(body, "")
require.NoError(t, err)
require.Equal(t, "claude-sonnet-4-5", parsed.Model)
require.True(t, parsed.ThinkingEnabled)
@@ -30,21 +31,21 @@ func TestParseGatewayRequest_ThinkingEnabled(t *testing.T) {
func TestParseGatewayRequest_MaxTokens(t *testing.T) {
body := []byte(`{"model":"claude-haiku-4-5","max_tokens":1}`)
parsed, err := ParseGatewayRequest(body)
parsed, err := ParseGatewayRequest(body, "")
require.NoError(t, err)
require.Equal(t, 1, parsed.MaxTokens)
}
func TestParseGatewayRequest_MaxTokensNonIntegralIgnored(t *testing.T) {
body := []byte(`{"model":"claude-haiku-4-5","max_tokens":1.5}`)
parsed, err := ParseGatewayRequest(body)
parsed, err := ParseGatewayRequest(body, "")
require.NoError(t, err)
require.Equal(t, 0, parsed.MaxTokens)
}
func TestParseGatewayRequest_SystemNull(t *testing.T) {
body := []byte(`{"model":"claude-3","system":null}`)
parsed, err := ParseGatewayRequest(body)
parsed, err := ParseGatewayRequest(body, "")
require.NoError(t, err)
// 显式传入 system:null 也应视为“字段已存在”,避免默认 system 被注入。
require.True(t, parsed.HasSystem)
@@ -53,16 +54,112 @@ func TestParseGatewayRequest_SystemNull(t *testing.T) {
func TestParseGatewayRequest_InvalidModelType(t *testing.T) {
body := []byte(`{"model":123}`)
_, err := ParseGatewayRequest(body)
_, err := ParseGatewayRequest(body, "")
require.Error(t, err)
}
func TestParseGatewayRequest_InvalidStreamType(t *testing.T) {
body := []byte(`{"stream":"true"}`)
_, err := ParseGatewayRequest(body)
_, err := ParseGatewayRequest(body, "")
require.Error(t, err)
}
// ============ Gemini 原生格式解析测试 ============
func TestParseGatewayRequest_GeminiContents(t *testing.T) {
body := []byte(`{
"contents": [
{"role": "user", "parts": [{"text": "Hello"}]},
{"role": "model", "parts": [{"text": "Hi there"}]},
{"role": "user", "parts": [{"text": "How are you?"}]}
]
}`)
parsed, err := ParseGatewayRequest(body, domain.PlatformGemini)
require.NoError(t, err)
require.Len(t, parsed.Messages, 3, "should parse contents as Messages")
require.False(t, parsed.HasSystem, "Gemini format should not set HasSystem")
require.Nil(t, parsed.System, "no systemInstruction means nil System")
}
func TestParseGatewayRequest_GeminiSystemInstruction(t *testing.T) {
body := []byte(`{
"systemInstruction": {
"parts": [{"text": "You are a helpful assistant."}]
},
"contents": [
{"role": "user", "parts": [{"text": "Hello"}]}
]
}`)
parsed, err := ParseGatewayRequest(body, domain.PlatformGemini)
require.NoError(t, err)
require.NotNil(t, parsed.System, "should parse systemInstruction.parts as System")
parts, ok := parsed.System.([]any)
require.True(t, ok)
require.Len(t, parts, 1)
partMap, ok := parts[0].(map[string]any)
require.True(t, ok)
require.Equal(t, "You are a helpful assistant.", partMap["text"])
require.Len(t, parsed.Messages, 1)
}
func TestParseGatewayRequest_GeminiWithModel(t *testing.T) {
body := []byte(`{
"model": "gemini-2.5-pro",
"contents": [{"role": "user", "parts": [{"text": "test"}]}]
}`)
parsed, err := ParseGatewayRequest(body, domain.PlatformGemini)
require.NoError(t, err)
require.Equal(t, "gemini-2.5-pro", parsed.Model)
require.Len(t, parsed.Messages, 1)
}
func TestParseGatewayRequest_GeminiIgnoresAnthropicFields(t *testing.T) {
// Gemini 格式下 system/messages 字段应被忽略
body := []byte(`{
"system": "should be ignored",
"messages": [{"role": "user", "content": "ignored"}],
"contents": [{"role": "user", "parts": [{"text": "real content"}]}]
}`)
parsed, err := ParseGatewayRequest(body, domain.PlatformGemini)
require.NoError(t, err)
require.False(t, parsed.HasSystem, "Gemini protocol should not parse Anthropic system field")
require.Nil(t, parsed.System, "no systemInstruction = nil System")
require.Len(t, parsed.Messages, 1, "should use contents, not messages")
}
func TestParseGatewayRequest_GeminiEmptyContents(t *testing.T) {
body := []byte(`{"contents": []}`)
parsed, err := ParseGatewayRequest(body, domain.PlatformGemini)
require.NoError(t, err)
require.Empty(t, parsed.Messages)
}
func TestParseGatewayRequest_GeminiNoContents(t *testing.T) {
body := []byte(`{"model": "gemini-2.5-flash"}`)
parsed, err := ParseGatewayRequest(body, domain.PlatformGemini)
require.NoError(t, err)
require.Nil(t, parsed.Messages)
require.Equal(t, "gemini-2.5-flash", parsed.Model)
}
func TestParseGatewayRequest_AnthropicIgnoresGeminiFields(t *testing.T) {
// Anthropic 格式下 contents/systemInstruction 字段应被忽略
body := []byte(`{
"system": "real system",
"messages": [{"role": "user", "content": "real content"}],
"contents": [{"role": "user", "parts": [{"text": "ignored"}]}],
"systemInstruction": {"parts": [{"text": "ignored"}]}
}`)
parsed, err := ParseGatewayRequest(body, domain.PlatformAnthropic)
require.NoError(t, err)
require.True(t, parsed.HasSystem)
require.Equal(t, "real system", parsed.System)
require.Len(t, parsed.Messages, 1)
msg, ok := parsed.Messages[0].(map[string]any)
require.True(t, ok)
require.Equal(t, "real content", msg["content"])
}
func TestFilterThinkingBlocks(t *testing.T) {
containsThinkingBlock := func(body []byte) bool {
var req map[string]any

View File

@@ -5,7 +5,6 @@ import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
@@ -17,6 +16,7 @@ import (
"os"
"regexp"
"sort"
"strconv"
"strings"
"sync/atomic"
"time"
@@ -26,6 +26,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/cespare/xxhash/v2"
"github.com/google/uuid"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
@@ -245,9 +246,6 @@ var (
// ErrClaudeCodeOnly 表示分组仅允许 Claude Code 客户端访问
var ErrClaudeCodeOnly = errors.New("this group only allows Claude Code clients")
// ErrModelScopeNotSupported 表示请求的模型系列不在分组支持的范围内
var ErrModelScopeNotSupported = errors.New("model scope not supported by this group")
// allowedHeaders 白名单headers参考CRS项目
var allowedHeaders = map[string]bool{
"accept": true,
@@ -273,13 +271,6 @@ var allowedHeaders = map[string]bool{
// GatewayCache 定义网关服务的缓存操作接口。
// 提供粘性会话Sticky Session的存储、查询、刷新和删除功能。
//
// ModelLoadInfo 模型负载信息(用于 Antigravity 调度)
// Model load info for Antigravity scheduling
type ModelLoadInfo struct {
CallCount int64 // 当前分钟调用次数 / Call count in current minute
LastUsedAt time.Time // 最后调度时间(零值表示未调度过)/ Last scheduling time (zero means never scheduled)
}
// GatewayCache defines cache operations for gateway service.
// Provides sticky session storage, retrieval, refresh and deletion capabilities.
type GatewayCache interface {
@@ -295,32 +286,6 @@ type GatewayCache interface {
// DeleteSessionAccountID 删除粘性会话绑定,用于账号不可用时主动清理
// Delete sticky session binding, used to proactively clean up when account becomes unavailable
DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error
// IncrModelCallCount 增加模型调用次数并更新最后调度时间Antigravity 专用)
// Increment model call count and update last scheduling time (Antigravity only)
// 返回更新后的调用次数
IncrModelCallCount(ctx context.Context, accountID int64, model string) (int64, error)
// GetModelLoadBatch 批量获取账号的模型负载信息Antigravity 专用)
// Batch get model load info for accounts (Antigravity only)
GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*ModelLoadInfo, error)
// FindGeminiSession 查找 Gemini 会话MGET 倒序匹配)
// Find Gemini session using MGET reverse order matching
// 返回最长匹配的会话信息uuid, accountID
FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool)
// SaveGeminiSession 保存 Gemini 会话
// Save Gemini session binding
SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error
// FindAnthropicSession 查找 Anthropic 会话Trie 匹配)
// Find Anthropic session using Trie matching
FindAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool)
// SaveAnthropicSession 保存 Anthropic 会话
// Save Anthropic session binding
SaveAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error
}
// derefGroupID safely dereferences *int64 to int64, returning 0 if nil
@@ -415,6 +380,7 @@ type GatewayService struct {
userSubRepo UserSubscriptionRepository
userGroupRateRepo UserGroupRateRepository
cache GatewayCache
digestStore *DigestSessionStore
cfg *config.Config
schedulerSnapshot *SchedulerSnapshotService
billingService *BillingService
@@ -448,6 +414,7 @@ func NewGatewayService(
deferredService *DeferredService,
claudeTokenProvider *ClaudeTokenProvider,
sessionLimitCache SessionLimitCache,
digestStore *DigestSessionStore,
) *GatewayService {
return &GatewayService{
accountRepo: accountRepo,
@@ -457,6 +424,7 @@ func NewGatewayService(
userSubRepo: userSubRepo,
userGroupRateRepo: userGroupRateRepo,
cache: cache,
digestStore: digestStore,
cfg: cfg,
schedulerSnapshot: schedulerSnapshot,
concurrencyService: concurrencyService,
@@ -490,8 +458,17 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string {
return s.hashContent(cacheableContent)
}
// 3. 最后 fallback: 使用 system + 所有消息的完整摘要串
// 3. 最后 fallback: 使用 session上下文 + system + 所有消息的完整摘要串
var combined strings.Builder
// 混入请求上下文区分因子,避免不同用户相同消息产生相同 hash
if parsed.SessionContext != nil {
_, _ = combined.WriteString(parsed.SessionContext.ClientIP)
_, _ = combined.WriteString(":")
_, _ = combined.WriteString(parsed.SessionContext.UserAgent)
_, _ = combined.WriteString(":")
_, _ = combined.WriteString(strconv.FormatInt(parsed.SessionContext.APIKeyID, 10))
_, _ = combined.WriteString("|")
}
if parsed.System != nil {
systemText := s.extractTextFromSystem(parsed.System)
if systemText != "" {
@@ -500,9 +477,20 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string {
}
for _, msg := range parsed.Messages {
if m, ok := msg.(map[string]any); ok {
msgText := s.extractTextFromContent(m["content"])
if msgText != "" {
_, _ = combined.WriteString(msgText)
if content, exists := m["content"]; exists {
// Anthropic: messages[].content
if msgText := s.extractTextFromContent(content); msgText != "" {
_, _ = combined.WriteString(msgText)
}
} else if parts, ok := m["parts"].([]any); ok {
// Gemini: contents[].parts[].text
for _, part := range parts {
if partMap, ok := part.(map[string]any); ok {
if text, ok := partMap["text"].(string); ok {
_, _ = combined.WriteString(text)
}
}
}
}
}
}
@@ -536,35 +524,37 @@ func (s *GatewayService) GetCachedSessionAccountID(ctx context.Context, groupID
// FindGeminiSession 查找 Gemini 会话(基于内容摘要链的 Fallback 匹配)
// 返回最长匹配的会话信息uuid, accountID
func (s *GatewayService) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
if digestChain == "" || s.cache == nil {
return "", 0, false
func (s *GatewayService) FindGeminiSession(_ context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, matchedChain string, found bool) {
if digestChain == "" || s.digestStore == nil {
return "", 0, "", false
}
return s.cache.FindGeminiSession(ctx, groupID, prefixHash, digestChain)
return s.digestStore.Find(groupID, prefixHash, digestChain)
}
// SaveGeminiSession 保存 Gemini 会话
func (s *GatewayService) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
if digestChain == "" || s.cache == nil {
// SaveGeminiSession 保存 Gemini 会话。oldDigestChain 为 Find 返回的 matchedChain用于删旧 key。
func (s *GatewayService) SaveGeminiSession(_ context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64, oldDigestChain string) error {
if digestChain == "" || s.digestStore == nil {
return nil
}
return s.cache.SaveGeminiSession(ctx, groupID, prefixHash, digestChain, uuid, accountID)
s.digestStore.Save(groupID, prefixHash, digestChain, uuid, accountID, oldDigestChain)
return nil
}
// FindAnthropicSession 查找 Anthropic 会话(基于内容摘要链的 Fallback 匹配)
func (s *GatewayService) FindAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
if digestChain == "" || s.cache == nil {
return "", 0, false
func (s *GatewayService) FindAnthropicSession(_ context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, matchedChain string, found bool) {
if digestChain == "" || s.digestStore == nil {
return "", 0, "", false
}
return s.cache.FindAnthropicSession(ctx, groupID, prefixHash, digestChain)
return s.digestStore.Find(groupID, prefixHash, digestChain)
}
// SaveAnthropicSession 保存 Anthropic 会话
func (s *GatewayService) SaveAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
if digestChain == "" || s.cache == nil {
func (s *GatewayService) SaveAnthropicSession(_ context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64, oldDigestChain string) error {
if digestChain == "" || s.digestStore == nil {
return nil
}
return s.cache.SaveAnthropicSession(ctx, groupID, prefixHash, digestChain, uuid, accountID)
s.digestStore.Save(groupID, prefixHash, digestChain, uuid, accountID, oldDigestChain)
return nil
}
func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string {
@@ -649,8 +639,8 @@ func (s *GatewayService) extractTextFromContent(content any) string {
}
func (s *GatewayService) hashContent(content string) string {
hash := sha256.Sum256([]byte(content))
return hex.EncodeToString(hash[:16]) // 32字符
h := xxhash.Sum64String(content)
return strconv.FormatUint(h, 36)
}
// replaceModelInBody 替换请求体中的model字段
@@ -1009,13 +999,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
log.Printf("[ModelRoutingDebug] load-aware enabled: group_id=%v model=%s session=%s platform=%s", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), platform)
}
// Antigravity 模型系列检查(在账号选择前检查,确保所有代码路径都经过此检查)
if platform == PlatformAntigravity && groupID != nil && requestedModel != "" {
if err := s.checkAntigravityModelScope(ctx, *groupID, requestedModel); err != nil {
return nil, err
}
}
accounts, useMixed, err := s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
if err != nil {
return nil, err
@@ -1209,6 +1192,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
return a.account.LastUsedAt.Before(*b.account.LastUsedAt)
}
})
shuffleWithinSortGroups(routingAvailable)
// 4. 尝试获取槽位
for _, item := range routingAvailable {
@@ -1362,10 +1346,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
return result, nil
}
} else {
// Antigravity 平台:获取模型负载信息
var modelLoadMap map[int64]*ModelLoadInfo
isAntigravity := platform == PlatformAntigravity
var available []accountWithLoad
for _, acc := range candidates {
loadInfo := loadMap[acc.ID]
@@ -1380,109 +1360,44 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
}
}
// Antigravity 平台:按账号实际映射后的模型名获取模型负载(与 Forward 的统计保持一致)
if isAntigravity && requestedModel != "" && s.cache != nil && len(available) > 0 {
modelLoadMap = make(map[int64]*ModelLoadInfo, len(available))
modelToAccountIDs := make(map[string][]int64)
for _, item := range available {
mappedModel := mapAntigravityModel(item.account, requestedModel)
if mappedModel == "" {
continue
}
modelToAccountIDs[mappedModel] = append(modelToAccountIDs[mappedModel], item.account.ID)
// 分层过滤选择:优先级 → 负载率 → LRU
for len(available) > 0 {
// 1. 取优先级最小的集合
candidates := filterByMinPriority(available)
// 2. 取负载率最低的集合
candidates = filterByMinLoadRate(candidates)
// 3. LRU 选择最久未用的账号
selected := selectByLRU(candidates, preferOAuth)
if selected == nil {
break
}
for model, ids := range modelToAccountIDs {
batch, err := s.cache.GetModelLoadBatch(ctx, ids, model)
if err != nil {
continue
}
for id, info := range batch {
modelLoadMap[id] = info
}
}
if len(modelLoadMap) == 0 {
modelLoadMap = nil
}
}
// Antigravity 平台:优先级硬过滤 →(同优先级内)按调用次数选择(最少优先,新账号用平均值)
// 其他平台:分层过滤选择:优先级 → 负载率 → LRU
if isAntigravity {
for len(available) > 0 {
// 1. 取优先级最小的集合(硬过滤)
candidates := filterByMinPriority(available)
// 2. 同优先级内按调用次数选择(调用次数最少优先,新账号使用平均值)
selected := selectByCallCount(candidates, modelLoadMap, preferOAuth)
if selected == nil {
break
}
result, err := s.tryAcquireAccountSlot(ctx, selected.account.ID, selected.account.Concurrency)
if err == nil && result.Acquired {
// 会话数量限制检查
if !s.checkAndRegisterSession(ctx, selected.account, sessionHash) {
result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
} else {
if sessionHash != "" && s.cache != nil {
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.account.ID, stickySessionTTL)
}
return &AccountSelectionResult{
Account: selected.account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
result, err := s.tryAcquireAccountSlot(ctx, selected.account.ID, selected.account.Concurrency)
if err == nil && result.Acquired {
// 会话数量限制检查
if !s.checkAndRegisterSession(ctx, selected.account, sessionHash) {
result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
} else {
if sessionHash != "" && s.cache != nil {
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.account.ID, stickySessionTTL)
}
return &AccountSelectionResult{
Account: selected.account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
// 移除已尝试的账号,重新选择
selectedID := selected.account.ID
newAvailable := make([]accountWithLoad, 0, len(available)-1)
for _, acc := range available {
if acc.account.ID != selectedID {
newAvailable = append(newAvailable, acc)
}
}
available = newAvailable
}
} else {
for len(available) > 0 {
// 1. 取优先级最小的集合
candidates := filterByMinPriority(available)
// 2. 取负载率最低的集合
candidates = filterByMinLoadRate(candidates)
// 3. LRU 选择最久未用的账号
selected := selectByLRU(candidates, preferOAuth)
if selected == nil {
break
}
result, err := s.tryAcquireAccountSlot(ctx, selected.account.ID, selected.account.Concurrency)
if err == nil && result.Acquired {
// 会话数量限制检查
if !s.checkAndRegisterSession(ctx, selected.account, sessionHash) {
result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
} else {
if sessionHash != "" && s.cache != nil {
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.account.ID, stickySessionTTL)
}
return &AccountSelectionResult{
Account: selected.account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
// 移除已尝试的账号,重新进行分层过滤
selectedID := selected.account.ID
newAvailable := make([]accountWithLoad, 0, len(available)-1)
for _, acc := range available {
if acc.account.ID != selectedID {
newAvailable = append(newAvailable, acc)
}
// 移除已尝试的账号,重新进行分层过滤
selectedID := selected.account.ID
newAvailable := make([]accountWithLoad, 0, len(available)-1)
for _, acc := range available {
if acc.account.ID != selectedID {
newAvailable = append(newAvailable, acc)
}
}
available = newAvailable
}
available = newAvailable
}
}
@@ -2018,87 +1933,79 @@ func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) {
return a.LastUsedAt.Before(*b.LastUsedAt)
}
})
shuffleWithinPriorityAndLastUsed(accounts)
}
// selectByCallCount 从候选账号中选择调用次数最少的账号Antigravity 专用)
// 新账号CallCount=0使用平均调用次数作为虚拟值避免冷启动被猛调
// 如果有多个账号具有相同的最小调用次数,则随机选择一个
func selectByCallCount(accounts []accountWithLoad, modelLoadMap map[int64]*ModelLoadInfo, preferOAuth bool) *accountWithLoad {
if len(accounts) == 0 {
return nil
// shuffleWithinSortGroups 对排序后的 accountWithLoad 切片,按 (Priority, LoadRate, LastUsedAt) 分组后组内随机打乱。
// 防止并发请求读取同一快照时,确定性排序导致所有请求命中相同账号。
func shuffleWithinSortGroups(accounts []accountWithLoad) {
if len(accounts) <= 1 {
return
}
if len(accounts) == 1 {
return &accounts[0]
}
// 如果没有负载信息,回退到 LRU
if modelLoadMap == nil {
return selectByLRU(accounts, preferOAuth)
}
// 1. 计算平均调用次数(用于新账号冷启动)
var totalCallCount int64
var countWithCalls int
for _, acc := range accounts {
if info := modelLoadMap[acc.account.ID]; info != nil && info.CallCount > 0 {
totalCallCount += info.CallCount
countWithCalls++
i := 0
for i < len(accounts) {
j := i + 1
for j < len(accounts) && sameAccountWithLoadGroup(accounts[i], accounts[j]) {
j++
}
}
var avgCallCount int64
if countWithCalls > 0 {
avgCallCount = totalCallCount / int64(countWithCalls)
}
// 2. 获取每个账号的有效调用次数
getEffectiveCallCount := func(acc accountWithLoad) int64 {
if acc.account == nil {
return 0
if j-i > 1 {
mathrand.Shuffle(j-i, func(a, b int) {
accounts[i+a], accounts[i+b] = accounts[i+b], accounts[i+a]
})
}
info := modelLoadMap[acc.account.ID]
if info == nil || info.CallCount == 0 {
return avgCallCount // 新账号使用平均值
}
return info.CallCount
i = j
}
}
// 3. 找到最小调用次数
minCount := getEffectiveCallCount(accounts[0])
for _, acc := range accounts[1:] {
if c := getEffectiveCallCount(acc); c < minCount {
minCount = c
}
// sameAccountWithLoadGroup 判断两个 accountWithLoad 是否属于同一排序组
func sameAccountWithLoadGroup(a, b accountWithLoad) bool {
if a.account.Priority != b.account.Priority {
return false
}
// 4. 收集所有具有最小调用次数的账号
var candidateIdxs []int
for i, acc := range accounts {
if getEffectiveCallCount(acc) == minCount {
candidateIdxs = append(candidateIdxs, i)
}
if a.loadInfo.LoadRate != b.loadInfo.LoadRate {
return false
}
return sameLastUsedAt(a.account.LastUsedAt, b.account.LastUsedAt)
}
// 5. 如果只有一个候选,直接返回
if len(candidateIdxs) == 1 {
return &accounts[candidateIdxs[0]]
// shuffleWithinPriorityAndLastUsed 对排序后的 []*Account 切片,按 (Priority, LastUsedAt) 分组后组内随机打乱。
func shuffleWithinPriorityAndLastUsed(accounts []*Account) {
if len(accounts) <= 1 {
return
}
// 6. preferOAuth 处理
if preferOAuth {
var oauthIdxs []int
for _, idx := range candidateIdxs {
if accounts[idx].account.Type == AccountTypeOAuth {
oauthIdxs = append(oauthIdxs, idx)
}
i := 0
for i < len(accounts) {
j := i + 1
for j < len(accounts) && sameAccountGroup(accounts[i], accounts[j]) {
j++
}
if len(oauthIdxs) > 0 {
candidateIdxs = oauthIdxs
if j-i > 1 {
mathrand.Shuffle(j-i, func(a, b int) {
accounts[i+a], accounts[i+b] = accounts[i+b], accounts[i+a]
})
}
i = j
}
}
// 7. 随机选择
return &accounts[candidateIdxs[mathrand.Intn(len(candidateIdxs))]]
// sameAccountGroup 判断两个 Account 是否属于同一排序组Priority + LastUsedAt
func sameAccountGroup(a, b *Account) bool {
if a.Priority != b.Priority {
return false
}
return sameLastUsedAt(a.LastUsedAt, b.LastUsedAt)
}
// sameLastUsedAt 判断两个 LastUsedAt 是否相同(精度到秒)
func sameLastUsedAt(a, b *time.Time) bool {
switch {
case a == nil && b == nil:
return true
case a == nil || b == nil:
return false
default:
return a.Unix() == b.Unix()
}
}
// sortCandidatesForFallback 根据配置选择排序策略
@@ -2153,13 +2060,6 @@ func shuffleWithinPriority(accounts []*Account) {
// selectAccountForModelWithPlatform 选择单平台账户(完全隔离)
func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) {
// 对 Antigravity 平台,检查请求的模型系列是否在分组支持范围内
if platform == PlatformAntigravity && groupID != nil && requestedModel != "" {
if err := s.checkAntigravityModelScope(ctx, *groupID, requestedModel); err != nil {
return nil, err
}
}
preferOAuth := platform == PlatformGemini
routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, platform)
@@ -5171,27 +5071,6 @@ func (s *GatewayService) validateUpstreamBaseURL(raw string) (string, error) {
return normalized, nil
}
// checkAntigravityModelScope 检查 Antigravity 平台的模型系列是否在分组支持范围内
func (s *GatewayService) checkAntigravityModelScope(ctx context.Context, groupID int64, requestedModel string) error {
scope, ok := ResolveAntigravityQuotaScope(requestedModel)
if !ok {
return nil // 无法解析 scope跳过检查
}
group, err := s.resolveGroupByID(ctx, groupID)
if err != nil {
return nil // 查询失败时放行
}
if group == nil {
return nil // 分组不存在时放行
}
if !IsScopeSupported(group.SupportedModelScopes, scope) {
return ErrModelScopeNotSupported
}
return nil
}
// GetAvailableModels returns the list of models available for a group
// It aggregates model_mapping keys from all schedulable accounts in the group
func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64, platform string) []string {

View File

@@ -14,7 +14,7 @@ func BenchmarkGenerateSessionHash_Metadata(b *testing.B) {
b.ReportAllocs()
for i := 0; i < b.N; i++ {
parsed, err := ParseGatewayRequest(body)
parsed, err := ParseGatewayRequest(body, "")
if err != nil {
b.Fatalf("解析请求失败: %v", err)
}

View File

@@ -0,0 +1,384 @@
//go:build unit
package service
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
// ---------------------------------------------------------------------------
// TestShouldFailoverGeminiUpstreamError — verifies the failover decision
// for the ErrorPolicyNone path (original logic preserved).
// ---------------------------------------------------------------------------
func TestShouldFailoverGeminiUpstreamError(t *testing.T) {
svc := &GeminiMessagesCompatService{}
tests := []struct {
name string
statusCode int
expected bool
}{
{"401_failover", 401, true},
{"403_failover", 403, true},
{"429_failover", 429, true},
{"529_failover", 529, true},
{"500_failover", 500, true},
{"502_failover", 502, true},
{"503_failover", 503, true},
{"400_no_failover", 400, false},
{"404_no_failover", 404, false},
{"422_no_failover", 422, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := svc.shouldFailoverGeminiUpstreamError(tt.statusCode)
require.Equal(t, tt.expected, got)
})
}
}
// ---------------------------------------------------------------------------
// TestCheckErrorPolicy_GeminiAccounts — verifies CheckErrorPolicy works
// correctly for Gemini platform accounts (API Key type).
// ---------------------------------------------------------------------------
func TestCheckErrorPolicy_GeminiAccounts(t *testing.T) {
tests := []struct {
name string
account *Account
statusCode int
body []byte
expected ErrorPolicyResult
}{
{
name: "gemini_apikey_custom_codes_hit",
account: &Account{
ID: 100,
Type: AccountTypeAPIKey,
Platform: PlatformGemini,
Credentials: map[string]any{
"custom_error_codes_enabled": true,
"custom_error_codes": []any{float64(429), float64(500)},
},
},
statusCode: 429,
body: []byte(`{"error":"rate limited"}`),
expected: ErrorPolicyMatched,
},
{
name: "gemini_apikey_custom_codes_miss",
account: &Account{
ID: 101,
Type: AccountTypeAPIKey,
Platform: PlatformGemini,
Credentials: map[string]any{
"custom_error_codes_enabled": true,
"custom_error_codes": []any{float64(429)},
},
},
statusCode: 500,
body: []byte(`{"error":"internal"}`),
expected: ErrorPolicySkipped,
},
{
name: "gemini_apikey_no_custom_codes_returns_none",
account: &Account{
ID: 102,
Type: AccountTypeAPIKey,
Platform: PlatformGemini,
},
statusCode: 500,
body: []byte(`{"error":"internal"}`),
expected: ErrorPolicyNone,
},
{
name: "gemini_apikey_temp_unschedulable_hit",
account: &Account{
ID: 103,
Type: AccountTypeAPIKey,
Platform: PlatformGemini,
Credentials: map[string]any{
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": float64(503),
"keywords": []any{"overloaded"},
"duration_minutes": float64(10),
},
},
},
},
statusCode: 503,
body: []byte(`overloaded service`),
expected: ErrorPolicyTempUnscheduled,
},
{
name: "gemini_custom_codes_override_temp_unschedulable",
account: &Account{
ID: 104,
Type: AccountTypeAPIKey,
Platform: PlatformGemini,
Credentials: map[string]any{
"custom_error_codes_enabled": true,
"custom_error_codes": []any{float64(503)},
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": float64(503),
"keywords": []any{"overloaded"},
"duration_minutes": float64(10),
},
},
},
},
statusCode: 503,
body: []byte(`overloaded`),
expected: ErrorPolicyMatched, // custom codes take precedence
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
repo := &errorPolicyRepoStub{}
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
result := svc.CheckErrorPolicy(context.Background(), tt.account, tt.statusCode, tt.body)
require.Equal(t, tt.expected, result)
})
}
}
// ---------------------------------------------------------------------------
// TestGeminiErrorPolicyIntegration — verifies the Gemini error handling
// paths produce the correct behavior for each ErrorPolicyResult.
//
// These tests simulate the inline error policy switch in handleClaudeCompat
// and forwardNativeGemini by calling the same methods in the same order.
// ---------------------------------------------------------------------------
func TestGeminiErrorPolicyIntegration(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
account *Account
statusCode int
respBody []byte
expectFailover bool // expect UpstreamFailoverError
expectHandleError bool // expect handleGeminiUpstreamError to be called
expectShouldFailover bool // for None path, whether shouldFailover triggers
}{
{
name: "custom_codes_matched_429_failover",
account: &Account{
ID: 200,
Type: AccountTypeAPIKey,
Platform: PlatformGemini,
Credentials: map[string]any{
"custom_error_codes_enabled": true,
"custom_error_codes": []any{float64(429)},
},
},
statusCode: 429,
respBody: []byte(`{"error":"rate limited"}`),
expectFailover: true,
expectHandleError: true,
},
{
name: "custom_codes_skipped_500_no_failover",
account: &Account{
ID: 201,
Type: AccountTypeAPIKey,
Platform: PlatformGemini,
Credentials: map[string]any{
"custom_error_codes_enabled": true,
"custom_error_codes": []any{float64(429)},
},
},
statusCode: 500,
respBody: []byte(`{"error":"internal"}`),
expectFailover: false,
expectHandleError: false,
},
{
name: "temp_unschedulable_matched_failover",
account: &Account{
ID: 202,
Type: AccountTypeAPIKey,
Platform: PlatformGemini,
Credentials: map[string]any{
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": float64(503),
"keywords": []any{"overloaded"},
"duration_minutes": float64(10),
},
},
},
},
statusCode: 503,
respBody: []byte(`overloaded`),
expectFailover: true,
expectHandleError: true,
},
{
name: "no_policy_429_failover_via_shouldFailover",
account: &Account{
ID: 203,
Type: AccountTypeAPIKey,
Platform: PlatformGemini,
},
statusCode: 429,
respBody: []byte(`{"error":"rate limited"}`),
expectFailover: true,
expectHandleError: true,
expectShouldFailover: true,
},
{
name: "no_policy_400_no_failover",
account: &Account{
ID: 204,
Type: AccountTypeAPIKey,
Platform: PlatformGemini,
},
statusCode: 400,
respBody: []byte(`{"error":"bad request"}`),
expectFailover: false,
expectHandleError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
repo := &geminiErrorPolicyRepo{}
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
svc := &GeminiMessagesCompatService{
accountRepo: repo,
rateLimitService: rlSvc,
}
writer := httptest.NewRecorder()
c, _ := gin.CreateTestContext(writer)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
// Simulate the Claude compat error handling path (same logic as native).
// This mirrors the inline switch in handleClaudeCompat.
var handleErrorCalled bool
var gotFailover bool
ctx := context.Background()
statusCode := tt.statusCode
respBody := tt.respBody
account := tt.account
headers := http.Header{}
if svc.rateLimitService != nil {
switch svc.rateLimitService.CheckErrorPolicy(ctx, account, statusCode, respBody) {
case ErrorPolicySkipped:
// Skipped → return error directly (no handleGeminiUpstreamError, no failover)
gotFailover = false
handleErrorCalled = false
goto verify
case ErrorPolicyMatched, ErrorPolicyTempUnscheduled:
svc.handleGeminiUpstreamError(ctx, account, statusCode, headers, respBody)
handleErrorCalled = true
gotFailover = true
goto verify
}
}
// ErrorPolicyNone → original logic
svc.handleGeminiUpstreamError(ctx, account, statusCode, headers, respBody)
handleErrorCalled = true
if svc.shouldFailoverGeminiUpstreamError(statusCode) {
gotFailover = true
}
verify:
require.Equal(t, tt.expectFailover, gotFailover, "failover mismatch")
require.Equal(t, tt.expectHandleError, handleErrorCalled, "handleGeminiUpstreamError call mismatch")
if tt.expectShouldFailover {
require.True(t, svc.shouldFailoverGeminiUpstreamError(statusCode),
"shouldFailoverGeminiUpstreamError should return true for status %d", statusCode)
}
})
}
}
// ---------------------------------------------------------------------------
// TestGeminiErrorPolicy_NilRateLimitService — verifies nil safety
// ---------------------------------------------------------------------------
func TestGeminiErrorPolicy_NilRateLimitService(t *testing.T) {
svc := &GeminiMessagesCompatService{
rateLimitService: nil,
}
// When rateLimitService is nil, error policy is skipped → falls through to
// shouldFailoverGeminiUpstreamError (original logic).
// Verify this doesn't panic and follows expected behavior.
ctx := context.Background()
account := &Account{
ID: 300,
Type: AccountTypeAPIKey,
Platform: PlatformGemini,
Credentials: map[string]any{
"custom_error_codes_enabled": true,
"custom_error_codes": []any{float64(429)},
},
}
// The nil check should prevent CheckErrorPolicy from being called
if svc.rateLimitService != nil {
t.Fatal("rateLimitService should be nil for this test")
}
// shouldFailoverGeminiUpstreamError still works
require.True(t, svc.shouldFailoverGeminiUpstreamError(429))
require.False(t, svc.shouldFailoverGeminiUpstreamError(400))
// handleGeminiUpstreamError should not panic with nil rateLimitService
require.NotPanics(t, func() {
svc.handleGeminiUpstreamError(ctx, account, 500, http.Header{}, []byte(`error`))
})
}
// ---------------------------------------------------------------------------
// geminiErrorPolicyRepo — minimal AccountRepository stub for Gemini error
// policy tests. Embeds mockAccountRepoForGemini and adds tracking.
// ---------------------------------------------------------------------------
type geminiErrorPolicyRepo struct {
mockAccountRepoForGemini
setErrorCalls int
setRateLimitedCalls int
setTempCalls int
}
func (r *geminiErrorPolicyRepo) SetError(_ context.Context, _ int64, _ string) error {
r.setErrorCalls++
return nil
}
func (r *geminiErrorPolicyRepo) SetRateLimited(_ context.Context, _ int64, _ time.Time) error {
r.setRateLimitedCalls++
return nil
}
func (r *geminiErrorPolicyRepo) SetTempUnschedulable(_ context.Context, _ int64, _ time.Time, _ string) error {
r.setTempCalls++
return nil
}

View File

@@ -831,38 +831,47 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
tempMatched := false
// 统一错误策略:自定义错误码 + 临时不可调度
if s.rateLimitService != nil {
tempMatched = s.rateLimitService.HandleTempUnschedulable(ctx, account, resp.StatusCode, respBody)
}
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
if tempMatched {
upstreamReqID := resp.Header.Get(requestIDHeader)
if upstreamReqID == "" {
upstreamReqID = resp.Header.Get("x-goog-request-id")
}
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
switch s.rateLimitService.CheckErrorPolicy(ctx, account, resp.StatusCode, respBody) {
case ErrorPolicySkipped:
upstreamReqID := resp.Header.Get(requestIDHeader)
if upstreamReqID == "" {
upstreamReqID = resp.Header.Get("x-goog-request-id")
}
upstreamDetail = truncateString(string(respBody), maxBytes)
return nil, s.writeGeminiMappedError(c, account, resp.StatusCode, upstreamReqID, respBody)
case ErrorPolicyMatched, ErrorPolicyTempUnscheduled:
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
upstreamReqID := resp.Header.Get(requestIDHeader)
if upstreamReqID == "" {
upstreamReqID = resp.Header.Get("x-goog-request-id")
}
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(respBody), maxBytes)
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: upstreamReqID,
Kind: "failover",
Message: upstreamMsg,
Detail: upstreamDetail,
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: upstreamReqID,
Kind: "failover",
Message: upstreamMsg,
Detail: upstreamDetail,
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
}
// ErrorPolicyNone → 原有逻辑
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
upstreamReqID := resp.Header.Get(requestIDHeader)
if upstreamReqID == "" {
@@ -1249,14 +1258,9 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
tempMatched := false
if s.rateLimitService != nil {
tempMatched = s.rateLimitService.HandleTempUnschedulable(ctx, account, resp.StatusCode, respBody)
}
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
// Best-effort fallback for OAuth tokens missing AI Studio scopes when calling countTokens.
// This avoids Gemini SDKs failing hard during preflight token counting.
// Checked before error policy so it always works regardless of custom error codes.
if action == "countTokens" && isOAuth && isGeminiInsufficientScope(resp.Header, respBody) {
estimated := estimateGeminiCountTokens(body)
c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated})
@@ -1270,30 +1274,46 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
}, nil
}
if tempMatched {
evBody := unwrapIfNeeded(isOAuth, respBody)
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(evBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
// 统一错误策略:自定义错误码 + 临时不可调度
if s.rateLimitService != nil {
switch s.rateLimitService.CheckErrorPolicy(ctx, account, resp.StatusCode, respBody) {
case ErrorPolicySkipped:
respBody = unwrapIfNeeded(isOAuth, respBody)
contentType := resp.Header.Get("Content-Type")
if contentType == "" {
contentType = "application/json"
}
upstreamDetail = truncateString(string(evBody), maxBytes)
c.Data(resp.StatusCode, contentType, respBody)
return nil, fmt.Errorf("gemini upstream error: %d (skipped by error policy)", resp.StatusCode)
case ErrorPolicyMatched, ErrorPolicyTempUnscheduled:
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
evBody := unwrapIfNeeded(isOAuth, respBody)
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(evBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(evBody), maxBytes)
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: requestID,
Kind: "failover",
Message: upstreamMsg,
Detail: upstreamDetail,
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: requestID,
Kind: "failover",
Message: upstreamMsg,
Detail: upstreamDetail,
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
}
// ErrorPolicyNone → 原有逻辑
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
evBody := unwrapIfNeeded(isOAuth, respBody)
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(evBody))

View File

@@ -133,9 +133,6 @@ func (m *mockAccountRepoForGemini) ListSchedulableByGroupIDAndPlatforms(ctx cont
func (m *mockAccountRepoForGemini) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
return nil
}
func (m *mockAccountRepoForGemini) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error {
return nil
}
func (m *mockAccountRepoForGemini) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
return nil
}
@@ -265,29 +262,6 @@ func (m *mockGatewayCacheForGemini) DeleteSessionAccountID(ctx context.Context,
return nil
}
func (m *mockGatewayCacheForGemini) IncrModelCallCount(ctx context.Context, accountID int64, model string) (int64, error) {
return 0, nil
}
func (m *mockGatewayCacheForGemini) GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*ModelLoadInfo, error) {
return nil, nil
}
func (m *mockGatewayCacheForGemini) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
return "", 0, false
}
func (m *mockGatewayCacheForGemini) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
return nil
}
func (m *mockGatewayCacheForGemini) FindAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
return "", 0, false
}
func (m *mockGatewayCacheForGemini) SaveAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
return nil
}
// TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform 测试 Gemini 单平台选择
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform(t *testing.T) {

View File

@@ -6,26 +6,11 @@ import (
"encoding/json"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/cespare/xxhash/v2"
)
// Gemini 会话 ID Fallback 相关常量
const (
// geminiSessionTTLSeconds Gemini 会话缓存 TTL5 分钟)
geminiSessionTTLSeconds = 300
// geminiSessionKeyPrefix Gemini 会话 Redis key 前缀
geminiSessionKeyPrefix = "gemini:sess:"
)
// GeminiSessionTTL 返回 Gemini 会话缓存 TTL
func GeminiSessionTTL() time.Duration {
return geminiSessionTTLSeconds * time.Second
}
// shortHash 使用 XXHash64 + Base36 生成短 hash16 字符)
// XXHash64 比 SHA256 快约 10 倍Base36 比 Hex 短约 20%
func shortHash(data []byte) string {
@@ -79,35 +64,6 @@ func GenerateGeminiPrefixHash(userID, apiKeyID int64, ip, userAgent, platform, m
return base64.RawURLEncoding.EncodeToString(hash[:12])
}
// BuildGeminiSessionKey 构建 Gemini 会话 Redis key
// 格式: gemini:sess:{groupID}:{prefixHash}:{digestChain}
func BuildGeminiSessionKey(groupID int64, prefixHash, digestChain string) string {
return geminiSessionKeyPrefix + strconv.FormatInt(groupID, 10) + ":" + prefixHash + ":" + digestChain
}
// GenerateDigestChainPrefixes 生成摘要链的所有前缀(从长到短)
// 用于 MGET 批量查询最长匹配
func GenerateDigestChainPrefixes(chain string) []string {
if chain == "" {
return nil
}
var prefixes []string
c := chain
for c != "" {
prefixes = append(prefixes, c)
// 找到最后一个 "-" 的位置
if i := strings.LastIndex(c, "-"); i > 0 {
c = c[:i]
} else {
break
}
}
return prefixes
}
// ParseGeminiSessionValue 解析 Gemini 会话缓存值
// 格式: {uuid}:{accountID}
func ParseGeminiSessionValue(value string) (uuid string, accountID int64, ok bool) {
@@ -139,15 +95,6 @@ func FormatGeminiSessionValue(uuid string, accountID int64) string {
// geminiDigestSessionKeyPrefix Gemini 摘要 fallback 会话 key 前缀
const geminiDigestSessionKeyPrefix = "gemini:digest:"
// geminiTrieKeyPrefix Gemini Trie 会话 key 前缀
const geminiTrieKeyPrefix = "gemini:trie:"
// BuildGeminiTrieKey 构建 Gemini Trie Redis key
// 格式: gemini:trie:{groupID}:{prefixHash}
func BuildGeminiTrieKey(groupID int64, prefixHash string) string {
return geminiTrieKeyPrefix + strconv.FormatInt(groupID, 10) + ":" + prefixHash
}
// GenerateGeminiDigestSessionKey 生成 Gemini 摘要 fallback 的 sessionKey
// 组合 prefixHash 前 8 位 + uuid 前 8 位,确保不同会话产生不同的 sessionKey
// 用于在 SelectAccountWithLoadAwareness 中保持粘性会话

View File

@@ -1,41 +1,14 @@
package service
import (
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
)
// mockGeminiSessionCache 模拟 Redis 缓存
type mockGeminiSessionCache struct {
sessions map[string]string // key -> value
}
func newMockGeminiSessionCache() *mockGeminiSessionCache {
return &mockGeminiSessionCache{sessions: make(map[string]string)}
}
func (m *mockGeminiSessionCache) Save(groupID int64, prefixHash, digestChain, uuid string, accountID int64) {
key := BuildGeminiSessionKey(groupID, prefixHash, digestChain)
value := FormatGeminiSessionValue(uuid, accountID)
m.sessions[key] = value
}
func (m *mockGeminiSessionCache) Find(groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
prefixes := GenerateDigestChainPrefixes(digestChain)
for _, p := range prefixes {
key := BuildGeminiSessionKey(groupID, prefixHash, p)
if val, ok := m.sessions[key]; ok {
return ParseGeminiSessionValue(val)
}
}
return "", 0, false
}
// TestGeminiSessionContinuousConversation 测试连续会话的摘要链匹配
func TestGeminiSessionContinuousConversation(t *testing.T) {
cache := newMockGeminiSessionCache()
store := NewDigestSessionStore()
groupID := int64(1)
prefixHash := "test_prefix_hash"
sessionUUID := "session-uuid-12345"
@@ -54,13 +27,13 @@ func TestGeminiSessionContinuousConversation(t *testing.T) {
t.Logf("Round 1 chain: %s", chain1)
// 第一轮:没有找到会话,创建新会话
_, _, found := cache.Find(groupID, prefixHash, chain1)
_, _, _, found := store.Find(groupID, prefixHash, chain1)
if found {
t.Error("Round 1: should not find existing session")
}
// 保存第一轮会话
cache.Save(groupID, prefixHash, chain1, sessionUUID, accountID)
// 保存第一轮会话(首轮无旧 chain
store.Save(groupID, prefixHash, chain1, sessionUUID, accountID, "")
// 模拟第二轮对话(用户继续对话)
req2 := &antigravity.GeminiRequest{
@@ -77,7 +50,7 @@ func TestGeminiSessionContinuousConversation(t *testing.T) {
t.Logf("Round 2 chain: %s", chain2)
// 第二轮:应该能找到会话(通过前缀匹配)
foundUUID, foundAccID, found := cache.Find(groupID, prefixHash, chain2)
foundUUID, foundAccID, matchedChain, found := store.Find(groupID, prefixHash, chain2)
if !found {
t.Error("Round 2: should find session via prefix matching")
}
@@ -88,8 +61,8 @@ func TestGeminiSessionContinuousConversation(t *testing.T) {
t.Errorf("Round 2: expected accountID %d, got %d", accountID, foundAccID)
}
// 保存第二轮会话
cache.Save(groupID, prefixHash, chain2, sessionUUID, accountID)
// 保存第二轮会话,传入 Find 返回的 matchedChain 以删旧 key
store.Save(groupID, prefixHash, chain2, sessionUUID, accountID, matchedChain)
// 模拟第三轮对话
req3 := &antigravity.GeminiRequest{
@@ -108,7 +81,7 @@ func TestGeminiSessionContinuousConversation(t *testing.T) {
t.Logf("Round 3 chain: %s", chain3)
// 第三轮:应该能找到会话(通过第二轮的前缀匹配)
foundUUID, foundAccID, found = cache.Find(groupID, prefixHash, chain3)
foundUUID, foundAccID, _, found = store.Find(groupID, prefixHash, chain3)
if !found {
t.Error("Round 3: should find session via prefix matching")
}
@@ -118,13 +91,11 @@ func TestGeminiSessionContinuousConversation(t *testing.T) {
if foundAccID != accountID {
t.Errorf("Round 3: expected accountID %d, got %d", accountID, foundAccID)
}
t.Log("✓ Continuous conversation session matching works correctly!")
}
// TestGeminiSessionDifferentConversations 测试不同会话不会错误匹配
func TestGeminiSessionDifferentConversations(t *testing.T) {
cache := newMockGeminiSessionCache()
store := NewDigestSessionStore()
groupID := int64(1)
prefixHash := "test_prefix_hash"
@@ -135,7 +106,7 @@ func TestGeminiSessionDifferentConversations(t *testing.T) {
},
}
chain1 := BuildGeminiDigestChain(req1)
cache.Save(groupID, prefixHash, chain1, "session-1", 100)
store.Save(groupID, prefixHash, chain1, "session-1", 100, "")
// 第二个完全不同的会话
req2 := &antigravity.GeminiRequest{
@@ -146,61 +117,29 @@ func TestGeminiSessionDifferentConversations(t *testing.T) {
chain2 := BuildGeminiDigestChain(req2)
// 不同会话不应该匹配
_, _, found := cache.Find(groupID, prefixHash, chain2)
_, _, _, found := store.Find(groupID, prefixHash, chain2)
if found {
t.Error("Different conversations should not match")
}
t.Log("✓ Different conversations are correctly isolated!")
}
// TestGeminiSessionPrefixMatchingOrder 测试前缀匹配的优先级(最长匹配优先)
func TestGeminiSessionPrefixMatchingOrder(t *testing.T) {
cache := newMockGeminiSessionCache()
store := NewDigestSessionStore()
groupID := int64(1)
prefixHash := "test_prefix_hash"
// 创建一个三轮对话
req := &antigravity.GeminiRequest{
SystemInstruction: &antigravity.GeminiContent{
Parts: []antigravity.GeminiPart{{Text: "System prompt"}},
},
Contents: []antigravity.GeminiContent{
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "Q1"}}},
{Role: "model", Parts: []antigravity.GeminiPart{{Text: "A1"}}},
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "Q2"}}},
},
}
fullChain := BuildGeminiDigestChain(req)
prefixes := GenerateDigestChainPrefixes(fullChain)
t.Logf("Full chain: %s", fullChain)
t.Logf("Prefixes (longest first): %v", prefixes)
// 验证前缀生成顺序(从长到短)
if len(prefixes) != 4 {
t.Errorf("Expected 4 prefixes, got %d", len(prefixes))
}
// 保存不同轮次的会话到不同账号
// 第一轮(最短前缀)-> 账号 1
cache.Save(groupID, prefixHash, prefixes[3], "session-round1", 1)
// 第二轮 -> 账号 2
cache.Save(groupID, prefixHash, prefixes[2], "session-round2", 2)
// 第三轮(最长前缀,完整链)-> 账号 3
cache.Save(groupID, prefixHash, prefixes[0], "session-round3", 3)
store.Save(groupID, prefixHash, "s:sys-u:q1", "session-round1", 1, "")
store.Save(groupID, prefixHash, "s:sys-u:q1-m:a1", "session-round2", 2, "")
store.Save(groupID, prefixHash, "s:sys-u:q1-m:a1-u:q2", "session-round3", 3, "")
// 查找应该返回最长匹配(账号 3
_, accID, found := cache.Find(groupID, prefixHash, fullChain)
// 查找更长的链,应该返回最长匹配(账号 3
_, accID, _, found := store.Find(groupID, prefixHash, "s:sys-u:q1-m:a1-u:q2-m:a2")
if !found {
t.Error("Should find session")
}
if accID != 3 {
t.Errorf("Should match longest prefix (account 3), got account %d", accID)
}
t.Log("✓ Longest prefix matching works correctly!")
}
// 确保 context 包被使用(避免未使用的导入警告)
var _ = context.Background

View File

@@ -152,61 +152,6 @@ func TestGenerateGeminiPrefixHash(t *testing.T) {
}
}
func TestGenerateDigestChainPrefixes(t *testing.T) {
tests := []struct {
name string
chain string
want []string
wantLen int
}{
{
name: "empty",
chain: "",
wantLen: 0,
},
{
name: "single part",
chain: "u:abc123",
want: []string{"u:abc123"},
wantLen: 1,
},
{
name: "two parts",
chain: "s:xyz-u:abc",
want: []string{"s:xyz-u:abc", "s:xyz"},
wantLen: 2,
},
{
name: "four parts",
chain: "s:a-u:b-m:c-u:d",
want: []string{"s:a-u:b-m:c-u:d", "s:a-u:b-m:c", "s:a-u:b", "s:a"},
wantLen: 4,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GenerateDigestChainPrefixes(tt.chain)
if len(result) != tt.wantLen {
t.Errorf("expected %d prefixes, got %d: %v", tt.wantLen, len(result), result)
}
if tt.want != nil {
for i, want := range tt.want {
if i >= len(result) {
t.Errorf("missing prefix at index %d", i)
continue
}
if result[i] != want {
t.Errorf("prefix[%d]: expected %s, got %s", i, want, result[i])
}
}
}
})
}
}
func TestParseGeminiSessionValue(t *testing.T) {
tests := []struct {
name string
@@ -442,40 +387,3 @@ func TestGenerateGeminiDigestSessionKey(t *testing.T) {
}
})
}
func TestBuildGeminiTrieKey(t *testing.T) {
tests := []struct {
name string
groupID int64
prefixHash string
want string
}{
{
name: "normal",
groupID: 123,
prefixHash: "abcdef12",
want: "gemini:trie:123:abcdef12",
},
{
name: "zero group",
groupID: 0,
prefixHash: "xyz",
want: "gemini:trie:0:xyz",
},
{
name: "empty prefix",
groupID: 1,
prefixHash: "",
want: "gemini:trie:1:",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := BuildGeminiTrieKey(tt.groupID, tt.prefixHash)
if got != tt.want {
t.Errorf("BuildGeminiTrieKey(%d, %q) = %q, want %q", tt.groupID, tt.prefixHash, got, tt.want)
}
})
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -318,110 +318,6 @@ func TestGetModelRateLimitRemainingTime(t *testing.T) {
}
}
func TestGetQuotaScopeRateLimitRemainingTime(t *testing.T) {
now := time.Now()
future10m := now.Add(10 * time.Minute).Format(time.RFC3339)
past := now.Add(-10 * time.Minute).Format(time.RFC3339)
tests := []struct {
name string
account *Account
requestedModel string
minExpected time.Duration
maxExpected time.Duration
}{
{
name: "nil account",
account: nil,
requestedModel: "claude-sonnet-4-5",
minExpected: 0,
maxExpected: 0,
},
{
name: "non-antigravity platform",
account: &Account{
Platform: PlatformAnthropic,
Extra: map[string]any{
antigravityQuotaScopesKey: map[string]any{
"claude": map[string]any{
"rate_limit_reset_at": future10m,
},
},
},
},
requestedModel: "claude-sonnet-4-5",
minExpected: 0,
maxExpected: 0,
},
{
name: "claude scope rate limited",
account: &Account{
Platform: PlatformAntigravity,
Extra: map[string]any{
antigravityQuotaScopesKey: map[string]any{
"claude": map[string]any{
"rate_limit_reset_at": future10m,
},
},
},
},
requestedModel: "claude-sonnet-4-5",
minExpected: 9 * time.Minute,
maxExpected: 11 * time.Minute,
},
{
name: "gemini_text scope rate limited",
account: &Account{
Platform: PlatformAntigravity,
Extra: map[string]any{
antigravityQuotaScopesKey: map[string]any{
"gemini_text": map[string]any{
"rate_limit_reset_at": future10m,
},
},
},
},
requestedModel: "gemini-3-flash",
minExpected: 9 * time.Minute,
maxExpected: 11 * time.Minute,
},
{
name: "expired scope rate limit",
account: &Account{
Platform: PlatformAntigravity,
Extra: map[string]any{
antigravityQuotaScopesKey: map[string]any{
"claude": map[string]any{
"rate_limit_reset_at": past,
},
},
},
},
requestedModel: "claude-sonnet-4-5",
minExpected: 0,
maxExpected: 0,
},
{
name: "unsupported model",
account: &Account{
Platform: PlatformAntigravity,
},
requestedModel: "gpt-4",
minExpected: 0,
maxExpected: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.account.GetQuotaScopeRateLimitRemainingTime(tt.requestedModel)
if result < tt.minExpected || result > tt.maxExpected {
t.Errorf("GetQuotaScopeRateLimitRemainingTime() = %v, want between %v and %v", result, tt.minExpected, tt.maxExpected)
}
})
}
}
func TestGetRateLimitRemainingTime(t *testing.T) {
now := time.Now()
future15m := now.Add(15 * time.Minute).Format(time.RFC3339)
@@ -442,45 +338,19 @@ func TestGetRateLimitRemainingTime(t *testing.T) {
maxExpected: 0,
},
{
name: "model remaining > scope remaining - returns model",
name: "model rate limited - 15 minutes",
account: &Account{
Platform: PlatformAntigravity,
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"claude-sonnet-4-5": map[string]any{
"rate_limit_reset_at": future15m, // 15 分钟
},
},
antigravityQuotaScopesKey: map[string]any{
"claude": map[string]any{
"rate_limit_reset_at": future5m, // 5 分钟
"rate_limit_reset_at": future15m,
},
},
},
},
requestedModel: "claude-sonnet-4-5",
minExpected: 14 * time.Minute, // 应返回较大的 15 分钟
maxExpected: 16 * time.Minute,
},
{
name: "scope remaining > model remaining - returns scope",
account: &Account{
Platform: PlatformAntigravity,
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"claude-sonnet-4-5": map[string]any{
"rate_limit_reset_at": future5m, // 5 分钟
},
},
antigravityQuotaScopesKey: map[string]any{
"claude": map[string]any{
"rate_limit_reset_at": future15m, // 15 分钟
},
},
},
},
requestedModel: "claude-sonnet-4-5",
minExpected: 14 * time.Minute, // 应返回较大的 15 分钟
minExpected: 14 * time.Minute,
maxExpected: 16 * time.Minute,
},
{
@@ -499,22 +369,6 @@ func TestGetRateLimitRemainingTime(t *testing.T) {
minExpected: 4 * time.Minute,
maxExpected: 6 * time.Minute,
},
{
name: "only scope rate limited",
account: &Account{
Platform: PlatformAntigravity,
Extra: map[string]any{
antigravityQuotaScopesKey: map[string]any{
"claude": map[string]any{
"rate_limit_reset_at": future5m,
},
},
},
},
requestedModel: "claude-sonnet-4-5",
minExpected: 4 * time.Minute,
maxExpected: 6 * time.Minute,
},
{
name: "neither rate limited",
account: &Account{

View File

@@ -580,10 +580,6 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
}
}
} else {
type accountWithLoad struct {
account *Account
loadInfo *AccountLoadInfo
}
var available []accountWithLoad
for _, acc := range candidates {
loadInfo := loadMap[acc.ID]
@@ -618,6 +614,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
return a.account.LastUsedAt.Before(*b.account.LastUsedAt)
}
})
shuffleWithinSortGroups(available)
for _, item := range available {
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)

View File

@@ -204,30 +204,6 @@ func (c *stubGatewayCache) DeleteSessionAccountID(ctx context.Context, groupID i
return nil
}
func (c *stubGatewayCache) IncrModelCallCount(ctx context.Context, accountID int64, model string) (int64, error) {
return 0, nil
}
func (c *stubGatewayCache) GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*ModelLoadInfo, error) {
return nil, nil
}
func (c *stubGatewayCache) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
return "", 0, false
}
func (c *stubGatewayCache) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
return nil
}
func (c *stubGatewayCache) FindAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
return "", 0, false
}
func (c *stubGatewayCache) SaveAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
return nil
}
func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulable(t *testing.T) {
now := time.Now()
resetAt := now.Add(10 * time.Minute)

View File

@@ -66,7 +66,6 @@ func (s *OpsService) GetAccountAvailabilityStats(ctx context.Context, platformFi
}
isAvailable := acc.Status == StatusActive && acc.Schedulable && !isRateLimited && !isOverloaded && !isTempUnsched
scopeRateLimits := acc.GetAntigravityScopeRateLimits()
if acc.Platform != "" {
if _, ok := platform[acc.Platform]; !ok {
@@ -85,14 +84,6 @@ func (s *OpsService) GetAccountAvailabilityStats(ctx context.Context, platformFi
if hasError {
p.ErrorCount++
}
if len(scopeRateLimits) > 0 {
if p.ScopeRateLimitCount == nil {
p.ScopeRateLimitCount = make(map[string]int64)
}
for scope := range scopeRateLimits {
p.ScopeRateLimitCount[scope]++
}
}
}
for _, grp := range acc.Groups {
@@ -117,14 +108,6 @@ func (s *OpsService) GetAccountAvailabilityStats(ctx context.Context, platformFi
if hasError {
g.ErrorCount++
}
if len(scopeRateLimits) > 0 {
if g.ScopeRateLimitCount == nil {
g.ScopeRateLimitCount = make(map[string]int64)
}
for scope := range scopeRateLimits {
g.ScopeRateLimitCount[scope]++
}
}
}
displayGroupID := int64(0)
@@ -157,9 +140,6 @@ func (s *OpsService) GetAccountAvailabilityStats(ctx context.Context, platformFi
item.RateLimitRemainingSec = &remainingSec
}
}
if len(scopeRateLimits) > 0 {
item.ScopeRateLimits = scopeRateLimits
}
if isOverloaded && acc.OverloadUntil != nil {
item.OverloadUntil = acc.OverloadUntil
remainingSec := int64(time.Until(*acc.OverloadUntil).Seconds())

View File

@@ -50,24 +50,22 @@ type UserConcurrencyInfo struct {
// PlatformAvailability aggregates account availability by platform.
type PlatformAvailability struct {
Platform string `json:"platform"`
TotalAccounts int64 `json:"total_accounts"`
AvailableCount int64 `json:"available_count"`
RateLimitCount int64 `json:"rate_limit_count"`
ScopeRateLimitCount map[string]int64 `json:"scope_rate_limit_count,omitempty"`
ErrorCount int64 `json:"error_count"`
Platform string `json:"platform"`
TotalAccounts int64 `json:"total_accounts"`
AvailableCount int64 `json:"available_count"`
RateLimitCount int64 `json:"rate_limit_count"`
ErrorCount int64 `json:"error_count"`
}
// GroupAvailability aggregates account availability by group.
type GroupAvailability struct {
GroupID int64 `json:"group_id"`
GroupName string `json:"group_name"`
Platform string `json:"platform"`
TotalAccounts int64 `json:"total_accounts"`
AvailableCount int64 `json:"available_count"`
RateLimitCount int64 `json:"rate_limit_count"`
ScopeRateLimitCount map[string]int64 `json:"scope_rate_limit_count,omitempty"`
ErrorCount int64 `json:"error_count"`
GroupID int64 `json:"group_id"`
GroupName string `json:"group_name"`
Platform string `json:"platform"`
TotalAccounts int64 `json:"total_accounts"`
AvailableCount int64 `json:"available_count"`
RateLimitCount int64 `json:"rate_limit_count"`
ErrorCount int64 `json:"error_count"`
}
// AccountAvailability represents current availability for a single account.
@@ -85,11 +83,10 @@ type AccountAvailability struct {
IsOverloaded bool `json:"is_overloaded"`
HasError bool `json:"has_error"`
RateLimitResetAt *time.Time `json:"rate_limit_reset_at"`
RateLimitRemainingSec *int64 `json:"rate_limit_remaining_sec"`
ScopeRateLimits map[string]int64 `json:"scope_rate_limits,omitempty"`
OverloadUntil *time.Time `json:"overload_until"`
OverloadRemainingSec *int64 `json:"overload_remaining_sec"`
ErrorMessage string `json:"error_message"`
TempUnschedulableUntil *time.Time `json:"temp_unschedulable_until,omitempty"`
RateLimitResetAt *time.Time `json:"rate_limit_reset_at"`
RateLimitRemainingSec *int64 `json:"rate_limit_remaining_sec"`
OverloadUntil *time.Time `json:"overload_until"`
OverloadRemainingSec *int64 `json:"overload_remaining_sec"`
ErrorMessage string `json:"error_message"`
TempUnschedulableUntil *time.Time `json:"temp_unschedulable_until,omitempty"`
}

View File

@@ -12,6 +12,7 @@ import (
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/domain"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/gin-gonic/gin"
@@ -528,7 +529,7 @@ func (s *OpsService) selectAccountForRetry(ctx context.Context, reqType opsRetry
func extractRetryModelAndStream(reqType opsRetryRequestType, errorLog *OpsErrorLogDetail, body []byte) (model string, stream bool, err error) {
switch reqType {
case opsRetryTypeMessages:
parsed, parseErr := ParseGatewayRequest(body)
parsed, parseErr := ParseGatewayRequest(body, domain.PlatformAnthropic)
if parseErr != nil {
return "", false, fmt.Errorf("failed to parse messages request body: %w", parseErr)
}
@@ -596,7 +597,7 @@ func (s *OpsService) executeWithAccount(ctx context.Context, reqType opsRetryReq
if s.gatewayService == nil {
return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "gateway service not available"}
}
parsedReq, parseErr := ParseGatewayRequest(body)
parsedReq, parseErr := ParseGatewayRequest(body, domain.PlatformAnthropic)
if parseErr != nil {
return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "failed to parse request body"}
}

View File

@@ -62,6 +62,32 @@ func (s *RateLimitService) SetTokenCacheInvalidator(invalidator TokenCacheInvali
s.tokenCacheInvalidator = invalidator
}
// ErrorPolicyResult 表示错误策略检查的结果
type ErrorPolicyResult int
const (
ErrorPolicyNone ErrorPolicyResult = iota // 未命中任何策略,继续默认逻辑
ErrorPolicySkipped // 自定义错误码开启但未命中,跳过处理
ErrorPolicyMatched // 自定义错误码命中,应停止调度
ErrorPolicyTempUnscheduled // 临时不可调度规则命中
)
// CheckErrorPolicy 检查自定义错误码和临时不可调度规则。
// 自定义错误码开启时覆盖后续所有逻辑(包括临时不可调度)。
func (s *RateLimitService) CheckErrorPolicy(ctx context.Context, account *Account, statusCode int, responseBody []byte) ErrorPolicyResult {
if account.IsCustomErrorCodesEnabled() {
if account.ShouldHandleErrorCode(statusCode) {
return ErrorPolicyMatched
}
slog.Info("account_error_code_skipped", "account_id", account.ID, "status_code", statusCode)
return ErrorPolicySkipped
}
if s.tryTempUnschedulable(ctx, account, statusCode, responseBody) {
return ErrorPolicyTempUnscheduled
}
return ErrorPolicyNone
}
// HandleUpstreamError 处理上游错误响应,标记账号状态
// 返回是否应该停止该账号的调度
func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, responseBody []byte) (shouldDisable bool) {

View File

@@ -0,0 +1,318 @@
//go:build unit
package service
import (
"testing"
"time"
"github.com/stretchr/testify/require"
)
// ============ shuffleWithinSortGroups 测试 ============
func TestShuffleWithinSortGroups_Empty(t *testing.T) {
shuffleWithinSortGroups(nil)
shuffleWithinSortGroups([]accountWithLoad{})
}
func TestShuffleWithinSortGroups_SingleElement(t *testing.T) {
accounts := []accountWithLoad{
{account: &Account{ID: 1, Priority: 1}, loadInfo: &AccountLoadInfo{LoadRate: 10}},
}
shuffleWithinSortGroups(accounts)
require.Equal(t, int64(1), accounts[0].account.ID)
}
func TestShuffleWithinSortGroups_DifferentGroups_OrderPreserved(t *testing.T) {
now := time.Now()
earlier := now.Add(-1 * time.Hour)
accounts := []accountWithLoad{
{account: &Account{ID: 1, Priority: 1, LastUsedAt: &earlier}, loadInfo: &AccountLoadInfo{LoadRate: 10}},
{account: &Account{ID: 2, Priority: 1, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{LoadRate: 20}},
{account: &Account{ID: 3, Priority: 2, LastUsedAt: &earlier}, loadInfo: &AccountLoadInfo{LoadRate: 10}},
}
// 每个元素都属于不同组Priority 或 LoadRate 或 LastUsedAt 不同),顺序不变
for i := 0; i < 20; i++ {
cpy := make([]accountWithLoad, len(accounts))
copy(cpy, accounts)
shuffleWithinSortGroups(cpy)
require.Equal(t, int64(1), cpy[0].account.ID)
require.Equal(t, int64(2), cpy[1].account.ID)
require.Equal(t, int64(3), cpy[2].account.ID)
}
}
func TestShuffleWithinSortGroups_SameGroup_Shuffled(t *testing.T) {
now := time.Now()
// 同一秒的时间戳视为同一组
sameSecond := time.Unix(now.Unix(), 0)
sameSecond2 := time.Unix(now.Unix(), 500_000_000) // 同一秒但不同纳秒
accounts := []accountWithLoad{
{account: &Account{ID: 1, Priority: 1, LastUsedAt: &sameSecond}, loadInfo: &AccountLoadInfo{LoadRate: 10}},
{account: &Account{ID: 2, Priority: 1, LastUsedAt: &sameSecond2}, loadInfo: &AccountLoadInfo{LoadRate: 10}},
{account: &Account{ID: 3, Priority: 1, LastUsedAt: &sameSecond}, loadInfo: &AccountLoadInfo{LoadRate: 10}},
}
// 多次执行,验证所有 ID 都出现在第一个位置(说明确实被打乱了)
seen := map[int64]bool{}
for i := 0; i < 100; i++ {
cpy := make([]accountWithLoad, len(accounts))
copy(cpy, accounts)
shuffleWithinSortGroups(cpy)
seen[cpy[0].account.ID] = true
// 无论怎么打乱,所有 ID 都应在候选中
ids := map[int64]bool{}
for _, a := range cpy {
ids[a.account.ID] = true
}
require.True(t, ids[1] && ids[2] && ids[3])
}
// 至少 2 个不同的 ID 出现在首位(随机性验证)
require.GreaterOrEqual(t, len(seen), 2, "shuffle should produce different orderings")
}
func TestShuffleWithinSortGroups_NilLastUsedAt_SameGroup(t *testing.T) {
accounts := []accountWithLoad{
{account: &Account{ID: 1, Priority: 1, LastUsedAt: nil}, loadInfo: &AccountLoadInfo{LoadRate: 0}},
{account: &Account{ID: 2, Priority: 1, LastUsedAt: nil}, loadInfo: &AccountLoadInfo{LoadRate: 0}},
{account: &Account{ID: 3, Priority: 1, LastUsedAt: nil}, loadInfo: &AccountLoadInfo{LoadRate: 0}},
}
seen := map[int64]bool{}
for i := 0; i < 100; i++ {
cpy := make([]accountWithLoad, len(accounts))
copy(cpy, accounts)
shuffleWithinSortGroups(cpy)
seen[cpy[0].account.ID] = true
}
require.GreaterOrEqual(t, len(seen), 2, "nil LastUsedAt accounts should be shuffled")
}
func TestShuffleWithinSortGroups_MixedGroups(t *testing.T) {
now := time.Now()
earlier := now.Add(-1 * time.Hour)
sameAsNow := time.Unix(now.Unix(), 0)
// 组1: Priority=1, LoadRate=10, LastUsedAt=earlier (ID 1) — 单元素组
// 组2: Priority=1, LoadRate=20, LastUsedAt=now (ID 2, 3) — 双元素组
// 组3: Priority=2, LoadRate=10, LastUsedAt=earlier (ID 4) — 单元素组
accounts := []accountWithLoad{
{account: &Account{ID: 1, Priority: 1, LastUsedAt: &earlier}, loadInfo: &AccountLoadInfo{LoadRate: 10}},
{account: &Account{ID: 2, Priority: 1, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{LoadRate: 20}},
{account: &Account{ID: 3, Priority: 1, LastUsedAt: &sameAsNow}, loadInfo: &AccountLoadInfo{LoadRate: 20}},
{account: &Account{ID: 4, Priority: 2, LastUsedAt: &earlier}, loadInfo: &AccountLoadInfo{LoadRate: 10}},
}
for i := 0; i < 20; i++ {
cpy := make([]accountWithLoad, len(accounts))
copy(cpy, accounts)
shuffleWithinSortGroups(cpy)
// 组间顺序不变
require.Equal(t, int64(1), cpy[0].account.ID, "group 1 position fixed")
require.Equal(t, int64(4), cpy[3].account.ID, "group 3 position fixed")
// 组2 内部可以打乱,但仍在位置 1 和 2
mid := map[int64]bool{cpy[1].account.ID: true, cpy[2].account.ID: true}
require.True(t, mid[2] && mid[3], "group 2 elements should stay in positions 1-2")
}
}
// ============ shuffleWithinPriorityAndLastUsed 测试 ============
func TestShuffleWithinPriorityAndLastUsed_Empty(t *testing.T) {
shuffleWithinPriorityAndLastUsed(nil)
shuffleWithinPriorityAndLastUsed([]*Account{})
}
func TestShuffleWithinPriorityAndLastUsed_SingleElement(t *testing.T) {
accounts := []*Account{{ID: 1, Priority: 1}}
shuffleWithinPriorityAndLastUsed(accounts)
require.Equal(t, int64(1), accounts[0].ID)
}
func TestShuffleWithinPriorityAndLastUsed_SameGroup_Shuffled(t *testing.T) {
accounts := []*Account{
{ID: 1, Priority: 1, LastUsedAt: nil},
{ID: 2, Priority: 1, LastUsedAt: nil},
{ID: 3, Priority: 1, LastUsedAt: nil},
}
seen := map[int64]bool{}
for i := 0; i < 100; i++ {
cpy := make([]*Account, len(accounts))
copy(cpy, accounts)
shuffleWithinPriorityAndLastUsed(cpy)
seen[cpy[0].ID] = true
}
require.GreaterOrEqual(t, len(seen), 2, "same group should be shuffled")
}
func TestShuffleWithinPriorityAndLastUsed_DifferentPriority_OrderPreserved(t *testing.T) {
accounts := []*Account{
{ID: 1, Priority: 1, LastUsedAt: nil},
{ID: 2, Priority: 2, LastUsedAt: nil},
{ID: 3, Priority: 3, LastUsedAt: nil},
}
for i := 0; i < 20; i++ {
cpy := make([]*Account, len(accounts))
copy(cpy, accounts)
shuffleWithinPriorityAndLastUsed(cpy)
require.Equal(t, int64(1), cpy[0].ID)
require.Equal(t, int64(2), cpy[1].ID)
require.Equal(t, int64(3), cpy[2].ID)
}
}
func TestShuffleWithinPriorityAndLastUsed_DifferentLastUsedAt_OrderPreserved(t *testing.T) {
now := time.Now()
earlier := now.Add(-1 * time.Hour)
accounts := []*Account{
{ID: 1, Priority: 1, LastUsedAt: nil},
{ID: 2, Priority: 1, LastUsedAt: &earlier},
{ID: 3, Priority: 1, LastUsedAt: &now},
}
for i := 0; i < 20; i++ {
cpy := make([]*Account, len(accounts))
copy(cpy, accounts)
shuffleWithinPriorityAndLastUsed(cpy)
require.Equal(t, int64(1), cpy[0].ID)
require.Equal(t, int64(2), cpy[1].ID)
require.Equal(t, int64(3), cpy[2].ID)
}
}
// ============ sameLastUsedAt 测试 ============
func TestSameLastUsedAt(t *testing.T) {
now := time.Now()
sameSecond := time.Unix(now.Unix(), 0)
sameSecondDiffNano := time.Unix(now.Unix(), 999_999_999)
differentSecond := now.Add(1 * time.Second)
t.Run("both nil", func(t *testing.T) {
require.True(t, sameLastUsedAt(nil, nil))
})
t.Run("one nil one not", func(t *testing.T) {
require.False(t, sameLastUsedAt(nil, &now))
require.False(t, sameLastUsedAt(&now, nil))
})
t.Run("same second different nanoseconds", func(t *testing.T) {
require.True(t, sameLastUsedAt(&sameSecond, &sameSecondDiffNano))
})
t.Run("different seconds", func(t *testing.T) {
require.False(t, sameLastUsedAt(&now, &differentSecond))
})
t.Run("exact same time", func(t *testing.T) {
require.True(t, sameLastUsedAt(&now, &now))
})
}
// ============ sameAccountWithLoadGroup 测试 ============
func TestSameAccountWithLoadGroup(t *testing.T) {
now := time.Now()
sameSecond := time.Unix(now.Unix(), 0)
t.Run("same group", func(t *testing.T) {
a := accountWithLoad{account: &Account{Priority: 1, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{LoadRate: 10}}
b := accountWithLoad{account: &Account{Priority: 1, LastUsedAt: &sameSecond}, loadInfo: &AccountLoadInfo{LoadRate: 10}}
require.True(t, sameAccountWithLoadGroup(a, b))
})
t.Run("different priority", func(t *testing.T) {
a := accountWithLoad{account: &Account{Priority: 1, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{LoadRate: 10}}
b := accountWithLoad{account: &Account{Priority: 2, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{LoadRate: 10}}
require.False(t, sameAccountWithLoadGroup(a, b))
})
t.Run("different load rate", func(t *testing.T) {
a := accountWithLoad{account: &Account{Priority: 1, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{LoadRate: 10}}
b := accountWithLoad{account: &Account{Priority: 1, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{LoadRate: 20}}
require.False(t, sameAccountWithLoadGroup(a, b))
})
t.Run("different last used at", func(t *testing.T) {
later := now.Add(1 * time.Second)
a := accountWithLoad{account: &Account{Priority: 1, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{LoadRate: 10}}
b := accountWithLoad{account: &Account{Priority: 1, LastUsedAt: &later}, loadInfo: &AccountLoadInfo{LoadRate: 10}}
require.False(t, sameAccountWithLoadGroup(a, b))
})
t.Run("both nil LastUsedAt", func(t *testing.T) {
a := accountWithLoad{account: &Account{Priority: 1, LastUsedAt: nil}, loadInfo: &AccountLoadInfo{LoadRate: 0}}
b := accountWithLoad{account: &Account{Priority: 1, LastUsedAt: nil}, loadInfo: &AccountLoadInfo{LoadRate: 0}}
require.True(t, sameAccountWithLoadGroup(a, b))
})
}
// ============ sameAccountGroup 测试 ============
func TestSameAccountGroup(t *testing.T) {
now := time.Now()
t.Run("same group", func(t *testing.T) {
a := &Account{Priority: 1, LastUsedAt: nil}
b := &Account{Priority: 1, LastUsedAt: nil}
require.True(t, sameAccountGroup(a, b))
})
t.Run("different priority", func(t *testing.T) {
a := &Account{Priority: 1, LastUsedAt: nil}
b := &Account{Priority: 2, LastUsedAt: nil}
require.False(t, sameAccountGroup(a, b))
})
t.Run("different LastUsedAt", func(t *testing.T) {
later := now.Add(1 * time.Second)
a := &Account{Priority: 1, LastUsedAt: &now}
b := &Account{Priority: 1, LastUsedAt: &later}
require.False(t, sameAccountGroup(a, b))
})
}
// ============ sortAccountsByPriorityAndLastUsed 集成随机化测试 ============
func TestSortAccountsByPriorityAndLastUsed_WithShuffle(t *testing.T) {
t.Run("same priority and nil LastUsedAt are shuffled", func(t *testing.T) {
accounts := []*Account{
{ID: 1, Priority: 1, LastUsedAt: nil},
{ID: 2, Priority: 1, LastUsedAt: nil},
{ID: 3, Priority: 1, LastUsedAt: nil},
}
seen := map[int64]bool{}
for i := 0; i < 100; i++ {
cpy := make([]*Account, len(accounts))
copy(cpy, accounts)
sortAccountsByPriorityAndLastUsed(cpy, false)
seen[cpy[0].ID] = true
}
require.GreaterOrEqual(t, len(seen), 2, "identical sort keys should produce different orderings after shuffle")
})
t.Run("different priorities still sorted correctly", func(t *testing.T) {
now := time.Now()
accounts := []*Account{
{ID: 3, Priority: 3, LastUsedAt: &now},
{ID: 1, Priority: 1, LastUsedAt: &now},
{ID: 2, Priority: 2, LastUsedAt: &now},
}
sortAccountsByPriorityAndLastUsed(accounts, false)
require.Equal(t, int64(1), accounts[0].ID)
require.Equal(t, int64(2), accounts[1].ID)
require.Equal(t, int64(3), accounts[2].ID)
})
}

View File

@@ -275,4 +275,5 @@ var ProviderSet = wire.NewSet(
NewUsageCache,
NewTotpService,
NewErrorPassthroughService,
NewDigestSessionStore,
)

View File

@@ -47,13 +47,15 @@ services:
# =======================================================================
# Database Configuration (PostgreSQL)
# Default: uses local postgres container
# External DB: set DATABASE_HOST and DATABASE_SSLMODE in .env
# =======================================================================
- DATABASE_HOST=postgres
- DATABASE_PORT=5432
- DATABASE_HOST=${DATABASE_HOST:-postgres}
- DATABASE_PORT=${DATABASE_PORT:-5432}
- DATABASE_USER=${POSTGRES_USER:-sub2api}
- DATABASE_PASSWORD=${POSTGRES_PASSWORD:?POSTGRES_PASSWORD is required}
- DATABASE_DBNAME=${POSTGRES_DB:-sub2api}
- DATABASE_SSLMODE=disable
- DATABASE_SSLMODE=${DATABASE_SSLMODE:-disable}
# =======================================================================
# Redis Configuration
@@ -128,8 +130,6 @@ services:
# Examples: http://host:port, socks5://host:port
- UPDATE_PROXY_URL=${UPDATE_PROXY_URL:-}
depends_on:
postgres:
condition: service_healthy
redis:
condition: service_healthy
networks:
@@ -141,35 +141,6 @@ services:
retries: 3
start_period: 30s
# ===========================================================================
# PostgreSQL Database
# ===========================================================================
postgres:
image: postgres:18-alpine
container_name: sub2api-postgres
restart: unless-stopped
ulimits:
nofile:
soft: 100000
hard: 100000
volumes:
- postgres_data:/var/lib/postgresql/data
environment:
- POSTGRES_USER=${POSTGRES_USER:-sub2api}
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:?POSTGRES_PASSWORD is required}
- POSTGRES_DB=${POSTGRES_DB:-sub2api}
- TZ=${TZ:-Asia/Shanghai}
networks:
- sub2api-network
healthcheck:
test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER:-sub2api} -d ${POSTGRES_DB:-sub2api}"]
interval: 10s
timeout: 5s
retries: 5
start_period: 10s
# 注意:不暴露端口到宿主机,应用通过内部网络连接
# 如需调试可临时添加ports: ["127.0.0.1:5433:5432"]
# ===========================================================================
# Redis Cache
# ===========================================================================
@@ -209,8 +180,6 @@ services:
volumes:
sub2api_data:
driver: local
postgres_data:
driver: local
redis_data:
driver: local

Binary file not shown.

After

Width:  |  Height:  |  Size: 148 KiB

View File

@@ -376,7 +376,6 @@ export interface PlatformAvailability {
total_accounts: number
available_count: number
rate_limit_count: number
scope_rate_limit_count?: Record<string, number>
error_count: number
}
@@ -387,7 +386,6 @@ export interface GroupAvailability {
total_accounts: number
available_count: number
rate_limit_count: number
scope_rate_limit_count?: Record<string, number>
error_count: number
}
@@ -402,7 +400,6 @@ export interface AccountAvailability {
is_rate_limited: boolean
rate_limit_reset_at?: string
rate_limit_remaining_sec?: number
scope_rate_limits?: Record<string, number>
is_overloaded: boolean
overload_until?: string
overload_remaining_sec?: number

View File

@@ -76,26 +76,6 @@
</div>
</div>
<!-- Scope Rate Limit Indicators (Antigravity) -->
<template v-if="activeScopeRateLimits.length > 0">
<div v-for="item in activeScopeRateLimits" :key="item.scope" class="group relative">
<span
class="inline-flex items-center gap-1 rounded bg-orange-100 px-1.5 py-0.5 text-xs font-medium text-orange-700 dark:bg-orange-900/30 dark:text-orange-400"
>
<Icon name="exclamationTriangle" size="xs" :stroke-width="2" />
{{ formatScopeName(item.scope) }}
</span>
<!-- Tooltip -->
<div
class="pointer-events-none absolute bottom-full left-1/2 z-50 mb-2 -translate-x-1/2 whitespace-nowrap rounded bg-gray-900 px-2 py-1 text-xs text-white opacity-0 transition-opacity group-hover:opacity-100 dark:bg-gray-700"
>
{{ t('admin.accounts.status.scopeRateLimitedUntil', { scope: formatScopeName(item.scope), time: formatTime(item.reset_at) }) }}
<div
class="absolute left-1/2 top-full -translate-x-1/2 border-4 border-transparent border-t-gray-900 dark:border-t-gray-700" ></div>
</div>
</div>
</template>
<!-- Model Rate Limit Indicators (Antigravity OAuth Smart Retry) -->
<template v-if="activeModelRateLimits.length > 0">
<div v-for="item in activeModelRateLimits" :key="item.model" class="group relative">
@@ -160,16 +140,6 @@ const isRateLimited = computed(() => {
return new Date(props.account.rate_limit_reset_at) > new Date()
})
// Computed: active scope rate limits (Antigravity)
const activeScopeRateLimits = computed(() => {
const scopeLimits = props.account.scope_rate_limits
if (!scopeLimits) return []
const now = new Date()
return Object.entries(scopeLimits)
.filter(([, info]) => new Date(info.reset_at) > now)
.map(([scope, info]) => ({ scope, reset_at: info.reset_at }))
})
// Computed: active model rate limits (Antigravity OAuth Smart Retry)
const activeModelRateLimits = computed(() => {
const modelLimits = (props.account.extra as Record<string, unknown> | undefined)?.model_rate_limits as

View File

@@ -1038,10 +1038,7 @@
</div>
<!-- Custom Error Codes Section -->
<div
v-if="form.platform !== 'gemini'"
class="border-t border-gray-200 pt-4 dark:border-dark-600"
>
<div class="border-t border-gray-200 pt-4 dark:border-dark-600">
<div class="mb-3 flex items-center justify-between">
<div>
<label class="input-label mb-0">{{ t('admin.accounts.customErrorCodes') }}</label>

View File

@@ -0,0 +1,104 @@
<template>
<!-- 悬浮按钮 - 使用主题色 -->
<button
@click="showModal = true"
class="fixed bottom-6 right-6 z-50 flex items-center gap-2 rounded-full bg-gradient-to-r from-primary-500 to-primary-600 px-4 py-3 text-white shadow-lg shadow-primary-500/25 transition-all hover:from-primary-600 hover:to-primary-700 hover:shadow-xl hover:shadow-primary-500/30"
>
<svg class="h-5 w-5" viewBox="0 0 24 24" fill="currentColor">
<path d="M8.691 2.188C3.891 2.188 0 5.476 0 9.53c0 2.212 1.17 4.203 3.002 5.55a.59.59 0 01.213.665l-.39 1.48c-.019.07-.048.141-.048.213 0 .163.13.295.29.295a.328.328 0 00.186-.059l2.114-1.225a.87.87 0 01.415-.106.807.807 0 01.213.026 10.07 10.07 0 002.696.37c.262 0 .52-.011.776-.028a5.91 5.91 0 01-.193-1.479c0-3.644 3.374-6.6 7.536-6.6.262 0 .52.011.776.028-.628-3.513-4.27-6.472-8.885-6.472zM5.785 5.97a1.1 1.1 0 110 2.2 1.1 1.1 0 010-2.2zm5.813 0a1.1 1.1 0 110 2.2 1.1 1.1 0 010-2.2zm5.192 2.642c-3.703 0-6.71 2.567-6.71 5.73 0 3.163 3.007 5.73 6.71 5.73a7.9 7.9 0 002.126-.288.644.644 0 01.17-.022.69.69 0 01.329.085l1.672.97a.262.262 0 00.147.046c.128 0 .23-.104.23-.233a.403.403 0 00-.038-.168l-.309-1.17a.468.468 0 01.168-.527c1.449-1.065 2.374-2.643 2.374-4.423 0-3.163-3.007-5.73-6.71-5.73h-.159zm-2.434 3.34a.88.88 0 110 1.76.88.88 0 010-1.76zm4.868 0a.88.88 0 110 1.76.88.88 0 010-1.76z"/>
</svg>
<span class="text-sm font-medium">客服</span>
</button>
<!-- 弹窗 -->
<Teleport to="body">
<Transition name="fade">
<div
v-if="showModal"
class="fixed inset-0 z-[100] flex items-center justify-center bg-black/50 p-4 backdrop-blur-sm"
@click.self="showModal = false"
>
<Transition name="scale">
<div
v-if="showModal"
class="relative w-full max-w-sm rounded-2xl bg-white p-6 shadow-2xl dark:bg-dark-700"
>
<!-- 关闭按钮 -->
<button
@click="showModal = false"
class="absolute right-4 top-4 text-gray-400 transition-colors hover:text-gray-600 dark:text-dark-400 dark:hover:text-dark-200"
>
<svg class="h-5 w-5" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M6 18L18 6M6 6l12 12" />
</svg>
</button>
<!-- 标题 -->
<div class="mb-4 flex items-center gap-3">
<div class="flex h-10 w-10 items-center justify-center rounded-full bg-gradient-to-br from-primary-500 to-primary-600">
<svg class="h-6 w-6 text-white" viewBox="0 0 24 24" fill="currentColor">
<path d="M8.691 2.188C3.891 2.188 0 5.476 0 9.53c0 2.212 1.17 4.203 3.002 5.55a.59.59 0 01.213.665l-.39 1.48c-.019.07-.048.141-.048.213 0 .163.13.295.29.295a.328.328 0 00.186-.059l2.114-1.225a.87.87 0 01.415-.106.807.807 0 01.213.026 10.07 10.07 0 002.696.37c.262 0 .52-.011.776-.028a5.91 5.91 0 01-.193-1.479c0-3.644 3.374-6.6 7.536-6.6.262 0 .52.011.776.028-.628-3.513-4.27-6.472-8.885-6.472zM5.785 5.97a1.1 1.1 0 110 2.2 1.1 1.1 0 010-2.2zm5.813 0a1.1 1.1 0 110 2.2 1.1 1.1 0 010-2.2zm5.192 2.642c-3.703 0-6.71 2.567-6.71 5.73 0 3.163 3.007 5.73 6.71 5.73a7.9 7.9 0 002.126-.288.644.644 0 01.17-.022.69.69 0 01.329.085l1.672.97a.262.262 0 00.147.046c.128 0 .23-.104.23-.233a.403.403 0 00-.038-.168l-.309-1.17a.468.468 0 01.168-.527c1.449-1.065 2.374-2.643 2.374-4.423 0-3.163-3.007-5.73-6.71-5.73h-.159zm-2.434 3.34a.88.88 0 110 1.76.88.88 0 010-1.76zm4.868 0a.88.88 0 110 1.76.88.88 0 010-1.76z"/>
</svg>
</div>
<div>
<h3 class="text-lg font-semibold text-gray-900 dark:text-white">联系客服</h3>
<p class="text-sm text-gray-500 dark:text-dark-400">扫码添加好友</p>
</div>
</div>
<!-- 二维码卡片 -->
<div class="mb-4 overflow-hidden rounded-xl border border-primary-100 bg-gradient-to-br from-primary-50 to-white p-3 dark:border-primary-800/30 dark:from-primary-900/10 dark:to-dark-800">
<img
src="/wechat-qr.jpg"
alt="微信二维码"
class="w-full rounded-lg"
/>
</div>
<!-- 提示文字 -->
<div class="text-center">
<p class="mb-2 text-sm font-medium text-primary-600 dark:text-primary-400">
微信扫码添加客服
</p>
<p class="flex items-center justify-center gap-1 text-xs text-gray-500 dark:text-dark-400">
<svg class="h-4 w-4" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M12 8v4l3 3m6-3a9 9 0 11-18 0 9 9 0 0118 0z" />
</svg>
工作时间周一至周五 9:00-18:00
</p>
</div>
</div>
</Transition>
</div>
</Transition>
</Teleport>
</template>
<script setup lang="ts">
import { ref } from 'vue'
const showModal = ref(false)
</script>
<style scoped>
.fade-enter-active,
.fade-leave-active {
transition: opacity 0.2s ease;
}
.fade-enter-from,
.fade-leave-to {
opacity: 0;
}
.scale-enter-active,
.scale-leave-active {
transition: all 0.2s ease;
}
.scale-enter-from,
.scale-leave-to {
opacity: 0;
transform: scale(0.95);
}
</style>

View File

@@ -121,23 +121,6 @@
<Icon name="key" size="sm" />
{{ t('nav.apiKeys') }}
</router-link>
<a
href="https://github.com/Wei-Shaw/sub2api"
target="_blank"
rel="noopener noreferrer"
@click="closeDropdown"
class="dropdown-item"
>
<svg class="h-4 w-4" fill="currentColor" viewBox="0 0 24 24">
<path
fill-rule="evenodd"
clip-rule="evenodd"
d="M12 2C6.477 2 2 6.477 2 12c0 4.42 2.865 8.17 6.839 9.49.5.092.682-.217.682-.482 0-.237-.008-.866-.013-1.7-2.782.604-3.369-1.34-3.369-1.34-.454-1.156-1.11-1.464-1.11-1.464-.908-.62.069-.608.069-.608 1.003.07 1.531 1.03 1.531 1.03.892 1.529 2.341 1.087 2.91.831.092-.646.35-1.086.636-1.336-2.22-.253-4.555-1.11-4.555-4.943 0-1.091.39-1.984 1.029-2.683-.103-.253-.446-1.27.098-2.647 0 0 .84-.269 2.75 1.025A9.578 9.578 0 0112 6.836c.85.004 1.705.114 2.504.336 1.909-1.294 2.747-1.025 2.747-1.025.546 1.377.203 2.394.1 2.647.64.699 1.028 1.592 1.028 2.683 0 3.842-2.339 4.687-4.566 4.935.359.309.678.919.678 1.852 0 1.336-.012 2.415-.012 2.743 0 .267.18.578.688.48C19.138 20.167 22 16.418 22 12c0-5.523-4.477-10-10-10z"
/>
</svg>
{{ t('nav.github') }}
</a>
</div>
<!-- Contact Support (only show if configured) -->

View File

@@ -1356,7 +1356,6 @@ export default {
overloaded: 'Overloaded',
tempUnschedulable: 'Temp Unschedulable',
rateLimitedUntil: 'Rate limited until {time}',
scopeRateLimitedUntil: '{scope} rate limited until {time}',
modelRateLimitedUntil: '{model} rate limited until {time}',
overloadedUntil: 'Overloaded until {time}',
viewTempUnschedDetails: 'View temp unschedulable details'
@@ -3059,7 +3058,6 @@ export default {
empty: 'No data',
queued: 'Queue {count}',
rateLimited: 'Rate-limited {count}',
scopeRateLimitedTooltip: '{scope} rate-limited ({count} accounts)',
errorAccounts: 'Errors {count}',
loadFailed: 'Failed to load concurrency data'
},

View File

@@ -1492,7 +1492,6 @@ export default {
overloaded: '过载中',
tempUnschedulable: '临时不可调度',
rateLimitedUntil: '限流中,重置时间:{time}',
scopeRateLimitedUntil: '{scope} 限流中,重置时间:{time}',
modelRateLimitedUntil: '{model} 限流至 {time}',
overloadedUntil: '负载过重,重置时间:{time}',
viewTempUnschedDetails: '查看临时不可调度详情'
@@ -3232,7 +3231,6 @@ export default {
empty: '暂无数据',
queued: '队列 {count}',
rateLimited: '限流 {count}',
scopeRateLimitedTooltip: '{scope} 限流中 ({count} 个账号)',
errorAccounts: '异常 {count}',
loadFailed: '加载并发数据失败'
},

View File

@@ -591,9 +591,6 @@ export interface Account {
temp_unschedulable_until: string | null
temp_unschedulable_reason: string | null
// Antigravity scope 级限流状态
scope_rate_limits?: Record<string, { reset_at: string; remaining_sec: number }>
// Session window fields (5-hour window)
session_window_start: string | null
session_window_end: string | null

View File

@@ -122,8 +122,11 @@
>
{{ siteName }}
</h1>
<p class="mb-8 text-lg text-gray-600 dark:text-dark-300 md:text-xl">
{{ siteSubtitle }}
<p class="mb-3 text-xl font-semibold text-primary-600 dark:text-primary-400 md:text-2xl">
{{ t('home.heroSubtitle') }}
</p>
<p class="mb-8 text-base text-gray-600 dark:text-dark-300 md:text-lg">
{{ t('home.heroDescription') }}
</p>
<!-- CTA Button -->
@@ -177,7 +180,7 @@
</div>
<!-- Feature Tags - Centered -->
<div class="mb-12 flex flex-wrap items-center justify-center gap-4 md:gap-6">
<div class="mb-16 flex flex-wrap items-center justify-center gap-4 md:gap-6">
<div
class="inline-flex items-center gap-2.5 rounded-full border border-gray-200/50 bg-white/80 px-5 py-2.5 shadow-sm backdrop-blur-sm dark:border-dark-700/50 dark:bg-dark-800/80"
>
@@ -204,6 +207,63 @@
</div>
</div>
<!-- Pain Points Section -->
<div class="mb-16">
<h2 class="mb-8 text-center text-2xl font-bold text-gray-900 dark:text-white md:text-3xl">
{{ t('home.painPoints.title') }}
</h2>
<div class="grid gap-4 sm:grid-cols-2 lg:grid-cols-4">
<!-- Pain Point 1: Expensive -->
<div class="rounded-xl border border-red-200/50 bg-red-50/50 p-5 dark:border-red-900/30 dark:bg-red-950/20">
<div class="mb-3 flex h-10 w-10 items-center justify-center rounded-lg bg-red-100 dark:bg-red-900/30">
<svg class="h-5 w-5 text-red-500" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
<path stroke-linecap="round" stroke-linejoin="round" d="M12 8c-1.657 0-3 .895-3 2s1.343 2 3 2 3 .895 3 2-1.343 2-3 2m0-8c1.11 0 2.08.402 2.599 1M12 8V7m0 1v8m0 0v1m0-1c-1.11 0-2.08-.402-2.599-1M21 12a9 9 0 11-18 0 9 9 0 0118 0z" />
</svg>
</div>
<h3 class="mb-1.5 font-semibold text-gray-900 dark:text-white">{{ t('home.painPoints.items.expensive.title') }}</h3>
<p class="text-sm text-gray-600 dark:text-dark-400">{{ t('home.painPoints.items.expensive.desc') }}</p>
</div>
<!-- Pain Point 2: Complex -->
<div class="rounded-xl border border-orange-200/50 bg-orange-50/50 p-5 dark:border-orange-900/30 dark:bg-orange-950/20">
<div class="mb-3 flex h-10 w-10 items-center justify-center rounded-lg bg-orange-100 dark:bg-orange-900/30">
<svg class="h-5 w-5 text-orange-500" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
<path stroke-linecap="round" stroke-linejoin="round" d="M19 11H5m14 0a2 2 0 012 2v6a2 2 0 01-2 2H5a2 2 0 01-2-2v-6a2 2 0 012-2m14 0V9a2 2 0 00-2-2M5 11V9a2 2 0 012-2m0 0V5a2 2 0 012-2h6a2 2 0 012 2v2M7 7h10" />
</svg>
</div>
<h3 class="mb-1.5 font-semibold text-gray-900 dark:text-white">{{ t('home.painPoints.items.complex.title') }}</h3>
<p class="text-sm text-gray-600 dark:text-dark-400">{{ t('home.painPoints.items.complex.desc') }}</p>
</div>
<!-- Pain Point 3: Unstable -->
<div class="rounded-xl border border-yellow-200/50 bg-yellow-50/50 p-5 dark:border-yellow-900/30 dark:bg-yellow-950/20">
<div class="mb-3 flex h-10 w-10 items-center justify-center rounded-lg bg-yellow-100 dark:bg-yellow-900/30">
<svg class="h-5 w-5 text-yellow-600" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
<path stroke-linecap="round" stroke-linejoin="round" d="M12 9v2m0 4h.01m-6.938 4h13.856c1.54 0 2.502-1.667 1.732-3L13.732 4c-.77-1.333-2.694-1.333-3.464 0L3.34 16c-.77 1.333.192 3 1.732 3z" />
</svg>
</div>
<h3 class="mb-1.5 font-semibold text-gray-900 dark:text-white">{{ t('home.painPoints.items.unstable.title') }}</h3>
<p class="text-sm text-gray-600 dark:text-dark-400">{{ t('home.painPoints.items.unstable.desc') }}</p>
</div>
<!-- Pain Point 4: No Control -->
<div class="rounded-xl border border-gray-200/50 bg-gray-50/50 p-5 dark:border-dark-700/50 dark:bg-dark-800/50">
<div class="mb-3 flex h-10 w-10 items-center justify-center rounded-lg bg-gray-100 dark:bg-dark-700">
<svg class="h-5 w-5 text-gray-500" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
<path stroke-linecap="round" stroke-linejoin="round" d="M18.364 18.364A9 9 0 005.636 5.636m12.728 12.728A9 9 0 015.636 5.636m12.728 12.728L5.636 5.636" />
</svg>
</div>
<h3 class="mb-1.5 font-semibold text-gray-900 dark:text-white">{{ t('home.painPoints.items.noControl.title') }}</h3>
<p class="text-sm text-gray-600 dark:text-dark-400">{{ t('home.painPoints.items.noControl.desc') }}</p>
</div>
</div>
</div>
<!-- Solutions Section Title -->
<div class="mb-8 text-center">
<h2 class="mb-2 text-2xl font-bold text-gray-900 dark:text-white md:text-3xl">
{{ t('home.solutions.title') }}
</h2>
<p class="text-gray-600 dark:text-dark-400">{{ t('home.solutions.subtitle') }}</p>
</div>
<!-- Features Grid -->
<div class="mb-12 grid gap-6 md:grid-cols-3">
<!-- Feature 1: Unified Gateway -->
@@ -369,6 +429,77 @@
>
</div>
</div>
<!-- Comparison Table -->
<div class="mb-16">
<h2 class="mb-8 text-center text-2xl font-bold text-gray-900 dark:text-white md:text-3xl">
{{ t('home.comparison.title') }}
</h2>
<div class="overflow-x-auto">
<table class="w-full rounded-xl border border-gray-200/50 bg-white/60 backdrop-blur-sm dark:border-dark-700/50 dark:bg-dark-800/60">
<thead>
<tr class="border-b border-gray-200/50 dark:border-dark-700/50">
<th class="px-6 py-4 text-left text-sm font-semibold text-gray-900 dark:text-white">{{ t('home.comparison.headers.feature') }}</th>
<th class="px-6 py-4 text-center text-sm font-semibold text-gray-500 dark:text-dark-400">{{ t('home.comparison.headers.official') }}</th>
<th class="px-6 py-4 text-center text-sm font-semibold text-primary-600 dark:text-primary-400">{{ t('home.comparison.headers.us') }}</th>
</tr>
</thead>
<tbody class="divide-y divide-gray-200/50 dark:divide-dark-700/50">
<tr>
<td class="px-6 py-4 text-sm font-medium text-gray-900 dark:text-white">{{ t('home.comparison.items.pricing.feature') }}</td>
<td class="px-6 py-4 text-center text-sm text-gray-500 dark:text-dark-400">{{ t('home.comparison.items.pricing.official') }}</td>
<td class="px-6 py-4 text-center text-sm font-medium text-primary-600 dark:text-primary-400">{{ t('home.comparison.items.pricing.us') }}</td>
</tr>
<tr>
<td class="px-6 py-4 text-sm font-medium text-gray-900 dark:text-white">{{ t('home.comparison.items.models.feature') }}</td>
<td class="px-6 py-4 text-center text-sm text-gray-500 dark:text-dark-400">{{ t('home.comparison.items.models.official') }}</td>
<td class="px-6 py-4 text-center text-sm font-medium text-primary-600 dark:text-primary-400">{{ t('home.comparison.items.models.us') }}</td>
</tr>
<tr>
<td class="px-6 py-4 text-sm font-medium text-gray-900 dark:text-white">{{ t('home.comparison.items.management.feature') }}</td>
<td class="px-6 py-4 text-center text-sm text-gray-500 dark:text-dark-400">{{ t('home.comparison.items.management.official') }}</td>
<td class="px-6 py-4 text-center text-sm font-medium text-primary-600 dark:text-primary-400">{{ t('home.comparison.items.management.us') }}</td>
</tr>
<tr>
<td class="px-6 py-4 text-sm font-medium text-gray-900 dark:text-white">{{ t('home.comparison.items.stability.feature') }}</td>
<td class="px-6 py-4 text-center text-sm text-gray-500 dark:text-dark-400">{{ t('home.comparison.items.stability.official') }}</td>
<td class="px-6 py-4 text-center text-sm font-medium text-primary-600 dark:text-primary-400">{{ t('home.comparison.items.stability.us') }}</td>
</tr>
<tr>
<td class="px-6 py-4 text-sm font-medium text-gray-900 dark:text-white">{{ t('home.comparison.items.control.feature') }}</td>
<td class="px-6 py-4 text-center text-sm text-gray-500 dark:text-dark-400">{{ t('home.comparison.items.control.official') }}</td>
<td class="px-6 py-4 text-center text-sm font-medium text-primary-600 dark:text-primary-400">{{ t('home.comparison.items.control.us') }}</td>
</tr>
</tbody>
</table>
</div>
</div>
<!-- CTA Section -->
<div class="mb-8 rounded-2xl bg-gradient-to-r from-primary-500 to-primary-600 p-8 text-center shadow-xl shadow-primary-500/20 md:p-12">
<h2 class="mb-3 text-2xl font-bold text-white md:text-3xl">
{{ t('home.cta.title') }}
</h2>
<p class="mb-6 text-primary-100">
{{ t('home.cta.description') }}
</p>
<router-link
v-if="!isAuthenticated"
to="/register"
class="inline-flex items-center gap-2 rounded-full bg-white px-8 py-3 font-semibold text-primary-600 shadow-lg transition-all hover:bg-gray-50 hover:shadow-xl"
>
{{ t('home.cta.button') }}
<Icon name="arrowRight" size="md" :stroke-width="2" />
</router-link>
<router-link
v-else
:to="dashboardPath"
class="inline-flex items-center gap-2 rounded-full bg-white px-8 py-3 font-semibold text-primary-600 shadow-lg transition-all hover:bg-gray-50 hover:shadow-xl"
>
{{ t('home.goToDashboard') }}
<Icon name="arrowRight" size="md" :stroke-width="2" />
</router-link>
</div>
</div>
</main>
@@ -380,27 +511,20 @@
<p class="text-sm text-gray-500 dark:text-dark-400">
&copy; {{ currentYear }} {{ siteName }}. {{ t('home.footer.allRightsReserved') }}
</p>
<div class="flex items-center gap-4">
<a
v-if="docUrl"
:href="docUrl"
target="_blank"
rel="noopener noreferrer"
class="text-sm text-gray-500 transition-colors hover:text-gray-700 dark:text-dark-400 dark:hover:text-white"
>
{{ t('home.docs') }}
</a>
<a
:href="githubUrl"
target="_blank"
rel="noopener noreferrer"
class="text-sm text-gray-500 transition-colors hover:text-gray-700 dark:text-dark-400 dark:hover:text-white"
>
GitHub
</a>
</div>
<a
v-if="docUrl"
:href="docUrl"
target="_blank"
rel="noopener noreferrer"
class="text-sm text-gray-500 transition-colors hover:text-gray-700 dark:text-dark-400 dark:hover:text-white"
>
{{ t('home.docs') }}
</a>
</div>
</footer>
<!-- 微信客服悬浮按钮 -->
<WechatServiceButton />
</div>
</template>
@@ -410,6 +534,7 @@ import { useI18n } from 'vue-i18n'
import { useAuthStore, useAppStore } from '@/stores'
import LocaleSwitcher from '@/components/common/LocaleSwitcher.vue'
import Icon from '@/components/icons/Icon.vue'
import WechatServiceButton from '@/components/common/WechatServiceButton.vue'
const { t } = useI18n()
@@ -419,7 +544,6 @@ const appStore = useAppStore()
// Site settings - directly from appStore (already initialized from injected config)
const siteName = computed(() => appStore.cachedPublicSettings?.site_name || appStore.siteName || 'Sub2API')
const siteLogo = computed(() => appStore.cachedPublicSettings?.site_logo || appStore.siteLogo || '')
const siteSubtitle = computed(() => appStore.cachedPublicSettings?.site_subtitle || 'AI API Gateway Platform')
const docUrl = computed(() => appStore.cachedPublicSettings?.doc_url || appStore.docUrl || '')
const homeContent = computed(() => appStore.cachedPublicSettings?.home_content || '')
@@ -432,9 +556,6 @@ const isHomeContentUrl = computed(() => {
// Theme
const isDark = ref(document.documentElement.classList.contains('dark'))
// GitHub URL
const githubUrl = 'https://github.com/Wei-Shaw/sub2api'
// Auth state
const isAuthenticated = computed(() => authStore.isAuthenticated)
const isAdmin = computed(() => authStore.isAdmin)

View File

@@ -56,7 +56,6 @@ interface SummaryRow {
total_accounts: number
available_accounts: number
rate_limited_accounts: number
scope_rate_limit_count?: Record<string, number>
error_accounts: number
// 并发统计
total_concurrency: number
@@ -122,7 +121,6 @@ const platformRows = computed((): SummaryRow[] => {
total_accounts: totalAccounts,
available_accounts: availableAccounts,
rate_limited_accounts: safeNumber(avail.rate_limit_count),
scope_rate_limit_count: avail.scope_rate_limit_count,
error_accounts: safeNumber(avail.error_count),
total_concurrency: totalConcurrency,
used_concurrency: usedConcurrency,
@@ -162,7 +160,6 @@ const groupRows = computed((): SummaryRow[] => {
total_accounts: totalAccounts,
available_accounts: availableAccounts,
rate_limited_accounts: safeNumber(avail.rate_limit_count),
scope_rate_limit_count: avail.scope_rate_limit_count,
error_accounts: safeNumber(avail.error_count),
total_concurrency: totalConcurrency,
used_concurrency: usedConcurrency,
@@ -329,15 +326,6 @@ function formatDuration(seconds: number): string {
return `${hours}h`
}
function formatScopeName(scope: string): string {
const names: Record<string, string> = {
claude: 'Claude',
gemini_text: 'Gemini',
gemini_image: 'Image'
}
return names[scope] || scope
}
watch(
() => realtimeEnabled.value,
async (enabled) => {
@@ -505,18 +493,6 @@ watch(
{{ t('admin.ops.concurrency.rateLimited', { count: row.rate_limited_accounts }) }}
</span>
<!-- Scope 限流 ( Antigravity) -->
<template v-if="row.scope_rate_limit_count && Object.keys(row.scope_rate_limit_count).length > 0">
<span
v-for="(count, scope) in row.scope_rate_limit_count"
:key="scope"
class="rounded-full bg-orange-100 px-1.5 py-0.5 font-semibold text-orange-700 dark:bg-orange-900/30 dark:text-orange-400"
:title="t('admin.ops.concurrency.scopeRateLimitedTooltip', { scope, count })"
>
{{ formatScopeName(scope as string) }} {{ count }}
</span>
</template>
<!-- 异常账号 -->
<span
v-if="row.error_accounts > 0"

View File

@@ -0,0 +1,127 @@
#!/bin/bash
# Gemini 粘性会话压力测试脚本
# 测试目标:验证不同会话分配不同账号,同一会话保持同一账号
BASE_URL="http://host.clicodeplus.com:8080"
API_KEY="sk-32ad0a3197e528c840ea84f0dc6b2056dd3fead03526b5c605a60709bd408f7e"
MODEL="gemini-2.5-flash"
# 创建临时目录存放结果
RESULT_DIR="/tmp/gemini_stress_test_$(date +%s)"
mkdir -p "$RESULT_DIR"
echo "=========================================="
echo "Gemini 粘性会话压力测试"
echo "结果目录: $RESULT_DIR"
echo "=========================================="
# 函数:发送请求并记录
send_request() {
local session_id=$1
local round=$2
local system_prompt=$3
local contents=$4
local output_file="$RESULT_DIR/session_${session_id}_round_${round}.json"
local request_body=$(cat <<EOF
{
"systemInstruction": {
"parts": [{"text": "$system_prompt"}]
},
"contents": $contents
}
EOF
)
curl -s -X POST "${BASE_URL}/v1beta/models/${MODEL}:generateContent" \
-H "Content-Type: application/json" \
-H "x-goog-api-key: ${API_KEY}" \
-d "$request_body" > "$output_file" 2>&1
echo "[Session $session_id Round $round] 完成"
}
# 会话1数学计算器累加序列
run_session_1() {
local sys_prompt="你是一个数学计算器,只返回计算结果数字,不要任何解释"
# Round 1: 1+1=?
send_request 1 1 "$sys_prompt" '[{"role":"user","parts":[{"text":"1+1=?"}]}]'
# Round 2: 继续 2+2=?(累加历史)
send_request 1 2 "$sys_prompt" '[{"role":"user","parts":[{"text":"1+1=?"}]},{"role":"model","parts":[{"text":"2"}]},{"role":"user","parts":[{"text":"2+2=?"}]}]'
# Round 3: 继续 3+3=?
send_request 1 3 "$sys_prompt" '[{"role":"user","parts":[{"text":"1+1=?"}]},{"role":"model","parts":[{"text":"2"}]},{"role":"user","parts":[{"text":"2+2=?"}]},{"role":"model","parts":[{"text":"4"}]},{"role":"user","parts":[{"text":"3+3=?"}]}]'
# Round 4: 批量计算 10+10, 20+20, 30+30
send_request 1 4 "$sys_prompt" '[{"role":"user","parts":[{"text":"1+1=?"}]},{"role":"model","parts":[{"text":"2"}]},{"role":"user","parts":[{"text":"2+2=?"}]},{"role":"model","parts":[{"text":"4"}]},{"role":"user","parts":[{"text":"3+3=?"}]},{"role":"model","parts":[{"text":"6"}]},{"role":"user","parts":[{"text":"计算: 10+10=? 20+20=? 30+30=?"}]}]'
}
# 会话2英文翻译器不同系统提示词 = 不同会话)
run_session_2() {
local sys_prompt="你是一个英文翻译器,将中文翻译成英文,只返回翻译结果"
send_request 2 1 "$sys_prompt" '[{"role":"user","parts":[{"text":"你好"}]}]'
send_request 2 2 "$sys_prompt" '[{"role":"user","parts":[{"text":"你好"}]},{"role":"model","parts":[{"text":"Hello"}]},{"role":"user","parts":[{"text":"世界"}]}]'
send_request 2 3 "$sys_prompt" '[{"role":"user","parts":[{"text":"你好"}]},{"role":"model","parts":[{"text":"Hello"}]},{"role":"user","parts":[{"text":"世界"}]},{"role":"model","parts":[{"text":"World"}]},{"role":"user","parts":[{"text":"早上好"}]}]'
}
# 会话3日文翻译器
run_session_3() {
local sys_prompt="你是一个日文翻译器,将中文翻译成日文,只返回翻译结果"
send_request 3 1 "$sys_prompt" '[{"role":"user","parts":[{"text":"你好"}]}]'
send_request 3 2 "$sys_prompt" '[{"role":"user","parts":[{"text":"你好"}]},{"role":"model","parts":[{"text":"こんにちは"}]},{"role":"user","parts":[{"text":"谢谢"}]}]'
send_request 3 3 "$sys_prompt" '[{"role":"user","parts":[{"text":"你好"}]},{"role":"model","parts":[{"text":"こんにちは"}]},{"role":"user","parts":[{"text":"谢谢"}]},{"role":"model","parts":[{"text":"ありがとう"}]},{"role":"user","parts":[{"text":"再见"}]}]'
}
# 会话4乘法计算器另一个数学会话但系统提示词不同
run_session_4() {
local sys_prompt="你是一个乘法专用计算器,只计算乘法,返回数字结果"
send_request 4 1 "$sys_prompt" '[{"role":"user","parts":[{"text":"2*3=?"}]}]'
send_request 4 2 "$sys_prompt" '[{"role":"user","parts":[{"text":"2*3=?"}]},{"role":"model","parts":[{"text":"6"}]},{"role":"user","parts":[{"text":"4*5=?"}]}]'
send_request 4 3 "$sys_prompt" '[{"role":"user","parts":[{"text":"2*3=?"}]},{"role":"model","parts":[{"text":"6"}]},{"role":"user","parts":[{"text":"4*5=?"}]},{"role":"model","parts":[{"text":"20"}]},{"role":"user","parts":[{"text":"计算: 10*10=? 20*20=?"}]}]'
}
# 会话5诗人完全不同的角色
run_session_5() {
local sys_prompt="你是一位诗人,用简短的诗句回应每个话题,每次只写一句诗"
send_request 5 1 "$sys_prompt" '[{"role":"user","parts":[{"text":"春天"}]}]'
send_request 5 2 "$sys_prompt" '[{"role":"user","parts":[{"text":"春天"}]},{"role":"model","parts":[{"text":"春风拂面花满枝"}]},{"role":"user","parts":[{"text":"夏天"}]}]'
send_request 5 3 "$sys_prompt" '[{"role":"user","parts":[{"text":"春天"}]},{"role":"model","parts":[{"text":"春风拂面花满枝"}]},{"role":"user","parts":[{"text":"夏天"}]},{"role":"model","parts":[{"text":"蝉鸣蛙声伴荷香"}]},{"role":"user","parts":[{"text":"秋天"}]}]'
}
echo ""
echo "开始并发测试 5 个独立会话..."
echo ""
# 并发运行所有会话
run_session_1 &
run_session_2 &
run_session_3 &
run_session_4 &
run_session_5 &
# 等待所有后台任务完成
wait
echo ""
echo "=========================================="
echo "所有请求完成,结果保存在: $RESULT_DIR"
echo "=========================================="
# 显示结果摘要
echo ""
echo "响应摘要:"
for f in "$RESULT_DIR"/*.json; do
filename=$(basename "$f")
response=$(cat "$f" | head -c 200)
echo "[$filename]: ${response}..."
done
echo ""
echo "请检查服务器日志确认账号分配情况"