diff --git a/.github/workflows/backend-ci.yml b/.github/workflows/backend-ci.yml index 2596a18c..84575a96 100644 --- a/.github/workflows/backend-ci.yml +++ b/.github/workflows/backend-ci.yml @@ -17,6 +17,7 @@ jobs: go-version-file: backend/go.mod check-latest: false cache: true + cache-dependency-path: backend/go.sum - name: Verify Go version run: | go version | grep -q 'go1.25.7' @@ -36,6 +37,7 @@ jobs: go-version-file: backend/go.mod check-latest: false cache: true + cache-dependency-path: backend/go.sum - name: Verify Go version run: | go version | grep -q 'go1.25.7' diff --git a/.gitignore b/.gitignore index 48172982..2f2bfbdf 100644 --- a/.gitignore +++ b/.gitignore @@ -78,6 +78,7 @@ Desktop.ini # =================== tmp/ temp/ +logs/ *.tmp *.temp *.log diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 00000000..a7a3e34a --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,723 @@ +# Sub2API 开发说明 + +## 版本管理策略 + +### 版本号规则 + +我们在官方版本号后面添加自己的小版本号: + +- 官方版本:`v0.1.68` +- 我们的版本:`v0.1.68.1`、`v0.1.68.2`(递增) + +### 分支策略 + +| 分支 | 说明 | +|------|------| +| `main` | 我们的主分支,包含所有定制功能 | +| `release/custom-X.Y.Z` | 基于官方 `vX.Y.Z` 的发布分支 | +| `upstream/main` | 上游官方仓库 | + +--- + +## 发布流程(基于新官方版本) + +当官方发布新版本(如 `v0.1.69`)时: + +### 1. 同步上游并创建发布分支 + +```bash +# 获取上游最新代码 +git fetch upstream --tags + +# 基于官方标签创建新的发布分支 +git checkout v0.1.69 -b release/custom-0.1.69 + +# 合并我们的 main 分支(包含所有定制功能) +git merge main --no-edit + +# 解决可能的冲突后继续 +``` + +### 2. 更新版本号并打标签 + +```bash +# 更新版本号文件 +echo "0.1.69.1" > backend/cmd/server/VERSION +git add backend/cmd/server/VERSION +git commit -m "chore: bump version to 0.1.69.1" + +# 打上我们自己的标签 +git tag v0.1.69.1 + +# 推送分支和标签 +git push origin release/custom-0.1.69 +git push origin v0.1.69.1 +``` + +### 3. 更新 main 分支 + +```bash +# 将发布分支合并回 main,保持 main 包含最新定制功能 +git checkout main +git merge release/custom-0.1.69 +git push origin main +``` + +--- + +## 热修复发布(在现有版本上修复) + +当需要在当前版本上发布修复时: + +```bash +# 在当前发布分支上修复 +git checkout release/custom-0.1.68 +# ... 进行修复 ... +git commit -m "fix: 修复描述" + +# 递增小版本号 +echo "0.1.68.2" > backend/cmd/server/VERSION +git add backend/cmd/server/VERSION +git commit -m "chore: bump version to 0.1.68.2" + +# 打标签并推送 +git tag v0.1.68.2 +git push origin release/custom-0.1.68 +git push origin v0.1.68.2 + +# 同步修复到 main +git checkout main +git cherry-pick +git push origin main +``` + +--- + +## 服务器部署流程 + +### 前置条件 + +- 本地已配置 SSH 别名 `clicodeplus` 连接到服务器 +- 服务器部署目录:`/root/sub2api`(正式)、`/root/sub2api-beta`(测试) +- 服务器使用 Docker Compose 部署 + +### 部署环境说明 + +| 环境 | 目录 | 端口 | 数据库 | 容器名 | +|------|------|------|--------|--------| +| 正式 | `/root/sub2api` | 8080 | `sub2api` | `sub2api` | +| Beta | `/root/sub2api-beta` | 8084 | `beta` | `sub2api-beta` | + +### 外部数据库 + +正式和 Beta 环境**共用外部 PostgreSQL 数据库**(非容器内数据库),配置在 `.env` 文件中: +- `DATABASE_HOST`:外部数据库地址 +- `DATABASE_SSLMODE`:SSL 模式(通常为 `require`) +- `POSTGRES_USER` / `POSTGRES_DB`:用户名和数据库名 + +#### 数据库操作命令 + +通过 SSH 在服务器上执行数据库操作: + +```bash +# 正式环境 - 查询迁移记录 +ssh clicodeplus "source /root/sub2api/deploy/.env && PGPASSWORD=\"\$POSTGRES_PASSWORD\" psql -h \$DATABASE_HOST -U \$POSTGRES_USER -d \$POSTGRES_DB -c 'SELECT * FROM schema_migrations ORDER BY applied_at DESC LIMIT 5;'" + +# Beta 环境 - 查询迁移记录 +ssh clicodeplus "source /root/sub2api-beta/deploy/.env && PGPASSWORD=\"\$POSTGRES_PASSWORD\" psql -h \$DATABASE_HOST -U \$POSTGRES_USER -d \$POSTGRES_DB -c 'SELECT * FROM schema_migrations ORDER BY applied_at DESC LIMIT 5;'" + +# Beta 环境 - 清除指定迁移记录(重新执行迁移) +ssh clicodeplus "source /root/sub2api-beta/deploy/.env && PGPASSWORD=\"\$POSTGRES_PASSWORD\" psql -h \$DATABASE_HOST -U \$POSTGRES_USER -d \$POSTGRES_DB -c \"DELETE FROM schema_migrations WHERE filename LIKE '%049%';\"" + +# Beta 环境 - 更新账号数据 +ssh clicodeplus "source /root/sub2api-beta/deploy/.env && PGPASSWORD=\"\$POSTGRES_PASSWORD\" psql -h \$DATABASE_HOST -U \$POSTGRES_USER -d \$POSTGRES_DB -c \"UPDATE accounts SET credentials = credentials - 'model_mapping' WHERE platform = 'antigravity';\"" +``` + +> **注意**:使用 `source .env` 加载环境变量,避免在命令行中暴露密码。 + +### 部署步骤 + +**重要:每次部署都必须递增版本号!** + +#### 0. 递增版本号(本地操作) + +每次部署前,先在本地递增小版本号: + +```bash +# 查看当前版本号 +cat backend/cmd/server/VERSION +# 假设当前是 0.1.69.1 + +# 递增版本号 +echo "0.1.69.2" > backend/cmd/server/VERSION +git add backend/cmd/server/VERSION +git commit -m "chore: bump version to 0.1.69.2" +git push origin release/custom-0.1.69 +``` + +#### 1. 服务器拉取代码 + +```bash +ssh clicodeplus "cd /root/sub2api && git fetch fork && git checkout -B release/custom-0.1.69 fork/release/custom-0.1.69" +``` + +#### 2. 服务器构建镜像 + +```bash +ssh clicodeplus "cd /root/sub2api && docker build --no-cache -t sub2api:latest -f Dockerfile ." +``` + +#### 3. 更新镜像标签并重启服务 + +```bash +ssh clicodeplus "docker tag sub2api:latest weishaw/sub2api:latest" +ssh clicodeplus "cd /root/sub2api/deploy && docker compose up -d --force-recreate sub2api" +``` + +#### 4. 验证部署 + +```bash +# 查看启动日志 +ssh clicodeplus "docker logs sub2api --tail 20" + +# 确认版本号(必须与步骤 0 中设置的版本号一致) +ssh clicodeplus "cat /root/sub2api/backend/cmd/server/VERSION" + +# 检查容器状态 +ssh clicodeplus "docker ps | grep sub2api" +``` + +--- + +## Beta 并行部署(不影响现网) + +目标:在同一台服务器上并行启动一个 beta 实例(例如端口 `8084`),**严禁改动/重启**现网实例(默认目录 `/root/sub2api`)。 + +### 设计原则 + +- **新目录**:beta 使用独立目录,例如 `/root/sub2api-beta`。 +- **敏感信息只放 `.env`**:beta 的数据库密码、JWT_SECRET 等只写入 `/root/sub2api-beta/deploy/.env`,不要提交到 git。 +- **独立 Compose Project**:通过 `docker compose -p sub2api-beta ...` 启动,确保 network/volume 隔离。 +- **独立端口**:通过 `.env` 的 `SERVER_PORT` 映射宿主机端口(例如 `8084:8080`)。 + +### 前置检查 + +```bash +# 1) 确保 8084 未被占用 +ssh clicodeplus "ss -ltnp | grep :8084 || echo '8084 is free'" + +# 2) 确认现网容器还在(只读检查) +ssh clicodeplus "docker ps --format 'table {{.Names}}\t{{.Image}}\t{{.Ports}}' | sed -n '1,200p'" +``` + +### 首次部署步骤 + +```bash +# 0) 进入服务器 +ssh clicodeplus + +# 1) 克隆代码到新目录(示例使用你的 fork) +cd /root +git clone https://github.com/touwaeriol/sub2api.git sub2api-beta +cd /root/sub2api-beta +git checkout release/custom-0.1.71 + +# 2) 准备 beta 的 .env(敏感信息只写这里) +cd /root/sub2api-beta/deploy + +# 推荐:从现网 .env 复制,保证除 DB 名/用户/端口外完全一致 +cp -f /root/sub2api/deploy/.env ./.env + +# 仅修改以下三项(其他保持不变) +perl -pi -e 's/^SERVER_PORT=.*/SERVER_PORT=8084/' ./.env +perl -pi -e 's/^POSTGRES_USER=.*/POSTGRES_USER=beta/' ./.env +perl -pi -e 's/^POSTGRES_DB=.*/POSTGRES_DB=beta/' ./.env + +# 3) 写 compose override(避免与现网容器名冲突,镜像使用本地构建的 sub2api:beta) +cat > docker-compose.override.yml <<'YAML' +services: + sub2api: + image: sub2api:beta + container_name: sub2api-beta + redis: + container_name: sub2api-beta-redis +YAML + +# 4) 构建 beta 镜像(基于当前代码) +cd /root/sub2api-beta +docker build -t sub2api:beta -f Dockerfile . + +# 5) 启动 beta(独立 project,确保不影响现网) +cd /root/sub2api-beta/deploy +docker compose -p sub2api-beta --env-file .env -f docker-compose.yml -f docker-compose.override.yml up -d + +# 6) 验证 beta +curl -fsS http://127.0.0.1:8084/health +docker logs sub2api-beta --tail 50 +``` + +### 数据库配置约定(beta) + +- 数据库地址/SSL/密码:与现网一致(从现网 `.env` 复制即可)。 +- 仅修改: + - `POSTGRES_USER=beta` + - `POSTGRES_DB=beta` + +注意:需要数据库侧已存在 `beta` 用户与 `beta` 数据库,并授予权限;否则容器会启动失败并不断重启。 + +### 更新 beta(拉代码 + 仅重建 beta 容器) + +```bash +ssh clicodeplus "set -e; cd /root/sub2api-beta && git fetch --all --tags && git checkout -f release/custom-0.1.71 && git reset --hard origin/release/custom-0.1.71" +ssh clicodeplus "cd /root/sub2api-beta && docker build -t sub2api:beta -f Dockerfile ." +ssh clicodeplus "cd /root/sub2api-beta/deploy && docker compose -p sub2api-beta --env-file .env -f docker-compose.yml -f docker-compose.override.yml up -d --no-deps --force-recreate sub2api" +ssh clicodeplus "curl -fsS http://127.0.0.1:8084/health" +``` + +### 停止/回滚 beta(只影响 beta) + +```bash +ssh clicodeplus "cd /root/sub2api-beta/deploy && docker compose -p sub2api-beta -f docker-compose.yml -f docker-compose.override.yml down" +``` + +--- + +## 服务器首次部署 + +### 1. 克隆代码并配置远程仓库 + +```bash +ssh clicodeplus +cd /root +git clone https://github.com/Wei-Shaw/sub2api.git +cd sub2api + +# 添加 fork 仓库 +git remote add fork https://github.com/touwaeriol/sub2api.git +``` + +### 2. 切换到定制分支并配置环境 + +```bash +git fetch fork +git checkout -B release/custom-0.1.69 fork/release/custom-0.1.69 + +cd deploy +cp .env.example .env +vim .env # 配置 DATABASE_URL, REDIS_URL, JWT_SECRET 等 +``` + +### 3. 构建并启动 + +```bash +cd /root/sub2api +docker build -t sub2api:latest -f Dockerfile . +docker tag sub2api:latest weishaw/sub2api:latest +cd deploy && docker compose up -d +``` + +### 6. 启动服务 + +```bash +# 进入 deploy 目录 +cd deploy + +# 启动所有服务(PostgreSQL、Redis、sub2api) +docker compose up -d + +# 查看服务状态 +docker compose ps +``` + +### 7. 验证部署 + +```bash +# 查看应用日志 +docker logs sub2api --tail 50 + +# 检查健康状态 +curl http://localhost:8080/health + +# 确认版本号 +cat /root/sub2api/backend/cmd/server/VERSION +``` + +### 8. 常用运维命令 + +```bash +# 查看实时日志 +docker logs -f sub2api + +# 重启服务 +docker compose restart sub2api + +# 停止所有服务 +docker compose down + +# 停止并删除数据卷(慎用!会删除数据库数据) +docker compose down -v + +# 查看资源使用情况 +docker stats sub2api +``` + +--- + +## 定制功能说明 + +当前定制分支包含以下功能(相对于官方版本): + +### UI/UX 定制 + +| 功能 | 说明 | +|------|------| +| 首页优化 | 面向用户的价值主张设计 | +| 移除 GitHub 链接 | 用户菜单中不显示 GitHub 导航 | +| 微信客服按钮 | 首页悬浮微信客服入口 | +| 限流时间精确显示 | 账号限流时间显示精确到秒 | + +### Antigravity 平台增强 + +| 功能 | 说明 | +|------|------| +| Scope 级别限流 | 按配额域(claude/gemini_text/gemini_image)独立限流,避免整个账号被锁定 | +| 模型级别限流 | 按具体模型(如 claude-opus-4-5)独立限流,更精细的限流控制 | +| 限流预检查 | 调度时预检查账号/模型限流状态,避免选中已限流账号 | +| 秒级冷却时间 | 支持 429 响应的秒级精确冷却时间 | +| 身份注入优化 | 模型身份信息注入 + 静默边界防止身份泄露 | +| thoughtSignature 修复 | Gemini 3 函数调用 400 错误修复 | +| max_tokens 自动修正 | 自动修正 max_tokens <= budget_tokens 导致的 400 错误 | + +### 调度算法优化 + +| 功能 | 说明 | +|------|------| +| 分层过滤选择 | 调度算法从全排序改为分层过滤,提升性能 | +| LRU 随机选择 | 相同 LRU 时间时随机选择,避免账号集中 | +| 限流等待阈值配置化 | 可配置的限流等待阈值 | + +### 运维增强 + +| 功能 | 说明 | +|------|------| +| Scope 限流统计 | 运维界面展示 Antigravity 账号 scope 级别限流统计 | +| 账号限流状态显示 | 账号列表显示 scope 和模型级别限流状态 | +| 清除限流按钮增强 | 有 scope/模型限流时也显示清除限流按钮 | + +### 其他修复 + +| 功能 | 说明 | +|------|------| +| .gitattributes | 确保迁移文件使用 LF 换行符(解决 Windows 下 SQL 摘要不一致) | +| 部署配置优化 | DATABASE_HOST 和 DATABASE_SSLMODE 可通过 .env 配置 | + +--- + +## 注意事项 + +1. **前端必须打包进镜像**:使用 `docker build` 在服务器上构建,Dockerfile 会自动编译前端并 embed 到后端二进制中 + +2. **镜像标签**:docker-compose.yml 使用 `weishaw/sub2api:latest`,本地构建后需要 `docker tag` 覆盖 + +3. **Windows 换行符问题**:已通过 `.gitattributes` 解决,确保 `*.sql` 文件始终使用 LF + +4. **版本号管理**:每次发布必须更新 `backend/cmd/server/VERSION` 并打标签 + +5. **合并冲突**:合并上游新版本时,重点关注以下文件可能的冲突: + - `backend/internal/service/antigravity_gateway_service.go` + - `backend/internal/service/gateway_service.go` + - `backend/internal/pkg/antigravity/request_transformer.go` + +--- + +## Go 代码规范 + +### 1. 函数设计 + +#### 单一职责原则 +- **函数行数**:单个函数常规不应超过 **30 行**,超过时应拆分为子函数。若某段逻辑确实不可拆分(如复杂的状态机、协议解析等),可以例外,但需添加注释说明原因 +- **嵌套层级**:避免超过 3 层嵌套,使用 early return 减少嵌套 + +```go +// ❌ 不推荐:深层嵌套 +func process(data []Item) { + for _, item := range data { + if item.Valid { + if item.Type == "A" { + if item.Status == "active" { + // 业务逻辑... + } + } + } + } +} + +// ✅ 推荐:early return +func process(data []Item) { + for _, item := range data { + if !item.Valid { + continue + } + if item.Type != "A" { + continue + } + if item.Status != "active" { + continue + } + // 业务逻辑... + } +} +``` + +#### 复杂逻辑提取 +将复杂的条件判断或处理逻辑提取为独立函数: + +```go +// ❌ 不推荐:内联复杂逻辑 +if resp.StatusCode == 429 || resp.StatusCode == 503 { + // 80+ 行处理逻辑... +} + +// ✅ 推荐:提取为独立函数 +result := handleRateLimitResponse(resp, params) +switch result.action { +case actionRetry: + continue +case actionBreak: + return result.resp, nil +} +``` + +### 2. 重复代码消除 + +#### 配置获取模式 +将重复的配置获取逻辑提取为方法: + +```go +// ❌ 不推荐:重复代码 +logBody := s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBody +maxBytes := 2048 +if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes > 0 { + maxBytes = s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes +} + +// ✅ 推荐:提取为方法 +func (s *Service) getLogConfig() (logBody bool, maxBytes int) { + maxBytes = 2048 + if s.settingService == nil || s.settingService.cfg == nil { + return false, maxBytes + } + cfg := s.settingService.cfg.Gateway + if cfg.LogUpstreamErrorBodyMaxBytes > 0 { + maxBytes = cfg.LogUpstreamErrorBodyMaxBytes + } + return cfg.LogUpstreamErrorBody, maxBytes +} +``` + +### 3. 常量管理 + +#### 避免魔法数字 +所有硬编码的数值都应定义为常量: + +```go +// ❌ 不推荐 +if retryDelay >= 10*time.Second { + resetAt := time.Now().Add(30 * time.Second) +} + +// ✅ 推荐 +const ( + rateLimitThreshold = 10 * time.Second + defaultRateLimitDuration = 30 * time.Second +) + +if retryDelay >= rateLimitThreshold { + resetAt := time.Now().Add(defaultRateLimitDuration) +} +``` + +#### 注释引用常量名 +在注释中引用常量名而非硬编码值: + +```go +// ❌ 不推荐 +// < 10s: 等待后重试 + +// ✅ 推荐 +// < rateLimitThreshold: 等待后重试 +``` + +### 4. 错误处理 + +#### 使用结构化日志 +优先使用 `slog` 进行结构化日志记录: + +```go +// ❌ 不推荐 +log.Printf("%s status=%d model_rate_limit_failed model=%s error=%v", prefix, statusCode, modelName, err) + +// ✅ 推荐 +slog.Error("failed to set model rate limit", + "prefix", prefix, + "status_code", statusCode, + "model", modelName, + "error", err, +) +``` + +### 5. 测试规范 + +#### Mock 函数签名同步 +修改函数签名时,必须同步更新所有测试中的 mock 函数: + +```go +// 如果修改了 handleError 签名 +handleError func(..., groupID int64, sessionHash string) *Result + +// 必须同步更新测试中的 mock +handleError: func(..., groupID int64, sessionHash string) *Result { + return nil +}, +``` + +#### 测试构建标签 +统一使用测试构建标签: + +```go +//go:build unit + +package service +``` + +### 6. 时间格式解析 + +#### 使用标准库 +优先使用 `time.ParseDuration`,支持所有 Go duration 格式: + +```go +// ❌ 不推荐:手动限制格式 +if !strings.HasSuffix(delay, "s") || strings.Contains(delay, "m") { + continue +} + +// ✅ 推荐:使用标准库 +dur, err := time.ParseDuration(delay) // 支持 "0.5s", "4m50s", "1h30m" 等 +``` + +### 7. 接口设计 + +#### 接口隔离原则 +定义最小化接口,只包含必需的方法: + +```go +// ❌ 不推荐:使用过于宽泛的接口 +type AccountRepository interface { + // 20+ 个方法... +} + +// ✅ 推荐:定义最小化接口 +type ModelRateLimiter interface { + SetModelRateLimit(ctx context.Context, id int64, modelKey string, resetAt time.Time) error +} +``` + +### 8. 并发安全 + +#### 共享数据保护 +访问可能被并发修改的数据时,确保线程安全: + +```go +// 如果 Account.Extra 可能被并发修改 +// 需要使用互斥锁或原子操作保护读取 +func (a *Account) GetRateLimitRemainingTime(model string) time.Duration { + a.mu.RLock() + defer a.mu.RUnlock() + // 读取 Extra 字段... +} +``` + +### 9. 命名规范 + +#### 一致的命名风格 +- 常量使用 camelCase:`rateLimitThreshold` +- 类型使用 PascalCase:`AntigravityQuotaScope` +- 同一概念使用统一命名:`Threshold` 或 `Limit`,不要混用 + +```go +// ❌ 不推荐:命名不一致 +antigravitySmartRetryMinWait // 使用 Min +antigravityRateLimitThreshold // 使用 Threshold + +// ✅ 推荐:统一风格 +antigravityMinRetryWait +antigravityRateLimitThreshold +``` + +### 10. 代码审查清单 + +在提交代码前,检查以下项目: + +- [ ] 函数是否超过 30 行?(不可拆分的逻辑除外,需注释说明) +- [ ] 嵌套是否超过 3 层? +- [ ] 是否有重复代码可以提取? +- [ ] 是否使用了魔法数字? +- [ ] Mock 函数签名是否与实际函数一致? +- [ ] 测试是否覆盖了新增逻辑? +- [ ] 日志是否包含足够的上下文信息? +- [ ] 是否考虑了并发安全? + +--- + +## CI 检查与发布门禁 + +### GitHub Actions 检查项 + +本项目有 4 个 CI 任务,**任何代码推送或发布前都必须全部通过**: + +| Workflow | Job | 说明 | 本地验证命令 | +|----------|-----|------|-------------| +| CI | `test` | 单元测试 + 集成测试 | `cd backend && make test-unit && make test-integration` | +| CI | `golangci-lint` | Go 代码静态检查(golangci-lint v2.7) | `cd backend && golangci-lint run --timeout=5m` | +| Security Scan | `backend-security` | govulncheck + gosec 安全扫描 | `cd backend && govulncheck ./... && gosec -severity high -confidence high ./...` | +| Security Scan | `frontend-security` | pnpm audit 前端依赖安全检查 | `cd frontend && pnpm audit --prod --audit-level=high` | + +### 向上游提交 PR + +PR 目标是上游官方仓库,**只包含通用功能改动**(bug fix、新功能、性能优化等)。 + +**以下文件禁止出现在 PR 中**(属于我们 fork 的定制化内容): +- `CLAUDE.md`、`AGENTS.md` — 我们的开发文档 +- `backend/cmd/server/VERSION` — 我们的版本号文件 +- UI 定制改动(GitHub 链接移除、微信客服按钮、首页定制等) +- 部署配置(`deploy/` 目录下的定制修改) + +**PR 流程**: +1. 从 `develop` 创建功能分支,只包含要提交给上游的改动 +2. 推送分支后,**等待 4 个 CI job 全部通过** +3. 确认通过后再创建 PR +4. 使用 `gh run list --repo touwaeriol/sub2api --branch ` 检查状态 + +### 自有分支推送(develop / main) + +推送到我们自己的 `develop` 或 `main` 分支时,包含所有改动(定制化 + 通用功能)。 + +**推送流程**: +1. 本地运行 `cd backend && make test-unit` 确保单元测试通过 +2. 本地运行 `cd backend && gofmt -l ./...` 确保格式正确 +3. 推送后确认 CI 和 Security Scan 两个 workflow 的 4 个 job 全部绿色 ✅ +4. 任何 job 失败必须立即修复,**禁止在 CI 未通过的状态下继续后续操作** + +### 发布版本 + +1. 确保 `main` 分支最新提交的 4 个 CI job 全部通过 +2. 递增 `backend/cmd/server/VERSION`,提交并推送 +3. 打 tag 推送后,确认 tag 触发的 3 个 workflow(CI、Security Scan、Release)全部通过 +4. **Release workflow 失败时禁止部署** — 必须先修复问题,删除旧 tag,重新打 tag +5. 使用 `gh run list --repo touwaeriol/sub2api --limit 10` 确认状态 + +### 常见 CI 失败原因及修复 +- **gofmt**:struct 字段对齐不一致 → 运行 `gofmt -w ` 修复 +- **golangci-lint**:未使用的变量/导入 → 删除或使用 `_` 忽略 +- **test 失败**:mock 函数签名不一致 → 同步更新 mock +- **gosec**:安全漏洞 → 根据提示修复或添加例外 diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 00000000..a7a3e34a --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,723 @@ +# Sub2API 开发说明 + +## 版本管理策略 + +### 版本号规则 + +我们在官方版本号后面添加自己的小版本号: + +- 官方版本:`v0.1.68` +- 我们的版本:`v0.1.68.1`、`v0.1.68.2`(递增) + +### 分支策略 + +| 分支 | 说明 | +|------|------| +| `main` | 我们的主分支,包含所有定制功能 | +| `release/custom-X.Y.Z` | 基于官方 `vX.Y.Z` 的发布分支 | +| `upstream/main` | 上游官方仓库 | + +--- + +## 发布流程(基于新官方版本) + +当官方发布新版本(如 `v0.1.69`)时: + +### 1. 同步上游并创建发布分支 + +```bash +# 获取上游最新代码 +git fetch upstream --tags + +# 基于官方标签创建新的发布分支 +git checkout v0.1.69 -b release/custom-0.1.69 + +# 合并我们的 main 分支(包含所有定制功能) +git merge main --no-edit + +# 解决可能的冲突后继续 +``` + +### 2. 更新版本号并打标签 + +```bash +# 更新版本号文件 +echo "0.1.69.1" > backend/cmd/server/VERSION +git add backend/cmd/server/VERSION +git commit -m "chore: bump version to 0.1.69.1" + +# 打上我们自己的标签 +git tag v0.1.69.1 + +# 推送分支和标签 +git push origin release/custom-0.1.69 +git push origin v0.1.69.1 +``` + +### 3. 更新 main 分支 + +```bash +# 将发布分支合并回 main,保持 main 包含最新定制功能 +git checkout main +git merge release/custom-0.1.69 +git push origin main +``` + +--- + +## 热修复发布(在现有版本上修复) + +当需要在当前版本上发布修复时: + +```bash +# 在当前发布分支上修复 +git checkout release/custom-0.1.68 +# ... 进行修复 ... +git commit -m "fix: 修复描述" + +# 递增小版本号 +echo "0.1.68.2" > backend/cmd/server/VERSION +git add backend/cmd/server/VERSION +git commit -m "chore: bump version to 0.1.68.2" + +# 打标签并推送 +git tag v0.1.68.2 +git push origin release/custom-0.1.68 +git push origin v0.1.68.2 + +# 同步修复到 main +git checkout main +git cherry-pick +git push origin main +``` + +--- + +## 服务器部署流程 + +### 前置条件 + +- 本地已配置 SSH 别名 `clicodeplus` 连接到服务器 +- 服务器部署目录:`/root/sub2api`(正式)、`/root/sub2api-beta`(测试) +- 服务器使用 Docker Compose 部署 + +### 部署环境说明 + +| 环境 | 目录 | 端口 | 数据库 | 容器名 | +|------|------|------|--------|--------| +| 正式 | `/root/sub2api` | 8080 | `sub2api` | `sub2api` | +| Beta | `/root/sub2api-beta` | 8084 | `beta` | `sub2api-beta` | + +### 外部数据库 + +正式和 Beta 环境**共用外部 PostgreSQL 数据库**(非容器内数据库),配置在 `.env` 文件中: +- `DATABASE_HOST`:外部数据库地址 +- `DATABASE_SSLMODE`:SSL 模式(通常为 `require`) +- `POSTGRES_USER` / `POSTGRES_DB`:用户名和数据库名 + +#### 数据库操作命令 + +通过 SSH 在服务器上执行数据库操作: + +```bash +# 正式环境 - 查询迁移记录 +ssh clicodeplus "source /root/sub2api/deploy/.env && PGPASSWORD=\"\$POSTGRES_PASSWORD\" psql -h \$DATABASE_HOST -U \$POSTGRES_USER -d \$POSTGRES_DB -c 'SELECT * FROM schema_migrations ORDER BY applied_at DESC LIMIT 5;'" + +# Beta 环境 - 查询迁移记录 +ssh clicodeplus "source /root/sub2api-beta/deploy/.env && PGPASSWORD=\"\$POSTGRES_PASSWORD\" psql -h \$DATABASE_HOST -U \$POSTGRES_USER -d \$POSTGRES_DB -c 'SELECT * FROM schema_migrations ORDER BY applied_at DESC LIMIT 5;'" + +# Beta 环境 - 清除指定迁移记录(重新执行迁移) +ssh clicodeplus "source /root/sub2api-beta/deploy/.env && PGPASSWORD=\"\$POSTGRES_PASSWORD\" psql -h \$DATABASE_HOST -U \$POSTGRES_USER -d \$POSTGRES_DB -c \"DELETE FROM schema_migrations WHERE filename LIKE '%049%';\"" + +# Beta 环境 - 更新账号数据 +ssh clicodeplus "source /root/sub2api-beta/deploy/.env && PGPASSWORD=\"\$POSTGRES_PASSWORD\" psql -h \$DATABASE_HOST -U \$POSTGRES_USER -d \$POSTGRES_DB -c \"UPDATE accounts SET credentials = credentials - 'model_mapping' WHERE platform = 'antigravity';\"" +``` + +> **注意**:使用 `source .env` 加载环境变量,避免在命令行中暴露密码。 + +### 部署步骤 + +**重要:每次部署都必须递增版本号!** + +#### 0. 递增版本号(本地操作) + +每次部署前,先在本地递增小版本号: + +```bash +# 查看当前版本号 +cat backend/cmd/server/VERSION +# 假设当前是 0.1.69.1 + +# 递增版本号 +echo "0.1.69.2" > backend/cmd/server/VERSION +git add backend/cmd/server/VERSION +git commit -m "chore: bump version to 0.1.69.2" +git push origin release/custom-0.1.69 +``` + +#### 1. 服务器拉取代码 + +```bash +ssh clicodeplus "cd /root/sub2api && git fetch fork && git checkout -B release/custom-0.1.69 fork/release/custom-0.1.69" +``` + +#### 2. 服务器构建镜像 + +```bash +ssh clicodeplus "cd /root/sub2api && docker build --no-cache -t sub2api:latest -f Dockerfile ." +``` + +#### 3. 更新镜像标签并重启服务 + +```bash +ssh clicodeplus "docker tag sub2api:latest weishaw/sub2api:latest" +ssh clicodeplus "cd /root/sub2api/deploy && docker compose up -d --force-recreate sub2api" +``` + +#### 4. 验证部署 + +```bash +# 查看启动日志 +ssh clicodeplus "docker logs sub2api --tail 20" + +# 确认版本号(必须与步骤 0 中设置的版本号一致) +ssh clicodeplus "cat /root/sub2api/backend/cmd/server/VERSION" + +# 检查容器状态 +ssh clicodeplus "docker ps | grep sub2api" +``` + +--- + +## Beta 并行部署(不影响现网) + +目标:在同一台服务器上并行启动一个 beta 实例(例如端口 `8084`),**严禁改动/重启**现网实例(默认目录 `/root/sub2api`)。 + +### 设计原则 + +- **新目录**:beta 使用独立目录,例如 `/root/sub2api-beta`。 +- **敏感信息只放 `.env`**:beta 的数据库密码、JWT_SECRET 等只写入 `/root/sub2api-beta/deploy/.env`,不要提交到 git。 +- **独立 Compose Project**:通过 `docker compose -p sub2api-beta ...` 启动,确保 network/volume 隔离。 +- **独立端口**:通过 `.env` 的 `SERVER_PORT` 映射宿主机端口(例如 `8084:8080`)。 + +### 前置检查 + +```bash +# 1) 确保 8084 未被占用 +ssh clicodeplus "ss -ltnp | grep :8084 || echo '8084 is free'" + +# 2) 确认现网容器还在(只读检查) +ssh clicodeplus "docker ps --format 'table {{.Names}}\t{{.Image}}\t{{.Ports}}' | sed -n '1,200p'" +``` + +### 首次部署步骤 + +```bash +# 0) 进入服务器 +ssh clicodeplus + +# 1) 克隆代码到新目录(示例使用你的 fork) +cd /root +git clone https://github.com/touwaeriol/sub2api.git sub2api-beta +cd /root/sub2api-beta +git checkout release/custom-0.1.71 + +# 2) 准备 beta 的 .env(敏感信息只写这里) +cd /root/sub2api-beta/deploy + +# 推荐:从现网 .env 复制,保证除 DB 名/用户/端口外完全一致 +cp -f /root/sub2api/deploy/.env ./.env + +# 仅修改以下三项(其他保持不变) +perl -pi -e 's/^SERVER_PORT=.*/SERVER_PORT=8084/' ./.env +perl -pi -e 's/^POSTGRES_USER=.*/POSTGRES_USER=beta/' ./.env +perl -pi -e 's/^POSTGRES_DB=.*/POSTGRES_DB=beta/' ./.env + +# 3) 写 compose override(避免与现网容器名冲突,镜像使用本地构建的 sub2api:beta) +cat > docker-compose.override.yml <<'YAML' +services: + sub2api: + image: sub2api:beta + container_name: sub2api-beta + redis: + container_name: sub2api-beta-redis +YAML + +# 4) 构建 beta 镜像(基于当前代码) +cd /root/sub2api-beta +docker build -t sub2api:beta -f Dockerfile . + +# 5) 启动 beta(独立 project,确保不影响现网) +cd /root/sub2api-beta/deploy +docker compose -p sub2api-beta --env-file .env -f docker-compose.yml -f docker-compose.override.yml up -d + +# 6) 验证 beta +curl -fsS http://127.0.0.1:8084/health +docker logs sub2api-beta --tail 50 +``` + +### 数据库配置约定(beta) + +- 数据库地址/SSL/密码:与现网一致(从现网 `.env` 复制即可)。 +- 仅修改: + - `POSTGRES_USER=beta` + - `POSTGRES_DB=beta` + +注意:需要数据库侧已存在 `beta` 用户与 `beta` 数据库,并授予权限;否则容器会启动失败并不断重启。 + +### 更新 beta(拉代码 + 仅重建 beta 容器) + +```bash +ssh clicodeplus "set -e; cd /root/sub2api-beta && git fetch --all --tags && git checkout -f release/custom-0.1.71 && git reset --hard origin/release/custom-0.1.71" +ssh clicodeplus "cd /root/sub2api-beta && docker build -t sub2api:beta -f Dockerfile ." +ssh clicodeplus "cd /root/sub2api-beta/deploy && docker compose -p sub2api-beta --env-file .env -f docker-compose.yml -f docker-compose.override.yml up -d --no-deps --force-recreate sub2api" +ssh clicodeplus "curl -fsS http://127.0.0.1:8084/health" +``` + +### 停止/回滚 beta(只影响 beta) + +```bash +ssh clicodeplus "cd /root/sub2api-beta/deploy && docker compose -p sub2api-beta -f docker-compose.yml -f docker-compose.override.yml down" +``` + +--- + +## 服务器首次部署 + +### 1. 克隆代码并配置远程仓库 + +```bash +ssh clicodeplus +cd /root +git clone https://github.com/Wei-Shaw/sub2api.git +cd sub2api + +# 添加 fork 仓库 +git remote add fork https://github.com/touwaeriol/sub2api.git +``` + +### 2. 切换到定制分支并配置环境 + +```bash +git fetch fork +git checkout -B release/custom-0.1.69 fork/release/custom-0.1.69 + +cd deploy +cp .env.example .env +vim .env # 配置 DATABASE_URL, REDIS_URL, JWT_SECRET 等 +``` + +### 3. 构建并启动 + +```bash +cd /root/sub2api +docker build -t sub2api:latest -f Dockerfile . +docker tag sub2api:latest weishaw/sub2api:latest +cd deploy && docker compose up -d +``` + +### 6. 启动服务 + +```bash +# 进入 deploy 目录 +cd deploy + +# 启动所有服务(PostgreSQL、Redis、sub2api) +docker compose up -d + +# 查看服务状态 +docker compose ps +``` + +### 7. 验证部署 + +```bash +# 查看应用日志 +docker logs sub2api --tail 50 + +# 检查健康状态 +curl http://localhost:8080/health + +# 确认版本号 +cat /root/sub2api/backend/cmd/server/VERSION +``` + +### 8. 常用运维命令 + +```bash +# 查看实时日志 +docker logs -f sub2api + +# 重启服务 +docker compose restart sub2api + +# 停止所有服务 +docker compose down + +# 停止并删除数据卷(慎用!会删除数据库数据) +docker compose down -v + +# 查看资源使用情况 +docker stats sub2api +``` + +--- + +## 定制功能说明 + +当前定制分支包含以下功能(相对于官方版本): + +### UI/UX 定制 + +| 功能 | 说明 | +|------|------| +| 首页优化 | 面向用户的价值主张设计 | +| 移除 GitHub 链接 | 用户菜单中不显示 GitHub 导航 | +| 微信客服按钮 | 首页悬浮微信客服入口 | +| 限流时间精确显示 | 账号限流时间显示精确到秒 | + +### Antigravity 平台增强 + +| 功能 | 说明 | +|------|------| +| Scope 级别限流 | 按配额域(claude/gemini_text/gemini_image)独立限流,避免整个账号被锁定 | +| 模型级别限流 | 按具体模型(如 claude-opus-4-5)独立限流,更精细的限流控制 | +| 限流预检查 | 调度时预检查账号/模型限流状态,避免选中已限流账号 | +| 秒级冷却时间 | 支持 429 响应的秒级精确冷却时间 | +| 身份注入优化 | 模型身份信息注入 + 静默边界防止身份泄露 | +| thoughtSignature 修复 | Gemini 3 函数调用 400 错误修复 | +| max_tokens 自动修正 | 自动修正 max_tokens <= budget_tokens 导致的 400 错误 | + +### 调度算法优化 + +| 功能 | 说明 | +|------|------| +| 分层过滤选择 | 调度算法从全排序改为分层过滤,提升性能 | +| LRU 随机选择 | 相同 LRU 时间时随机选择,避免账号集中 | +| 限流等待阈值配置化 | 可配置的限流等待阈值 | + +### 运维增强 + +| 功能 | 说明 | +|------|------| +| Scope 限流统计 | 运维界面展示 Antigravity 账号 scope 级别限流统计 | +| 账号限流状态显示 | 账号列表显示 scope 和模型级别限流状态 | +| 清除限流按钮增强 | 有 scope/模型限流时也显示清除限流按钮 | + +### 其他修复 + +| 功能 | 说明 | +|------|------| +| .gitattributes | 确保迁移文件使用 LF 换行符(解决 Windows 下 SQL 摘要不一致) | +| 部署配置优化 | DATABASE_HOST 和 DATABASE_SSLMODE 可通过 .env 配置 | + +--- + +## 注意事项 + +1. **前端必须打包进镜像**:使用 `docker build` 在服务器上构建,Dockerfile 会自动编译前端并 embed 到后端二进制中 + +2. **镜像标签**:docker-compose.yml 使用 `weishaw/sub2api:latest`,本地构建后需要 `docker tag` 覆盖 + +3. **Windows 换行符问题**:已通过 `.gitattributes` 解决,确保 `*.sql` 文件始终使用 LF + +4. **版本号管理**:每次发布必须更新 `backend/cmd/server/VERSION` 并打标签 + +5. **合并冲突**:合并上游新版本时,重点关注以下文件可能的冲突: + - `backend/internal/service/antigravity_gateway_service.go` + - `backend/internal/service/gateway_service.go` + - `backend/internal/pkg/antigravity/request_transformer.go` + +--- + +## Go 代码规范 + +### 1. 函数设计 + +#### 单一职责原则 +- **函数行数**:单个函数常规不应超过 **30 行**,超过时应拆分为子函数。若某段逻辑确实不可拆分(如复杂的状态机、协议解析等),可以例外,但需添加注释说明原因 +- **嵌套层级**:避免超过 3 层嵌套,使用 early return 减少嵌套 + +```go +// ❌ 不推荐:深层嵌套 +func process(data []Item) { + for _, item := range data { + if item.Valid { + if item.Type == "A" { + if item.Status == "active" { + // 业务逻辑... + } + } + } + } +} + +// ✅ 推荐:early return +func process(data []Item) { + for _, item := range data { + if !item.Valid { + continue + } + if item.Type != "A" { + continue + } + if item.Status != "active" { + continue + } + // 业务逻辑... + } +} +``` + +#### 复杂逻辑提取 +将复杂的条件判断或处理逻辑提取为独立函数: + +```go +// ❌ 不推荐:内联复杂逻辑 +if resp.StatusCode == 429 || resp.StatusCode == 503 { + // 80+ 行处理逻辑... +} + +// ✅ 推荐:提取为独立函数 +result := handleRateLimitResponse(resp, params) +switch result.action { +case actionRetry: + continue +case actionBreak: + return result.resp, nil +} +``` + +### 2. 重复代码消除 + +#### 配置获取模式 +将重复的配置获取逻辑提取为方法: + +```go +// ❌ 不推荐:重复代码 +logBody := s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBody +maxBytes := 2048 +if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes > 0 { + maxBytes = s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes +} + +// ✅ 推荐:提取为方法 +func (s *Service) getLogConfig() (logBody bool, maxBytes int) { + maxBytes = 2048 + if s.settingService == nil || s.settingService.cfg == nil { + return false, maxBytes + } + cfg := s.settingService.cfg.Gateway + if cfg.LogUpstreamErrorBodyMaxBytes > 0 { + maxBytes = cfg.LogUpstreamErrorBodyMaxBytes + } + return cfg.LogUpstreamErrorBody, maxBytes +} +``` + +### 3. 常量管理 + +#### 避免魔法数字 +所有硬编码的数值都应定义为常量: + +```go +// ❌ 不推荐 +if retryDelay >= 10*time.Second { + resetAt := time.Now().Add(30 * time.Second) +} + +// ✅ 推荐 +const ( + rateLimitThreshold = 10 * time.Second + defaultRateLimitDuration = 30 * time.Second +) + +if retryDelay >= rateLimitThreshold { + resetAt := time.Now().Add(defaultRateLimitDuration) +} +``` + +#### 注释引用常量名 +在注释中引用常量名而非硬编码值: + +```go +// ❌ 不推荐 +// < 10s: 等待后重试 + +// ✅ 推荐 +// < rateLimitThreshold: 等待后重试 +``` + +### 4. 错误处理 + +#### 使用结构化日志 +优先使用 `slog` 进行结构化日志记录: + +```go +// ❌ 不推荐 +log.Printf("%s status=%d model_rate_limit_failed model=%s error=%v", prefix, statusCode, modelName, err) + +// ✅ 推荐 +slog.Error("failed to set model rate limit", + "prefix", prefix, + "status_code", statusCode, + "model", modelName, + "error", err, +) +``` + +### 5. 测试规范 + +#### Mock 函数签名同步 +修改函数签名时,必须同步更新所有测试中的 mock 函数: + +```go +// 如果修改了 handleError 签名 +handleError func(..., groupID int64, sessionHash string) *Result + +// 必须同步更新测试中的 mock +handleError: func(..., groupID int64, sessionHash string) *Result { + return nil +}, +``` + +#### 测试构建标签 +统一使用测试构建标签: + +```go +//go:build unit + +package service +``` + +### 6. 时间格式解析 + +#### 使用标准库 +优先使用 `time.ParseDuration`,支持所有 Go duration 格式: + +```go +// ❌ 不推荐:手动限制格式 +if !strings.HasSuffix(delay, "s") || strings.Contains(delay, "m") { + continue +} + +// ✅ 推荐:使用标准库 +dur, err := time.ParseDuration(delay) // 支持 "0.5s", "4m50s", "1h30m" 等 +``` + +### 7. 接口设计 + +#### 接口隔离原则 +定义最小化接口,只包含必需的方法: + +```go +// ❌ 不推荐:使用过于宽泛的接口 +type AccountRepository interface { + // 20+ 个方法... +} + +// ✅ 推荐:定义最小化接口 +type ModelRateLimiter interface { + SetModelRateLimit(ctx context.Context, id int64, modelKey string, resetAt time.Time) error +} +``` + +### 8. 并发安全 + +#### 共享数据保护 +访问可能被并发修改的数据时,确保线程安全: + +```go +// 如果 Account.Extra 可能被并发修改 +// 需要使用互斥锁或原子操作保护读取 +func (a *Account) GetRateLimitRemainingTime(model string) time.Duration { + a.mu.RLock() + defer a.mu.RUnlock() + // 读取 Extra 字段... +} +``` + +### 9. 命名规范 + +#### 一致的命名风格 +- 常量使用 camelCase:`rateLimitThreshold` +- 类型使用 PascalCase:`AntigravityQuotaScope` +- 同一概念使用统一命名:`Threshold` 或 `Limit`,不要混用 + +```go +// ❌ 不推荐:命名不一致 +antigravitySmartRetryMinWait // 使用 Min +antigravityRateLimitThreshold // 使用 Threshold + +// ✅ 推荐:统一风格 +antigravityMinRetryWait +antigravityRateLimitThreshold +``` + +### 10. 代码审查清单 + +在提交代码前,检查以下项目: + +- [ ] 函数是否超过 30 行?(不可拆分的逻辑除外,需注释说明) +- [ ] 嵌套是否超过 3 层? +- [ ] 是否有重复代码可以提取? +- [ ] 是否使用了魔法数字? +- [ ] Mock 函数签名是否与实际函数一致? +- [ ] 测试是否覆盖了新增逻辑? +- [ ] 日志是否包含足够的上下文信息? +- [ ] 是否考虑了并发安全? + +--- + +## CI 检查与发布门禁 + +### GitHub Actions 检查项 + +本项目有 4 个 CI 任务,**任何代码推送或发布前都必须全部通过**: + +| Workflow | Job | 说明 | 本地验证命令 | +|----------|-----|------|-------------| +| CI | `test` | 单元测试 + 集成测试 | `cd backend && make test-unit && make test-integration` | +| CI | `golangci-lint` | Go 代码静态检查(golangci-lint v2.7) | `cd backend && golangci-lint run --timeout=5m` | +| Security Scan | `backend-security` | govulncheck + gosec 安全扫描 | `cd backend && govulncheck ./... && gosec -severity high -confidence high ./...` | +| Security Scan | `frontend-security` | pnpm audit 前端依赖安全检查 | `cd frontend && pnpm audit --prod --audit-level=high` | + +### 向上游提交 PR + +PR 目标是上游官方仓库,**只包含通用功能改动**(bug fix、新功能、性能优化等)。 + +**以下文件禁止出现在 PR 中**(属于我们 fork 的定制化内容): +- `CLAUDE.md`、`AGENTS.md` — 我们的开发文档 +- `backend/cmd/server/VERSION` — 我们的版本号文件 +- UI 定制改动(GitHub 链接移除、微信客服按钮、首页定制等) +- 部署配置(`deploy/` 目录下的定制修改) + +**PR 流程**: +1. 从 `develop` 创建功能分支,只包含要提交给上游的改动 +2. 推送分支后,**等待 4 个 CI job 全部通过** +3. 确认通过后再创建 PR +4. 使用 `gh run list --repo touwaeriol/sub2api --branch ` 检查状态 + +### 自有分支推送(develop / main) + +推送到我们自己的 `develop` 或 `main` 分支时,包含所有改动(定制化 + 通用功能)。 + +**推送流程**: +1. 本地运行 `cd backend && make test-unit` 确保单元测试通过 +2. 本地运行 `cd backend && gofmt -l ./...` 确保格式正确 +3. 推送后确认 CI 和 Security Scan 两个 workflow 的 4 个 job 全部绿色 ✅ +4. 任何 job 失败必须立即修复,**禁止在 CI 未通过的状态下继续后续操作** + +### 发布版本 + +1. 确保 `main` 分支最新提交的 4 个 CI job 全部通过 +2. 递增 `backend/cmd/server/VERSION`,提交并推送 +3. 打 tag 推送后,确认 tag 触发的 3 个 workflow(CI、Security Scan、Release)全部通过 +4. **Release workflow 失败时禁止部署** — 必须先修复问题,删除旧 tag,重新打 tag +5. 使用 `gh run list --repo touwaeriol/sub2api --limit 10` 确认状态 + +### 常见 CI 失败原因及修复 +- **gofmt**:struct 字段对齐不一致 → 运行 `gofmt -w ` 修复 +- **golangci-lint**:未使用的变量/导入 → 删除或使用 `_` 忽略 +- **test 失败**:mock 函数签名不一致 → 同步更新 mock +- **gosec**:安全漏洞 → 根据提示修复或添加例外 diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION index bc88be6e..b4fec824 100644 --- a/backend/cmd/server/VERSION +++ b/backend/cmd/server/VERSION @@ -1 +1 @@ -0.1.74.7 +0.1.75.7 \ No newline at end of file diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index ef205dc8..5ccd797e 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -154,7 +154,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { identityService := service.NewIdentityService(identityCache) deferredService := service.ProvideDeferredService(accountRepository, timingWheelService) claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService) - gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache) + digestSessionStore := service.NewDigestSessionStore() + gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, digestSessionStore) openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService) openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider) geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig) diff --git a/backend/go.mod b/backend/go.mod index 6916057f..08d54b91 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -103,6 +103,7 @@ require ( github.com/ncruces/go-strftime v1.0.0 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/image-spec v1.1.1 // indirect + github.com/patrickmn/go-cache v2.1.0+incompatible // indirect github.com/pelletier/go-toml/v2 v2.2.2 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect diff --git a/backend/go.sum b/backend/go.sum index 171995c7..e9525a10 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -207,6 +207,8 @@ github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8 github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040= github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M= +github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc= +github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go index 71bb1ed4..c1b70878 100644 --- a/backend/internal/handler/dto/types.go +++ b/backend/internal/handler/dto/types.go @@ -2,11 +2,6 @@ package dto import "time" -type ScopeRateLimitInfo struct { - ResetAt time.Time `json:"reset_at"` - RemainingSec int64 `json:"remaining_sec"` -} - type User struct { ID int64 `json:"id"` Email string `json:"email"` @@ -126,9 +121,6 @@ type Account struct { RateLimitResetAt *time.Time `json:"rate_limit_reset_at"` OverloadUntil *time.Time `json:"overload_until"` - // Antigravity scope 级限流状态(从 extra 提取) - ScopeRateLimits map[string]ScopeRateLimitInfo `json:"scope_rate_limits,omitempty"` - TempUnschedulableUntil *time.Time `json:"temp_unschedulable_until"` TempUnschedulableReason string `json:"temp_unschedulable_reason"` diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 255d3fab..6900fa55 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -13,6 +13,7 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/domain" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" "github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" @@ -114,7 +115,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { setOpsRequestContext(c, "", false, body) - parsedReq, err := service.ParseGatewayRequest(body) + parsedReq, err := service.ParseGatewayRequest(body, domain.PlatformAnthropic) if err != nil { h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") return @@ -203,6 +204,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } // 计算粘性会话hash + parsedReq.SessionContext = &service.SessionContext{ + ClientIP: ip.GetClientIP(c), + UserAgent: c.GetHeader("User-Agent"), + APIKeyID: apiKey.ID, + } sessionHash := h.gatewayService.GenerateSessionHash(parsedReq) // 获取平台:优先使用强制平台(/antigravity 路由,中间件已设置 request.Context),否则使用分组平台 @@ -335,7 +341,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { if errors.As(err, &failoverErr) { failedAccountIDs[account.ID] = struct{}{} lastFailoverErr = failoverErr - if failoverErr.ForceCacheBilling { + if needForceCacheBilling(hasBoundSession, failoverErr) { forceCacheBilling = true } if switchCount >= maxAccountSwitches { @@ -344,6 +350,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } switchCount++ log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches) + if account.Platform == service.PlatformAntigravity { + if !sleepFailoverDelay(c.Request.Context(), switchCount) { + return + } + } continue } // 错误响应已在Forward中处理,这里只记录日志 @@ -530,7 +541,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { if errors.As(err, &failoverErr) { failedAccountIDs[account.ID] = struct{}{} lastFailoverErr = failoverErr - if failoverErr.ForceCacheBilling { + if needForceCacheBilling(hasBoundSession, failoverErr) { forceCacheBilling = true } if switchCount >= maxAccountSwitches { @@ -539,6 +550,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } switchCount++ log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches) + if account.Platform == service.PlatformAntigravity { + if !sleepFailoverDelay(c.Request.Context(), switchCount) { + return + } + } continue } // 错误响应已在Forward中处理,这里只记录日志 @@ -801,6 +817,27 @@ func (h *GatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotT fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted) } +// needForceCacheBilling 判断 failover 时是否需要强制缓存计费 +// 粘性会话切换账号、或上游明确标记时,将 input_tokens 转为 cache_read 计费 +func needForceCacheBilling(hasBoundSession bool, failoverErr *service.UpstreamFailoverError) bool { + return hasBoundSession || (failoverErr != nil && failoverErr.ForceCacheBilling) +} + +// sleepFailoverDelay 账号切换线性递增延时:第1次0s、第2次1s、第3次2s… +// 返回 false 表示 context 已取消。 +func sleepFailoverDelay(ctx context.Context, switchCount int) bool { + delay := time.Duration(switchCount-1) * time.Second + if delay <= 0 { + return true + } + select { + case <-ctx.Done(): + return false + case <-time.After(delay): + return true + } +} + func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, platform string, streamStarted bool) { statusCode := failoverErr.StatusCode responseBody := failoverErr.ResponseBody @@ -934,7 +971,7 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { setOpsRequestContext(c, "", false, body) - parsedReq, err := service.ParseGatewayRequest(body) + parsedReq, err := service.ParseGatewayRequest(body, domain.PlatformAnthropic) if err != nil { h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") return @@ -962,6 +999,11 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { } // 计算粘性会话 hash + parsedReq.SessionContext = &service.SessionContext{ + ClientIP: ip.GetClientIP(c), + UserAgent: c.GetHeader("User-Agent"), + APIKeyID: apiKey.ID, + } sessionHash := h.gatewayService.GenerateSessionHash(parsedReq) // 选择支持该模型的账号 diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index 2b69be2e..d5149f22 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -14,6 +14,7 @@ import ( "strings" "time" + "github.com/Wei-Shaw/sub2api/internal/domain" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" "github.com/Wei-Shaw/sub2api/internal/pkg/gemini" @@ -30,13 +31,6 @@ import ( // 匹配格式: /Users/xxx/.gemini/tmp/[64位十六进制哈希] var geminiCLITmpDirRegex = regexp.MustCompile(`/\.gemini/tmp/([A-Fa-f0-9]{64})`) -func isGeminiCLIRequest(c *gin.Context, body []byte) bool { - if strings.TrimSpace(c.GetHeader("x-gemini-api-privileged-user-id")) != "" { - return true - } - return geminiCLITmpDirRegex.Match(body) -} - // GeminiV1BetaListModels proxies: // GET /v1beta/models func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) { @@ -239,7 +233,14 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { sessionHash := extractGeminiCLISessionHash(c, body) if sessionHash == "" { // Fallback: 使用通用的会话哈希生成逻辑(适用于其他客户端) - parsedReq, _ := service.ParseGatewayRequest(body) + parsedReq, _ := service.ParseGatewayRequest(body, domain.PlatformGemini) + if parsedReq != nil { + parsedReq.SessionContext = &service.SessionContext{ + ClientIP: ip.GetClientIP(c), + UserAgent: c.GetHeader("User-Agent"), + APIKeyID: apiKey.ID, + } + } sessionHash = h.gatewayService.GenerateSessionHash(parsedReq) } sessionKey := sessionHash @@ -258,6 +259,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { var geminiDigestChain string var geminiPrefixHash string var geminiSessionUUID string + var matchedDigestChain string useDigestFallback := sessionBoundAccountID == 0 if useDigestFallback { @@ -284,13 +286,14 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { ) // 查找会话 - foundUUID, foundAccountID, found := h.gatewayService.FindGeminiSession( + foundUUID, foundAccountID, foundMatchedChain, found := h.gatewayService.FindGeminiSession( c.Request.Context(), derefGroupID(apiKey.GroupID), geminiPrefixHash, geminiDigestChain, ) if found { + matchedDigestChain = foundMatchedChain sessionBoundAccountID = foundAccountID geminiSessionUUID = foundUUID log.Printf("[Gemini] Digest fallback matched: uuid=%s, accountID=%d, chain=%s", @@ -316,7 +319,6 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { // 判断是否真的绑定了粘性会话:有 sessionKey 且已经绑定到某个账号 hasBoundSession := sessionKey != "" && sessionBoundAccountID > 0 - isCLI := isGeminiCLIRequest(c, body) cleanedForUnknownBinding := false maxAccountSwitches := h.maxAccountSwitchesGemini @@ -344,10 +346,10 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { log.Printf("[Gemini] Sticky session account switched: %d -> %d, cleaning thoughtSignature", sessionBoundAccountID, account.ID) body = service.CleanGeminiNativeThoughtSignatures(body) sessionBoundAccountID = account.ID - } else if sessionKey != "" && sessionBoundAccountID == 0 && isCLI && !cleanedForUnknownBinding && bytes.Contains(body, []byte(`"thoughtSignature"`)) { - // 无缓存绑定但请求里已有 thoughtSignature:常见于缓存丢失/TTL 过期后,CLI 继续携带旧签名。 + } else if sessionKey != "" && sessionBoundAccountID == 0 && !cleanedForUnknownBinding && bytes.Contains(body, []byte(`"thoughtSignature"`)) { + // 无缓存绑定但请求里已有 thoughtSignature:常见于缓存丢失/TTL 过期后,客户端继续携带旧签名。 // 为避免第一次转发就 400,这里做一次确定性清理,让新账号重新生成签名链路。 - log.Printf("[Gemini] Sticky session binding missing for CLI request, cleaning thoughtSignature proactively") + log.Printf("[Gemini] Sticky session binding missing, cleaning thoughtSignature proactively") body = service.CleanGeminiNativeThoughtSignatures(body) cleanedForUnknownBinding = true sessionBoundAccountID = account.ID @@ -422,7 +424,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { var failoverErr *service.UpstreamFailoverError if errors.As(err, &failoverErr) { failedAccountIDs[account.ID] = struct{}{} - if failoverErr.ForceCacheBilling { + if needForceCacheBilling(hasBoundSession, failoverErr) { forceCacheBilling = true } if switchCount >= maxAccountSwitches { @@ -433,6 +435,11 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { lastFailoverErr = failoverErr switchCount++ log.Printf("Gemini account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches) + if account.Platform == service.PlatformAntigravity { + if !sleepFailoverDelay(c.Request.Context(), switchCount) { + return + } + } continue } // ForwardNative already wrote the response @@ -453,6 +460,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { geminiDigestChain, geminiSessionUUID, account.ID, + matchedDigestChain, ); err != nil { log.Printf("[Gemini] Failed to save digest session: %v", err) } diff --git a/backend/internal/repository/account_repo.go b/backend/internal/repository/account_repo.go index 11c206d8..7fb7d4ed 100644 --- a/backend/internal/repository/account_repo.go +++ b/backend/internal/repository/account_repo.go @@ -798,53 +798,6 @@ func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetA return nil } -func (r *accountRepository) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope service.AntigravityQuotaScope, resetAt time.Time) error { - now := time.Now().UTC() - payload := map[string]string{ - "rate_limited_at": now.Format(time.RFC3339), - "rate_limit_reset_at": resetAt.UTC().Format(time.RFC3339), - } - raw, err := json.Marshal(payload) - if err != nil { - return err - } - - scopeKey := string(scope) - client := clientFromContext(ctx, r.client) - result, err := client.ExecContext( - ctx, - `UPDATE accounts SET - extra = jsonb_set( - jsonb_set(COALESCE(extra, '{}'::jsonb), '{antigravity_quota_scopes}'::text[], COALESCE(extra->'antigravity_quota_scopes', '{}'::jsonb), true), - ARRAY['antigravity_quota_scopes', $1]::text[], - $2::jsonb, - true - ), - updated_at = NOW(), - last_used_at = NOW() - WHERE id = $3 AND deleted_at IS NULL`, - scopeKey, - raw, - id, - ) - if err != nil { - return err - } - - affected, err := result.RowsAffected() - if err != nil { - return err - } - if affected == 0 { - return service.ErrAccountNotFound - } - - if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { - log.Printf("[SchedulerOutbox] enqueue quota scope failed: account=%d err=%v", id, err) - } - return nil -} - func (r *accountRepository) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error { if scope == "" { return nil diff --git a/backend/internal/repository/gateway_cache.go b/backend/internal/repository/gateway_cache.go index 46ae0c16..58291b66 100644 --- a/backend/internal/repository/gateway_cache.go +++ b/backend/internal/repository/gateway_cache.go @@ -11,63 +11,6 @@ import ( const stickySessionPrefix = "sticky_session:" -// Gemini Trie Lua 脚本 -const ( - // geminiTrieFindScript 查找最长前缀匹配的 Lua 脚本 - // KEYS[1] = trie key - // ARGV[1] = digestChain (如 "u:a-m:b-u:c-m:d") - // ARGV[2] = TTL seconds (用于刷新) - // 返回: 最长匹配的 value (uuid:accountID) 或 nil - // 查找成功时自动刷新 TTL,防止活跃会话意外过期 - geminiTrieFindScript = ` -local chain = ARGV[1] -local ttl = tonumber(ARGV[2]) -local lastMatch = nil -local path = "" - -for part in string.gmatch(chain, "[^-]+") do - path = path == "" and part or path .. "-" .. part - local val = redis.call('HGET', KEYS[1], path) - if val and val ~= "" then - lastMatch = val - end -end - -if lastMatch then - redis.call('EXPIRE', KEYS[1], ttl) -end - -return lastMatch -` - - // geminiTrieSaveScript 保存会话到 Trie 的 Lua 脚本 - // KEYS[1] = trie key - // ARGV[1] = digestChain - // ARGV[2] = value (uuid:accountID) - // ARGV[3] = TTL seconds - geminiTrieSaveScript = ` -local chain = ARGV[1] -local value = ARGV[2] -local ttl = tonumber(ARGV[3]) -local path = "" - -for part in string.gmatch(chain, "[^-]+") do - path = path == "" and part or path .. "-" .. part -end -redis.call('HSET', KEYS[1], path, value) -redis.call('EXPIRE', KEYS[1], ttl) -return "OK" -` -) - -// 模型负载统计相关常量 -const ( - modelLoadKeyPrefix = "ag:model_load:" // 模型调用次数 key 前缀 - modelLastUsedKeyPrefix = "ag:model_last_used:" // 模型最后调度时间 key 前缀 - modelLoadTTL = 24 * time.Hour // 调用次数 TTL(24 小时无调用后清零) - modelLastUsedTTL = 24 * time.Hour // 最后调度时间 TTL -) - type gatewayCache struct { rdb *redis.Client } @@ -108,171 +51,3 @@ func (c *gatewayCache) DeleteSessionAccountID(ctx context.Context, groupID int64 key := buildSessionKey(groupID, sessionHash) return c.rdb.Del(ctx, key).Err() } - -// ============ Antigravity 模型负载统计方法 ============ - -// modelLoadKey 构建模型调用次数 key -// 格式: ag:model_load:{accountID}:{model} -func modelLoadKey(accountID int64, model string) string { - return fmt.Sprintf("%s%d:%s", modelLoadKeyPrefix, accountID, model) -} - -// modelLastUsedKey 构建模型最后调度时间 key -// 格式: ag:model_last_used:{accountID}:{model} -func modelLastUsedKey(accountID int64, model string) string { - return fmt.Sprintf("%s%d:%s", modelLastUsedKeyPrefix, accountID, model) -} - -// IncrModelCallCount 增加模型调用次数并更新最后调度时间 -// 返回更新后的调用次数 -func (c *gatewayCache) IncrModelCallCount(ctx context.Context, accountID int64, model string) (int64, error) { - loadKey := modelLoadKey(accountID, model) - lastUsedKey := modelLastUsedKey(accountID, model) - - pipe := c.rdb.Pipeline() - incrCmd := pipe.Incr(ctx, loadKey) - pipe.Expire(ctx, loadKey, modelLoadTTL) // 每次调用刷新 TTL - pipe.Set(ctx, lastUsedKey, time.Now().Unix(), modelLastUsedTTL) - if _, err := pipe.Exec(ctx); err != nil { - return 0, err - } - return incrCmd.Val(), nil -} - -// GetModelLoadBatch 批量获取账号的模型负载信息 -func (c *gatewayCache) GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*service.ModelLoadInfo, error) { - if len(accountIDs) == 0 { - return make(map[int64]*service.ModelLoadInfo), nil - } - - loadCmds, lastUsedCmds := c.pipelineModelLoadGet(ctx, accountIDs, model) - return c.parseModelLoadResults(accountIDs, loadCmds, lastUsedCmds), nil -} - -// pipelineModelLoadGet 批量获取模型负载的 Pipeline 操作 -func (c *gatewayCache) pipelineModelLoadGet( - ctx context.Context, - accountIDs []int64, - model string, -) (map[int64]*redis.StringCmd, map[int64]*redis.StringCmd) { - pipe := c.rdb.Pipeline() - loadCmds := make(map[int64]*redis.StringCmd, len(accountIDs)) - lastUsedCmds := make(map[int64]*redis.StringCmd, len(accountIDs)) - - for _, id := range accountIDs { - loadCmds[id] = pipe.Get(ctx, modelLoadKey(id, model)) - lastUsedCmds[id] = pipe.Get(ctx, modelLastUsedKey(id, model)) - } - _, _ = pipe.Exec(ctx) // 忽略错误,key 不存在是正常的 - return loadCmds, lastUsedCmds -} - -// parseModelLoadResults 解析 Pipeline 结果 -func (c *gatewayCache) parseModelLoadResults( - accountIDs []int64, - loadCmds map[int64]*redis.StringCmd, - lastUsedCmds map[int64]*redis.StringCmd, -) map[int64]*service.ModelLoadInfo { - result := make(map[int64]*service.ModelLoadInfo, len(accountIDs)) - for _, id := range accountIDs { - result[id] = &service.ModelLoadInfo{ - CallCount: getInt64OrZero(loadCmds[id]), - LastUsedAt: getTimeOrZero(lastUsedCmds[id]), - } - } - return result -} - -// getInt64OrZero 从 StringCmd 获取 int64 值,失败返回 0 -func getInt64OrZero(cmd *redis.StringCmd) int64 { - val, _ := cmd.Int64() - return val -} - -// getTimeOrZero 从 StringCmd 获取 time.Time,失败返回零值 -func getTimeOrZero(cmd *redis.StringCmd) time.Time { - val, err := cmd.Int64() - if err != nil { - return time.Time{} - } - return time.Unix(val, 0) -} - -// ============ Gemini 会话 Fallback 方法 (Trie 实现) ============ - -// FindGeminiSession 查找 Gemini 会话(使用 Trie + Lua 脚本实现 O(L) 查询) -// 返回最长匹配的会话信息,匹配成功时自动刷新 TTL -func (c *gatewayCache) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) { - if digestChain == "" { - return "", 0, false - } - - trieKey := service.BuildGeminiTrieKey(groupID, prefixHash) - ttlSeconds := int(service.GeminiSessionTTL().Seconds()) - - // 使用 Lua 脚本在 Redis 端执行 Trie 查找,O(L) 次 HGET,1 次网络往返 - // 查找成功时自动刷新 TTL,防止活跃会话意外过期 - result, err := c.rdb.Eval(ctx, geminiTrieFindScript, []string{trieKey}, digestChain, ttlSeconds).Result() - if err != nil || result == nil { - return "", 0, false - } - - value, ok := result.(string) - if !ok || value == "" { - return "", 0, false - } - - uuid, accountID, ok = service.ParseGeminiSessionValue(value) - return uuid, accountID, ok -} - -// SaveGeminiSession 保存 Gemini 会话(使用 Trie + Lua 脚本) -func (c *gatewayCache) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error { - if digestChain == "" { - return nil - } - - trieKey := service.BuildGeminiTrieKey(groupID, prefixHash) - value := service.FormatGeminiSessionValue(uuid, accountID) - ttlSeconds := int(service.GeminiSessionTTL().Seconds()) - - return c.rdb.Eval(ctx, geminiTrieSaveScript, []string{trieKey}, digestChain, value, ttlSeconds).Err() -} - -// ============ Anthropic 会话 Fallback 方法 (复用 Trie 实现) ============ - -// FindAnthropicSession 查找 Anthropic 会话(复用 Gemini Trie Lua 脚本) -func (c *gatewayCache) FindAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) { - if digestChain == "" { - return "", 0, false - } - - trieKey := service.BuildAnthropicTrieKey(groupID, prefixHash) - ttlSeconds := int(service.AnthropicSessionTTL().Seconds()) - - result, err := c.rdb.Eval(ctx, geminiTrieFindScript, []string{trieKey}, digestChain, ttlSeconds).Result() - if err != nil || result == nil { - return "", 0, false - } - - value, ok := result.(string) - if !ok || value == "" { - return "", 0, false - } - - uuid, accountID, ok = service.ParseGeminiSessionValue(value) - return uuid, accountID, ok -} - -// SaveAnthropicSession 保存 Anthropic 会话(复用 Gemini Trie Lua 脚本) -func (c *gatewayCache) SaveAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error { - if digestChain == "" { - return nil - } - - trieKey := service.BuildAnthropicTrieKey(groupID, prefixHash) - value := service.FormatGeminiSessionValue(uuid, accountID) - ttlSeconds := int(service.AnthropicSessionTTL().Seconds()) - - return c.rdb.Eval(ctx, geminiTrieSaveScript, []string{trieKey}, digestChain, value, ttlSeconds).Err() -} diff --git a/backend/internal/repository/gateway_cache_integration_test.go b/backend/internal/repository/gateway_cache_integration_test.go index fc8e7372..2fdaa3d1 100644 --- a/backend/internal/repository/gateway_cache_integration_test.go +++ b/backend/internal/repository/gateway_cache_integration_test.go @@ -104,157 +104,6 @@ func (s *GatewayCacheSuite) TestGetSessionAccountID_CorruptedValue() { require.False(s.T(), errors.Is(err, redis.Nil), "expected parsing error, not redis.Nil") } -// ============ Gemini Trie 会话测试 ============ - -func (s *GatewayCacheSuite) TestGeminiSessionTrie_SaveAndFind() { - groupID := int64(1) - prefixHash := "testprefix" - digestChain := "u:hash1-m:hash2-u:hash3" - uuid := "test-uuid-123" - accountID := int64(42) - - // 保存会话 - err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, digestChain, uuid, accountID) - require.NoError(s.T(), err, "SaveGeminiSession") - - // 精确匹配查找 - foundUUID, foundAccountID, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, digestChain) - require.True(s.T(), found, "should find exact match") - require.Equal(s.T(), uuid, foundUUID) - require.Equal(s.T(), accountID, foundAccountID) -} - -func (s *GatewayCacheSuite) TestGeminiSessionTrie_PrefixMatch() { - groupID := int64(1) - prefixHash := "prefixmatch" - shortChain := "u:a-m:b" - longChain := "u:a-m:b-u:c-m:d" - uuid := "uuid-prefix" - accountID := int64(100) - - // 保存短链 - err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, shortChain, uuid, accountID) - require.NoError(s.T(), err) - - // 用长链查找,应该匹配到短链(前缀匹配) - foundUUID, foundAccountID, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, longChain) - require.True(s.T(), found, "should find prefix match") - require.Equal(s.T(), uuid, foundUUID) - require.Equal(s.T(), accountID, foundAccountID) -} - -func (s *GatewayCacheSuite) TestGeminiSessionTrie_LongestPrefixMatch() { - groupID := int64(1) - prefixHash := "longestmatch" - - // 保存多个不同长度的链 - err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, "u:a", "uuid-short", 1) - require.NoError(s.T(), err) - err = s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, "u:a-m:b", "uuid-medium", 2) - require.NoError(s.T(), err) - err = s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, "u:a-m:b-u:c", "uuid-long", 3) - require.NoError(s.T(), err) - - // 查找更长的链,应该匹配到最长的前缀 - foundUUID, foundAccountID, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, "u:a-m:b-u:c-m:d-u:e") - require.True(s.T(), found, "should find longest prefix match") - require.Equal(s.T(), "uuid-long", foundUUID) - require.Equal(s.T(), int64(3), foundAccountID) - - // 查找中等长度的链 - foundUUID, foundAccountID, found = s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, "u:a-m:b-u:x") - require.True(s.T(), found) - require.Equal(s.T(), "uuid-medium", foundUUID) - require.Equal(s.T(), int64(2), foundAccountID) -} - -func (s *GatewayCacheSuite) TestGeminiSessionTrie_NoMatch() { - groupID := int64(1) - prefixHash := "nomatch" - digestChain := "u:a-m:b" - - // 保存一个会话 - err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, digestChain, "uuid", 1) - require.NoError(s.T(), err) - - // 用不同的链查找,应该找不到 - _, _, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, "u:x-m:y") - require.False(s.T(), found, "should not find non-matching chain") -} - -func (s *GatewayCacheSuite) TestGeminiSessionTrie_DifferentPrefixHash() { - groupID := int64(1) - digestChain := "u:a-m:b" - - // 保存到 prefixHash1 - err := s.cache.SaveGeminiSession(s.ctx, groupID, "prefix1", digestChain, "uuid1", 1) - require.NoError(s.T(), err) - - // 用 prefixHash2 查找,应该找不到(不同用户/客户端隔离) - _, _, found := s.cache.FindGeminiSession(s.ctx, groupID, "prefix2", digestChain) - require.False(s.T(), found, "different prefixHash should be isolated") -} - -func (s *GatewayCacheSuite) TestGeminiSessionTrie_DifferentGroupID() { - prefixHash := "sameprefix" - digestChain := "u:a-m:b" - - // 保存到 groupID 1 - err := s.cache.SaveGeminiSession(s.ctx, 1, prefixHash, digestChain, "uuid1", 1) - require.NoError(s.T(), err) - - // 用 groupID 2 查找,应该找不到(分组隔离) - _, _, found := s.cache.FindGeminiSession(s.ctx, 2, prefixHash, digestChain) - require.False(s.T(), found, "different groupID should be isolated") -} - -func (s *GatewayCacheSuite) TestGeminiSessionTrie_EmptyDigestChain() { - groupID := int64(1) - prefixHash := "emptytest" - - // 空链不应该保存 - err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, "", "uuid", 1) - require.NoError(s.T(), err, "empty chain should not error") - - // 空链查找应该返回 false - _, _, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, "") - require.False(s.T(), found, "empty chain should not match") -} - -func (s *GatewayCacheSuite) TestGeminiSessionTrie_MultipleSessions() { - groupID := int64(1) - prefixHash := "multisession" - - // 保存多个不同会话(模拟 1000 个并发会话的场景) - sessions := []struct { - chain string - uuid string - accountID int64 - }{ - {"u:session1", "uuid-1", 1}, - {"u:session2-m:reply2", "uuid-2", 2}, - {"u:session3-m:reply3-u:msg3", "uuid-3", 3}, - } - - for _, sess := range sessions { - err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, sess.chain, sess.uuid, sess.accountID) - require.NoError(s.T(), err) - } - - // 验证每个会话都能正确查找 - for _, sess := range sessions { - foundUUID, foundAccountID, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, sess.chain) - require.True(s.T(), found, "should find session: %s", sess.chain) - require.Equal(s.T(), sess.uuid, foundUUID) - require.Equal(s.T(), sess.accountID, foundAccountID) - } - - // 验证继续对话的场景 - foundUUID, foundAccountID, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, "u:session2-m:reply2-u:newmsg") - require.True(s.T(), found) - require.Equal(s.T(), "uuid-2", foundUUID) - require.Equal(s.T(), int64(2), foundAccountID) -} func TestGatewayCacheSuite(t *testing.T) { suite.Run(t, new(GatewayCacheSuite)) diff --git a/backend/internal/repository/gateway_cache_model_load_integration_test.go b/backend/internal/repository/gateway_cache_model_load_integration_test.go deleted file mode 100644 index de6fa5ae..00000000 --- a/backend/internal/repository/gateway_cache_model_load_integration_test.go +++ /dev/null @@ -1,234 +0,0 @@ -//go:build integration - -package repository - -import ( - "context" - "testing" - "time" - - "github.com/stretchr/testify/require" - "github.com/stretchr/testify/suite" -) - -// ============ Gateway Cache 模型负载统计集成测试 ============ - -type GatewayCacheModelLoadSuite struct { - suite.Suite -} - -func TestGatewayCacheModelLoadSuite(t *testing.T) { - suite.Run(t, new(GatewayCacheModelLoadSuite)) -} - -func (s *GatewayCacheModelLoadSuite) TestIncrModelCallCount_Basic() { - t := s.T() - rdb := testRedis(t) - cache := &gatewayCache{rdb: rdb} - ctx := context.Background() - - accountID := int64(123) - model := "claude-sonnet-4-20250514" - - // 首次调用应返回 1 - count1, err := cache.IncrModelCallCount(ctx, accountID, model) - require.NoError(t, err) - require.Equal(t, int64(1), count1) - - // 第二次调用应返回 2 - count2, err := cache.IncrModelCallCount(ctx, accountID, model) - require.NoError(t, err) - require.Equal(t, int64(2), count2) - - // 第三次调用应返回 3 - count3, err := cache.IncrModelCallCount(ctx, accountID, model) - require.NoError(t, err) - require.Equal(t, int64(3), count3) -} - -func (s *GatewayCacheModelLoadSuite) TestIncrModelCallCount_DifferentModels() { - t := s.T() - rdb := testRedis(t) - cache := &gatewayCache{rdb: rdb} - ctx := context.Background() - - accountID := int64(456) - model1 := "claude-sonnet-4-20250514" - model2 := "claude-opus-4-5-20251101" - - // 不同模型应该独立计数 - count1, err := cache.IncrModelCallCount(ctx, accountID, model1) - require.NoError(t, err) - require.Equal(t, int64(1), count1) - - count2, err := cache.IncrModelCallCount(ctx, accountID, model2) - require.NoError(t, err) - require.Equal(t, int64(1), count2) - - count1Again, err := cache.IncrModelCallCount(ctx, accountID, model1) - require.NoError(t, err) - require.Equal(t, int64(2), count1Again) -} - -func (s *GatewayCacheModelLoadSuite) TestIncrModelCallCount_DifferentAccounts() { - t := s.T() - rdb := testRedis(t) - cache := &gatewayCache{rdb: rdb} - ctx := context.Background() - - account1 := int64(111) - account2 := int64(222) - model := "gemini-2.5-pro" - - // 不同账号应该独立计数 - count1, err := cache.IncrModelCallCount(ctx, account1, model) - require.NoError(t, err) - require.Equal(t, int64(1), count1) - - count2, err := cache.IncrModelCallCount(ctx, account2, model) - require.NoError(t, err) - require.Equal(t, int64(1), count2) -} - -func (s *GatewayCacheModelLoadSuite) TestGetModelLoadBatch_Empty() { - t := s.T() - rdb := testRedis(t) - cache := &gatewayCache{rdb: rdb} - ctx := context.Background() - - result, err := cache.GetModelLoadBatch(ctx, []int64{}, "any-model") - require.NoError(t, err) - require.NotNil(t, result) - require.Empty(t, result) -} - -func (s *GatewayCacheModelLoadSuite) TestGetModelLoadBatch_NonExistent() { - t := s.T() - rdb := testRedis(t) - cache := &gatewayCache{rdb: rdb} - ctx := context.Background() - - // 查询不存在的账号应返回零值 - result, err := cache.GetModelLoadBatch(ctx, []int64{9999, 9998}, "claude-sonnet-4-20250514") - require.NoError(t, err) - require.Len(t, result, 2) - - require.Equal(t, int64(0), result[9999].CallCount) - require.True(t, result[9999].LastUsedAt.IsZero()) - require.Equal(t, int64(0), result[9998].CallCount) - require.True(t, result[9998].LastUsedAt.IsZero()) -} - -func (s *GatewayCacheModelLoadSuite) TestGetModelLoadBatch_AfterIncrement() { - t := s.T() - rdb := testRedis(t) - cache := &gatewayCache{rdb: rdb} - ctx := context.Background() - - accountID := int64(789) - model := "claude-sonnet-4-20250514" - - // 先增加调用次数 - beforeIncr := time.Now() - _, err := cache.IncrModelCallCount(ctx, accountID, model) - require.NoError(t, err) - _, err = cache.IncrModelCallCount(ctx, accountID, model) - require.NoError(t, err) - _, err = cache.IncrModelCallCount(ctx, accountID, model) - require.NoError(t, err) - afterIncr := time.Now() - - // 获取负载信息 - result, err := cache.GetModelLoadBatch(ctx, []int64{accountID}, model) - require.NoError(t, err) - require.Len(t, result, 1) - - loadInfo := result[accountID] - require.NotNil(t, loadInfo) - require.Equal(t, int64(3), loadInfo.CallCount) - require.False(t, loadInfo.LastUsedAt.IsZero()) - // LastUsedAt 应该在 beforeIncr 和 afterIncr 之间 - require.True(t, loadInfo.LastUsedAt.After(beforeIncr.Add(-time.Second)) || loadInfo.LastUsedAt.Equal(beforeIncr)) - require.True(t, loadInfo.LastUsedAt.Before(afterIncr.Add(time.Second)) || loadInfo.LastUsedAt.Equal(afterIncr)) -} - -func (s *GatewayCacheModelLoadSuite) TestGetModelLoadBatch_MultipleAccounts() { - t := s.T() - rdb := testRedis(t) - cache := &gatewayCache{rdb: rdb} - ctx := context.Background() - - model := "claude-opus-4-5-20251101" - account1 := int64(1001) - account2 := int64(1002) - account3 := int64(1003) // 不调用 - - // account1 调用 2 次 - _, err := cache.IncrModelCallCount(ctx, account1, model) - require.NoError(t, err) - _, err = cache.IncrModelCallCount(ctx, account1, model) - require.NoError(t, err) - - // account2 调用 5 次 - for i := 0; i < 5; i++ { - _, err = cache.IncrModelCallCount(ctx, account2, model) - require.NoError(t, err) - } - - // 批量获取 - result, err := cache.GetModelLoadBatch(ctx, []int64{account1, account2, account3}, model) - require.NoError(t, err) - require.Len(t, result, 3) - - require.Equal(t, int64(2), result[account1].CallCount) - require.False(t, result[account1].LastUsedAt.IsZero()) - - require.Equal(t, int64(5), result[account2].CallCount) - require.False(t, result[account2].LastUsedAt.IsZero()) - - require.Equal(t, int64(0), result[account3].CallCount) - require.True(t, result[account3].LastUsedAt.IsZero()) -} - -func (s *GatewayCacheModelLoadSuite) TestGetModelLoadBatch_ModelIsolation() { - t := s.T() - rdb := testRedis(t) - cache := &gatewayCache{rdb: rdb} - ctx := context.Background() - - accountID := int64(2001) - model1 := "claude-sonnet-4-20250514" - model2 := "gemini-2.5-pro" - - // 对 model1 调用 3 次 - for i := 0; i < 3; i++ { - _, err := cache.IncrModelCallCount(ctx, accountID, model1) - require.NoError(t, err) - } - - // 获取 model1 的负载 - result1, err := cache.GetModelLoadBatch(ctx, []int64{accountID}, model1) - require.NoError(t, err) - require.Equal(t, int64(3), result1[accountID].CallCount) - - // 获取 model2 的负载(应该为 0) - result2, err := cache.GetModelLoadBatch(ctx, []int64{accountID}, model2) - require.NoError(t, err) - require.Equal(t, int64(0), result2[accountID].CallCount) -} - -// ============ 辅助函数测试 ============ - -func (s *GatewayCacheModelLoadSuite) TestModelLoadKey_Format() { - t := s.T() - - key := modelLoadKey(123, "claude-sonnet-4") - require.Equal(t, "ag:model_load:123:claude-sonnet-4", key) -} - -func (s *GatewayCacheModelLoadSuite) TestModelLastUsedKey_Format() { - t := s.T() - - key := modelLastUsedKey(456, "gemini-2.5-pro") - require.Equal(t, "ag:model_last_used:456:gemini-2.5-pro", key) -} diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index efef0452..9b571123 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -1004,10 +1004,6 @@ func (s *stubAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt return errors.New("not implemented") } -func (s *stubAccountRepo) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope service.AntigravityQuotaScope, resetAt time.Time) error { - return errors.New("not implemented") -} - func (s *stubAccountRepo) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error { return errors.New("not implemented") } diff --git a/backend/internal/service/account_service.go b/backend/internal/service/account_service.go index 90365d2f..9bf58988 100644 --- a/backend/internal/service/account_service.go +++ b/backend/internal/service/account_service.go @@ -50,7 +50,6 @@ type AccountRepository interface { ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error - SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error SetOverloaded(ctx context.Context, id int64, until time.Time) error SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error diff --git a/backend/internal/service/account_service_delete_test.go b/backend/internal/service/account_service_delete_test.go index e5eabfc6..af3a3784 100644 --- a/backend/internal/service/account_service_delete_test.go +++ b/backend/internal/service/account_service_delete_test.go @@ -143,10 +143,6 @@ func (s *accountRepoStub) SetRateLimited(ctx context.Context, id int64, resetAt panic("unexpected SetRateLimited call") } -func (s *accountRepoStub) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error { - panic("unexpected SetAntigravityQuotaScopeLimit call") -} - func (s *accountRepoStub) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error { panic("unexpected SetModelRateLimit call") } diff --git a/backend/internal/service/anthropic_session.go b/backend/internal/service/anthropic_session.go index 2d86ed35..26544c68 100644 --- a/backend/internal/service/anthropic_session.go +++ b/backend/internal/service/anthropic_session.go @@ -2,7 +2,6 @@ package service import ( "encoding/json" - "strconv" "strings" "time" ) @@ -12,9 +11,6 @@ const ( // anthropicSessionTTLSeconds Anthropic 会话缓存 TTL(5 分钟) anthropicSessionTTLSeconds = 300 - // anthropicTrieKeyPrefix Anthropic Trie 会话 key 前缀 - anthropicTrieKeyPrefix = "anthropic:trie:" - // anthropicDigestSessionKeyPrefix Anthropic 摘要 fallback 会话 key 前缀 anthropicDigestSessionKeyPrefix = "anthropic:digest:" ) @@ -68,12 +64,6 @@ func rolePrefix(role string) string { } } -// BuildAnthropicTrieKey 构建 Anthropic Trie Redis key -// 格式: anthropic:trie:{groupID}:{prefixHash} -func BuildAnthropicTrieKey(groupID int64, prefixHash string) string { - return anthropicTrieKeyPrefix + strconv.FormatInt(groupID, 10) + ":" + prefixHash -} - // GenerateAnthropicDigestSessionKey 生成 Anthropic 摘要 fallback 的 sessionKey // 组合 prefixHash 前 8 位 + uuid 前 8 位,确保不同会话产生不同的 sessionKey func GenerateAnthropicDigestSessionKey(prefixHash, uuid string) string { diff --git a/backend/internal/service/anthropic_session_test.go b/backend/internal/service/anthropic_session_test.go index e2f873e7..10406643 100644 --- a/backend/internal/service/anthropic_session_test.go +++ b/backend/internal/service/anthropic_session_test.go @@ -236,43 +236,6 @@ func TestBuildAnthropicDigestChain_Deterministic(t *testing.T) { } } -func TestBuildAnthropicTrieKey(t *testing.T) { - tests := []struct { - name string - groupID int64 - prefixHash string - want string - }{ - { - name: "normal", - groupID: 123, - prefixHash: "abcdef12", - want: "anthropic:trie:123:abcdef12", - }, - { - name: "zero group", - groupID: 0, - prefixHash: "xyz", - want: "anthropic:trie:0:xyz", - }, - { - name: "empty prefix", - groupID: 1, - prefixHash: "", - want: "anthropic:trie:1:", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := BuildAnthropicTrieKey(tt.groupID, tt.prefixHash) - if got != tt.want { - t.Errorf("BuildAnthropicTrieKey(%d, %q) = %q, want %q", tt.groupID, tt.prefixHash, got, tt.want) - } - }) - } -} - func TestGenerateAnthropicDigestSessionKey(t *testing.T) { tests := []struct { name string diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index 3caf9a93..014b3c86 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -9,6 +9,7 @@ import ( "fmt" "io" "log" + "log/slog" mathrand "math/rand" "net" "net/http" @@ -100,12 +101,11 @@ type antigravityRetryLoopParams struct { accessToken string action string body []byte - quotaScope AntigravityQuotaScope c *gin.Context httpUpstream HTTPUpstream settingService *SettingService accountRepo AccountRepository // 用于智能重试的模型级别限流 - handleError func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult + handleError func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult requestedModel string // 用于限流检查的原始请求模型 isStickySession bool // 是否为粘性会话(用于账号切换时的缓存计费判断) groupID int64 // 用于模型级限流时清除粘性会话 @@ -148,13 +148,17 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam // 情况1: retryDelay >= 阈值,限流模型并切换账号 if shouldRateLimitModel { - log.Printf("%s status=%d oauth_long_delay model=%s account=%d (model rate limit, switch account)", - p.prefix, resp.StatusCode, modelName, p.account.ID) + rateLimitDuration := waitDuration + if rateLimitDuration <= 0 { + rateLimitDuration = antigravityDefaultRateLimitDuration + } + log.Printf("%s status=%d oauth_long_delay model=%s account=%d upstream_retry_delay=%v body=%s (model rate limit, switch account)", + p.prefix, resp.StatusCode, modelName, p.account.ID, rateLimitDuration, truncateForLog(respBody, 200)) - resetAt := time.Now().Add(antigravityDefaultRateLimitDuration) + resetAt := time.Now().Add(rateLimitDuration) if !setModelRateLimitByModelName(p.ctx, p.accountRepo, p.account.ID, modelName, p.prefix, resp.StatusCode, resetAt, false) { - p.handleError(p.ctx, p.prefix, p.account, resp.StatusCode, resp.Header, respBody, p.quotaScope, p.groupID, p.sessionHash, p.isStickySession) - log.Printf("%s status=%d rate_limited account=%d (no scope mapping)", p.prefix, resp.StatusCode, p.account.ID) + p.handleError(p.ctx, p.prefix, p.account, resp.StatusCode, resp.Header, respBody, p.requestedModel, p.groupID, p.sessionHash, p.isStickySession) + log.Printf("%s status=%d rate_limited account=%d (no model mapping)", p.prefix, resp.StatusCode, p.account.ID) } else { s.updateAccountModelRateLimitInCache(p.ctx, p.account, modelName, resetAt) } @@ -190,7 +194,7 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam retryReq, err := antigravity.NewAPIRequestWithURL(p.ctx, baseURL, p.action, p.accessToken, p.body) if err != nil { log.Printf("%s status=smart_retry_request_build_failed error=%v", p.prefix, err) - p.handleError(p.ctx, p.prefix, p.account, resp.StatusCode, resp.Header, respBody, p.quotaScope, p.groupID, p.sessionHash, p.isStickySession) + p.handleError(p.ctx, p.prefix, p.account, resp.StatusCode, resp.Header, respBody, p.requestedModel, p.groupID, p.sessionHash, p.isStickySession) return &smartRetryResult{ action: smartRetryActionBreakWithResp, resp: &http.Response{ @@ -233,16 +237,24 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam } // 所有重试都失败,限流当前模型并切换账号 - log.Printf("%s status=%d smart_retry_exhausted attempts=%d model=%s account=%d (switch account)", - p.prefix, resp.StatusCode, antigravitySmartRetryMaxAttempts, modelName, p.account.ID) + rateLimitDuration := waitDuration + if rateLimitDuration <= 0 { + rateLimitDuration = antigravityDefaultRateLimitDuration + } + retryBody := lastRetryBody + if retryBody == nil { + retryBody = respBody + } + log.Printf("%s status=%d smart_retry_exhausted attempts=%d model=%s account=%d upstream_retry_delay=%v body=%s (switch account)", + p.prefix, resp.StatusCode, antigravitySmartRetryMaxAttempts, modelName, p.account.ID, rateLimitDuration, truncateForLog(retryBody, 200)) - resetAt := time.Now().Add(antigravityDefaultRateLimitDuration) + resetAt := time.Now().Add(rateLimitDuration) if p.accountRepo != nil && modelName != "" { if err := p.accountRepo.SetModelRateLimit(p.ctx, p.account.ID, modelName, resetAt); err != nil { log.Printf("%s status=%d model_rate_limit_failed model=%s error=%v", p.prefix, resp.StatusCode, modelName, err) } else { log.Printf("%s status=%d model_rate_limited_after_smart_retry model=%s account=%d reset_in=%v", - p.prefix, resp.StatusCode, modelName, p.account.ID, antigravityDefaultRateLimitDuration) + p.prefix, resp.StatusCode, modelName, p.account.ID, rateLimitDuration) s.updateAccountModelRateLimitInCache(p.ctx, p.account, modelName, resetAt) } } @@ -353,87 +365,102 @@ urlFallbackLoop: return nil, fmt.Errorf("upstream request failed after retries: %w", err) } - // 429/503 限流处理:区分 URL 级别限流、智能重试和账户配额限流 - if resp.StatusCode == http.StatusTooManyRequests || resp.StatusCode == http.StatusServiceUnavailable { + // 统一处理错误响应 + if resp.StatusCode >= 400 { respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) _ = resp.Body.Close() - // 尝试智能重试处理(OAuth 账号专用) - smartResult := s.handleSmartRetry(p, resp, respBody, baseURL, urlIdx, availableURLs) - switch smartResult.action { - case smartRetryActionContinueURL: - continue urlFallbackLoop - case smartRetryActionBreakWithResp: - if smartResult.err != nil { - return nil, smartResult.err + // ★ 统一入口:自定义错误码 + 临时不可调度 + if handled, policyErr := s.applyErrorPolicy(p, resp.StatusCode, resp.Header, respBody); handled { + if policyErr != nil { + return nil, policyErr } - // 模型限流时返回切换账号信号 - if smartResult.switchError != nil { - return nil, smartResult.switchError + resp = &http.Response{ + StatusCode: resp.StatusCode, + Header: resp.Header.Clone(), + Body: io.NopCloser(bytes.NewReader(respBody)), } - resp = smartResult.resp break urlFallbackLoop } - // smartRetryActionContinue: 继续默认重试逻辑 - // 账户/模型配额限流,重试 3 次(指数退避)- 默认逻辑(非 OAuth 账号或解析失败) - if attempt < antigravityMaxRetries { - upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody)) - upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) - appendOpsUpstreamError(p.c, OpsUpstreamErrorEvent{ - Platform: p.account.Platform, - AccountID: p.account.ID, - AccountName: p.account.Name, - UpstreamStatusCode: resp.StatusCode, - UpstreamRequestID: resp.Header.Get("x-request-id"), - Kind: "retry", - Message: upstreamMsg, - Detail: getUpstreamDetail(respBody), - }) - log.Printf("%s status=%d retry=%d/%d body=%s", p.prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 200)) - if !sleepAntigravityBackoffWithContext(p.ctx, attempt) { - log.Printf("%s status=context_canceled_during_backoff", p.prefix) - return nil, p.ctx.Err() + // 429/503 限流处理:区分 URL 级别限流、智能重试和账户配额限流 + if resp.StatusCode == http.StatusTooManyRequests || resp.StatusCode == http.StatusServiceUnavailable { + // 尝试智能重试处理(OAuth 账号专用) + smartResult := s.handleSmartRetry(p, resp, respBody, baseURL, urlIdx, availableURLs) + switch smartResult.action { + case smartRetryActionContinueURL: + continue urlFallbackLoop + case smartRetryActionBreakWithResp: + if smartResult.err != nil { + return nil, smartResult.err + } + // 模型限流时返回切换账号信号 + if smartResult.switchError != nil { + return nil, smartResult.switchError + } + resp = smartResult.resp + break urlFallbackLoop } - continue + // smartRetryActionContinue: 继续默认重试逻辑 + + // 账户/模型配额限流,重试 3 次(指数退避)- 默认逻辑(非 OAuth 账号或解析失败) + if attempt < antigravityMaxRetries { + upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + appendOpsUpstreamError(p.c, OpsUpstreamErrorEvent{ + Platform: p.account.Platform, + AccountID: p.account.ID, + AccountName: p.account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: "retry", + Message: upstreamMsg, + Detail: getUpstreamDetail(respBody), + }) + log.Printf("%s status=%d retry=%d/%d body=%s", p.prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 200)) + if !sleepAntigravityBackoffWithContext(p.ctx, attempt) { + log.Printf("%s status=context_canceled_during_backoff", p.prefix) + return nil, p.ctx.Err() + } + continue + } + + // 重试用尽,标记账户限流 + p.handleError(p.ctx, p.prefix, p.account, resp.StatusCode, resp.Header, respBody, p.requestedModel, p.groupID, p.sessionHash, p.isStickySession) + log.Printf("%s status=%d rate_limited base_url=%s body=%s", p.prefix, resp.StatusCode, baseURL, truncateForLog(respBody, 200)) + resp = &http.Response{ + StatusCode: resp.StatusCode, + Header: resp.Header.Clone(), + Body: io.NopCloser(bytes.NewReader(respBody)), + } + break urlFallbackLoop } - // 重试用尽,标记账户限流 - p.handleError(p.ctx, p.prefix, p.account, resp.StatusCode, resp.Header, respBody, p.quotaScope, p.groupID, p.sessionHash, p.isStickySession) - log.Printf("%s status=%d rate_limited base_url=%s body=%s", p.prefix, resp.StatusCode, baseURL, truncateForLog(respBody, 200)) - resp = &http.Response{ - StatusCode: resp.StatusCode, - Header: resp.Header.Clone(), - Body: io.NopCloser(bytes.NewReader(respBody)), - } - break urlFallbackLoop - } - - // 其他可重试错误(不包括 429 和 503,因为上面已处理) - if resp.StatusCode >= 400 && shouldRetryAntigravityError(resp.StatusCode) { - respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - _ = resp.Body.Close() - - if attempt < antigravityMaxRetries { - upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody)) - upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) - appendOpsUpstreamError(p.c, OpsUpstreamErrorEvent{ - Platform: p.account.Platform, - AccountID: p.account.ID, - AccountName: p.account.Name, - UpstreamStatusCode: resp.StatusCode, - UpstreamRequestID: resp.Header.Get("x-request-id"), - Kind: "retry", - Message: upstreamMsg, - Detail: getUpstreamDetail(respBody), - }) - log.Printf("%s status=%d retry=%d/%d body=%s", p.prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 500)) - if !sleepAntigravityBackoffWithContext(p.ctx, attempt) { - log.Printf("%s status=context_canceled_during_backoff", p.prefix) - return nil, p.ctx.Err() - } - continue - } + // 其他可重试错误(500/502/504/529,不包括 429 和 503) + if shouldRetryAntigravityError(resp.StatusCode) { + if attempt < antigravityMaxRetries { + upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + appendOpsUpstreamError(p.c, OpsUpstreamErrorEvent{ + Platform: p.account.Platform, + AccountID: p.account.ID, + AccountName: p.account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: "retry", + Message: upstreamMsg, + Detail: getUpstreamDetail(respBody), + }) + log.Printf("%s status=%d retry=%d/%d body=%s", p.prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 500)) + if !sleepAntigravityBackoffWithContext(p.ctx, attempt) { + log.Printf("%s status=context_canceled_during_backoff", p.prefix) + return nil, p.ctx.Err() + } + continue + } + } + + // 其他 4xx 错误或重试用尽,直接返回 resp = &http.Response{ StatusCode: resp.StatusCode, Header: resp.Header.Clone(), @@ -442,6 +469,7 @@ urlFallbackLoop: break urlFallbackLoop } + // 成功响应(< 400) break urlFallbackLoop } } @@ -574,6 +602,31 @@ func (s *AntigravityGatewayService) getUpstreamErrorDetail(body []byte) string { return truncateString(string(body), maxBytes) } +// checkErrorPolicy nil 安全的包装 +func (s *AntigravityGatewayService) checkErrorPolicy(ctx context.Context, account *Account, statusCode int, body []byte) ErrorPolicyResult { + if s.rateLimitService == nil { + return ErrorPolicyNone + } + return s.rateLimitService.CheckErrorPolicy(ctx, account, statusCode, body) +} + +// applyErrorPolicy 应用错误策略结果,返回是否应终止当前循环 +func (s *AntigravityGatewayService) applyErrorPolicy(p antigravityRetryLoopParams, statusCode int, headers http.Header, respBody []byte) (handled bool, retErr error) { + switch s.checkErrorPolicy(p.ctx, p.account, statusCode, respBody) { + case ErrorPolicySkipped: + return true, nil + case ErrorPolicyMatched: + _ = p.handleError(p.ctx, p.prefix, p.account, statusCode, headers, respBody, + p.requestedModel, p.groupID, p.sessionHash, p.isStickySession) + return true, nil + case ErrorPolicyTempUnscheduled: + slog.Info("temp_unschedulable_matched", + "prefix", p.prefix, "status_code", statusCode, "account_id", p.account.ID) + return true, &AntigravityAccountSwitchError{OriginalAccountID: p.account.ID, IsStickySession: p.isStickySession} + } + return false, nil +} + // mapAntigravityModel 获取映射后的模型名 // 完全依赖映射配置:账户映射(通配符)→ 默认映射兜底(DefaultAntigravityModelMapping) // 注意:返回空字符串表示模型不被支持,调度时会过滤掉该账号 @@ -969,6 +1022,11 @@ func isModelNotFoundError(statusCode int, body []byte) bool { // ├─ 成功 → 正常返回 // └─ 失败 → 设置模型限流 + 清除粘性绑定 → 切换账号 func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte, isStickySession bool) (*ForwardResult, error) { + // 上游透传账号直接转发,不走 OAuth token 刷新 + if account.Type == AccountTypeUpstream { + return s.ForwardUpstream(ctx, c, account, body) + } + startTime := time.Now() sessionID := getSessionID(c) @@ -988,11 +1046,9 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, if mappedModel == "" { return nil, s.writeClaudeError(c, http.StatusForbidden, "permission_error", fmt.Sprintf("model %s not in whitelist", claudeReq.Model)) } - loadModel := mappedModel // 应用 thinking 模式自动后缀:如果 thinking 开启且目标是 claude-sonnet-4-5,自动改为 thinking 版本 thinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled" mappedModel = applyThinkingModelSuffix(mappedModel, thinkingEnabled) - quotaScope, _ := resolveAntigravityQuotaScope(originalModel) // 获取 access_token if s.tokenProvider == nil { @@ -1027,11 +1083,6 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, // 如果客户端请求非流式,在响应处理阶段会收集完整流式响应后转换返回 action := "streamGenerateContent" - // 统计模型调用次数(包括粘性会话,用于负载均衡调度) - if s.cache != nil { - _, _ = s.cache.IncrModelCallCount(ctx, account.ID, loadModel) - } - // 执行带重试的请求 result, err := s.antigravityRetryLoop(antigravityRetryLoopParams{ ctx: ctx, @@ -1041,7 +1092,6 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, accessToken: accessToken, action: action, body: geminiBody, - quotaScope: quotaScope, c: c, httpUpstream: s.httpUpstream, settingService: s.settingService, @@ -1122,7 +1172,6 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, accessToken: accessToken, action: action, body: retryGeminiBody, - quotaScope: quotaScope, c: c, httpUpstream: s.httpUpstream, settingService: s.settingService, @@ -1233,7 +1282,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, } } - s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope, 0, "", isStickySession) + s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, originalModel, 0, "", isStickySession) if s.shouldFailoverUpstreamError(resp.StatusCode) { upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody)) @@ -1263,6 +1312,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, var usage *ClaudeUsage var firstTokenMs *int + var clientDisconnect bool if claudeReq.Stream { // 客户端要求流式,直接透传转换 streamRes, err := s.handleClaudeStreamingResponse(c, resp, startTime, originalModel) @@ -1272,6 +1322,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, } usage = streamRes.usage firstTokenMs = streamRes.firstTokenMs + clientDisconnect = streamRes.clientDisconnect } else { // 客户端要求非流式,收集流式响应后转换返回 streamRes, err := s.handleClaudeStreamToNonStreaming(c, resp, startTime, originalModel) @@ -1284,12 +1335,13 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, } return &ForwardResult{ - RequestID: requestID, - Usage: *usage, - Model: originalModel, // 使用原始模型用于计费和日志 - Stream: claudeReq.Stream, - Duration: time.Since(startTime), - FirstTokenMs: firstTokenMs, + RequestID: requestID, + Usage: *usage, + Model: originalModel, // 使用原始模型用于计费和日志 + Stream: claudeReq.Stream, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + ClientDisconnect: clientDisconnect, }, nil } @@ -1613,7 +1665,6 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co if len(body) == 0 { return nil, s.writeGoogleError(c, http.StatusBadRequest, "Request body is empty") } - quotaScope, _ := resolveAntigravityQuotaScope(originalModel) // 解析请求以获取 image_size(用于图片计费) imageSize := s.extractImageSize(body) @@ -1683,11 +1734,6 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co // 如果客户端请求非流式,在响应处理阶段会收集完整流式响应后返回 upstreamAction := "streamGenerateContent" - // 统计模型调用次数(包括粘性会话,用于负载均衡调度) - if s.cache != nil { - _, _ = s.cache.IncrModelCallCount(ctx, account.ID, mappedModel) - } - // 执行带重试的请求 result, err := s.antigravityRetryLoop(antigravityRetryLoopParams{ ctx: ctx, @@ -1697,7 +1743,6 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co accessToken: accessToken, action: upstreamAction, body: wrappedBody, - quotaScope: quotaScope, c: c, httpUpstream: s.httpUpstream, settingService: s.settingService, @@ -1771,7 +1816,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co if unwrapErr != nil || len(unwrappedForOps) == 0 { unwrappedForOps = respBody } - s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope, 0, "", isStickySession) + s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, originalModel, 0, "", isStickySession) upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(unwrappedForOps)) upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) upstreamDetail := s.getUpstreamErrorDetail(unwrappedForOps) @@ -1818,6 +1863,7 @@ handleSuccess: var usage *ClaudeUsage var firstTokenMs *int + var clientDisconnect bool if stream { // 客户端要求流式,直接透传 @@ -1828,6 +1874,7 @@ handleSuccess: } usage = streamRes.usage firstTokenMs = streamRes.firstTokenMs + clientDisconnect = streamRes.clientDisconnect } else { // 客户端要求非流式,收集流式响应后返回 streamRes, err := s.handleGeminiStreamToNonStreaming(c, resp, startTime) @@ -1851,14 +1898,15 @@ handleSuccess: } return &ForwardResult{ - RequestID: requestID, - Usage: *usage, - Model: originalModel, - Stream: stream, - Duration: time.Since(startTime), - FirstTokenMs: firstTokenMs, - ImageCount: imageCount, - ImageSize: imageSize, + RequestID: requestID, + Usage: *usage, + Model: originalModel, + Stream: stream, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + ClientDisconnect: clientDisconnect, + ImageCount: imageCount, + ImageSize: imageSize, }, nil } @@ -2067,9 +2115,9 @@ func shouldTriggerAntigravitySmartRetry(account *Account, respBody []byte) (shou } // retryDelay >= 阈值:直接限流模型,不重试 - // 注意:如果上游未提供 retryDelay,parseAntigravitySmartRetryInfo 已设置为默认 5 分钟 + // 注意:如果上游未提供 retryDelay,parseAntigravitySmartRetryInfo 已设置为默认 30s if info.RetryDelay >= antigravityRateLimitThreshold { - return false, true, 0, info.ModelName + return false, true, info.RetryDelay, info.ModelName } // retryDelay < 阈值:智能重试 @@ -2191,10 +2239,10 @@ func (s *AntigravityGatewayService) updateAccountModelRateLimitInCache(ctx conte func (s *AntigravityGatewayService) handleUpstreamError( ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, - quotaScope AntigravityQuotaScope, + requestedModel string, groupID int64, sessionHash string, isStickySession bool, ) *handleModelRateLimitResult { - // ✨ 模型级限流处理(在原有逻辑之前) + // 模型级限流处理(优先) result := s.handleModelRateLimit(&handleModelRateLimitParams{ ctx: ctx, prefix: prefix, @@ -2216,52 +2264,35 @@ func (s *AntigravityGatewayService) handleUpstreamError( return nil } - // ========== 原有逻辑,保持不变 ========== - // 429 使用 Gemini 格式解析(从 body 解析重置时间) + // 429:尝试解析模型级限流,解析失败时兜底为账号级限流 if statusCode == 429 { - // 调试日志遵循统一日志开关与长度限制,避免无条件记录完整上游响应体。 if logBody, maxBytes := s.getLogConfig(); logBody { log.Printf("[Antigravity-Debug] 429 response body: %s", truncateString(string(body), maxBytes)) } - useScopeLimit := quotaScope != "" resetAt := ParseGeminiRateLimitResetTime(body) - if resetAt == nil { - // 解析失败:使用默认限流时间(与临时限流保持一致) - // 可通过配置或环境变量覆盖 - defaultDur := antigravityDefaultRateLimitDuration - if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.AntigravityFallbackCooldownMinutes > 0 { - defaultDur = time.Duration(s.settingService.cfg.Gateway.AntigravityFallbackCooldownMinutes) * time.Minute - } - // 秒级环境变量优先级最高 - if override, ok := antigravityFallbackCooldownSeconds(); ok { - defaultDur = override - } - ra := time.Now().Add(defaultDur) - if useScopeLimit { - log.Printf("%s status=429 rate_limited scope=%s reset_in=%v (fallback)", prefix, quotaScope, defaultDur) - if err := s.accountRepo.SetAntigravityQuotaScopeLimit(ctx, account.ID, quotaScope, ra); err != nil { - log.Printf("%s status=429 rate_limit_set_failed scope=%s error=%v", prefix, quotaScope, err) - } + defaultDur := s.getDefaultRateLimitDuration() + + // 尝试解析模型 key 并设置模型级限流 + modelKey := resolveAntigravityModelKey(requestedModel) + if modelKey != "" { + ra := s.resolveResetTime(resetAt, defaultDur) + if err := s.accountRepo.SetModelRateLimit(ctx, account.ID, modelKey, ra); err != nil { + log.Printf("%s status=429 model_rate_limit_set_failed model=%s error=%v", prefix, modelKey, err) } else { - log.Printf("%s status=429 rate_limited account=%d reset_in=%v (fallback)", prefix, account.ID, defaultDur) - if err := s.accountRepo.SetRateLimited(ctx, account.ID, ra); err != nil { - log.Printf("%s status=429 rate_limit_set_failed account=%d error=%v", prefix, account.ID, err) - } + log.Printf("%s status=429 model_rate_limited model=%s account=%d reset_at=%v reset_in=%v", + prefix, modelKey, account.ID, ra.Format("15:04:05"), time.Until(ra).Truncate(time.Second)) + s.updateAccountModelRateLimitInCache(ctx, account, modelKey, ra) } return nil } - resetTime := time.Unix(*resetAt, 0) - if useScopeLimit { - log.Printf("%s status=429 rate_limited scope=%s reset_at=%v reset_in=%v", prefix, quotaScope, resetTime.Format("15:04:05"), time.Until(resetTime).Truncate(time.Second)) - if err := s.accountRepo.SetAntigravityQuotaScopeLimit(ctx, account.ID, quotaScope, resetTime); err != nil { - log.Printf("%s status=429 rate_limit_set_failed scope=%s error=%v", prefix, quotaScope, err) - } - } else { - log.Printf("%s status=429 rate_limited account=%d reset_at=%v reset_in=%v", prefix, account.ID, resetTime.Format("15:04:05"), time.Until(resetTime).Truncate(time.Second)) - if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetTime); err != nil { - log.Printf("%s status=429 rate_limit_set_failed account=%d error=%v", prefix, account.ID, err) - } + + // 无法解析模型 key,兜底为账号级限流 + ra := s.resolveResetTime(resetAt, defaultDur) + log.Printf("%s status=429 rate_limited account=%d reset_at=%v reset_in=%v (fallback)", + prefix, account.ID, ra.Format("15:04:05"), time.Until(ra).Truncate(time.Second)) + if err := s.accountRepo.SetRateLimited(ctx, account.ID, ra); err != nil { + log.Printf("%s status=429 rate_limit_set_failed account=%d error=%v", prefix, account.ID, err) } return nil } @@ -2276,9 +2307,90 @@ func (s *AntigravityGatewayService) handleUpstreamError( return nil } +// getDefaultRateLimitDuration 获取默认限流时间 +func (s *AntigravityGatewayService) getDefaultRateLimitDuration() time.Duration { + defaultDur := antigravityDefaultRateLimitDuration + if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.AntigravityFallbackCooldownMinutes > 0 { + defaultDur = time.Duration(s.settingService.cfg.Gateway.AntigravityFallbackCooldownMinutes) * time.Minute + } + if override, ok := antigravityFallbackCooldownSeconds(); ok { + defaultDur = override + } + return defaultDur +} + +// resolveResetTime 根据解析的重置时间或默认时长计算重置时间点 +func (s *AntigravityGatewayService) resolveResetTime(resetAt *int64, defaultDur time.Duration) time.Time { + if resetAt != nil { + return time.Unix(*resetAt, 0) + } + return time.Now().Add(defaultDur) +} + type antigravityStreamResult struct { - usage *ClaudeUsage - firstTokenMs *int + usage *ClaudeUsage + firstTokenMs *int + clientDisconnect bool // 客户端是否在流式传输过程中断开 +} + +// antigravityClientWriter 封装流式响应的客户端写入,自动检测断开并标记。 +// 断开后所有写入操作变为 no-op,调用方通过 Disconnected() 判断是否继续 drain 上游。 +type antigravityClientWriter struct { + w gin.ResponseWriter + flusher http.Flusher + disconnected bool + prefix string // 日志前缀,标识来源方法 +} + +func newAntigravityClientWriter(w gin.ResponseWriter, flusher http.Flusher, prefix string) *antigravityClientWriter { + return &antigravityClientWriter{w: w, flusher: flusher, prefix: prefix} +} + +// Write 写入数据到客户端,写入失败时标记断开并返回 false +func (cw *antigravityClientWriter) Write(p []byte) bool { + if cw.disconnected { + return false + } + if _, err := cw.w.Write(p); err != nil { + cw.markDisconnected() + return false + } + cw.flusher.Flush() + return true +} + +// Fprintf 格式化写入数据到客户端,写入失败时标记断开并返回 false +func (cw *antigravityClientWriter) Fprintf(format string, args ...any) bool { + if cw.disconnected { + return false + } + if _, err := fmt.Fprintf(cw.w, format, args...); err != nil { + cw.markDisconnected() + return false + } + cw.flusher.Flush() + return true +} + +func (cw *antigravityClientWriter) Disconnected() bool { return cw.disconnected } + +func (cw *antigravityClientWriter) markDisconnected() { + cw.disconnected = true + log.Printf("Client disconnected during streaming (%s), continuing to drain upstream for billing", cw.prefix) +} + +// handleStreamReadError 处理上游读取错误的通用逻辑。 +// 返回 (clientDisconnect, handled):handled=true 表示错误已处理,调用方应返回已收集的 usage。 +func handleStreamReadError(err error, clientDisconnected bool, prefix string) (disconnect bool, handled bool) { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + log.Printf("Context canceled during streaming (%s), returning collected usage", prefix) + return true, true + } + if clientDisconnected { + log.Printf("Upstream read error after client disconnect (%s): %v, returning collected usage", prefix, err) + return true, true + } + return false, false } func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time) (*antigravityStreamResult, error) { @@ -2354,10 +2466,12 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context intervalCh = intervalTicker.C } + cw := newAntigravityClientWriter(c.Writer, flusher, "antigravity gemini") + // 仅发送一次错误事件,避免多次写入导致协议混乱 errorEventSent := false sendErrorEvent := func(reason string) { - if errorEventSent { + if errorEventSent || cw.Disconnected() { return } errorEventSent = true @@ -2369,9 +2483,12 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context select { case ev, ok := <-events: if !ok { - return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, nil + return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: cw.Disconnected()}, nil } if ev.err != nil { + if disconnect, handled := handleStreamReadError(ev.err, cw.Disconnected(), "antigravity gemini"); handled { + return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: disconnect}, nil + } if errors.Is(ev.err, bufio.ErrTooLong) { log.Printf("SSE line too long (antigravity): max_size=%d error=%v", maxLineSize, ev.err) sendErrorEvent("response_too_large") @@ -2386,11 +2503,7 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context if strings.HasPrefix(trimmed, "data:") { payload := strings.TrimSpace(strings.TrimPrefix(trimmed, "data:")) if payload == "" || payload == "[DONE]" { - if _, err := fmt.Fprintf(c.Writer, "%s\n", line); err != nil { - sendErrorEvent("write_failed") - return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, err - } - flusher.Flush() + cw.Fprintf("%s\n", line) continue } @@ -2426,27 +2539,22 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context firstTokenMs = &ms } - if _, err := fmt.Fprintf(c.Writer, "data: %s\n\n", payload); err != nil { - sendErrorEvent("write_failed") - return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, err - } - flusher.Flush() + cw.Fprintf("data: %s\n\n", payload) continue } - if _, err := fmt.Fprintf(c.Writer, "%s\n", line); err != nil { - sendErrorEvent("write_failed") - return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, err - } - flusher.Flush() + cw.Fprintf("%s\n", line) case <-intervalCh: lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt)) if time.Since(lastRead) < streamInterval { continue } + if cw.Disconnected() { + log.Printf("Upstream timeout after client disconnect (antigravity gemini), returning collected usage") + return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil + } log.Printf("Stream data interval timeout (antigravity)") - // 注意:此函数没有 account 上下文,无法调用 HandleStreamTimeout sendErrorEvent("stream_timeout") return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout") } @@ -3144,10 +3252,12 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context intervalCh = intervalTicker.C } + cw := newAntigravityClientWriter(c.Writer, flusher, "antigravity claude") + // 仅发送一次错误事件,避免多次写入导致协议混乱 errorEventSent := false sendErrorEvent := func(reason string) { - if errorEventSent { + if errorEventSent || cw.Disconnected() { return } errorEventSent = true @@ -3155,19 +3265,27 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context flusher.Flush() } + // finishUsage 是获取 processor 最终 usage 的辅助函数 + finishUsage := func() *ClaudeUsage { + _, agUsage := processor.Finish() + return convertUsage(agUsage) + } + for { select { case ev, ok := <-events: if !ok { - // 发送结束事件 + // 上游完成,发送结束事件 finalEvents, agUsage := processor.Finish() if len(finalEvents) > 0 { - _, _ = c.Writer.Write(finalEvents) - flusher.Flush() + cw.Write(finalEvents) } - return &antigravityStreamResult{usage: convertUsage(agUsage), firstTokenMs: firstTokenMs}, nil + return &antigravityStreamResult{usage: convertUsage(agUsage), firstTokenMs: firstTokenMs, clientDisconnect: cw.Disconnected()}, nil } if ev.err != nil { + if disconnect, handled := handleStreamReadError(ev.err, cw.Disconnected(), "antigravity claude"); handled { + return &antigravityStreamResult{usage: finishUsage(), firstTokenMs: firstTokenMs, clientDisconnect: disconnect}, nil + } if errors.Is(ev.err, bufio.ErrTooLong) { log.Printf("SSE line too long (antigravity): max_size=%d error=%v", maxLineSize, ev.err) sendErrorEvent("response_too_large") @@ -3177,25 +3295,14 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context return nil, fmt.Errorf("stream read error: %w", ev.err) } - line := ev.line // 处理 SSE 行,转换为 Claude 格式 - claudeEvents := processor.ProcessLine(strings.TrimRight(line, "\r\n")) - + claudeEvents := processor.ProcessLine(strings.TrimRight(ev.line, "\r\n")) if len(claudeEvents) > 0 { if firstTokenMs == nil { ms := int(time.Since(startTime).Milliseconds()) firstTokenMs = &ms } - - if _, writeErr := c.Writer.Write(claudeEvents); writeErr != nil { - finalEvents, agUsage := processor.Finish() - if len(finalEvents) > 0 { - _, _ = c.Writer.Write(finalEvents) - } - sendErrorEvent("write_failed") - return &antigravityStreamResult{usage: convertUsage(agUsage), firstTokenMs: firstTokenMs}, writeErr - } - flusher.Flush() + cw.Write(claudeEvents) } case <-intervalCh: @@ -3203,13 +3310,15 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context if time.Since(lastRead) < streamInterval { continue } + if cw.Disconnected() { + log.Printf("Upstream timeout after client disconnect (antigravity claude), returning collected usage") + return &antigravityStreamResult{usage: finishUsage(), firstTokenMs: firstTokenMs, clientDisconnect: true}, nil + } log.Printf("Stream data interval timeout (antigravity)") - // 注意:此函数没有 account 上下文,无法调用 HandleStreamTimeout sendErrorEvent("stream_timeout") return &antigravityStreamResult{usage: convertUsage(nil), firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout") } } - } // extractImageSize 从 Gemini 请求中提取 image_size 参数 @@ -3348,3 +3457,288 @@ func filterEmptyPartsFromGeminiRequest(body []byte) ([]byte, error) { payload["contents"] = filtered return json.Marshal(payload) } + +// ForwardUpstream 使用 base_url + /v1/messages + 双 header 认证透传上游 Claude 请求 +func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) { + startTime := time.Now() + sessionID := getSessionID(c) + prefix := logPrefix(sessionID, account.Name) + + // 获取上游配置 + baseURL := strings.TrimSpace(account.GetCredential("base_url")) + apiKey := strings.TrimSpace(account.GetCredential("api_key")) + if baseURL == "" || apiKey == "" { + return nil, fmt.Errorf("upstream account missing base_url or api_key") + } + baseURL = strings.TrimSuffix(baseURL, "/") + + // 解析请求获取模型信息 + var claudeReq antigravity.ClaudeRequest + if err := json.Unmarshal(body, &claudeReq); err != nil { + return nil, fmt.Errorf("parse claude request: %w", err) + } + if strings.TrimSpace(claudeReq.Model) == "" { + return nil, fmt.Errorf("missing model") + } + originalModel := claudeReq.Model + billingModel := originalModel + + // 构建上游请求 URL + upstreamURL := baseURL + "/v1/messages" + + // 创建请求 + req, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("create upstream request: %w", err) + } + + // 设置请求头 + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+apiKey) + req.Header.Set("x-api-key", apiKey) // Claude API 兼容 + + // 透传 Claude 相关 headers + if v := c.GetHeader("anthropic-version"); v != "" { + req.Header.Set("anthropic-version", v) + } + if v := c.GetHeader("anthropic-beta"); v != "" { + req.Header.Set("anthropic-beta", v) + } + + // 代理 URL + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + // 发送请求 + resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) + if err != nil { + log.Printf("%s upstream request failed: %v", prefix, err) + return nil, fmt.Errorf("upstream request failed: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + // 处理错误响应 + if resp.StatusCode >= 400 { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + + // 429 错误时标记账号限流 + if resp.StatusCode == http.StatusTooManyRequests { + s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, originalModel, 0, "", false) + } + + // 透传上游错误 + c.Header("Content-Type", resp.Header.Get("Content-Type")) + c.Status(resp.StatusCode) + _, _ = c.Writer.Write(respBody) + + return &ForwardResult{ + Model: billingModel, + }, nil + } + + // 处理成功响应(流式/非流式) + var usage *ClaudeUsage + var firstTokenMs *int + var clientDisconnect bool + + if claudeReq.Stream { + // 流式响应:透传 + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("X-Accel-Buffering", "no") + c.Status(http.StatusOK) + + streamRes := s.streamUpstreamResponse(c, resp, startTime) + usage = streamRes.usage + firstTokenMs = streamRes.firstTokenMs + clientDisconnect = streamRes.clientDisconnect + } else { + // 非流式响应:直接透传 + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("read upstream response: %w", err) + } + + // 提取 usage + usage = s.extractClaudeUsage(respBody) + + c.Header("Content-Type", resp.Header.Get("Content-Type")) + c.Status(http.StatusOK) + _, _ = c.Writer.Write(respBody) + } + + // 构建计费结果 + duration := time.Since(startTime) + log.Printf("%s status=success duration_ms=%d", prefix, duration.Milliseconds()) + + return &ForwardResult{ + Model: billingModel, + Stream: claudeReq.Stream, + Duration: duration, + FirstTokenMs: firstTokenMs, + ClientDisconnect: clientDisconnect, + Usage: ClaudeUsage{ + InputTokens: usage.InputTokens, + OutputTokens: usage.OutputTokens, + CacheReadInputTokens: usage.CacheReadInputTokens, + CacheCreationInputTokens: usage.CacheCreationInputTokens, + }, + }, nil +} + +// streamUpstreamResponse 透传上游 SSE 流并提取 Claude usage +func (s *AntigravityGatewayService) streamUpstreamResponse(c *gin.Context, resp *http.Response, startTime time.Time) *antigravityStreamResult { + usage := &ClaudeUsage{} + var firstTokenMs *int + + scanner := bufio.NewScanner(resp.Body) + maxLineSize := defaultMaxLineSize + if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 { + maxLineSize = s.settingService.cfg.Gateway.MaxLineSize + } + scanner.Buffer(make([]byte, 64*1024), maxLineSize) + + type scanEvent struct { + line string + err error + } + events := make(chan scanEvent, 16) + done := make(chan struct{}) + sendEvent := func(ev scanEvent) bool { + select { + case events <- ev: + return true + case <-done: + return false + } + } + var lastReadAt int64 + atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) + go func() { + defer close(events) + for scanner.Scan() { + atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) + if !sendEvent(scanEvent{line: scanner.Text()}) { + return + } + } + if err := scanner.Err(); err != nil { + _ = sendEvent(scanEvent{err: err}) + } + }() + defer close(done) + + streamInterval := time.Duration(0) + if s.settingService.cfg != nil && s.settingService.cfg.Gateway.StreamDataIntervalTimeout > 0 { + streamInterval = time.Duration(s.settingService.cfg.Gateway.StreamDataIntervalTimeout) * time.Second + } + var intervalTicker *time.Ticker + if streamInterval > 0 { + intervalTicker = time.NewTicker(streamInterval) + defer intervalTicker.Stop() + } + var intervalCh <-chan time.Time + if intervalTicker != nil { + intervalCh = intervalTicker.C + } + + flusher, _ := c.Writer.(http.Flusher) + cw := newAntigravityClientWriter(c.Writer, flusher, "antigravity upstream") + + for { + select { + case ev, ok := <-events: + if !ok { + return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: cw.Disconnected()} + } + if ev.err != nil { + if disconnect, handled := handleStreamReadError(ev.err, cw.Disconnected(), "antigravity upstream"); handled { + return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: disconnect} + } + log.Printf("Stream read error (antigravity upstream): %v", ev.err) + return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs} + } + + line := ev.line + + // 记录首 token 时间 + if firstTokenMs == nil && len(line) > 0 { + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + + // 尝试从 message_delta 或 message_stop 事件提取 usage + s.extractSSEUsage(line, usage) + + // 透传行 + cw.Fprintf("%s\n", line) + + case <-intervalCh: + lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt)) + if time.Since(lastRead) < streamInterval { + continue + } + if cw.Disconnected() { + log.Printf("Upstream timeout after client disconnect (antigravity upstream), returning collected usage") + return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true} + } + log.Printf("Stream data interval timeout (antigravity upstream)") + return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs} + } + } +} + +// extractSSEUsage 从 SSE data 行中提取 Claude usage(用于流式透传场景) +func (s *AntigravityGatewayService) extractSSEUsage(line string, usage *ClaudeUsage) { + if !strings.HasPrefix(line, "data: ") { + return + } + dataStr := strings.TrimPrefix(line, "data: ") + var event map[string]any + if json.Unmarshal([]byte(dataStr), &event) != nil { + return + } + u, ok := event["usage"].(map[string]any) + if !ok { + return + } + if v, ok := u["input_tokens"].(float64); ok && int(v) > 0 { + usage.InputTokens = int(v) + } + if v, ok := u["output_tokens"].(float64); ok && int(v) > 0 { + usage.OutputTokens = int(v) + } + if v, ok := u["cache_read_input_tokens"].(float64); ok && int(v) > 0 { + usage.CacheReadInputTokens = int(v) + } + if v, ok := u["cache_creation_input_tokens"].(float64); ok && int(v) > 0 { + usage.CacheCreationInputTokens = int(v) + } +} + +// extractClaudeUsage 从非流式 Claude 响应提取 usage +func (s *AntigravityGatewayService) extractClaudeUsage(body []byte) *ClaudeUsage { + usage := &ClaudeUsage{} + var resp map[string]any + if json.Unmarshal(body, &resp) != nil { + return usage + } + if u, ok := resp["usage"].(map[string]any); ok { + if v, ok := u["input_tokens"].(float64); ok { + usage.InputTokens = int(v) + } + if v, ok := u["output_tokens"].(float64); ok { + usage.OutputTokens = int(v) + } + if v, ok := u["cache_read_input_tokens"].(float64); ok { + usage.CacheReadInputTokens = int(v) + } + if v, ok := u["cache_creation_input_tokens"].(float64); ok { + usage.CacheCreationInputTokens = int(v) + } + } + return usage +} diff --git a/backend/internal/service/antigravity_gateway_service_test.go b/backend/internal/service/antigravity_gateway_service_test.go index 91cefc28..a6a349c1 100644 --- a/backend/internal/service/antigravity_gateway_service_test.go +++ b/backend/internal/service/antigravity_gateway_service_test.go @@ -4,17 +4,42 @@ import ( "bytes" "context" "encoding/json" + "errors" + "fmt" "io" "net/http" "net/http/httptest" "testing" "time" + "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" ) +// antigravityFailingWriter 模拟客户端断开连接的 gin.ResponseWriter +type antigravityFailingWriter struct { + gin.ResponseWriter + failAfter int // 允许成功写入的次数,之后所有写入返回错误 + writes int +} + +func (w *antigravityFailingWriter) Write(p []byte) (int, error) { + if w.writes >= w.failAfter { + return 0, errors.New("write failed: client disconnected") + } + w.writes++ + return w.ResponseWriter.Write(p) +} + +// newAntigravityTestService 创建用于流式测试的 AntigravityGatewayService +func newAntigravityTestService(cfg *config.Config) *AntigravityGatewayService { + return &AntigravityGatewayService{ + settingService: &SettingService{cfg: cfg}, + } +} + func TestStripSignatureSensitiveBlocksFromClaudeRequest(t *testing.T) { req := &antigravity.ClaudeRequest{ Model: "claude-sonnet-4-5", @@ -337,8 +362,8 @@ func TestAntigravityGatewayService_Forward_StickySessionForceCacheBilling(t *tes require.True(t, failoverErr.ForceCacheBilling, "ForceCacheBilling should be true for sticky session switch") } -// TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling -// 验证:ForwardGemini 粘性会话切换时,UpstreamFailoverError.ForceCacheBilling 应为 true +// TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling verifies +// that ForwardGemini sets ForceCacheBilling=true for sticky session switch. func TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling(t *testing.T) { gin.SetMode(gin.TestMode) writer := httptest.NewRecorder() @@ -391,3 +416,438 @@ func TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling( require.Equal(t, http.StatusServiceUnavailable, failoverErr.StatusCode) require.True(t, failoverErr.ForceCacheBilling, "ForceCacheBilling should be true for sticky session switch") } + +// --- 流式 happy path 测试 --- + +// TestStreamUpstreamResponse_NormalComplete +// 验证:正常流式转发完成时,数据正确透传、usage 正确收集、clientDisconnect=false +func TestStreamUpstreamResponse_NormalComplete(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := newAntigravityTestService(&config.Config{ + Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}, + }) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + pr, pw := io.Pipe() + resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}} + + go func() { + defer func() { _ = pw.Close() }() + fmt.Fprintln(pw, `event: message_start`) + fmt.Fprintln(pw, `data: {"type":"message_start","message":{"usage":{"input_tokens":10}}}`) + fmt.Fprintln(pw, "") + fmt.Fprintln(pw, `event: content_block_delta`) + fmt.Fprintln(pw, `data: {"type":"content_block_delta","delta":{"text":"hello"}}`) + fmt.Fprintln(pw, "") + fmt.Fprintln(pw, `event: message_delta`) + fmt.Fprintln(pw, `data: {"type":"message_delta","usage":{"output_tokens":5}}`) + fmt.Fprintln(pw, "") + }() + + result := svc.streamUpstreamResponse(c, resp, time.Now()) + _ = pr.Close() + + require.NotNil(t, result) + require.False(t, result.clientDisconnect, "normal completion should not set clientDisconnect") + require.NotNil(t, result.usage) + require.Equal(t, 5, result.usage.OutputTokens, "should collect output_tokens from message_delta") + require.NotNil(t, result.firstTokenMs, "should record first token time") + + // 验证数据被透传到客户端 + body := rec.Body.String() + require.Contains(t, body, "event: message_start") + require.Contains(t, body, "content_block_delta") + require.Contains(t, body, "message_delta") +} + +// TestHandleGeminiStreamingResponse_NormalComplete +// 验证:正常 Gemini 流式转发,数据正确透传、usage 正确收集 +func TestHandleGeminiStreamingResponse_NormalComplete(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := newAntigravityTestService(&config.Config{ + Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}, + }) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + pr, pw := io.Pipe() + resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}} + + go func() { + defer func() { _ = pw.Close() }() + // 第一个 chunk(部分内容) + fmt.Fprintln(pw, `data: {"candidates":[{"content":{"parts":[{"text":"Hello"}]}}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":3}}`) + fmt.Fprintln(pw, "") + // 第二个 chunk(最终内容+完整 usage) + fmt.Fprintln(pw, `data: {"candidates":[{"content":{"parts":[{"text":" world"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":8,"cachedContentTokenCount":2}}`) + fmt.Fprintln(pw, "") + }() + + result, err := svc.handleGeminiStreamingResponse(c, resp, time.Now()) + _ = pr.Close() + + require.NoError(t, err) + require.NotNil(t, result) + require.False(t, result.clientDisconnect, "normal completion should not set clientDisconnect") + require.NotNil(t, result.usage) + // Gemini usage: promptTokenCount=10, candidatesTokenCount=8, cachedContentTokenCount=2 + // → InputTokens=10-2=8, OutputTokens=8, CacheReadInputTokens=2 + require.Equal(t, 8, result.usage.InputTokens) + require.Equal(t, 8, result.usage.OutputTokens) + require.Equal(t, 2, result.usage.CacheReadInputTokens) + require.NotNil(t, result.firstTokenMs, "should record first token time") + + // 验证数据被透传到客户端 + body := rec.Body.String() + require.Contains(t, body, "Hello") + require.Contains(t, body, "world") + // 不应包含错误事件 + require.NotContains(t, body, "event: error") +} + +// TestHandleClaudeStreamingResponse_NormalComplete +// 验证:正常 Claude 流式转发(Gemini→Claude 转换),数据正确转换并输出 +func TestHandleClaudeStreamingResponse_NormalComplete(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := newAntigravityTestService(&config.Config{ + Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}, + }) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + pr, pw := io.Pipe() + resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}} + + go func() { + defer func() { _ = pw.Close() }() + // v1internal 包装格式:Gemini 数据嵌套在 "response" 字段下 + // ProcessLine 先尝试反序列化为 V1InternalResponse,裸格式会导致 Response.UsageMetadata 为空 + fmt.Fprintln(pw, `data: {"response":{"candidates":[{"content":{"parts":[{"text":"Hi there"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":5,"candidatesTokenCount":3}}}`) + fmt.Fprintln(pw, "") + }() + + result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5") + _ = pr.Close() + + require.NoError(t, err) + require.NotNil(t, result) + require.False(t, result.clientDisconnect, "normal completion should not set clientDisconnect") + require.NotNil(t, result.usage) + // Gemini→Claude 转换的 usage:promptTokenCount=5→InputTokens=5, candidatesTokenCount=3→OutputTokens=3 + require.Equal(t, 5, result.usage.InputTokens) + require.Equal(t, 3, result.usage.OutputTokens) + require.NotNil(t, result.firstTokenMs, "should record first token time") + + // 验证输出是 Claude SSE 格式(processor 会转换) + body := rec.Body.String() + require.Contains(t, body, "event: message_start", "should contain Claude message_start event") + require.Contains(t, body, "event: message_stop", "should contain Claude message_stop event") + // 不应包含错误事件 + require.NotContains(t, body, "event: error") +} + +// --- 流式客户端断开检测测试 --- + +// TestStreamUpstreamResponse_ClientDisconnectDrainsUsage +// 验证:客户端写入失败后,streamUpstreamResponse 继续读取上游以收集 usage +func TestStreamUpstreamResponse_ClientDisconnectDrainsUsage(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := newAntigravityTestService(&config.Config{ + Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}, + }) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + c.Writer = &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0} + + pr, pw := io.Pipe() + resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}} + + go func() { + defer func() { _ = pw.Close() }() + fmt.Fprintln(pw, `event: message_start`) + fmt.Fprintln(pw, `data: {"type":"message_start","message":{"usage":{"input_tokens":10}}}`) + fmt.Fprintln(pw, "") + fmt.Fprintln(pw, `event: message_delta`) + fmt.Fprintln(pw, `data: {"type":"message_delta","usage":{"output_tokens":20}}`) + fmt.Fprintln(pw, "") + }() + + result := svc.streamUpstreamResponse(c, resp, time.Now()) + _ = pr.Close() + + require.NotNil(t, result) + require.True(t, result.clientDisconnect) + require.NotNil(t, result.usage) + require.Equal(t, 20, result.usage.OutputTokens) +} + +// TestStreamUpstreamResponse_ContextCanceled +// 验证:context 取消时返回 usage 且标记 clientDisconnect +func TestStreamUpstreamResponse_ContextCanceled(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := newAntigravityTestService(&config.Config{ + Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}, + }) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + c.Request = httptest.NewRequest(http.MethodPost, "/", nil).WithContext(ctx) + + resp := &http.Response{StatusCode: http.StatusOK, Body: cancelReadCloser{}, Header: http.Header{}} + + result := svc.streamUpstreamResponse(c, resp, time.Now()) + + require.NotNil(t, result) + require.True(t, result.clientDisconnect) + require.NotContains(t, rec.Body.String(), "event: error") +} + +// TestStreamUpstreamResponse_Timeout +// 验证:上游超时时返回已收集的 usage +func TestStreamUpstreamResponse_Timeout(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := newAntigravityTestService(&config.Config{ + Gateway: config.GatewayConfig{StreamDataIntervalTimeout: 1, MaxLineSize: defaultMaxLineSize}, + }) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + pr, pw := io.Pipe() + resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}} + + result := svc.streamUpstreamResponse(c, resp, time.Now()) + _ = pw.Close() + _ = pr.Close() + + require.NotNil(t, result) + require.False(t, result.clientDisconnect) +} + +// TestStreamUpstreamResponse_TimeoutAfterClientDisconnect +// 验证:客户端断开后上游超时,返回 usage 并标记 clientDisconnect +func TestStreamUpstreamResponse_TimeoutAfterClientDisconnect(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := newAntigravityTestService(&config.Config{ + Gateway: config.GatewayConfig{StreamDataIntervalTimeout: 1, MaxLineSize: defaultMaxLineSize}, + }) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + c.Writer = &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0} + + pr, pw := io.Pipe() + resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}} + + go func() { + fmt.Fprintln(pw, `data: {"type":"message_start","message":{"usage":{"input_tokens":5}}}`) + fmt.Fprintln(pw, "") + // 不关闭 pw → 等待超时 + }() + + result := svc.streamUpstreamResponse(c, resp, time.Now()) + _ = pw.Close() + _ = pr.Close() + + require.NotNil(t, result) + require.True(t, result.clientDisconnect) +} + +// TestHandleGeminiStreamingResponse_ClientDisconnect +// 验证:Gemini 流式转发中客户端断开后继续 drain 上游 +func TestHandleGeminiStreamingResponse_ClientDisconnect(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := newAntigravityTestService(&config.Config{ + Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}, + }) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + c.Writer = &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0} + + pr, pw := io.Pipe() + resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}} + + go func() { + defer func() { _ = pw.Close() }() + fmt.Fprintln(pw, `data: {"candidates":[{"content":{"parts":[{"text":"hi"}]}}],"usageMetadata":{"promptTokenCount":5,"candidatesTokenCount":10}}`) + fmt.Fprintln(pw, "") + }() + + result, err := svc.handleGeminiStreamingResponse(c, resp, time.Now()) + _ = pr.Close() + + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, result.clientDisconnect) + require.NotContains(t, rec.Body.String(), "write_failed") +} + +// TestHandleGeminiStreamingResponse_ContextCanceled +// 验证:context 取消时不注入错误事件 +func TestHandleGeminiStreamingResponse_ContextCanceled(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := newAntigravityTestService(&config.Config{ + Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}, + }) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + c.Request = httptest.NewRequest(http.MethodPost, "/", nil).WithContext(ctx) + + resp := &http.Response{StatusCode: http.StatusOK, Body: cancelReadCloser{}, Header: http.Header{}} + + result, err := svc.handleGeminiStreamingResponse(c, resp, time.Now()) + + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, result.clientDisconnect) + require.NotContains(t, rec.Body.String(), "event: error") +} + +// TestHandleClaudeStreamingResponse_ClientDisconnect +// 验证:Claude 流式转发中客户端断开后继续 drain 上游 +func TestHandleClaudeStreamingResponse_ClientDisconnect(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := newAntigravityTestService(&config.Config{ + Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}, + }) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + c.Writer = &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0} + + pr, pw := io.Pipe() + resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}} + + go func() { + defer func() { _ = pw.Close() }() + // v1internal 包装格式 + fmt.Fprintln(pw, `data: {"response":{"candidates":[{"content":{"parts":[{"text":"hello"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":8,"candidatesTokenCount":15}}}`) + fmt.Fprintln(pw, "") + }() + + result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5") + _ = pr.Close() + + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, result.clientDisconnect) +} + +// TestHandleClaudeStreamingResponse_ContextCanceled +// 验证:context 取消时不注入错误事件 +func TestHandleClaudeStreamingResponse_ContextCanceled(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := newAntigravityTestService(&config.Config{ + Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}, + }) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + c.Request = httptest.NewRequest(http.MethodPost, "/", nil).WithContext(ctx) + + resp := &http.Response{StatusCode: http.StatusOK, Body: cancelReadCloser{}, Header: http.Header{}} + + result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5") + + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, result.clientDisconnect) + require.NotContains(t, rec.Body.String(), "event: error") +} + +// TestExtractSSEUsage 验证 extractSSEUsage 从 SSE data 行正确提取 usage +func TestExtractSSEUsage(t *testing.T) { + svc := &AntigravityGatewayService{} + tests := []struct { + name string + line string + expected ClaudeUsage + }{ + { + name: "message_delta with output_tokens", + line: `data: {"type":"message_delta","usage":{"output_tokens":42}}`, + expected: ClaudeUsage{OutputTokens: 42}, + }, + { + name: "non-data line ignored", + line: `event: message_start`, + expected: ClaudeUsage{}, + }, + { + name: "top-level usage with all fields", + line: `data: {"usage":{"input_tokens":10,"output_tokens":20,"cache_read_input_tokens":5,"cache_creation_input_tokens":3}}`, + expected: ClaudeUsage{InputTokens: 10, OutputTokens: 20, CacheReadInputTokens: 5, CacheCreationInputTokens: 3}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + usage := &ClaudeUsage{} + svc.extractSSEUsage(tt.line, usage) + require.Equal(t, tt.expected, *usage) + }) + } +} + +// TestAntigravityClientWriter 验证 antigravityClientWriter 的断开检测 +func TestAntigravityClientWriter(t *testing.T) { + t.Run("normal write succeeds", func(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + flusher, _ := c.Writer.(http.Flusher) + cw := newAntigravityClientWriter(c.Writer, flusher, "test") + + ok := cw.Write([]byte("hello")) + require.True(t, ok) + require.False(t, cw.Disconnected()) + require.Contains(t, rec.Body.String(), "hello") + }) + + t.Run("write failure marks disconnected", func(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + fw := &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0} + flusher, _ := c.Writer.(http.Flusher) + cw := newAntigravityClientWriter(fw, flusher, "test") + + ok := cw.Write([]byte("hello")) + require.False(t, ok) + require.True(t, cw.Disconnected()) + }) + + t.Run("subsequent writes are no-op", func(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + fw := &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0} + flusher, _ := c.Writer.(http.Flusher) + cw := newAntigravityClientWriter(fw, flusher, "test") + + cw.Write([]byte("first")) + ok := cw.Fprintf("second %d", 2) + require.False(t, ok) + require.True(t, cw.Disconnected()) + }) +} diff --git a/backend/internal/service/antigravity_quota_scope.go b/backend/internal/service/antigravity_quota_scope.go index 43ac6c2f..e181e7f8 100644 --- a/backend/internal/service/antigravity_quota_scope.go +++ b/backend/internal/service/antigravity_quota_scope.go @@ -2,63 +2,23 @@ package service import ( "context" - "slices" "strings" "time" ) -const antigravityQuotaScopesKey = "antigravity_quota_scopes" - -// AntigravityQuotaScope 表示 Antigravity 的配额域 -type AntigravityQuotaScope string - -const ( - AntigravityQuotaScopeClaude AntigravityQuotaScope = "claude" - AntigravityQuotaScopeGeminiText AntigravityQuotaScope = "gemini_text" - AntigravityQuotaScopeGeminiImage AntigravityQuotaScope = "gemini_image" -) - -// IsScopeSupported 检查给定的 scope 是否在分组支持的 scope 列表中 -func IsScopeSupported(supportedScopes []string, scope AntigravityQuotaScope) bool { - if len(supportedScopes) == 0 { - // 未配置时默认全部支持 - return true - } - supported := slices.Contains(supportedScopes, string(scope)) - return supported -} - -// ResolveAntigravityQuotaScope 根据模型名称解析配额域(导出版本) -func ResolveAntigravityQuotaScope(requestedModel string) (AntigravityQuotaScope, bool) { - return resolveAntigravityQuotaScope(requestedModel) -} - -// resolveAntigravityQuotaScope 根据模型名称解析配额域 -func resolveAntigravityQuotaScope(requestedModel string) (AntigravityQuotaScope, bool) { - model := normalizeAntigravityModelName(requestedModel) - if model == "" { - return "", false - } - switch { - case strings.HasPrefix(model, "claude-"): - return AntigravityQuotaScopeClaude, true - case strings.HasPrefix(model, "gemini-"): - if isImageGenerationModel(model) { - return AntigravityQuotaScopeGeminiImage, true - } - return AntigravityQuotaScopeGeminiText, true - default: - return "", false - } -} - func normalizeAntigravityModelName(model string) string { normalized := strings.ToLower(strings.TrimSpace(model)) normalized = strings.TrimPrefix(normalized, "models/") return normalized } -// IsSchedulableForModel 结合 Antigravity 配额域限流判断是否可调度。 +// resolveAntigravityModelKey 根据请求的模型名解析限流 key +// 返回空字符串表示无法解析 +func resolveAntigravityModelKey(requestedModel string) string { + return normalizeAntigravityModelName(requestedModel) +} + +// IsSchedulableForModel 结合模型级限流判断是否可调度。 // 保持旧签名以兼容既有调用方;默认使用 context.Background()。 func (a *Account) IsSchedulableForModel(requestedModel string) bool { return a.IsSchedulableForModelWithContext(context.Background(), requestedModel) @@ -74,107 +34,20 @@ func (a *Account) IsSchedulableForModelWithContext(ctx context.Context, requeste if a.isModelRateLimitedWithContext(ctx, requestedModel) { return false } - if a.Platform != PlatformAntigravity { - return true - } - scope, ok := resolveAntigravityQuotaScope(requestedModel) - if !ok { - return true - } - resetAt := a.antigravityQuotaScopeResetAt(scope) - if resetAt == nil { - return true - } - now := time.Now() - return !now.Before(*resetAt) + return true } -func (a *Account) antigravityQuotaScopeResetAt(scope AntigravityQuotaScope) *time.Time { - if a == nil || a.Extra == nil || scope == "" { - return nil - } - rawScopes, ok := a.Extra[antigravityQuotaScopesKey].(map[string]any) - if !ok { - return nil - } - rawScope, ok := rawScopes[string(scope)].(map[string]any) - if !ok { - return nil - } - resetAtRaw, ok := rawScope["rate_limit_reset_at"].(string) - if !ok || strings.TrimSpace(resetAtRaw) == "" { - return nil - } - resetAt, err := time.Parse(time.RFC3339, resetAtRaw) - if err != nil { - return nil - } - return &resetAt -} - -var antigravityAllScopes = []AntigravityQuotaScope{ - AntigravityQuotaScopeClaude, - AntigravityQuotaScopeGeminiText, - AntigravityQuotaScopeGeminiImage, -} - -func (a *Account) GetAntigravityScopeRateLimits() map[string]int64 { - if a == nil || a.Platform != PlatformAntigravity { - return nil - } - now := time.Now() - result := make(map[string]int64) - for _, scope := range antigravityAllScopes { - resetAt := a.antigravityQuotaScopeResetAt(scope) - if resetAt != nil && now.Before(*resetAt) { - remainingSec := int64(time.Until(*resetAt).Seconds()) - if remainingSec > 0 { - result[string(scope)] = remainingSec - } - } - } - if len(result) == 0 { - return nil - } - return result -} - -// GetQuotaScopeRateLimitRemainingTime 获取模型域限流剩余时间 -// 返回 0 表示未限流或已过期 -func (a *Account) GetQuotaScopeRateLimitRemainingTime(requestedModel string) time.Duration { - if a == nil || a.Platform != PlatformAntigravity { - return 0 - } - scope, ok := resolveAntigravityQuotaScope(requestedModel) - if !ok { - return 0 - } - resetAt := a.antigravityQuotaScopeResetAt(scope) - if resetAt == nil { - return 0 - } - if remaining := time.Until(*resetAt); remaining > 0 { - return remaining - } - return 0 -} - -// GetRateLimitRemainingTime 获取限流剩余时间(模型限流和模型域限流取最大值) +// GetRateLimitRemainingTime 获取限流剩余时间(模型级限流) // 返回 0 表示未限流或已过期 func (a *Account) GetRateLimitRemainingTime(requestedModel string) time.Duration { return a.GetRateLimitRemainingTimeWithContext(context.Background(), requestedModel) } -// GetRateLimitRemainingTimeWithContext 获取限流剩余时间(模型限流和模型域限流取最大值) +// GetRateLimitRemainingTimeWithContext 获取限流剩余时间(模型级限流) // 返回 0 表示未限流或已过期 func (a *Account) GetRateLimitRemainingTimeWithContext(ctx context.Context, requestedModel string) time.Duration { if a == nil { return 0 } - modelRemaining := a.GetModelRateLimitRemainingTimeWithContext(ctx, requestedModel) - scopeRemaining := a.GetQuotaScopeRateLimitRemainingTime(requestedModel) - if modelRemaining > scopeRemaining { - return modelRemaining - } - return scopeRemaining + return a.GetModelRateLimitRemainingTimeWithContext(ctx, requestedModel) } diff --git a/backend/internal/service/antigravity_rate_limit_test.go b/backend/internal/service/antigravity_rate_limit_test.go index cd2a7a4a..59cc9331 100644 --- a/backend/internal/service/antigravity_rate_limit_test.go +++ b/backend/internal/service/antigravity_rate_limit_test.go @@ -59,12 +59,6 @@ func (s *stubAntigravityUpstream) DoWithTLS(req *http.Request, proxyURL string, return s.Do(req, proxyURL, accountID, accountConcurrency) } -type scopeLimitCall struct { - accountID int64 - scope AntigravityQuotaScope - resetAt time.Time -} - type rateLimitCall struct { accountID int64 resetAt time.Time @@ -78,16 +72,10 @@ type modelRateLimitCall struct { type stubAntigravityAccountRepo struct { AccountRepository - scopeCalls []scopeLimitCall rateCalls []rateLimitCall modelRateLimitCalls []modelRateLimitCall } -func (s *stubAntigravityAccountRepo) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error { - s.scopeCalls = append(s.scopeCalls, scopeLimitCall{accountID: id, scope: scope, resetAt: resetAt}) - return nil -} - func (s *stubAntigravityAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { s.rateCalls = append(s.rateCalls, rateLimitCall{accountID: id, resetAt: resetAt}) return nil @@ -131,10 +119,9 @@ func TestAntigravityRetryLoop_URLFallback_UsesLatestSuccess(t *testing.T) { accessToken: "token", action: "generateContent", body: []byte(`{"input":"test"}`), - quotaScope: AntigravityQuotaScopeClaude, httpUpstream: upstream, requestedModel: "claude-sonnet-4-5", - handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { handleErrorCalled = true return nil }, @@ -155,23 +142,6 @@ func TestAntigravityRetryLoop_URLFallback_UsesLatestSuccess(t *testing.T) { require.Equal(t, base2, available[0]) } -func TestAntigravityHandleUpstreamError_UsesScopeLimit(t *testing.T) { - // 分区限流始终开启,不再支持通过环境变量关闭 - repo := &stubAntigravityAccountRepo{} - svc := &AntigravityGatewayService{accountRepo: repo} - account := &Account{ID: 9, Name: "acc-9", Platform: PlatformAntigravity} - - body := buildGeminiRateLimitBody("3s") - svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, AntigravityQuotaScopeClaude, 0, "", false) - - require.Len(t, repo.scopeCalls, 1) - require.Empty(t, repo.rateCalls) - call := repo.scopeCalls[0] - require.Equal(t, account.ID, call.accountID) - require.Equal(t, AntigravityQuotaScopeClaude, call.scope) - require.WithinDuration(t, time.Now().Add(3*time.Second), call.resetAt, 2*time.Second) -} - // TestHandleUpstreamError_429_ModelRateLimit 测试 429 模型限流场景 func TestHandleUpstreamError_429_ModelRateLimit(t *testing.T) { repo := &stubAntigravityAccountRepo{} @@ -189,7 +159,7 @@ func TestHandleUpstreamError_429_ModelRateLimit(t *testing.T) { } }`) - result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, AntigravityQuotaScopeClaude, 0, "", false) + result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, "claude-sonnet-4-5", 0, "", false) // 应该触发模型限流 require.NotNil(t, result) @@ -200,22 +170,22 @@ func TestHandleUpstreamError_429_ModelRateLimit(t *testing.T) { require.Equal(t, "claude-sonnet-4-5", repo.modelRateLimitCalls[0].modelKey) } -// TestHandleUpstreamError_429_NonModelRateLimit 测试 429 非模型限流场景(走 scope 限流) +// TestHandleUpstreamError_429_NonModelRateLimit 测试 429 非模型限流场景(走模型级限流兜底) func TestHandleUpstreamError_429_NonModelRateLimit(t *testing.T) { repo := &stubAntigravityAccountRepo{} svc := &AntigravityGatewayService{accountRepo: repo} account := &Account{ID: 2, Name: "acc-2", Platform: PlatformAntigravity} - // 429 + 普通限流响应(无 RATE_LIMIT_EXCEEDED reason)→ scope 限流 + // 429 + 普通限流响应(无 RATE_LIMIT_EXCEEDED reason)→ 走模型级限流兜底 body := buildGeminiRateLimitBody("5s") - result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, AntigravityQuotaScopeClaude, 0, "", false) + result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, "claude-sonnet-4-5", 0, "", false) - // 不应该触发模型限流,应该走 scope 限流 + // handleModelRateLimit 不会处理(因为没有 RATE_LIMIT_EXCEEDED), + // 但 429 兜底逻辑会使用 requestedModel 设置模型级限流 require.Nil(t, result) - require.Empty(t, repo.modelRateLimitCalls) - require.Len(t, repo.scopeCalls, 1) - require.Equal(t, AntigravityQuotaScopeClaude, repo.scopeCalls[0].scope) + require.Len(t, repo.modelRateLimitCalls, 1) + require.Equal(t, "claude-sonnet-4-5", repo.modelRateLimitCalls[0].modelKey) } // TestHandleUpstreamError_503_ModelRateLimit 测试 503 模型限流场景 @@ -235,7 +205,7 @@ func TestHandleUpstreamError_503_ModelRateLimit(t *testing.T) { } }`) - result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, AntigravityQuotaScopeGeminiText, 0, "", false) + result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, "gemini-3-pro-high", 0, "", false) // 应该触发模型限流 require.NotNil(t, result) @@ -263,12 +233,11 @@ func TestHandleUpstreamError_503_NonModelRateLimit(t *testing.T) { } }`) - result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, AntigravityQuotaScopeGeminiText, 0, "", false) + result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, "gemini-3-pro-high", 0, "", false) // 503 非模型限流不应该做任何处理 require.Nil(t, result) require.Empty(t, repo.modelRateLimitCalls, "503 non-model rate limit should not trigger model rate limit") - require.Empty(t, repo.scopeCalls, "503 non-model rate limit should not trigger scope rate limit") require.Empty(t, repo.rateCalls, "503 non-model rate limit should not trigger account rate limit") } @@ -281,12 +250,11 @@ func TestHandleUpstreamError_503_EmptyBody(t *testing.T) { // 503 + 空响应体 → 不做任何处理 body := []byte(`{}`) - result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, AntigravityQuotaScopeGeminiText, 0, "", false) + result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, "gemini-3-pro-high", 0, "", false) // 503 空响应不应该做任何处理 require.Nil(t, result) require.Empty(t, repo.modelRateLimitCalls) - require.Empty(t, repo.scopeCalls) require.Empty(t, repo.rateCalls) } @@ -307,15 +275,7 @@ func TestAccountIsSchedulableForModel_AntigravityRateLimits(t *testing.T) { require.False(t, account.IsSchedulableForModel("gemini-3-flash")) account.RateLimitResetAt = nil - account.Extra = map[string]any{ - antigravityQuotaScopesKey: map[string]any{ - "claude": map[string]any{ - "rate_limit_reset_at": future.Format(time.RFC3339), - }, - }, - } - - require.False(t, account.IsSchedulableForModel("claude-sonnet-4-5")) + require.True(t, account.IsSchedulableForModel("claude-sonnet-4-5")) require.True(t, account.IsSchedulableForModel("gemini-3-flash")) } @@ -635,6 +595,7 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) { }`, expectedShouldRetry: false, expectedShouldRateLimit: true, + minWait: 7 * time.Second, modelName: "gemini-pro", }, { @@ -652,6 +613,7 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) { }`, expectedShouldRetry: false, expectedShouldRateLimit: true, + minWait: 39 * time.Second, modelName: "gemini-3-pro-high", }, { @@ -669,6 +631,7 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) { }`, expectedShouldRetry: false, expectedShouldRateLimit: true, + minWait: 30 * time.Second, modelName: "gemini-2.5-flash", }, { @@ -686,6 +649,7 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) { }`, expectedShouldRetry: false, expectedShouldRateLimit: true, + minWait: 30 * time.Second, modelName: "claude-sonnet-4-5", }, } @@ -704,6 +668,11 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) { t.Errorf("wait = %v, want >= %v", wait, tt.minWait) } } + if shouldRateLimit && tt.minWait > 0 { + if wait < tt.minWait { + t.Errorf("rate limit wait = %v, want >= %v", wait, tt.minWait) + } + } if (shouldRetry || shouldRateLimit) && model != tt.modelName { t.Errorf("modelName = %q, want %q", model, tt.modelName) } @@ -832,7 +801,7 @@ func TestAntigravityRetryLoop_PreCheck_SwitchesWhenRateLimited(t *testing.T) { requestedModel: "claude-sonnet-4-5", httpUpstream: upstream, isStickySession: true, - handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { return nil }, }) @@ -875,7 +844,7 @@ func TestAntigravityRetryLoop_PreCheck_SwitchesWhenRemainingLong(t *testing.T) { requestedModel: "claude-sonnet-4-5", httpUpstream: upstream, isStickySession: true, - handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { return nil }, }) diff --git a/backend/internal/service/antigravity_smart_retry_test.go b/backend/internal/service/antigravity_smart_retry_test.go index 999b408f..a7e0d296 100644 --- a/backend/internal/service/antigravity_smart_retry_test.go +++ b/backend/internal/service/antigravity_smart_retry_test.go @@ -75,7 +75,7 @@ func TestHandleSmartRetry_URLLevelRateLimit(t *testing.T) { accessToken: "token", action: "generateContent", body: []byte(`{"input":"test"}`), - handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { return nil }, } @@ -127,7 +127,7 @@ func TestHandleSmartRetry_LongDelay_ReturnsSwitchError(t *testing.T) { body: []byte(`{"input":"test"}`), accountRepo: repo, isStickySession: true, - handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { return nil }, } @@ -194,7 +194,7 @@ func TestHandleSmartRetry_ShortDelay_SmartRetrySuccess(t *testing.T) { action: "generateContent", body: []byte(`{"input":"test"}`), httpUpstream: upstream, - handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { return nil }, } @@ -269,7 +269,7 @@ func TestHandleSmartRetry_ShortDelay_SmartRetryFailed_ReturnsSwitchError(t *test httpUpstream: upstream, accountRepo: repo, isStickySession: false, - handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { return nil }, } @@ -331,7 +331,7 @@ func TestHandleSmartRetry_503_ModelCapacityExhausted_ReturnsSwitchError(t *testi body: []byte(`{"input":"test"}`), accountRepo: repo, isStickySession: true, - handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { return nil }, } @@ -387,7 +387,7 @@ func TestHandleSmartRetry_NonAntigravityAccount_ContinuesDefaultLogic(t *testing accessToken: "token", action: "generateContent", body: []byte(`{"input":"test"}`), - handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { return nil }, } @@ -436,7 +436,7 @@ func TestHandleSmartRetry_NonModelRateLimit_ContinuesDefaultLogic(t *testing.T) accessToken: "token", action: "generateContent", body: []byte(`{"input":"test"}`), - handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { return nil }, } @@ -487,7 +487,7 @@ func TestHandleSmartRetry_ExactlyAtThreshold_ReturnsSwitchError(t *testing.T) { action: "generateContent", body: []byte(`{"input":"test"}`), accountRepo: repo, - handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { return nil }, } @@ -548,7 +548,7 @@ func TestAntigravityRetryLoop_HandleSmartRetry_SwitchError_Propagates(t *testing httpUpstream: upstream, accountRepo: repo, isStickySession: true, - handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { return nil }, }) @@ -604,7 +604,7 @@ func TestHandleSmartRetry_NetworkError_ExhaustsRetry(t *testing.T) { body: []byte(`{"input":"test"}`), httpUpstream: upstream, accountRepo: repo, - handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { return nil }, } @@ -662,7 +662,7 @@ func TestHandleSmartRetry_NoRetryDelay_UsesDefaultRateLimit(t *testing.T) { body: []byte(`{"input":"test"}`), accountRepo: repo, isStickySession: true, - handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { return nil }, } @@ -754,7 +754,7 @@ func TestHandleSmartRetry_ShortDelay_StickySession_FailedRetry_ClearsSession(t * isStickySession: true, groupID: 42, sessionHash: "sticky-hash-abc", - handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { return nil }, } @@ -842,7 +842,7 @@ func TestHandleSmartRetry_ShortDelay_NonStickySession_FailedRetry_NoDeleteSessio isStickySession: false, groupID: 42, sessionHash: "", // 非粘性会话,sessionHash 为空 - handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { return nil }, } @@ -918,7 +918,7 @@ func TestHandleSmartRetry_ShortDelay_StickySession_FailedRetry_NilCache_NoPanic( isStickySession: true, groupID: 42, sessionHash: "sticky-hash-nil-cache", - handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { return nil }, } @@ -983,7 +983,7 @@ func TestHandleSmartRetry_ShortDelay_StickySession_SuccessRetry_NoDeleteSession( isStickySession: true, groupID: 42, sessionHash: "sticky-hash-success", - handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { return nil }, } @@ -1043,7 +1043,7 @@ func TestHandleSmartRetry_LongDelay_StickySession_NoDeleteInHandleSmartRetry(t * isStickySession: true, groupID: 42, sessionHash: "sticky-hash-long-delay", - handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { return nil }, } @@ -1108,7 +1108,7 @@ func TestHandleSmartRetry_ShortDelay_NetworkError_StickySession_ClearsSession(t isStickySession: true, groupID: 99, sessionHash: "sticky-net-error", - handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { return nil }, } @@ -1188,7 +1188,7 @@ func TestHandleSmartRetry_ShortDelay_503_StickySession_FailedRetry_ClearsSession isStickySession: true, groupID: 77, sessionHash: "sticky-503-short", - handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { return nil }, } @@ -1278,7 +1278,7 @@ func TestAntigravityRetryLoop_SmartRetryFailed_StickySession_SwitchErrorPropagat isStickySession: true, groupID: 55, sessionHash: "sticky-loop-test", - handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { return nil }, }) @@ -1296,4 +1296,4 @@ func TestAntigravityRetryLoop_SmartRetryFailed_StickySession_SwitchErrorPropagat require.Len(t, cache.deleteCalls, 1, "should clear sticky session in handleSmartRetry") require.Equal(t, int64(55), cache.deleteCalls[0].groupID) require.Equal(t, "sticky-loop-test", cache.deleteCalls[0].sessionHash) -} \ No newline at end of file +} diff --git a/backend/internal/service/digest_session_store.go b/backend/internal/service/digest_session_store.go new file mode 100644 index 00000000..3ac08936 --- /dev/null +++ b/backend/internal/service/digest_session_store.go @@ -0,0 +1,69 @@ +package service + +import ( + "strconv" + "strings" + "time" + + gocache "github.com/patrickmn/go-cache" +) + +// digestSessionTTL 摘要会话默认 TTL +const digestSessionTTL = 5 * time.Minute + +// sessionEntry flat cache 条目 +type sessionEntry struct { + uuid string + accountID int64 +} + +// DigestSessionStore 内存摘要会话存储(flat cache 实现) +// key: "{groupID}:{prefixHash}|{digestChain}" → *sessionEntry +type DigestSessionStore struct { + cache *gocache.Cache +} + +// NewDigestSessionStore 创建内存摘要会话存储 +func NewDigestSessionStore() *DigestSessionStore { + return &DigestSessionStore{ + cache: gocache.New(digestSessionTTL, time.Minute), + } +} + +// Save 保存摘要会话。oldDigestChain 为 Find 返回的 matchedChain,用于删旧 key。 +func (s *DigestSessionStore) Save(groupID int64, prefixHash, digestChain, uuid string, accountID int64, oldDigestChain string) { + if digestChain == "" { + return + } + ns := buildNS(groupID, prefixHash) + s.cache.Set(ns+digestChain, &sessionEntry{uuid: uuid, accountID: accountID}, gocache.DefaultExpiration) + if oldDigestChain != "" && oldDigestChain != digestChain { + s.cache.Delete(ns + oldDigestChain) + } +} + +// Find 查找摘要会话,从完整 chain 逐段截断,返回最长匹配及对应 matchedChain。 +func (s *DigestSessionStore) Find(groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, matchedChain string, found bool) { + if digestChain == "" { + return "", 0, "", false + } + ns := buildNS(groupID, prefixHash) + chain := digestChain + for { + if val, ok := s.cache.Get(ns + chain); ok { + if e, ok := val.(*sessionEntry); ok { + return e.uuid, e.accountID, chain, true + } + } + i := strings.LastIndex(chain, "-") + if i < 0 { + return "", 0, "", false + } + chain = chain[:i] + } +} + +// buildNS 构建 namespace 前缀 +func buildNS(groupID int64, prefixHash string) string { + return strconv.FormatInt(groupID, 10) + ":" + prefixHash + "|" +} diff --git a/backend/internal/service/digest_session_store_test.go b/backend/internal/service/digest_session_store_test.go new file mode 100644 index 00000000..e505bf30 --- /dev/null +++ b/backend/internal/service/digest_session_store_test.go @@ -0,0 +1,312 @@ +//go:build unit + +package service + +import ( + "fmt" + "sync" + "testing" + "time" + + gocache "github.com/patrickmn/go-cache" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDigestSessionStore_SaveAndFind(t *testing.T) { + store := NewDigestSessionStore() + + store.Save(1, "prefix", "s:a1-u:b2-m:c3", "uuid-1", 100, "") + + uuid, accountID, _, found := store.Find(1, "prefix", "s:a1-u:b2-m:c3") + require.True(t, found) + assert.Equal(t, "uuid-1", uuid) + assert.Equal(t, int64(100), accountID) +} + +func TestDigestSessionStore_PrefixMatch(t *testing.T) { + store := NewDigestSessionStore() + + // 保存短链 + store.Save(1, "prefix", "u:a-m:b", "uuid-short", 10, "") + + // 用长链查找,应前缀匹配到短链 + uuid, accountID, matchedChain, found := store.Find(1, "prefix", "u:a-m:b-u:c-m:d") + require.True(t, found) + assert.Equal(t, "uuid-short", uuid) + assert.Equal(t, int64(10), accountID) + assert.Equal(t, "u:a-m:b", matchedChain) +} + +func TestDigestSessionStore_LongestPrefixMatch(t *testing.T) { + store := NewDigestSessionStore() + + store.Save(1, "prefix", "u:a", "uuid-1", 1, "") + store.Save(1, "prefix", "u:a-m:b", "uuid-2", 2, "") + store.Save(1, "prefix", "u:a-m:b-u:c", "uuid-3", 3, "") + + // 应匹配最深的 "u:a-m:b-u:c"(从完整 chain 逐段截断,先命中最长的) + uuid, accountID, _, found := store.Find(1, "prefix", "u:a-m:b-u:c-m:d-u:e") + require.True(t, found) + assert.Equal(t, "uuid-3", uuid) + assert.Equal(t, int64(3), accountID) + + // 查找中等长度,应匹配到 "u:a-m:b" + uuid, accountID, _, found = store.Find(1, "prefix", "u:a-m:b-u:x") + require.True(t, found) + assert.Equal(t, "uuid-2", uuid) + assert.Equal(t, int64(2), accountID) +} + +func TestDigestSessionStore_SaveDeletesOldChain(t *testing.T) { + store := NewDigestSessionStore() + + // 第一轮:保存 "u:a-m:b" + store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "") + + // 第二轮:同一 uuid 保存更长的链,传入旧 chain + store.Save(1, "prefix", "u:a-m:b-u:c-m:d", "uuid-1", 100, "u:a-m:b") + + // 旧链 "u:a-m:b" 应已被删除 + _, _, _, found := store.Find(1, "prefix", "u:a-m:b") + assert.False(t, found, "old chain should be deleted") + + // 新链应能找到 + uuid, accountID, _, found := store.Find(1, "prefix", "u:a-m:b-u:c-m:d") + require.True(t, found) + assert.Equal(t, "uuid-1", uuid) + assert.Equal(t, int64(100), accountID) +} + +func TestDigestSessionStore_DifferentSessionsNoInterference(t *testing.T) { + store := NewDigestSessionStore() + + // 相同系统提示词,不同用户提示词 + store.Save(1, "prefix", "s:sys-u:user1", "uuid-1", 100, "") + store.Save(1, "prefix", "s:sys-u:user2", "uuid-2", 200, "") + + uuid, accountID, _, found := store.Find(1, "prefix", "s:sys-u:user1-m:reply1") + require.True(t, found) + assert.Equal(t, "uuid-1", uuid) + assert.Equal(t, int64(100), accountID) + + uuid, accountID, _, found = store.Find(1, "prefix", "s:sys-u:user2-m:reply2") + require.True(t, found) + assert.Equal(t, "uuid-2", uuid) + assert.Equal(t, int64(200), accountID) +} + +func TestDigestSessionStore_NoMatch(t *testing.T) { + store := NewDigestSessionStore() + + store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "") + + // 完全不同的 chain + _, _, _, found := store.Find(1, "prefix", "u:x-m:y") + assert.False(t, found) +} + +func TestDigestSessionStore_DifferentPrefixHash(t *testing.T) { + store := NewDigestSessionStore() + + store.Save(1, "prefix1", "u:a-m:b", "uuid-1", 100, "") + + // 不同 prefixHash 应隔离 + _, _, _, found := store.Find(1, "prefix2", "u:a-m:b") + assert.False(t, found) +} + +func TestDigestSessionStore_DifferentGroupID(t *testing.T) { + store := NewDigestSessionStore() + + store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "") + + // 不同 groupID 应隔离 + _, _, _, found := store.Find(2, "prefix", "u:a-m:b") + assert.False(t, found) +} + +func TestDigestSessionStore_EmptyDigestChain(t *testing.T) { + store := NewDigestSessionStore() + + // 空链不应保存 + store.Save(1, "prefix", "", "uuid-1", 100, "") + _, _, _, found := store.Find(1, "prefix", "") + assert.False(t, found) +} + +func TestDigestSessionStore_TTLExpiration(t *testing.T) { + store := &DigestSessionStore{ + cache: gocache.New(100*time.Millisecond, 50*time.Millisecond), + } + + store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "") + + // 立即应该能找到 + _, _, _, found := store.Find(1, "prefix", "u:a-m:b") + require.True(t, found) + + // 等待过期 + 清理周期 + time.Sleep(300 * time.Millisecond) + + // 过期后应找不到 + _, _, _, found = store.Find(1, "prefix", "u:a-m:b") + assert.False(t, found) +} + +func TestDigestSessionStore_ConcurrentSafety(t *testing.T) { + store := NewDigestSessionStore() + + var wg sync.WaitGroup + const goroutines = 50 + const operations = 100 + + wg.Add(goroutines) + for g := 0; g < goroutines; g++ { + go func(id int) { + defer wg.Done() + prefix := fmt.Sprintf("prefix-%d", id%5) + for i := 0; i < operations; i++ { + chain := fmt.Sprintf("u:%d-m:%d", id, i) + uuid := fmt.Sprintf("uuid-%d-%d", id, i) + store.Save(1, prefix, chain, uuid, int64(id), "") + store.Find(1, prefix, chain) + } + }(g) + } + wg.Wait() +} + +func TestDigestSessionStore_MultipleSessions(t *testing.T) { + store := NewDigestSessionStore() + + sessions := []struct { + chain string + uuid string + accountID int64 + }{ + {"u:session1", "uuid-1", 1}, + {"u:session2-m:reply2", "uuid-2", 2}, + {"u:session3-m:reply3-u:msg3", "uuid-3", 3}, + } + + for _, sess := range sessions { + store.Save(1, "prefix", sess.chain, sess.uuid, sess.accountID, "") + } + + // 验证每个会话都能正确查找 + for _, sess := range sessions { + uuid, accountID, _, found := store.Find(1, "prefix", sess.chain) + require.True(t, found, "should find session: %s", sess.chain) + assert.Equal(t, sess.uuid, uuid) + assert.Equal(t, sess.accountID, accountID) + } + + // 验证继续对话的场景 + uuid, accountID, _, found := store.Find(1, "prefix", "u:session2-m:reply2-u:newmsg") + require.True(t, found) + assert.Equal(t, "uuid-2", uuid) + assert.Equal(t, int64(2), accountID) +} + +func TestDigestSessionStore_Performance1000Sessions(t *testing.T) { + store := NewDigestSessionStore() + + // 插入 1000 个会话 + for i := 0; i < 1000; i++ { + chain := fmt.Sprintf("s:sys-u:user%d-m:reply%d", i, i) + store.Save(1, "prefix", chain, fmt.Sprintf("uuid-%d", i), int64(i), "") + } + + // 查找性能测试 + start := time.Now() + const lookups = 10000 + for i := 0; i < lookups; i++ { + idx := i % 1000 + chain := fmt.Sprintf("s:sys-u:user%d-m:reply%d-u:newmsg", idx, idx) + _, _, _, found := store.Find(1, "prefix", chain) + assert.True(t, found) + } + elapsed := time.Since(start) + t.Logf("%d lookups in %v (%.0f ns/op)", lookups, elapsed, float64(elapsed.Nanoseconds())/lookups) +} + +func TestDigestSessionStore_FindReturnsMatchedChain(t *testing.T) { + store := NewDigestSessionStore() + + store.Save(1, "prefix", "u:a-m:b-u:c", "uuid-1", 100, "") + + // 精确匹配 + _, _, matchedChain, found := store.Find(1, "prefix", "u:a-m:b-u:c") + require.True(t, found) + assert.Equal(t, "u:a-m:b-u:c", matchedChain) + + // 前缀匹配(截断后命中) + _, _, matchedChain, found = store.Find(1, "prefix", "u:a-m:b-u:c-m:d-u:e") + require.True(t, found) + assert.Equal(t, "u:a-m:b-u:c", matchedChain) +} + +func TestDigestSessionStore_CacheItemCountStable(t *testing.T) { + store := NewDigestSessionStore() + + // 模拟 100 个独立会话,每个进行 10 轮对话 + // 正确传递 oldDigestChain 时,每个会话始终只保留 1 个 key + for conv := 0; conv < 100; conv++ { + var prevMatchedChain string + for round := 0; round < 10; round++ { + chain := fmt.Sprintf("s:sys-u:user%d", conv) + for r := 0; r < round; r++ { + chain += fmt.Sprintf("-m:a%d-u:q%d", r, r+1) + } + uuid := fmt.Sprintf("uuid-conv%d", conv) + + _, _, matched, _ := store.Find(1, "prefix", chain) + store.Save(1, "prefix", chain, uuid, int64(conv), matched) + prevMatchedChain = matched + _ = prevMatchedChain + } + } + + // 100 个会话 × 1 key/会话 = 应该 ≤ 100 个 key + // 允许少量并发残留,但绝不能接近 100×10=1000 + itemCount := store.cache.ItemCount() + assert.LessOrEqual(t, itemCount, 100, "cache should have at most 100 items (1 per conversation), got %d", itemCount) + t.Logf("Cache item count after 100 conversations × 10 rounds: %d", itemCount) +} + +func TestDigestSessionStore_TTLPreventsUnboundedGrowth(t *testing.T) { + // 使用极短 TTL 验证大量写入后 cache 能被清理 + store := &DigestSessionStore{ + cache: gocache.New(100*time.Millisecond, 50*time.Millisecond), + } + + // 插入 500 个不同的 key(无 oldDigestChain,模拟最坏场景:全是新会话首轮) + for i := 0; i < 500; i++ { + chain := fmt.Sprintf("u:user%d", i) + store.Save(1, "prefix", chain, fmt.Sprintf("uuid-%d", i), int64(i), "") + } + + assert.Equal(t, 500, store.cache.ItemCount()) + + // 等待 TTL + 清理周期 + time.Sleep(300 * time.Millisecond) + + assert.Equal(t, 0, store.cache.ItemCount(), "all items should be expired and cleaned up") +} + +func TestDigestSessionStore_SaveSameChainNoDelete(t *testing.T) { + store := NewDigestSessionStore() + + // 保存 chain + store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "") + + // 用户重发相同消息:oldDigestChain == digestChain,不应删掉刚设置的 key + store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "u:a-m:b") + + // 仍然能找到 + uuid, accountID, _, found := store.Find(1, "prefix", "u:a-m:b") + require.True(t, found) + assert.Equal(t, "uuid-1", uuid) + assert.Equal(t, int64(100), accountID) +} diff --git a/backend/internal/service/error_policy_integration_test.go b/backend/internal/service/error_policy_integration_test.go new file mode 100644 index 00000000..9f8ad938 --- /dev/null +++ b/backend/internal/service/error_policy_integration_test.go @@ -0,0 +1,366 @@ +//go:build unit + +package service + +import ( + "context" + "io" + "net/http" + "strings" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// Mocks (scoped to this file by naming convention) +// --------------------------------------------------------------------------- + +// epFixedUpstream returns a fixed response for every request. +type epFixedUpstream struct { + statusCode int + body string + calls int +} + +func (u *epFixedUpstream) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) { + u.calls++ + return &http.Response{ + StatusCode: u.statusCode, + Header: http.Header{}, + Body: io.NopCloser(strings.NewReader(u.body)), + }, nil +} + +func (u *epFixedUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) { + return u.Do(req, proxyURL, accountID, accountConcurrency) +} + +// epAccountRepo records SetTempUnschedulable / SetError calls. +type epAccountRepo struct { + mockAccountRepoForGemini + tempCalls int + setErrCalls int +} + +func (r *epAccountRepo) SetTempUnschedulable(_ context.Context, _ int64, _ time.Time, _ string) error { + r.tempCalls++ + return nil +} + +func (r *epAccountRepo) SetError(_ context.Context, _ int64, _ string) error { + r.setErrCalls++ + return nil +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +func saveAndSetBaseURLs(t *testing.T) { + t.Helper() + oldBaseURLs := append([]string(nil), antigravity.BaseURLs...) + oldAvail := antigravity.DefaultURLAvailability + antigravity.BaseURLs = []string{"https://ep-test.example"} + antigravity.DefaultURLAvailability = antigravity.NewURLAvailability(time.Minute) + t.Cleanup(func() { + antigravity.BaseURLs = oldBaseURLs + antigravity.DefaultURLAvailability = oldAvail + }) +} + +func newRetryParams(account *Account, upstream HTTPUpstream, handleError func(context.Context, string, *Account, int, http.Header, []byte, string, int64, string, bool) *handleModelRateLimitResult) antigravityRetryLoopParams { + return antigravityRetryLoopParams{ + ctx: context.Background(), + prefix: "[ep-test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"input":"test"}`), + httpUpstream: upstream, + requestedModel: "claude-sonnet-4-5", + handleError: handleError, + } +} + +// --------------------------------------------------------------------------- +// TestRetryLoop_ErrorPolicy_CustomErrorCodes +// --------------------------------------------------------------------------- + +func TestRetryLoop_ErrorPolicy_CustomErrorCodes(t *testing.T) { + tests := []struct { + name string + upstreamStatus int + upstreamBody string + customCodes []any + expectHandleError int + expectUpstream int + expectStatusCode int + }{ + { + name: "429_in_custom_codes_matched", + upstreamStatus: 429, + upstreamBody: `{"error":"rate limited"}`, + customCodes: []any{float64(429)}, + expectHandleError: 1, + expectUpstream: 1, + expectStatusCode: 429, + }, + { + name: "429_not_in_custom_codes_skipped", + upstreamStatus: 429, + upstreamBody: `{"error":"rate limited"}`, + customCodes: []any{float64(500)}, + expectHandleError: 0, + expectUpstream: 1, + expectStatusCode: 429, + }, + { + name: "500_in_custom_codes_matched", + upstreamStatus: 500, + upstreamBody: `{"error":"internal"}`, + customCodes: []any{float64(500)}, + expectHandleError: 1, + expectUpstream: 1, + expectStatusCode: 500, + }, + { + name: "500_not_in_custom_codes_skipped", + upstreamStatus: 500, + upstreamBody: `{"error":"internal"}`, + customCodes: []any{float64(429)}, + expectHandleError: 0, + expectUpstream: 1, + expectStatusCode: 500, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + saveAndSetBaseURLs(t) + + upstream := &epFixedUpstream{statusCode: tt.upstreamStatus, body: tt.upstreamBody} + repo := &epAccountRepo{} + rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + + account := &Account{ + ID: 100, + Type: AccountTypeAPIKey, + Platform: PlatformAntigravity, + Schedulable: true, + Status: StatusActive, + Concurrency: 1, + Credentials: map[string]any{ + "custom_error_codes_enabled": true, + "custom_error_codes": tt.customCodes, + }, + } + + svc := &AntigravityGatewayService{rateLimitService: rlSvc} + + var handleErrorCount int + p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult { + handleErrorCount++ + return nil + }) + + result, err := svc.antigravityRetryLoop(p) + + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.resp) + defer func() { _ = result.resp.Body.Close() }() + + require.Equal(t, tt.expectStatusCode, result.resp.StatusCode) + require.Equal(t, tt.expectHandleError, handleErrorCount, "handleError call count") + require.Equal(t, tt.expectUpstream, upstream.calls, "upstream call count") + }) + } +} + +// --------------------------------------------------------------------------- +// TestRetryLoop_ErrorPolicy_TempUnschedulable +// --------------------------------------------------------------------------- + +func TestRetryLoop_ErrorPolicy_TempUnschedulable(t *testing.T) { + tempRulesAccount := func(rules []any) *Account { + return &Account{ + ID: 200, + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + Schedulable: true, + Status: StatusActive, + Concurrency: 1, + Credentials: map[string]any{ + "temp_unschedulable_enabled": true, + "temp_unschedulable_rules": rules, + }, + } + } + + overloadedRule := map[string]any{ + "error_code": float64(503), + "keywords": []any{"overloaded"}, + "duration_minutes": float64(10), + } + + rateLimitRule := map[string]any{ + "error_code": float64(429), + "keywords": []any{"rate limited keyword"}, + "duration_minutes": float64(5), + } + + t.Run("503_overloaded_matches_rule", func(t *testing.T) { + saveAndSetBaseURLs(t) + + upstream := &epFixedUpstream{statusCode: 503, body: `overloaded`} + repo := &epAccountRepo{} + rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + svc := &AntigravityGatewayService{rateLimitService: rlSvc} + + account := tempRulesAccount([]any{overloadedRule}) + p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult { + t.Error("handleError should not be called for temp unschedulable") + return nil + }) + + result, err := svc.antigravityRetryLoop(p) + + require.Nil(t, result) + var switchErr *AntigravityAccountSwitchError + require.ErrorAs(t, err, &switchErr) + require.Equal(t, account.ID, switchErr.OriginalAccountID) + require.Equal(t, 1, upstream.calls, "should not retry") + }) + + t.Run("429_rate_limited_keyword_matches_rule", func(t *testing.T) { + saveAndSetBaseURLs(t) + + upstream := &epFixedUpstream{statusCode: 429, body: `rate limited keyword`} + repo := &epAccountRepo{} + rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + svc := &AntigravityGatewayService{rateLimitService: rlSvc} + + account := tempRulesAccount([]any{rateLimitRule}) + p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult { + t.Error("handleError should not be called for temp unschedulable") + return nil + }) + + result, err := svc.antigravityRetryLoop(p) + + require.Nil(t, result) + var switchErr *AntigravityAccountSwitchError + require.ErrorAs(t, err, &switchErr) + require.Equal(t, account.ID, switchErr.OriginalAccountID) + require.Equal(t, 1, upstream.calls, "should not retry") + }) + + t.Run("503_body_no_match_continues_default_retry", func(t *testing.T) { + saveAndSetBaseURLs(t) + + upstream := &epFixedUpstream{statusCode: 503, body: `random`} + repo := &epAccountRepo{} + rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + svc := &AntigravityGatewayService{rateLimitService: rlSvc} + + account := tempRulesAccount([]any{overloadedRule}) + + // Use a short-lived context: the backoff sleep (~1s) will be + // interrupted, proving the code entered the default retry path + // instead of breaking early via error policy. + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult { + return nil + }) + p.ctx = ctx + + result, err := svc.antigravityRetryLoop(p) + + // Context cancellation during backoff proves default retry was entered + require.Nil(t, result) + require.ErrorIs(t, err, context.DeadlineExceeded) + require.GreaterOrEqual(t, upstream.calls, 1, "should have called upstream at least once") + }) +} + +// --------------------------------------------------------------------------- +// TestRetryLoop_ErrorPolicy_NilRateLimitService +// --------------------------------------------------------------------------- + +func TestRetryLoop_ErrorPolicy_NilRateLimitService(t *testing.T) { + saveAndSetBaseURLs(t) + + upstream := &epFixedUpstream{statusCode: 429, body: `{"error":"rate limited"}`} + // rateLimitService is nil — must not panic + svc := &AntigravityGatewayService{rateLimitService: nil} + + account := &Account{ + ID: 300, + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + Schedulable: true, + Status: StatusActive, + Concurrency: 1, + } + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult { + return nil + }) + p.ctx = ctx + + // Should not panic; enters the default retry path (eventually times out) + result, err := svc.antigravityRetryLoop(p) + + require.Nil(t, result) + require.ErrorIs(t, err, context.DeadlineExceeded) + require.GreaterOrEqual(t, upstream.calls, 1) +} + +// --------------------------------------------------------------------------- +// TestRetryLoop_ErrorPolicy_NoPolicy_OriginalBehavior +// --------------------------------------------------------------------------- + +func TestRetryLoop_ErrorPolicy_NoPolicy_OriginalBehavior(t *testing.T) { + saveAndSetBaseURLs(t) + + upstream := &epFixedUpstream{statusCode: 429, body: `{"error":"rate limited"}`} + repo := &epAccountRepo{} + rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + svc := &AntigravityGatewayService{rateLimitService: rlSvc} + + // Plain OAuth account with no error policy configured + account := &Account{ + ID: 400, + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + Schedulable: true, + Status: StatusActive, + Concurrency: 1, + } + + var handleErrorCount int + p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult { + handleErrorCount++ + return nil + }) + + result, err := svc.antigravityRetryLoop(p) + + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.resp) + defer func() { _ = result.resp.Body.Close() }() + + require.Equal(t, http.StatusTooManyRequests, result.resp.StatusCode) + require.Equal(t, antigravityMaxRetries, upstream.calls, "should exhaust all retries") + require.Equal(t, 1, handleErrorCount, "handleError should be called once after retries exhausted") +} diff --git a/backend/internal/service/error_policy_test.go b/backend/internal/service/error_policy_test.go new file mode 100644 index 00000000..a8b69c22 --- /dev/null +++ b/backend/internal/service/error_policy_test.go @@ -0,0 +1,289 @@ +//go:build unit + +package service + +import ( + "context" + "net/http" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// TestCheckErrorPolicy — 6 table-driven cases for the pure logic function +// --------------------------------------------------------------------------- + +func TestCheckErrorPolicy(t *testing.T) { + tests := []struct { + name string + account *Account + statusCode int + body []byte + expected ErrorPolicyResult + }{ + { + name: "no_policy_oauth_returns_none", + account: &Account{ + ID: 1, + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + // no custom error codes, no temp rules + }, + statusCode: 500, + body: []byte(`"error"`), + expected: ErrorPolicyNone, + }, + { + name: "custom_error_codes_hit_returns_matched", + account: &Account{ + ID: 2, + Type: AccountTypeAPIKey, + Platform: PlatformAntigravity, + Credentials: map[string]any{ + "custom_error_codes_enabled": true, + "custom_error_codes": []any{float64(429), float64(500)}, + }, + }, + statusCode: 500, + body: []byte(`"error"`), + expected: ErrorPolicyMatched, + }, + { + name: "custom_error_codes_miss_returns_skipped", + account: &Account{ + ID: 3, + Type: AccountTypeAPIKey, + Platform: PlatformAntigravity, + Credentials: map[string]any{ + "custom_error_codes_enabled": true, + "custom_error_codes": []any{float64(429), float64(500)}, + }, + }, + statusCode: 503, + body: []byte(`"error"`), + expected: ErrorPolicySkipped, + }, + { + name: "temp_unschedulable_hit_returns_temp_unscheduled", + account: &Account{ + ID: 4, + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + Credentials: map[string]any{ + "temp_unschedulable_enabled": true, + "temp_unschedulable_rules": []any{ + map[string]any{ + "error_code": float64(503), + "keywords": []any{"overloaded"}, + "duration_minutes": float64(10), + "description": "overloaded rule", + }, + }, + }, + }, + statusCode: 503, + body: []byte(`overloaded service`), + expected: ErrorPolicyTempUnscheduled, + }, + { + name: "temp_unschedulable_body_miss_returns_none", + account: &Account{ + ID: 5, + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + Credentials: map[string]any{ + "temp_unschedulable_enabled": true, + "temp_unschedulable_rules": []any{ + map[string]any{ + "error_code": float64(503), + "keywords": []any{"overloaded"}, + "duration_minutes": float64(10), + "description": "overloaded rule", + }, + }, + }, + }, + statusCode: 503, + body: []byte(`random msg`), + expected: ErrorPolicyNone, + }, + { + name: "custom_error_codes_override_temp_unschedulable", + account: &Account{ + ID: 6, + Type: AccountTypeAPIKey, + Platform: PlatformAntigravity, + Credentials: map[string]any{ + "custom_error_codes_enabled": true, + "custom_error_codes": []any{float64(503)}, + "temp_unschedulable_enabled": true, + "temp_unschedulable_rules": []any{ + map[string]any{ + "error_code": float64(503), + "keywords": []any{"overloaded"}, + "duration_minutes": float64(10), + "description": "overloaded rule", + }, + }, + }, + }, + statusCode: 503, + body: []byte(`overloaded`), + expected: ErrorPolicyMatched, // custom codes take precedence + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + repo := &errorPolicyRepoStub{} + svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + + result := svc.CheckErrorPolicy(context.Background(), tt.account, tt.statusCode, tt.body) + require.Equal(t, tt.expected, result, "unexpected ErrorPolicyResult") + }) + } +} + +// --------------------------------------------------------------------------- +// TestApplyErrorPolicy — 4 table-driven cases for the wrapper method +// --------------------------------------------------------------------------- + +func TestApplyErrorPolicy(t *testing.T) { + tests := []struct { + name string + account *Account + statusCode int + body []byte + expectedHandled bool + expectedSwitchErr bool // expect *AntigravityAccountSwitchError + handleErrorCalls int + }{ + { + name: "none_not_handled", + account: &Account{ + ID: 10, + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + }, + statusCode: 500, + body: []byte(`"error"`), + expectedHandled: false, + handleErrorCalls: 0, + }, + { + name: "skipped_handled_no_handleError", + account: &Account{ + ID: 11, + Type: AccountTypeAPIKey, + Platform: PlatformAntigravity, + Credentials: map[string]any{ + "custom_error_codes_enabled": true, + "custom_error_codes": []any{float64(429)}, + }, + }, + statusCode: 500, // not in custom codes + body: []byte(`"error"`), + expectedHandled: true, + handleErrorCalls: 0, + }, + { + name: "matched_handled_calls_handleError", + account: &Account{ + ID: 12, + Type: AccountTypeAPIKey, + Platform: PlatformAntigravity, + Credentials: map[string]any{ + "custom_error_codes_enabled": true, + "custom_error_codes": []any{float64(500)}, + }, + }, + statusCode: 500, + body: []byte(`"error"`), + expectedHandled: true, + handleErrorCalls: 1, + }, + { + name: "temp_unscheduled_returns_switch_error", + account: &Account{ + ID: 13, + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + Credentials: map[string]any{ + "temp_unschedulable_enabled": true, + "temp_unschedulable_rules": []any{ + map[string]any{ + "error_code": float64(503), + "keywords": []any{"overloaded"}, + "duration_minutes": float64(10), + }, + }, + }, + }, + statusCode: 503, + body: []byte(`overloaded`), + expectedHandled: true, + expectedSwitchErr: true, + handleErrorCalls: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + repo := &errorPolicyRepoStub{} + rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + svc := &AntigravityGatewayService{ + rateLimitService: rlSvc, + } + + var handleErrorCount int + p := antigravityRetryLoopParams{ + ctx: context.Background(), + prefix: "[test]", + account: tt.account, + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + handleErrorCount++ + return nil + }, + isStickySession: true, + } + + handled, retErr := svc.applyErrorPolicy(p, tt.statusCode, http.Header{}, tt.body) + + require.Equal(t, tt.expectedHandled, handled, "handled mismatch") + require.Equal(t, tt.handleErrorCalls, handleErrorCount, "handleError call count mismatch") + + if tt.expectedSwitchErr { + var switchErr *AntigravityAccountSwitchError + require.ErrorAs(t, retErr, &switchErr) + require.Equal(t, tt.account.ID, switchErr.OriginalAccountID) + } else { + require.NoError(t, retErr) + } + }) + } +} + +// --------------------------------------------------------------------------- +// errorPolicyRepoStub — minimal AccountRepository stub for error policy tests +// --------------------------------------------------------------------------- + +type errorPolicyRepoStub struct { + mockAccountRepoForGemini + tempCalls int + setErrCalls int + lastErrorMsg string +} + +func (r *errorPolicyRepoStub) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error { + r.tempCalls++ + return nil +} + +func (r *errorPolicyRepoStub) SetError(ctx context.Context, id int64, errorMsg string) error { + r.setErrCalls++ + r.lastErrorMsg = errorMsg + return nil +} diff --git a/backend/internal/service/gateway_multiplatform_test.go b/backend/internal/service/gateway_multiplatform_test.go index d9c852e0..e7544024 100644 --- a/backend/internal/service/gateway_multiplatform_test.go +++ b/backend/internal/service/gateway_multiplatform_test.go @@ -142,9 +142,6 @@ func (m *mockAccountRepoForPlatform) ListSchedulableByGroupIDAndPlatforms(ctx co func (m *mockAccountRepoForPlatform) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { return nil } -func (m *mockAccountRepoForPlatform) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error { - return nil -} func (m *mockAccountRepoForPlatform) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error { return nil } @@ -216,29 +213,6 @@ func (m *mockGatewayCacheForPlatform) DeleteSessionAccountID(ctx context.Context return nil } -func (m *mockGatewayCacheForPlatform) IncrModelCallCount(ctx context.Context, accountID int64, model string) (int64, error) { - return 0, nil -} - -func (m *mockGatewayCacheForPlatform) GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*ModelLoadInfo, error) { - return nil, nil -} - -func (m *mockGatewayCacheForPlatform) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) { - return "", 0, false -} - -func (m *mockGatewayCacheForPlatform) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error { - return nil -} - -func (m *mockGatewayCacheForPlatform) FindAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) { - return "", 0, false -} - -func (m *mockGatewayCacheForPlatform) SaveAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error { - return nil -} type mockGroupRepoForGateway struct { groups map[int64]*Group diff --git a/backend/internal/service/gateway_request.go b/backend/internal/service/gateway_request.go index 0ecd18aa..c039f030 100644 --- a/backend/internal/service/gateway_request.go +++ b/backend/internal/service/gateway_request.go @@ -6,9 +6,19 @@ import ( "fmt" "math" + "github.com/Wei-Shaw/sub2api/internal/domain" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" ) +// SessionContext 粘性会话上下文,用于区分不同来源的请求。 +// 仅在 GenerateSessionHash 第 3 级 fallback(消息内容 hash)时混入, +// 避免不同用户发送相同消息产生相同 hash 导致账号集中。 +type SessionContext struct { + ClientIP string + UserAgent string + APIKeyID int64 +} + // ParsedRequest 保存网关请求的预解析结果 // // 性能优化说明: @@ -22,20 +32,22 @@ import ( // 2. 将解析结果 ParsedRequest 传递给 Service 层 // 3. 避免重复 json.Unmarshal,减少 CPU 和内存开销 type ParsedRequest struct { - Body []byte // 原始请求体(保留用于转发) - Model string // 请求的模型名称 - Stream bool // 是否为流式请求 - MetadataUserID string // metadata.user_id(用于会话亲和) - System any // system 字段内容 - Messages []any // messages 数组 - HasSystem bool // 是否包含 system 字段(包含 null 也视为显式传入) - ThinkingEnabled bool // 是否开启 thinking(部分平台会影响最终模型名) - MaxTokens int // max_tokens 值(用于探测请求拦截) + Body []byte // 原始请求体(保留用于转发) + Model string // 请求的模型名称 + Stream bool // 是否为流式请求 + MetadataUserID string // metadata.user_id(用于会话亲和) + System any // system 字段内容 + Messages []any // messages 数组 + HasSystem bool // 是否包含 system 字段(包含 null 也视为显式传入) + ThinkingEnabled bool // 是否开启 thinking(部分平台会影响最终模型名) + MaxTokens int // max_tokens 值(用于探测请求拦截) + SessionContext *SessionContext // 可选:请求上下文区分因子(nil 时行为不变) } -// ParseGatewayRequest 解析网关请求体并返回结构化结果 -// 性能优化:一次解析提取所有需要的字段,避免重复 Unmarshal -func ParseGatewayRequest(body []byte) (*ParsedRequest, error) { +// ParseGatewayRequest 解析网关请求体并返回结构化结果。 +// protocol 指定请求协议格式(domain.PlatformAnthropic / domain.PlatformGemini), +// 不同协议使用不同的 system/messages 字段名。 +func ParseGatewayRequest(body []byte, protocol string) (*ParsedRequest, error) { var req map[string]any if err := json.Unmarshal(body, &req); err != nil { return nil, err @@ -64,14 +76,29 @@ func ParseGatewayRequest(body []byte) (*ParsedRequest, error) { parsed.MetadataUserID = userID } } - // system 字段只要存在就视为显式提供(即使为 null), - // 以避免客户端传 null 时被默认 system 误注入。 - if system, ok := req["system"]; ok { - parsed.HasSystem = true - parsed.System = system - } - if messages, ok := req["messages"].([]any); ok { - parsed.Messages = messages + + switch protocol { + case domain.PlatformGemini: + // Gemini 原生格式: systemInstruction.parts / contents + if sysInst, ok := req["systemInstruction"].(map[string]any); ok { + if parts, ok := sysInst["parts"].([]any); ok { + parsed.System = parts + } + } + if contents, ok := req["contents"].([]any); ok { + parsed.Messages = contents + } + default: + // Anthropic / OpenAI 格式: system / messages + // system 字段只要存在就视为显式提供(即使为 null), + // 以避免客户端传 null 时被默认 system 误注入。 + if system, ok := req["system"]; ok { + parsed.HasSystem = true + parsed.System = system + } + if messages, ok := req["messages"].([]any); ok { + parsed.Messages = messages + } } // thinking: {type: "enabled"} diff --git a/backend/internal/service/gateway_request_test.go b/backend/internal/service/gateway_request_test.go index 4e390b0a..cef41c91 100644 --- a/backend/internal/service/gateway_request_test.go +++ b/backend/internal/service/gateway_request_test.go @@ -4,12 +4,13 @@ import ( "encoding/json" "testing" + "github.com/Wei-Shaw/sub2api/internal/domain" "github.com/stretchr/testify/require" ) func TestParseGatewayRequest(t *testing.T) { body := []byte(`{"model":"claude-3-7-sonnet","stream":true,"metadata":{"user_id":"session_123e4567-e89b-12d3-a456-426614174000"},"system":[{"type":"text","text":"hello","cache_control":{"type":"ephemeral"}}],"messages":[{"content":"hi"}]}`) - parsed, err := ParseGatewayRequest(body) + parsed, err := ParseGatewayRequest(body, "") require.NoError(t, err) require.Equal(t, "claude-3-7-sonnet", parsed.Model) require.True(t, parsed.Stream) @@ -22,7 +23,7 @@ func TestParseGatewayRequest(t *testing.T) { func TestParseGatewayRequest_ThinkingEnabled(t *testing.T) { body := []byte(`{"model":"claude-sonnet-4-5","thinking":{"type":"enabled"},"messages":[{"content":"hi"}]}`) - parsed, err := ParseGatewayRequest(body) + parsed, err := ParseGatewayRequest(body, "") require.NoError(t, err) require.Equal(t, "claude-sonnet-4-5", parsed.Model) require.True(t, parsed.ThinkingEnabled) @@ -30,21 +31,21 @@ func TestParseGatewayRequest_ThinkingEnabled(t *testing.T) { func TestParseGatewayRequest_MaxTokens(t *testing.T) { body := []byte(`{"model":"claude-haiku-4-5","max_tokens":1}`) - parsed, err := ParseGatewayRequest(body) + parsed, err := ParseGatewayRequest(body, "") require.NoError(t, err) require.Equal(t, 1, parsed.MaxTokens) } func TestParseGatewayRequest_MaxTokensNonIntegralIgnored(t *testing.T) { body := []byte(`{"model":"claude-haiku-4-5","max_tokens":1.5}`) - parsed, err := ParseGatewayRequest(body) + parsed, err := ParseGatewayRequest(body, "") require.NoError(t, err) require.Equal(t, 0, parsed.MaxTokens) } func TestParseGatewayRequest_SystemNull(t *testing.T) { body := []byte(`{"model":"claude-3","system":null}`) - parsed, err := ParseGatewayRequest(body) + parsed, err := ParseGatewayRequest(body, "") require.NoError(t, err) // 显式传入 system:null 也应视为“字段已存在”,避免默认 system 被注入。 require.True(t, parsed.HasSystem) @@ -53,16 +54,112 @@ func TestParseGatewayRequest_SystemNull(t *testing.T) { func TestParseGatewayRequest_InvalidModelType(t *testing.T) { body := []byte(`{"model":123}`) - _, err := ParseGatewayRequest(body) + _, err := ParseGatewayRequest(body, "") require.Error(t, err) } func TestParseGatewayRequest_InvalidStreamType(t *testing.T) { body := []byte(`{"stream":"true"}`) - _, err := ParseGatewayRequest(body) + _, err := ParseGatewayRequest(body, "") require.Error(t, err) } +// ============ Gemini 原生格式解析测试 ============ + +func TestParseGatewayRequest_GeminiContents(t *testing.T) { + body := []byte(`{ + "contents": [ + {"role": "user", "parts": [{"text": "Hello"}]}, + {"role": "model", "parts": [{"text": "Hi there"}]}, + {"role": "user", "parts": [{"text": "How are you?"}]} + ] + }`) + parsed, err := ParseGatewayRequest(body, domain.PlatformGemini) + require.NoError(t, err) + require.Len(t, parsed.Messages, 3, "should parse contents as Messages") + require.False(t, parsed.HasSystem, "Gemini format should not set HasSystem") + require.Nil(t, parsed.System, "no systemInstruction means nil System") +} + +func TestParseGatewayRequest_GeminiSystemInstruction(t *testing.T) { + body := []byte(`{ + "systemInstruction": { + "parts": [{"text": "You are a helpful assistant."}] + }, + "contents": [ + {"role": "user", "parts": [{"text": "Hello"}]} + ] + }`) + parsed, err := ParseGatewayRequest(body, domain.PlatformGemini) + require.NoError(t, err) + require.NotNil(t, parsed.System, "should parse systemInstruction.parts as System") + parts, ok := parsed.System.([]any) + require.True(t, ok) + require.Len(t, parts, 1) + partMap, ok := parts[0].(map[string]any) + require.True(t, ok) + require.Equal(t, "You are a helpful assistant.", partMap["text"]) + require.Len(t, parsed.Messages, 1) +} + +func TestParseGatewayRequest_GeminiWithModel(t *testing.T) { + body := []byte(`{ + "model": "gemini-2.5-pro", + "contents": [{"role": "user", "parts": [{"text": "test"}]}] + }`) + parsed, err := ParseGatewayRequest(body, domain.PlatformGemini) + require.NoError(t, err) + require.Equal(t, "gemini-2.5-pro", parsed.Model) + require.Len(t, parsed.Messages, 1) +} + +func TestParseGatewayRequest_GeminiIgnoresAnthropicFields(t *testing.T) { + // Gemini 格式下 system/messages 字段应被忽略 + body := []byte(`{ + "system": "should be ignored", + "messages": [{"role": "user", "content": "ignored"}], + "contents": [{"role": "user", "parts": [{"text": "real content"}]}] + }`) + parsed, err := ParseGatewayRequest(body, domain.PlatformGemini) + require.NoError(t, err) + require.False(t, parsed.HasSystem, "Gemini protocol should not parse Anthropic system field") + require.Nil(t, parsed.System, "no systemInstruction = nil System") + require.Len(t, parsed.Messages, 1, "should use contents, not messages") +} + +func TestParseGatewayRequest_GeminiEmptyContents(t *testing.T) { + body := []byte(`{"contents": []}`) + parsed, err := ParseGatewayRequest(body, domain.PlatformGemini) + require.NoError(t, err) + require.Empty(t, parsed.Messages) +} + +func TestParseGatewayRequest_GeminiNoContents(t *testing.T) { + body := []byte(`{"model": "gemini-2.5-flash"}`) + parsed, err := ParseGatewayRequest(body, domain.PlatformGemini) + require.NoError(t, err) + require.Nil(t, parsed.Messages) + require.Equal(t, "gemini-2.5-flash", parsed.Model) +} + +func TestParseGatewayRequest_AnthropicIgnoresGeminiFields(t *testing.T) { + // Anthropic 格式下 contents/systemInstruction 字段应被忽略 + body := []byte(`{ + "system": "real system", + "messages": [{"role": "user", "content": "real content"}], + "contents": [{"role": "user", "parts": [{"text": "ignored"}]}], + "systemInstruction": {"parts": [{"text": "ignored"}]} + }`) + parsed, err := ParseGatewayRequest(body, domain.PlatformAnthropic) + require.NoError(t, err) + require.True(t, parsed.HasSystem) + require.Equal(t, "real system", parsed.System) + require.Len(t, parsed.Messages, 1) + msg, ok := parsed.Messages[0].(map[string]any) + require.True(t, ok) + require.Equal(t, "real content", msg["content"]) +} + func TestFilterThinkingBlocks(t *testing.T) { containsThinkingBlock := func(body []byte) bool { var req map[string]any diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 480f5b67..4e723232 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -5,7 +5,6 @@ import ( "bytes" "context" "crypto/sha256" - "encoding/hex" "encoding/json" "errors" "fmt" @@ -17,6 +16,7 @@ import ( "os" "regexp" "sort" + "strconv" "strings" "sync/atomic" "time" @@ -26,6 +26,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" + "github.com/cespare/xxhash/v2" "github.com/google/uuid" "github.com/tidwall/gjson" "github.com/tidwall/sjson" @@ -245,9 +246,6 @@ var ( // ErrClaudeCodeOnly 表示分组仅允许 Claude Code 客户端访问 var ErrClaudeCodeOnly = errors.New("this group only allows Claude Code clients") -// ErrModelScopeNotSupported 表示请求的模型系列不在分组支持的范围内 -var ErrModelScopeNotSupported = errors.New("model scope not supported by this group") - // allowedHeaders 白名单headers(参考CRS项目) var allowedHeaders = map[string]bool{ "accept": true, @@ -273,13 +271,6 @@ var allowedHeaders = map[string]bool{ // GatewayCache 定义网关服务的缓存操作接口。 // 提供粘性会话(Sticky Session)的存储、查询、刷新和删除功能。 // -// ModelLoadInfo 模型负载信息(用于 Antigravity 调度) -// Model load info for Antigravity scheduling -type ModelLoadInfo struct { - CallCount int64 // 当前分钟调用次数 / Call count in current minute - LastUsedAt time.Time // 最后调度时间(零值表示未调度过)/ Last scheduling time (zero means never scheduled) -} - // GatewayCache defines cache operations for gateway service. // Provides sticky session storage, retrieval, refresh and deletion capabilities. type GatewayCache interface { @@ -295,32 +286,6 @@ type GatewayCache interface { // DeleteSessionAccountID 删除粘性会话绑定,用于账号不可用时主动清理 // Delete sticky session binding, used to proactively clean up when account becomes unavailable DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error - - // IncrModelCallCount 增加模型调用次数并更新最后调度时间(Antigravity 专用) - // Increment model call count and update last scheduling time (Antigravity only) - // 返回更新后的调用次数 - IncrModelCallCount(ctx context.Context, accountID int64, model string) (int64, error) - - // GetModelLoadBatch 批量获取账号的模型负载信息(Antigravity 专用) - // Batch get model load info for accounts (Antigravity only) - GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*ModelLoadInfo, error) - - // FindGeminiSession 查找 Gemini 会话(MGET 倒序匹配) - // Find Gemini session using MGET reverse order matching - // 返回最长匹配的会话信息(uuid, accountID) - FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) - - // SaveGeminiSession 保存 Gemini 会话 - // Save Gemini session binding - SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error - - // FindAnthropicSession 查找 Anthropic 会话(Trie 匹配) - // Find Anthropic session using Trie matching - FindAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) - - // SaveAnthropicSession 保存 Anthropic 会话 - // Save Anthropic session binding - SaveAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error } // derefGroupID safely dereferences *int64 to int64, returning 0 if nil @@ -415,6 +380,7 @@ type GatewayService struct { userSubRepo UserSubscriptionRepository userGroupRateRepo UserGroupRateRepository cache GatewayCache + digestStore *DigestSessionStore cfg *config.Config schedulerSnapshot *SchedulerSnapshotService billingService *BillingService @@ -448,6 +414,7 @@ func NewGatewayService( deferredService *DeferredService, claudeTokenProvider *ClaudeTokenProvider, sessionLimitCache SessionLimitCache, + digestStore *DigestSessionStore, ) *GatewayService { return &GatewayService{ accountRepo: accountRepo, @@ -457,6 +424,7 @@ func NewGatewayService( userSubRepo: userSubRepo, userGroupRateRepo: userGroupRateRepo, cache: cache, + digestStore: digestStore, cfg: cfg, schedulerSnapshot: schedulerSnapshot, concurrencyService: concurrencyService, @@ -490,8 +458,17 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string { return s.hashContent(cacheableContent) } - // 3. 最后 fallback: 使用 system + 所有消息的完整摘要串 + // 3. 最后 fallback: 使用 session上下文 + system + 所有消息的完整摘要串 var combined strings.Builder + // 混入请求上下文区分因子,避免不同用户相同消息产生相同 hash + if parsed.SessionContext != nil { + _, _ = combined.WriteString(parsed.SessionContext.ClientIP) + _, _ = combined.WriteString(":") + _, _ = combined.WriteString(parsed.SessionContext.UserAgent) + _, _ = combined.WriteString(":") + _, _ = combined.WriteString(strconv.FormatInt(parsed.SessionContext.APIKeyID, 10)) + _, _ = combined.WriteString("|") + } if parsed.System != nil { systemText := s.extractTextFromSystem(parsed.System) if systemText != "" { @@ -500,9 +477,20 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string { } for _, msg := range parsed.Messages { if m, ok := msg.(map[string]any); ok { - msgText := s.extractTextFromContent(m["content"]) - if msgText != "" { - _, _ = combined.WriteString(msgText) + if content, exists := m["content"]; exists { + // Anthropic: messages[].content + if msgText := s.extractTextFromContent(content); msgText != "" { + _, _ = combined.WriteString(msgText) + } + } else if parts, ok := m["parts"].([]any); ok { + // Gemini: contents[].parts[].text + for _, part := range parts { + if partMap, ok := part.(map[string]any); ok { + if text, ok := partMap["text"].(string); ok { + _, _ = combined.WriteString(text) + } + } + } } } } @@ -536,35 +524,37 @@ func (s *GatewayService) GetCachedSessionAccountID(ctx context.Context, groupID // FindGeminiSession 查找 Gemini 会话(基于内容摘要链的 Fallback 匹配) // 返回最长匹配的会话信息(uuid, accountID) -func (s *GatewayService) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) { - if digestChain == "" || s.cache == nil { - return "", 0, false +func (s *GatewayService) FindGeminiSession(_ context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, matchedChain string, found bool) { + if digestChain == "" || s.digestStore == nil { + return "", 0, "", false } - return s.cache.FindGeminiSession(ctx, groupID, prefixHash, digestChain) + return s.digestStore.Find(groupID, prefixHash, digestChain) } -// SaveGeminiSession 保存 Gemini 会话 -func (s *GatewayService) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error { - if digestChain == "" || s.cache == nil { +// SaveGeminiSession 保存 Gemini 会话。oldDigestChain 为 Find 返回的 matchedChain,用于删旧 key。 +func (s *GatewayService) SaveGeminiSession(_ context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64, oldDigestChain string) error { + if digestChain == "" || s.digestStore == nil { return nil } - return s.cache.SaveGeminiSession(ctx, groupID, prefixHash, digestChain, uuid, accountID) + s.digestStore.Save(groupID, prefixHash, digestChain, uuid, accountID, oldDigestChain) + return nil } // FindAnthropicSession 查找 Anthropic 会话(基于内容摘要链的 Fallback 匹配) -func (s *GatewayService) FindAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) { - if digestChain == "" || s.cache == nil { - return "", 0, false +func (s *GatewayService) FindAnthropicSession(_ context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, matchedChain string, found bool) { + if digestChain == "" || s.digestStore == nil { + return "", 0, "", false } - return s.cache.FindAnthropicSession(ctx, groupID, prefixHash, digestChain) + return s.digestStore.Find(groupID, prefixHash, digestChain) } // SaveAnthropicSession 保存 Anthropic 会话 -func (s *GatewayService) SaveAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error { - if digestChain == "" || s.cache == nil { +func (s *GatewayService) SaveAnthropicSession(_ context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64, oldDigestChain string) error { + if digestChain == "" || s.digestStore == nil { return nil } - return s.cache.SaveAnthropicSession(ctx, groupID, prefixHash, digestChain, uuid, accountID) + s.digestStore.Save(groupID, prefixHash, digestChain, uuid, accountID, oldDigestChain) + return nil } func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string { @@ -649,8 +639,8 @@ func (s *GatewayService) extractTextFromContent(content any) string { } func (s *GatewayService) hashContent(content string) string { - hash := sha256.Sum256([]byte(content)) - return hex.EncodeToString(hash[:16]) // 32字符 + h := xxhash.Sum64String(content) + return strconv.FormatUint(h, 36) } // replaceModelInBody 替换请求体中的model字段 @@ -1009,13 +999,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro log.Printf("[ModelRoutingDebug] load-aware enabled: group_id=%v model=%s session=%s platform=%s", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), platform) } - // Antigravity 模型系列检查(在账号选择前检查,确保所有代码路径都经过此检查) - if platform == PlatformAntigravity && groupID != nil && requestedModel != "" { - if err := s.checkAntigravityModelScope(ctx, *groupID, requestedModel); err != nil { - return nil, err - } - } - accounts, useMixed, err := s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform) if err != nil { return nil, err @@ -1209,6 +1192,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro return a.account.LastUsedAt.Before(*b.account.LastUsedAt) } }) + shuffleWithinSortGroups(routingAvailable) // 4. 尝试获取槽位 for _, item := range routingAvailable { @@ -1362,10 +1346,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro return result, nil } } else { - // Antigravity 平台:获取模型负载信息 - var modelLoadMap map[int64]*ModelLoadInfo - isAntigravity := platform == PlatformAntigravity - var available []accountWithLoad for _, acc := range candidates { loadInfo := loadMap[acc.ID] @@ -1380,109 +1360,44 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro } } - // Antigravity 平台:按账号实际映射后的模型名获取模型负载(与 Forward 的统计保持一致) - if isAntigravity && requestedModel != "" && s.cache != nil && len(available) > 0 { - modelLoadMap = make(map[int64]*ModelLoadInfo, len(available)) - modelToAccountIDs := make(map[string][]int64) - for _, item := range available { - mappedModel := mapAntigravityModel(item.account, requestedModel) - if mappedModel == "" { - continue - } - modelToAccountIDs[mappedModel] = append(modelToAccountIDs[mappedModel], item.account.ID) + // 分层过滤选择:优先级 → 负载率 → LRU + for len(available) > 0 { + // 1. 取优先级最小的集合 + candidates := filterByMinPriority(available) + // 2. 取负载率最低的集合 + candidates = filterByMinLoadRate(candidates) + // 3. LRU 选择最久未用的账号 + selected := selectByLRU(candidates, preferOAuth) + if selected == nil { + break } - for model, ids := range modelToAccountIDs { - batch, err := s.cache.GetModelLoadBatch(ctx, ids, model) - if err != nil { - continue - } - for id, info := range batch { - modelLoadMap[id] = info - } - } - if len(modelLoadMap) == 0 { - modelLoadMap = nil - } - } - // Antigravity 平台:优先级硬过滤 →(同优先级内)按调用次数选择(最少优先,新账号用平均值) - // 其他平台:分层过滤选择:优先级 → 负载率 → LRU - if isAntigravity { - for len(available) > 0 { - // 1. 取优先级最小的集合(硬过滤) - candidates := filterByMinPriority(available) - // 2. 同优先级内按调用次数选择(调用次数最少优先,新账号使用平均值) - selected := selectByCallCount(candidates, modelLoadMap, preferOAuth) - if selected == nil { - break - } - - result, err := s.tryAcquireAccountSlot(ctx, selected.account.ID, selected.account.Concurrency) - if err == nil && result.Acquired { - // 会话数量限制检查 - if !s.checkAndRegisterSession(ctx, selected.account, sessionHash) { - result.ReleaseFunc() // 释放槽位,继续尝试下一个账号 - } else { - if sessionHash != "" && s.cache != nil { - _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.account.ID, stickySessionTTL) - } - return &AccountSelectionResult{ - Account: selected.account, - Acquired: true, - ReleaseFunc: result.ReleaseFunc, - }, nil + result, err := s.tryAcquireAccountSlot(ctx, selected.account.ID, selected.account.Concurrency) + if err == nil && result.Acquired { + // 会话数量限制检查 + if !s.checkAndRegisterSession(ctx, selected.account, sessionHash) { + result.ReleaseFunc() // 释放槽位,继续尝试下一个账号 + } else { + if sessionHash != "" && s.cache != nil { + _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.account.ID, stickySessionTTL) } + return &AccountSelectionResult{ + Account: selected.account, + Acquired: true, + ReleaseFunc: result.ReleaseFunc, + }, nil } - - // 移除已尝试的账号,重新选择 - selectedID := selected.account.ID - newAvailable := make([]accountWithLoad, 0, len(available)-1) - for _, acc := range available { - if acc.account.ID != selectedID { - newAvailable = append(newAvailable, acc) - } - } - available = newAvailable } - } else { - for len(available) > 0 { - // 1. 取优先级最小的集合 - candidates := filterByMinPriority(available) - // 2. 取负载率最低的集合 - candidates = filterByMinLoadRate(candidates) - // 3. LRU 选择最久未用的账号 - selected := selectByLRU(candidates, preferOAuth) - if selected == nil { - break - } - result, err := s.tryAcquireAccountSlot(ctx, selected.account.ID, selected.account.Concurrency) - if err == nil && result.Acquired { - // 会话数量限制检查 - if !s.checkAndRegisterSession(ctx, selected.account, sessionHash) { - result.ReleaseFunc() // 释放槽位,继续尝试下一个账号 - } else { - if sessionHash != "" && s.cache != nil { - _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.account.ID, stickySessionTTL) - } - return &AccountSelectionResult{ - Account: selected.account, - Acquired: true, - ReleaseFunc: result.ReleaseFunc, - }, nil - } + // 移除已尝试的账号,重新进行分层过滤 + selectedID := selected.account.ID + newAvailable := make([]accountWithLoad, 0, len(available)-1) + for _, acc := range available { + if acc.account.ID != selectedID { + newAvailable = append(newAvailable, acc) } - - // 移除已尝试的账号,重新进行分层过滤 - selectedID := selected.account.ID - newAvailable := make([]accountWithLoad, 0, len(available)-1) - for _, acc := range available { - if acc.account.ID != selectedID { - newAvailable = append(newAvailable, acc) - } - } - available = newAvailable } + available = newAvailable } } @@ -2018,87 +1933,79 @@ func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) { return a.LastUsedAt.Before(*b.LastUsedAt) } }) + shuffleWithinPriorityAndLastUsed(accounts) } -// selectByCallCount 从候选账号中选择调用次数最少的账号(Antigravity 专用) -// 新账号(CallCount=0)使用平均调用次数作为虚拟值,避免冷启动被猛调 -// 如果有多个账号具有相同的最小调用次数,则随机选择一个 -func selectByCallCount(accounts []accountWithLoad, modelLoadMap map[int64]*ModelLoadInfo, preferOAuth bool) *accountWithLoad { - if len(accounts) == 0 { - return nil +// shuffleWithinSortGroups 对排序后的 accountWithLoad 切片,按 (Priority, LoadRate, LastUsedAt) 分组后组内随机打乱。 +// 防止并发请求读取同一快照时,确定性排序导致所有请求命中相同账号。 +func shuffleWithinSortGroups(accounts []accountWithLoad) { + if len(accounts) <= 1 { + return } - if len(accounts) == 1 { - return &accounts[0] - } - - // 如果没有负载信息,回退到 LRU - if modelLoadMap == nil { - return selectByLRU(accounts, preferOAuth) - } - - // 1. 计算平均调用次数(用于新账号冷启动) - var totalCallCount int64 - var countWithCalls int - for _, acc := range accounts { - if info := modelLoadMap[acc.account.ID]; info != nil && info.CallCount > 0 { - totalCallCount += info.CallCount - countWithCalls++ + i := 0 + for i < len(accounts) { + j := i + 1 + for j < len(accounts) && sameAccountWithLoadGroup(accounts[i], accounts[j]) { + j++ } - } - - var avgCallCount int64 - if countWithCalls > 0 { - avgCallCount = totalCallCount / int64(countWithCalls) - } - - // 2. 获取每个账号的有效调用次数 - getEffectiveCallCount := func(acc accountWithLoad) int64 { - if acc.account == nil { - return 0 + if j-i > 1 { + mathrand.Shuffle(j-i, func(a, b int) { + accounts[i+a], accounts[i+b] = accounts[i+b], accounts[i+a] + }) } - info := modelLoadMap[acc.account.ID] - if info == nil || info.CallCount == 0 { - return avgCallCount // 新账号使用平均值 - } - return info.CallCount + i = j } +} - // 3. 找到最小调用次数 - minCount := getEffectiveCallCount(accounts[0]) - for _, acc := range accounts[1:] { - if c := getEffectiveCallCount(acc); c < minCount { - minCount = c - } +// sameAccountWithLoadGroup 判断两个 accountWithLoad 是否属于同一排序组 +func sameAccountWithLoadGroup(a, b accountWithLoad) bool { + if a.account.Priority != b.account.Priority { + return false } - - // 4. 收集所有具有最小调用次数的账号 - var candidateIdxs []int - for i, acc := range accounts { - if getEffectiveCallCount(acc) == minCount { - candidateIdxs = append(candidateIdxs, i) - } + if a.loadInfo.LoadRate != b.loadInfo.LoadRate { + return false } + return sameLastUsedAt(a.account.LastUsedAt, b.account.LastUsedAt) +} - // 5. 如果只有一个候选,直接返回 - if len(candidateIdxs) == 1 { - return &accounts[candidateIdxs[0]] +// shuffleWithinPriorityAndLastUsed 对排序后的 []*Account 切片,按 (Priority, LastUsedAt) 分组后组内随机打乱。 +func shuffleWithinPriorityAndLastUsed(accounts []*Account) { + if len(accounts) <= 1 { + return } - - // 6. preferOAuth 处理 - if preferOAuth { - var oauthIdxs []int - for _, idx := range candidateIdxs { - if accounts[idx].account.Type == AccountTypeOAuth { - oauthIdxs = append(oauthIdxs, idx) - } + i := 0 + for i < len(accounts) { + j := i + 1 + for j < len(accounts) && sameAccountGroup(accounts[i], accounts[j]) { + j++ } - if len(oauthIdxs) > 0 { - candidateIdxs = oauthIdxs + if j-i > 1 { + mathrand.Shuffle(j-i, func(a, b int) { + accounts[i+a], accounts[i+b] = accounts[i+b], accounts[i+a] + }) } + i = j } +} - // 7. 随机选择 - return &accounts[candidateIdxs[mathrand.Intn(len(candidateIdxs))]] +// sameAccountGroup 判断两个 Account 是否属于同一排序组(Priority + LastUsedAt) +func sameAccountGroup(a, b *Account) bool { + if a.Priority != b.Priority { + return false + } + return sameLastUsedAt(a.LastUsedAt, b.LastUsedAt) +} + +// sameLastUsedAt 判断两个 LastUsedAt 是否相同(精度到秒) +func sameLastUsedAt(a, b *time.Time) bool { + switch { + case a == nil && b == nil: + return true + case a == nil || b == nil: + return false + default: + return a.Unix() == b.Unix() + } } // sortCandidatesForFallback 根据配置选择排序策略 @@ -2153,13 +2060,6 @@ func shuffleWithinPriority(accounts []*Account) { // selectAccountForModelWithPlatform 选择单平台账户(完全隔离) func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) { - // 对 Antigravity 平台,检查请求的模型系列是否在分组支持范围内 - if platform == PlatformAntigravity && groupID != nil && requestedModel != "" { - if err := s.checkAntigravityModelScope(ctx, *groupID, requestedModel); err != nil { - return nil, err - } - } - preferOAuth := platform == PlatformGemini routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, platform) @@ -5171,27 +5071,6 @@ func (s *GatewayService) validateUpstreamBaseURL(raw string) (string, error) { return normalized, nil } -// checkAntigravityModelScope 检查 Antigravity 平台的模型系列是否在分组支持范围内 -func (s *GatewayService) checkAntigravityModelScope(ctx context.Context, groupID int64, requestedModel string) error { - scope, ok := ResolveAntigravityQuotaScope(requestedModel) - if !ok { - return nil // 无法解析 scope,跳过检查 - } - - group, err := s.resolveGroupByID(ctx, groupID) - if err != nil { - return nil // 查询失败时放行 - } - if group == nil { - return nil // 分组不存在时放行 - } - - if !IsScopeSupported(group.SupportedModelScopes, scope) { - return ErrModelScopeNotSupported - } - return nil -} - // GetAvailableModels returns the list of models available for a group // It aggregates model_mapping keys from all schedulable accounts in the group func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64, platform string) []string { diff --git a/backend/internal/service/gateway_service_benchmark_test.go b/backend/internal/service/gateway_service_benchmark_test.go index f15a85d6..c9c4d3dd 100644 --- a/backend/internal/service/gateway_service_benchmark_test.go +++ b/backend/internal/service/gateway_service_benchmark_test.go @@ -14,7 +14,7 @@ func BenchmarkGenerateSessionHash_Metadata(b *testing.B) { b.ReportAllocs() for i := 0; i < b.N; i++ { - parsed, err := ParseGatewayRequest(body) + parsed, err := ParseGatewayRequest(body, "") if err != nil { b.Fatalf("解析请求失败: %v", err) } diff --git a/backend/internal/service/gemini_error_policy_test.go b/backend/internal/service/gemini_error_policy_test.go new file mode 100644 index 00000000..2ce8793a --- /dev/null +++ b/backend/internal/service/gemini_error_policy_test.go @@ -0,0 +1,384 @@ +//go:build unit + +package service + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// TestShouldFailoverGeminiUpstreamError — verifies the failover decision +// for the ErrorPolicyNone path (original logic preserved). +// --------------------------------------------------------------------------- + +func TestShouldFailoverGeminiUpstreamError(t *testing.T) { + svc := &GeminiMessagesCompatService{} + + tests := []struct { + name string + statusCode int + expected bool + }{ + {"401_failover", 401, true}, + {"403_failover", 403, true}, + {"429_failover", 429, true}, + {"529_failover", 529, true}, + {"500_failover", 500, true}, + {"502_failover", 502, true}, + {"503_failover", 503, true}, + {"400_no_failover", 400, false}, + {"404_no_failover", 404, false}, + {"422_no_failover", 422, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := svc.shouldFailoverGeminiUpstreamError(tt.statusCode) + require.Equal(t, tt.expected, got) + }) + } +} + +// --------------------------------------------------------------------------- +// TestCheckErrorPolicy_GeminiAccounts — verifies CheckErrorPolicy works +// correctly for Gemini platform accounts (API Key type). +// --------------------------------------------------------------------------- + +func TestCheckErrorPolicy_GeminiAccounts(t *testing.T) { + tests := []struct { + name string + account *Account + statusCode int + body []byte + expected ErrorPolicyResult + }{ + { + name: "gemini_apikey_custom_codes_hit", + account: &Account{ + ID: 100, + Type: AccountTypeAPIKey, + Platform: PlatformGemini, + Credentials: map[string]any{ + "custom_error_codes_enabled": true, + "custom_error_codes": []any{float64(429), float64(500)}, + }, + }, + statusCode: 429, + body: []byte(`{"error":"rate limited"}`), + expected: ErrorPolicyMatched, + }, + { + name: "gemini_apikey_custom_codes_miss", + account: &Account{ + ID: 101, + Type: AccountTypeAPIKey, + Platform: PlatformGemini, + Credentials: map[string]any{ + "custom_error_codes_enabled": true, + "custom_error_codes": []any{float64(429)}, + }, + }, + statusCode: 500, + body: []byte(`{"error":"internal"}`), + expected: ErrorPolicySkipped, + }, + { + name: "gemini_apikey_no_custom_codes_returns_none", + account: &Account{ + ID: 102, + Type: AccountTypeAPIKey, + Platform: PlatformGemini, + }, + statusCode: 500, + body: []byte(`{"error":"internal"}`), + expected: ErrorPolicyNone, + }, + { + name: "gemini_apikey_temp_unschedulable_hit", + account: &Account{ + ID: 103, + Type: AccountTypeAPIKey, + Platform: PlatformGemini, + Credentials: map[string]any{ + "temp_unschedulable_enabled": true, + "temp_unschedulable_rules": []any{ + map[string]any{ + "error_code": float64(503), + "keywords": []any{"overloaded"}, + "duration_minutes": float64(10), + }, + }, + }, + }, + statusCode: 503, + body: []byte(`overloaded service`), + expected: ErrorPolicyTempUnscheduled, + }, + { + name: "gemini_custom_codes_override_temp_unschedulable", + account: &Account{ + ID: 104, + Type: AccountTypeAPIKey, + Platform: PlatformGemini, + Credentials: map[string]any{ + "custom_error_codes_enabled": true, + "custom_error_codes": []any{float64(503)}, + "temp_unschedulable_enabled": true, + "temp_unschedulable_rules": []any{ + map[string]any{ + "error_code": float64(503), + "keywords": []any{"overloaded"}, + "duration_minutes": float64(10), + }, + }, + }, + }, + statusCode: 503, + body: []byte(`overloaded`), + expected: ErrorPolicyMatched, // custom codes take precedence + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + repo := &errorPolicyRepoStub{} + svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + + result := svc.CheckErrorPolicy(context.Background(), tt.account, tt.statusCode, tt.body) + require.Equal(t, tt.expected, result) + }) + } +} + +// --------------------------------------------------------------------------- +// TestGeminiErrorPolicyIntegration — verifies the Gemini error handling +// paths produce the correct behavior for each ErrorPolicyResult. +// +// These tests simulate the inline error policy switch in handleClaudeCompat +// and forwardNativeGemini by calling the same methods in the same order. +// --------------------------------------------------------------------------- + +func TestGeminiErrorPolicyIntegration(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + account *Account + statusCode int + respBody []byte + expectFailover bool // expect UpstreamFailoverError + expectHandleError bool // expect handleGeminiUpstreamError to be called + expectShouldFailover bool // for None path, whether shouldFailover triggers + }{ + { + name: "custom_codes_matched_429_failover", + account: &Account{ + ID: 200, + Type: AccountTypeAPIKey, + Platform: PlatformGemini, + Credentials: map[string]any{ + "custom_error_codes_enabled": true, + "custom_error_codes": []any{float64(429)}, + }, + }, + statusCode: 429, + respBody: []byte(`{"error":"rate limited"}`), + expectFailover: true, + expectHandleError: true, + }, + { + name: "custom_codes_skipped_500_no_failover", + account: &Account{ + ID: 201, + Type: AccountTypeAPIKey, + Platform: PlatformGemini, + Credentials: map[string]any{ + "custom_error_codes_enabled": true, + "custom_error_codes": []any{float64(429)}, + }, + }, + statusCode: 500, + respBody: []byte(`{"error":"internal"}`), + expectFailover: false, + expectHandleError: false, + }, + { + name: "temp_unschedulable_matched_failover", + account: &Account{ + ID: 202, + Type: AccountTypeAPIKey, + Platform: PlatformGemini, + Credentials: map[string]any{ + "temp_unschedulable_enabled": true, + "temp_unschedulable_rules": []any{ + map[string]any{ + "error_code": float64(503), + "keywords": []any{"overloaded"}, + "duration_minutes": float64(10), + }, + }, + }, + }, + statusCode: 503, + respBody: []byte(`overloaded`), + expectFailover: true, + expectHandleError: true, + }, + { + name: "no_policy_429_failover_via_shouldFailover", + account: &Account{ + ID: 203, + Type: AccountTypeAPIKey, + Platform: PlatformGemini, + }, + statusCode: 429, + respBody: []byte(`{"error":"rate limited"}`), + expectFailover: true, + expectHandleError: true, + expectShouldFailover: true, + }, + { + name: "no_policy_400_no_failover", + account: &Account{ + ID: 204, + Type: AccountTypeAPIKey, + Platform: PlatformGemini, + }, + statusCode: 400, + respBody: []byte(`{"error":"bad request"}`), + expectFailover: false, + expectHandleError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + repo := &geminiErrorPolicyRepo{} + rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + svc := &GeminiMessagesCompatService{ + accountRepo: repo, + rateLimitService: rlSvc, + } + + writer := httptest.NewRecorder() + c, _ := gin.CreateTestContext(writer) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + // Simulate the Claude compat error handling path (same logic as native). + // This mirrors the inline switch in handleClaudeCompat. + var handleErrorCalled bool + var gotFailover bool + + ctx := context.Background() + statusCode := tt.statusCode + respBody := tt.respBody + account := tt.account + headers := http.Header{} + + if svc.rateLimitService != nil { + switch svc.rateLimitService.CheckErrorPolicy(ctx, account, statusCode, respBody) { + case ErrorPolicySkipped: + // Skipped → return error directly (no handleGeminiUpstreamError, no failover) + gotFailover = false + handleErrorCalled = false + goto verify + case ErrorPolicyMatched, ErrorPolicyTempUnscheduled: + svc.handleGeminiUpstreamError(ctx, account, statusCode, headers, respBody) + handleErrorCalled = true + gotFailover = true + goto verify + } + } + + // ErrorPolicyNone → original logic + svc.handleGeminiUpstreamError(ctx, account, statusCode, headers, respBody) + handleErrorCalled = true + if svc.shouldFailoverGeminiUpstreamError(statusCode) { + gotFailover = true + } + + verify: + require.Equal(t, tt.expectFailover, gotFailover, "failover mismatch") + require.Equal(t, tt.expectHandleError, handleErrorCalled, "handleGeminiUpstreamError call mismatch") + + if tt.expectShouldFailover { + require.True(t, svc.shouldFailoverGeminiUpstreamError(statusCode), + "shouldFailoverGeminiUpstreamError should return true for status %d", statusCode) + } + }) + } +} + +// --------------------------------------------------------------------------- +// TestGeminiErrorPolicy_NilRateLimitService — verifies nil safety +// --------------------------------------------------------------------------- + +func TestGeminiErrorPolicy_NilRateLimitService(t *testing.T) { + svc := &GeminiMessagesCompatService{ + rateLimitService: nil, + } + + // When rateLimitService is nil, error policy is skipped → falls through to + // shouldFailoverGeminiUpstreamError (original logic). + // Verify this doesn't panic and follows expected behavior. + + ctx := context.Background() + account := &Account{ + ID: 300, + Type: AccountTypeAPIKey, + Platform: PlatformGemini, + Credentials: map[string]any{ + "custom_error_codes_enabled": true, + "custom_error_codes": []any{float64(429)}, + }, + } + + // The nil check should prevent CheckErrorPolicy from being called + if svc.rateLimitService != nil { + t.Fatal("rateLimitService should be nil for this test") + } + + // shouldFailoverGeminiUpstreamError still works + require.True(t, svc.shouldFailoverGeminiUpstreamError(429)) + require.False(t, svc.shouldFailoverGeminiUpstreamError(400)) + + // handleGeminiUpstreamError should not panic with nil rateLimitService + require.NotPanics(t, func() { + svc.handleGeminiUpstreamError(ctx, account, 500, http.Header{}, []byte(`error`)) + }) +} + +// --------------------------------------------------------------------------- +// geminiErrorPolicyRepo — minimal AccountRepository stub for Gemini error +// policy tests. Embeds mockAccountRepoForGemini and adds tracking. +// --------------------------------------------------------------------------- + +type geminiErrorPolicyRepo struct { + mockAccountRepoForGemini + setErrorCalls int + setRateLimitedCalls int + setTempCalls int +} + +func (r *geminiErrorPolicyRepo) SetError(_ context.Context, _ int64, _ string) error { + r.setErrorCalls++ + return nil +} + +func (r *geminiErrorPolicyRepo) SetRateLimited(_ context.Context, _ int64, _ time.Time) error { + r.setRateLimitedCalls++ + return nil +} + +func (r *geminiErrorPolicyRepo) SetTempUnschedulable(_ context.Context, _ int64, _ time.Time, _ string) error { + r.setTempCalls++ + return nil +} diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go index 4e0442fd..d77f6f92 100644 --- a/backend/internal/service/gemini_messages_compat_service.go +++ b/backend/internal/service/gemini_messages_compat_service.go @@ -831,38 +831,47 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex if resp.StatusCode >= 400 { respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - tempMatched := false + // 统一错误策略:自定义错误码 + 临时不可调度 if s.rateLimitService != nil { - tempMatched = s.rateLimitService.HandleTempUnschedulable(ctx, account, resp.StatusCode, respBody) - } - s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) - if tempMatched { - upstreamReqID := resp.Header.Get(requestIDHeader) - if upstreamReqID == "" { - upstreamReqID = resp.Header.Get("x-goog-request-id") - } - upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) - upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) - upstreamDetail := "" - if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { - maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes - if maxBytes <= 0 { - maxBytes = 2048 + switch s.rateLimitService.CheckErrorPolicy(ctx, account, resp.StatusCode, respBody) { + case ErrorPolicySkipped: + upstreamReqID := resp.Header.Get(requestIDHeader) + if upstreamReqID == "" { + upstreamReqID = resp.Header.Get("x-goog-request-id") } - upstreamDetail = truncateString(string(respBody), maxBytes) + return nil, s.writeGeminiMappedError(c, account, resp.StatusCode, upstreamReqID, respBody) + case ErrorPolicyMatched, ErrorPolicyTempUnscheduled: + s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + upstreamReqID := resp.Header.Get(requestIDHeader) + if upstreamReqID == "" { + upstreamReqID = resp.Header.Get("x-goog-request-id") + } + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + upstreamDetail := "" + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + upstreamDetail = truncateString(string(respBody), maxBytes) + } + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: upstreamReqID, + Kind: "failover", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} } - appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ - Platform: account.Platform, - AccountID: account.ID, - AccountName: account.Name, - UpstreamStatusCode: resp.StatusCode, - UpstreamRequestID: upstreamReqID, - Kind: "failover", - Message: upstreamMsg, - Detail: upstreamDetail, - }) - return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} } + + // ErrorPolicyNone → 原有逻辑 + s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) { upstreamReqID := resp.Header.Get(requestIDHeader) if upstreamReqID == "" { @@ -1249,14 +1258,9 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. if resp.StatusCode >= 400 { respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - tempMatched := false - if s.rateLimitService != nil { - tempMatched = s.rateLimitService.HandleTempUnschedulable(ctx, account, resp.StatusCode, respBody) - } - s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) - // Best-effort fallback for OAuth tokens missing AI Studio scopes when calling countTokens. // This avoids Gemini SDKs failing hard during preflight token counting. + // Checked before error policy so it always works regardless of custom error codes. if action == "countTokens" && isOAuth && isGeminiInsufficientScope(resp.Header, respBody) { estimated := estimateGeminiCountTokens(body) c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated}) @@ -1270,30 +1274,46 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. }, nil } - if tempMatched { - evBody := unwrapIfNeeded(isOAuth, respBody) - upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(evBody)) - upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) - upstreamDetail := "" - if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { - maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes - if maxBytes <= 0 { - maxBytes = 2048 + // 统一错误策略:自定义错误码 + 临时不可调度 + if s.rateLimitService != nil { + switch s.rateLimitService.CheckErrorPolicy(ctx, account, resp.StatusCode, respBody) { + case ErrorPolicySkipped: + respBody = unwrapIfNeeded(isOAuth, respBody) + contentType := resp.Header.Get("Content-Type") + if contentType == "" { + contentType = "application/json" } - upstreamDetail = truncateString(string(evBody), maxBytes) + c.Data(resp.StatusCode, contentType, respBody) + return nil, fmt.Errorf("gemini upstream error: %d (skipped by error policy)", resp.StatusCode) + case ErrorPolicyMatched, ErrorPolicyTempUnscheduled: + s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + evBody := unwrapIfNeeded(isOAuth, respBody) + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(evBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + upstreamDetail := "" + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + upstreamDetail = truncateString(string(evBody), maxBytes) + } + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: requestID, + Kind: "failover", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} } - appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ - Platform: account.Platform, - AccountID: account.ID, - AccountName: account.Name, - UpstreamStatusCode: resp.StatusCode, - UpstreamRequestID: requestID, - Kind: "failover", - Message: upstreamMsg, - Detail: upstreamDetail, - }) - return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} } + + // ErrorPolicyNone → 原有逻辑 + s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) { evBody := unwrapIfNeeded(isOAuth, respBody) upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(evBody)) diff --git a/backend/internal/service/gemini_multiplatform_test.go b/backend/internal/service/gemini_multiplatform_test.go index 0c54dc39..fee0f1cf 100644 --- a/backend/internal/service/gemini_multiplatform_test.go +++ b/backend/internal/service/gemini_multiplatform_test.go @@ -133,9 +133,6 @@ func (m *mockAccountRepoForGemini) ListSchedulableByGroupIDAndPlatforms(ctx cont func (m *mockAccountRepoForGemini) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { return nil } -func (m *mockAccountRepoForGemini) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error { - return nil -} func (m *mockAccountRepoForGemini) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error { return nil } @@ -265,29 +262,6 @@ func (m *mockGatewayCacheForGemini) DeleteSessionAccountID(ctx context.Context, return nil } -func (m *mockGatewayCacheForGemini) IncrModelCallCount(ctx context.Context, accountID int64, model string) (int64, error) { - return 0, nil -} - -func (m *mockGatewayCacheForGemini) GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*ModelLoadInfo, error) { - return nil, nil -} - -func (m *mockGatewayCacheForGemini) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) { - return "", 0, false -} - -func (m *mockGatewayCacheForGemini) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error { - return nil -} - -func (m *mockGatewayCacheForGemini) FindAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) { - return "", 0, false -} - -func (m *mockGatewayCacheForGemini) SaveAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error { - return nil -} // TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform 测试 Gemini 单平台选择 func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform(t *testing.T) { diff --git a/backend/internal/service/gemini_session.go b/backend/internal/service/gemini_session.go index 859ae9f3..1780d1da 100644 --- a/backend/internal/service/gemini_session.go +++ b/backend/internal/service/gemini_session.go @@ -6,26 +6,11 @@ import ( "encoding/json" "strconv" "strings" - "time" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" "github.com/cespare/xxhash/v2" ) -// Gemini 会话 ID Fallback 相关常量 -const ( - // geminiSessionTTLSeconds Gemini 会话缓存 TTL(5 分钟) - geminiSessionTTLSeconds = 300 - - // geminiSessionKeyPrefix Gemini 会话 Redis key 前缀 - geminiSessionKeyPrefix = "gemini:sess:" -) - -// GeminiSessionTTL 返回 Gemini 会话缓存 TTL -func GeminiSessionTTL() time.Duration { - return geminiSessionTTLSeconds * time.Second -} - // shortHash 使用 XXHash64 + Base36 生成短 hash(16 字符) // XXHash64 比 SHA256 快约 10 倍,Base36 比 Hex 短约 20% func shortHash(data []byte) string { @@ -79,35 +64,6 @@ func GenerateGeminiPrefixHash(userID, apiKeyID int64, ip, userAgent, platform, m return base64.RawURLEncoding.EncodeToString(hash[:12]) } -// BuildGeminiSessionKey 构建 Gemini 会话 Redis key -// 格式: gemini:sess:{groupID}:{prefixHash}:{digestChain} -func BuildGeminiSessionKey(groupID int64, prefixHash, digestChain string) string { - return geminiSessionKeyPrefix + strconv.FormatInt(groupID, 10) + ":" + prefixHash + ":" + digestChain -} - -// GenerateDigestChainPrefixes 生成摘要链的所有前缀(从长到短) -// 用于 MGET 批量查询最长匹配 -func GenerateDigestChainPrefixes(chain string) []string { - if chain == "" { - return nil - } - - var prefixes []string - c := chain - - for c != "" { - prefixes = append(prefixes, c) - // 找到最后一个 "-" 的位置 - if i := strings.LastIndex(c, "-"); i > 0 { - c = c[:i] - } else { - break - } - } - - return prefixes -} - // ParseGeminiSessionValue 解析 Gemini 会话缓存值 // 格式: {uuid}:{accountID} func ParseGeminiSessionValue(value string) (uuid string, accountID int64, ok bool) { @@ -139,15 +95,6 @@ func FormatGeminiSessionValue(uuid string, accountID int64) string { // geminiDigestSessionKeyPrefix Gemini 摘要 fallback 会话 key 前缀 const geminiDigestSessionKeyPrefix = "gemini:digest:" -// geminiTrieKeyPrefix Gemini Trie 会话 key 前缀 -const geminiTrieKeyPrefix = "gemini:trie:" - -// BuildGeminiTrieKey 构建 Gemini Trie Redis key -// 格式: gemini:trie:{groupID}:{prefixHash} -func BuildGeminiTrieKey(groupID int64, prefixHash string) string { - return geminiTrieKeyPrefix + strconv.FormatInt(groupID, 10) + ":" + prefixHash -} - // GenerateGeminiDigestSessionKey 生成 Gemini 摘要 fallback 的 sessionKey // 组合 prefixHash 前 8 位 + uuid 前 8 位,确保不同会话产生不同的 sessionKey // 用于在 SelectAccountWithLoadAwareness 中保持粘性会话 diff --git a/backend/internal/service/gemini_session_integration_test.go b/backend/internal/service/gemini_session_integration_test.go index 928c62cf..95b5f594 100644 --- a/backend/internal/service/gemini_session_integration_test.go +++ b/backend/internal/service/gemini_session_integration_test.go @@ -1,41 +1,14 @@ package service import ( - "context" "testing" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" ) -// mockGeminiSessionCache 模拟 Redis 缓存 -type mockGeminiSessionCache struct { - sessions map[string]string // key -> value -} - -func newMockGeminiSessionCache() *mockGeminiSessionCache { - return &mockGeminiSessionCache{sessions: make(map[string]string)} -} - -func (m *mockGeminiSessionCache) Save(groupID int64, prefixHash, digestChain, uuid string, accountID int64) { - key := BuildGeminiSessionKey(groupID, prefixHash, digestChain) - value := FormatGeminiSessionValue(uuid, accountID) - m.sessions[key] = value -} - -func (m *mockGeminiSessionCache) Find(groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) { - prefixes := GenerateDigestChainPrefixes(digestChain) - for _, p := range prefixes { - key := BuildGeminiSessionKey(groupID, prefixHash, p) - if val, ok := m.sessions[key]; ok { - return ParseGeminiSessionValue(val) - } - } - return "", 0, false -} - // TestGeminiSessionContinuousConversation 测试连续会话的摘要链匹配 func TestGeminiSessionContinuousConversation(t *testing.T) { - cache := newMockGeminiSessionCache() + store := NewDigestSessionStore() groupID := int64(1) prefixHash := "test_prefix_hash" sessionUUID := "session-uuid-12345" @@ -54,13 +27,13 @@ func TestGeminiSessionContinuousConversation(t *testing.T) { t.Logf("Round 1 chain: %s", chain1) // 第一轮:没有找到会话,创建新会话 - _, _, found := cache.Find(groupID, prefixHash, chain1) + _, _, _, found := store.Find(groupID, prefixHash, chain1) if found { t.Error("Round 1: should not find existing session") } - // 保存第一轮会话 - cache.Save(groupID, prefixHash, chain1, sessionUUID, accountID) + // 保存第一轮会话(首轮无旧 chain) + store.Save(groupID, prefixHash, chain1, sessionUUID, accountID, "") // 模拟第二轮对话(用户继续对话) req2 := &antigravity.GeminiRequest{ @@ -77,7 +50,7 @@ func TestGeminiSessionContinuousConversation(t *testing.T) { t.Logf("Round 2 chain: %s", chain2) // 第二轮:应该能找到会话(通过前缀匹配) - foundUUID, foundAccID, found := cache.Find(groupID, prefixHash, chain2) + foundUUID, foundAccID, matchedChain, found := store.Find(groupID, prefixHash, chain2) if !found { t.Error("Round 2: should find session via prefix matching") } @@ -88,8 +61,8 @@ func TestGeminiSessionContinuousConversation(t *testing.T) { t.Errorf("Round 2: expected accountID %d, got %d", accountID, foundAccID) } - // 保存第二轮会话 - cache.Save(groupID, prefixHash, chain2, sessionUUID, accountID) + // 保存第二轮会话,传入 Find 返回的 matchedChain 以删旧 key + store.Save(groupID, prefixHash, chain2, sessionUUID, accountID, matchedChain) // 模拟第三轮对话 req3 := &antigravity.GeminiRequest{ @@ -108,7 +81,7 @@ func TestGeminiSessionContinuousConversation(t *testing.T) { t.Logf("Round 3 chain: %s", chain3) // 第三轮:应该能找到会话(通过第二轮的前缀匹配) - foundUUID, foundAccID, found = cache.Find(groupID, prefixHash, chain3) + foundUUID, foundAccID, _, found = store.Find(groupID, prefixHash, chain3) if !found { t.Error("Round 3: should find session via prefix matching") } @@ -118,13 +91,11 @@ func TestGeminiSessionContinuousConversation(t *testing.T) { if foundAccID != accountID { t.Errorf("Round 3: expected accountID %d, got %d", accountID, foundAccID) } - - t.Log("✓ Continuous conversation session matching works correctly!") } // TestGeminiSessionDifferentConversations 测试不同会话不会错误匹配 func TestGeminiSessionDifferentConversations(t *testing.T) { - cache := newMockGeminiSessionCache() + store := NewDigestSessionStore() groupID := int64(1) prefixHash := "test_prefix_hash" @@ -135,7 +106,7 @@ func TestGeminiSessionDifferentConversations(t *testing.T) { }, } chain1 := BuildGeminiDigestChain(req1) - cache.Save(groupID, prefixHash, chain1, "session-1", 100) + store.Save(groupID, prefixHash, chain1, "session-1", 100, "") // 第二个完全不同的会话 req2 := &antigravity.GeminiRequest{ @@ -146,61 +117,29 @@ func TestGeminiSessionDifferentConversations(t *testing.T) { chain2 := BuildGeminiDigestChain(req2) // 不同会话不应该匹配 - _, _, found := cache.Find(groupID, prefixHash, chain2) + _, _, _, found := store.Find(groupID, prefixHash, chain2) if found { t.Error("Different conversations should not match") } - - t.Log("✓ Different conversations are correctly isolated!") } // TestGeminiSessionPrefixMatchingOrder 测试前缀匹配的优先级(最长匹配优先) func TestGeminiSessionPrefixMatchingOrder(t *testing.T) { - cache := newMockGeminiSessionCache() + store := NewDigestSessionStore() groupID := int64(1) prefixHash := "test_prefix_hash" - // 创建一个三轮对话 - req := &antigravity.GeminiRequest{ - SystemInstruction: &antigravity.GeminiContent{ - Parts: []antigravity.GeminiPart{{Text: "System prompt"}}, - }, - Contents: []antigravity.GeminiContent{ - {Role: "user", Parts: []antigravity.GeminiPart{{Text: "Q1"}}}, - {Role: "model", Parts: []antigravity.GeminiPart{{Text: "A1"}}}, - {Role: "user", Parts: []antigravity.GeminiPart{{Text: "Q2"}}}, - }, - } - fullChain := BuildGeminiDigestChain(req) - prefixes := GenerateDigestChainPrefixes(fullChain) - - t.Logf("Full chain: %s", fullChain) - t.Logf("Prefixes (longest first): %v", prefixes) - - // 验证前缀生成顺序(从长到短) - if len(prefixes) != 4 { - t.Errorf("Expected 4 prefixes, got %d", len(prefixes)) - } - // 保存不同轮次的会话到不同账号 - // 第一轮(最短前缀)-> 账号 1 - cache.Save(groupID, prefixHash, prefixes[3], "session-round1", 1) - // 第二轮 -> 账号 2 - cache.Save(groupID, prefixHash, prefixes[2], "session-round2", 2) - // 第三轮(最长前缀,完整链)-> 账号 3 - cache.Save(groupID, prefixHash, prefixes[0], "session-round3", 3) + store.Save(groupID, prefixHash, "s:sys-u:q1", "session-round1", 1, "") + store.Save(groupID, prefixHash, "s:sys-u:q1-m:a1", "session-round2", 2, "") + store.Save(groupID, prefixHash, "s:sys-u:q1-m:a1-u:q2", "session-round3", 3, "") - // 查找应该返回最长匹配(账号 3) - _, accID, found := cache.Find(groupID, prefixHash, fullChain) + // 查找更长的链,应该返回最长匹配(账号 3) + _, accID, _, found := store.Find(groupID, prefixHash, "s:sys-u:q1-m:a1-u:q2-m:a2") if !found { t.Error("Should find session") } if accID != 3 { t.Errorf("Should match longest prefix (account 3), got account %d", accID) } - - t.Log("✓ Longest prefix matching works correctly!") } - -// 确保 context 包被使用(避免未使用的导入警告) -var _ = context.Background diff --git a/backend/internal/service/gemini_session_test.go b/backend/internal/service/gemini_session_test.go index 8c1908f7..a034cddd 100644 --- a/backend/internal/service/gemini_session_test.go +++ b/backend/internal/service/gemini_session_test.go @@ -152,61 +152,6 @@ func TestGenerateGeminiPrefixHash(t *testing.T) { } } -func TestGenerateDigestChainPrefixes(t *testing.T) { - tests := []struct { - name string - chain string - want []string - wantLen int - }{ - { - name: "empty", - chain: "", - wantLen: 0, - }, - { - name: "single part", - chain: "u:abc123", - want: []string{"u:abc123"}, - wantLen: 1, - }, - { - name: "two parts", - chain: "s:xyz-u:abc", - want: []string{"s:xyz-u:abc", "s:xyz"}, - wantLen: 2, - }, - { - name: "four parts", - chain: "s:a-u:b-m:c-u:d", - want: []string{"s:a-u:b-m:c-u:d", "s:a-u:b-m:c", "s:a-u:b", "s:a"}, - wantLen: 4, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := GenerateDigestChainPrefixes(tt.chain) - - if len(result) != tt.wantLen { - t.Errorf("expected %d prefixes, got %d: %v", tt.wantLen, len(result), result) - } - - if tt.want != nil { - for i, want := range tt.want { - if i >= len(result) { - t.Errorf("missing prefix at index %d", i) - continue - } - if result[i] != want { - t.Errorf("prefix[%d]: expected %s, got %s", i, want, result[i]) - } - } - } - }) - } -} - func TestParseGeminiSessionValue(t *testing.T) { tests := []struct { name string @@ -442,40 +387,3 @@ func TestGenerateGeminiDigestSessionKey(t *testing.T) { } }) } - -func TestBuildGeminiTrieKey(t *testing.T) { - tests := []struct { - name string - groupID int64 - prefixHash string - want string - }{ - { - name: "normal", - groupID: 123, - prefixHash: "abcdef12", - want: "gemini:trie:123:abcdef12", - }, - { - name: "zero group", - groupID: 0, - prefixHash: "xyz", - want: "gemini:trie:0:xyz", - }, - { - name: "empty prefix", - groupID: 1, - prefixHash: "", - want: "gemini:trie:1:", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := BuildGeminiTrieKey(tt.groupID, tt.prefixHash) - if got != tt.want { - t.Errorf("BuildGeminiTrieKey(%d, %q) = %q, want %q", tt.groupID, tt.prefixHash, got, tt.want) - } - }) - } -} diff --git a/backend/internal/service/generate_session_hash_test.go b/backend/internal/service/generate_session_hash_test.go new file mode 100644 index 00000000..8aa358a5 --- /dev/null +++ b/backend/internal/service/generate_session_hash_test.go @@ -0,0 +1,1213 @@ +//go:build unit + +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +// ============ 基础优先级测试 ============ + +func TestGenerateSessionHash_NilParsedRequest(t *testing.T) { + svc := &GatewayService{} + require.Empty(t, svc.GenerateSessionHash(nil)) +} + +func TestGenerateSessionHash_EmptyRequest(t *testing.T) { + svc := &GatewayService{} + require.Empty(t, svc.GenerateSessionHash(&ParsedRequest{})) +} + +func TestGenerateSessionHash_MetadataHasHighestPriority(t *testing.T) { + svc := &GatewayService{} + + parsed := &ParsedRequest{ + MetadataUserID: "session_123e4567-e89b-12d3-a456-426614174000", + System: "You are a helpful assistant.", + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + } + + hash := svc.GenerateSessionHash(parsed) + require.Equal(t, "123e4567-e89b-12d3-a456-426614174000", hash, "metadata session_id should have highest priority") +} + +// ============ System + Messages 基础测试 ============ + +func TestGenerateSessionHash_SystemPlusMessages(t *testing.T) { + svc := &GatewayService{} + + withSystem := &ParsedRequest{ + System: "You are a helpful assistant.", + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + } + withoutSystem := &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + } + + h1 := svc.GenerateSessionHash(withSystem) + h2 := svc.GenerateSessionHash(withoutSystem) + require.NotEmpty(t, h1) + require.NotEmpty(t, h2) + require.NotEqual(t, h1, h2, "system prompt should be part of digest, producing different hash") +} + +func TestGenerateSessionHash_SystemOnlyProducesHash(t *testing.T) { + svc := &GatewayService{} + + parsed := &ParsedRequest{ + System: "You are a helpful assistant.", + HasSystem: true, + } + hash := svc.GenerateSessionHash(parsed) + require.NotEmpty(t, hash, "system prompt alone should produce a hash as part of full digest") +} + +func TestGenerateSessionHash_DifferentSystemsSameMessages(t *testing.T) { + svc := &GatewayService{} + + parsed1 := &ParsedRequest{ + System: "You are assistant A.", + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + } + parsed2 := &ParsedRequest{ + System: "You are assistant B.", + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + } + + h1 := svc.GenerateSessionHash(parsed1) + h2 := svc.GenerateSessionHash(parsed2) + require.NotEqual(t, h1, h2, "different system prompts with same messages should produce different hashes") +} + +func TestGenerateSessionHash_SameSystemSameMessages(t *testing.T) { + svc := &GatewayService{} + + mk := func() *ParsedRequest { + return &ParsedRequest{ + System: "You are a helpful assistant.", + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + map[string]any{"role": "assistant", "content": "hi"}, + }, + } + } + + h1 := svc.GenerateSessionHash(mk()) + h2 := svc.GenerateSessionHash(mk()) + require.Equal(t, h1, h2, "same system + same messages should produce identical hash") +} + +func TestGenerateSessionHash_DifferentMessagesProduceDifferentHash(t *testing.T) { + svc := &GatewayService{} + + parsed1 := &ParsedRequest{ + System: "You are a helpful assistant.", + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "help me with Go"}, + }, + } + parsed2 := &ParsedRequest{ + System: "You are a helpful assistant.", + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "help me with Python"}, + }, + } + + h1 := svc.GenerateSessionHash(parsed1) + h2 := svc.GenerateSessionHash(parsed2) + require.NotEqual(t, h1, h2, "same system but different messages should produce different hashes") +} + +// ============ SessionContext 核心测试 ============ + +func TestGenerateSessionHash_DifferentSessionContextProducesDifferentHash(t *testing.T) { + svc := &GatewayService{} + + // 相同消息 + 不同 SessionContext → 不同 hash(解决碰撞问题的核心场景) + parsed1 := &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + SessionContext: &SessionContext{ + ClientIP: "192.168.1.1", + UserAgent: "Mozilla/5.0", + APIKeyID: 100, + }, + } + parsed2 := &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + SessionContext: &SessionContext{ + ClientIP: "10.0.0.1", + UserAgent: "curl/7.0", + APIKeyID: 200, + }, + } + + h1 := svc.GenerateSessionHash(parsed1) + h2 := svc.GenerateSessionHash(parsed2) + require.NotEmpty(t, h1) + require.NotEmpty(t, h2) + require.NotEqual(t, h1, h2, "same messages but different SessionContext should produce different hashes") +} + +func TestGenerateSessionHash_SameSessionContextProducesSameHash(t *testing.T) { + svc := &GatewayService{} + + mk := func() *ParsedRequest { + return &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + SessionContext: &SessionContext{ + ClientIP: "192.168.1.1", + UserAgent: "Mozilla/5.0", + APIKeyID: 100, + }, + } + } + + h1 := svc.GenerateSessionHash(mk()) + h2 := svc.GenerateSessionHash(mk()) + require.Equal(t, h1, h2, "same messages + same SessionContext should produce identical hash") +} + +func TestGenerateSessionHash_MetadataOverridesSessionContext(t *testing.T) { + svc := &GatewayService{} + + parsed := &ParsedRequest{ + MetadataUserID: "session_123e4567-e89b-12d3-a456-426614174000", + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + SessionContext: &SessionContext{ + ClientIP: "192.168.1.1", + UserAgent: "Mozilla/5.0", + APIKeyID: 100, + }, + } + + hash := svc.GenerateSessionHash(parsed) + require.Equal(t, "123e4567-e89b-12d3-a456-426614174000", hash, + "metadata session_id should take priority over SessionContext") +} + +func TestGenerateSessionHash_NilSessionContextBackwardCompatible(t *testing.T) { + svc := &GatewayService{} + + withCtx := &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + SessionContext: nil, + } + withoutCtx := &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + } + + h1 := svc.GenerateSessionHash(withCtx) + h2 := svc.GenerateSessionHash(withoutCtx) + require.Equal(t, h1, h2, "nil SessionContext should produce same hash as no SessionContext") +} + +// ============ 多轮连续会话测试 ============ + +func TestGenerateSessionHash_ContinuousConversation_HashChangesWithMessages(t *testing.T) { + svc := &GatewayService{} + + ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "test", APIKeyID: 1} + + // 模拟连续会话:每增加一轮对话,hash 应该不同(内容累积变化) + round1 := &ParsedRequest{ + System: "You are a helpful assistant.", + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + SessionContext: ctx, + } + + round2 := &ParsedRequest{ + System: "You are a helpful assistant.", + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + map[string]any{"role": "assistant", "content": "Hi there!"}, + map[string]any{"role": "user", "content": "How are you?"}, + }, + SessionContext: ctx, + } + + round3 := &ParsedRequest{ + System: "You are a helpful assistant.", + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + map[string]any{"role": "assistant", "content": "Hi there!"}, + map[string]any{"role": "user", "content": "How are you?"}, + map[string]any{"role": "assistant", "content": "I'm doing well!"}, + map[string]any{"role": "user", "content": "Tell me a joke"}, + }, + SessionContext: ctx, + } + + h1 := svc.GenerateSessionHash(round1) + h2 := svc.GenerateSessionHash(round2) + h3 := svc.GenerateSessionHash(round3) + + require.NotEmpty(t, h1) + require.NotEmpty(t, h2) + require.NotEmpty(t, h3) + require.NotEqual(t, h1, h2, "different conversation rounds should produce different hashes") + require.NotEqual(t, h2, h3, "each new round should produce a different hash") + require.NotEqual(t, h1, h3, "round 1 and round 3 should differ") +} + +func TestGenerateSessionHash_ContinuousConversation_SameRoundSameHash(t *testing.T) { + svc := &GatewayService{} + + ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "test", APIKeyID: 1} + + // 同一轮对话重复请求(如重试)应产生相同 hash + mk := func() *ParsedRequest { + return &ParsedRequest{ + System: "You are a helpful assistant.", + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + map[string]any{"role": "assistant", "content": "Hi there!"}, + map[string]any{"role": "user", "content": "How are you?"}, + }, + SessionContext: ctx, + } + } + + h1 := svc.GenerateSessionHash(mk()) + h2 := svc.GenerateSessionHash(mk()) + require.Equal(t, h1, h2, "same conversation state should produce identical hash on retry") +} + +// ============ 消息回退测试 ============ + +func TestGenerateSessionHash_MessageRollback(t *testing.T) { + svc := &GatewayService{} + + ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "test", APIKeyID: 1} + + // 模拟消息回退:用户删掉最后一轮再重发 + original := &ParsedRequest{ + System: "System prompt", + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "msg1"}, + map[string]any{"role": "assistant", "content": "reply1"}, + map[string]any{"role": "user", "content": "msg2"}, + map[string]any{"role": "assistant", "content": "reply2"}, + map[string]any{"role": "user", "content": "msg3"}, + }, + SessionContext: ctx, + } + + // 回退到 msg2 后,用新的 msg3 替代 + rollback := &ParsedRequest{ + System: "System prompt", + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "msg1"}, + map[string]any{"role": "assistant", "content": "reply1"}, + map[string]any{"role": "user", "content": "msg2"}, + map[string]any{"role": "assistant", "content": "reply2"}, + map[string]any{"role": "user", "content": "different msg3"}, + }, + SessionContext: ctx, + } + + hOrig := svc.GenerateSessionHash(original) + hRollback := svc.GenerateSessionHash(rollback) + require.NotEqual(t, hOrig, hRollback, "rollback with different last message should produce different hash") +} + +func TestGenerateSessionHash_MessageRollbackSameContent(t *testing.T) { + svc := &GatewayService{} + + ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "test", APIKeyID: 1} + + // 回退后重新发送相同内容 → 相同 hash(合理的粘性恢复) + mk := func() *ParsedRequest { + return &ParsedRequest{ + System: "System prompt", + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "msg1"}, + map[string]any{"role": "assistant", "content": "reply1"}, + map[string]any{"role": "user", "content": "msg2"}, + }, + SessionContext: ctx, + } + } + + h1 := svc.GenerateSessionHash(mk()) + h2 := svc.GenerateSessionHash(mk()) + require.Equal(t, h1, h2, "rollback and resend same content should produce same hash") +} + +// ============ 相同 System、不同用户消息 ============ + +func TestGenerateSessionHash_SameSystemDifferentUsers(t *testing.T) { + svc := &GatewayService{} + + // 两个不同用户使用相同 system prompt 但发送不同消息 + user1 := &ParsedRequest{ + System: "You are a code reviewer.", + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "Review this Go code"}, + }, + SessionContext: &SessionContext{ + ClientIP: "1.1.1.1", + UserAgent: "vscode", + APIKeyID: 1, + }, + } + user2 := &ParsedRequest{ + System: "You are a code reviewer.", + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "Review this Python code"}, + }, + SessionContext: &SessionContext{ + ClientIP: "2.2.2.2", + UserAgent: "vscode", + APIKeyID: 2, + }, + } + + h1 := svc.GenerateSessionHash(user1) + h2 := svc.GenerateSessionHash(user2) + require.NotEqual(t, h1, h2, "different users with different messages should get different hashes") +} + +func TestGenerateSessionHash_SameSystemSameMessageDifferentContext(t *testing.T) { + svc := &GatewayService{} + + // 这是修复的核心场景:两个不同用户发送完全相同的 system + messages(如 "hello") + // 有了 SessionContext 后应该产生不同 hash + user1 := &ParsedRequest{ + System: "You are a helpful assistant.", + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + SessionContext: &SessionContext{ + ClientIP: "1.1.1.1", + UserAgent: "Mozilla/5.0", + APIKeyID: 10, + }, + } + user2 := &ParsedRequest{ + System: "You are a helpful assistant.", + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + SessionContext: &SessionContext{ + ClientIP: "2.2.2.2", + UserAgent: "Mozilla/5.0", + APIKeyID: 20, + }, + } + + h1 := svc.GenerateSessionHash(user1) + h2 := svc.GenerateSessionHash(user2) + require.NotEqual(t, h1, h2, "CRITICAL: same system+messages but different users should get different hashes") +} + +// ============ SessionContext 各字段独立影响测试 ============ + +func TestGenerateSessionHash_SessionContext_IPDifference(t *testing.T) { + svc := &GatewayService{} + + base := func(ip string) *ParsedRequest { + return &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "test"}, + }, + SessionContext: &SessionContext{ + ClientIP: ip, + UserAgent: "same-ua", + APIKeyID: 1, + }, + } + } + + h1 := svc.GenerateSessionHash(base("1.1.1.1")) + h2 := svc.GenerateSessionHash(base("2.2.2.2")) + require.NotEqual(t, h1, h2, "different IP should produce different hash") +} + +func TestGenerateSessionHash_SessionContext_UADifference(t *testing.T) { + svc := &GatewayService{} + + base := func(ua string) *ParsedRequest { + return &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "test"}, + }, + SessionContext: &SessionContext{ + ClientIP: "1.1.1.1", + UserAgent: ua, + APIKeyID: 1, + }, + } + } + + h1 := svc.GenerateSessionHash(base("Mozilla/5.0")) + h2 := svc.GenerateSessionHash(base("curl/7.0")) + require.NotEqual(t, h1, h2, "different User-Agent should produce different hash") +} + +func TestGenerateSessionHash_SessionContext_APIKeyIDDifference(t *testing.T) { + svc := &GatewayService{} + + base := func(keyID int64) *ParsedRequest { + return &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "test"}, + }, + SessionContext: &SessionContext{ + ClientIP: "1.1.1.1", + UserAgent: "same-ua", + APIKeyID: keyID, + }, + } + } + + h1 := svc.GenerateSessionHash(base(1)) + h2 := svc.GenerateSessionHash(base(2)) + require.NotEqual(t, h1, h2, "different APIKeyID should produce different hash") +} + +// ============ 多用户并发相同消息场景 ============ + +func TestGenerateSessionHash_MultipleUsersSameFirstMessage(t *testing.T) { + svc := &GatewayService{} + + // 模拟 5 个不同用户同时发送 "hello" → 应该产生 5 个不同的 hash + hashes := make(map[string]bool) + for i := 0; i < 5; i++ { + parsed := &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + SessionContext: &SessionContext{ + ClientIP: "192.168.1." + string(rune('1'+i)), + UserAgent: "client-" + string(rune('A'+i)), + APIKeyID: int64(i + 1), + }, + } + h := svc.GenerateSessionHash(parsed) + require.NotEmpty(t, h) + require.False(t, hashes[h], "hash collision detected for user %d", i) + hashes[h] = true + } + require.Len(t, hashes, 5, "5 different users should produce 5 unique hashes") +} + +// ============ 连续会话粘性:多轮对话同一用户 ============ + +func TestGenerateSessionHash_SameUserGrowingConversation(t *testing.T) { + svc := &GatewayService{} + + ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "browser", APIKeyID: 42} + + // 模拟同一用户的连续会话,每轮 hash 不同但同用户重试保持一致 + messages := []map[string]any{ + {"role": "user", "content": "msg1"}, + {"role": "assistant", "content": "reply1"}, + {"role": "user", "content": "msg2"}, + {"role": "assistant", "content": "reply2"}, + {"role": "user", "content": "msg3"}, + {"role": "assistant", "content": "reply3"}, + {"role": "user", "content": "msg4"}, + } + + prevHash := "" + for round := 1; round <= len(messages); round += 2 { + // 构建前 round 条消息 + msgs := make([]any, round) + for j := 0; j < round; j++ { + msgs[j] = messages[j] + } + parsed := &ParsedRequest{ + System: "System", + HasSystem: true, + Messages: msgs, + SessionContext: ctx, + } + h := svc.GenerateSessionHash(parsed) + require.NotEmpty(t, h, "round %d hash should not be empty", round) + + if prevHash != "" { + require.NotEqual(t, prevHash, h, "round %d hash should differ from previous round", round) + } + prevHash = h + + // 同一轮重试应该相同 + h2 := svc.GenerateSessionHash(parsed) + require.Equal(t, h, h2, "retry of round %d should produce same hash", round) + } +} + +// ============ 多轮消息内容结构化测试 ============ + +func TestGenerateSessionHash_MultipleUserMessages(t *testing.T) { + svc := &GatewayService{} + + ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "test", APIKeyID: 1} + + // 5 条用户消息(无 assistant 回复) + parsed := &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "first"}, + map[string]any{"role": "user", "content": "second"}, + map[string]any{"role": "user", "content": "third"}, + map[string]any{"role": "user", "content": "fourth"}, + map[string]any{"role": "user", "content": "fifth"}, + }, + SessionContext: ctx, + } + + h := svc.GenerateSessionHash(parsed) + require.NotEmpty(t, h) + + // 修改中间一条消息应该改变 hash + parsed2 := &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "first"}, + map[string]any{"role": "user", "content": "CHANGED"}, + map[string]any{"role": "user", "content": "third"}, + map[string]any{"role": "user", "content": "fourth"}, + map[string]any{"role": "user", "content": "fifth"}, + }, + SessionContext: ctx, + } + + h2 := svc.GenerateSessionHash(parsed2) + require.NotEqual(t, h, h2, "changing any message should change the hash") +} + +func TestGenerateSessionHash_MessageOrderMatters(t *testing.T) { + svc := &GatewayService{} + + ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "test", APIKeyID: 1} + + parsed1 := &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "alpha"}, + map[string]any{"role": "user", "content": "beta"}, + }, + SessionContext: ctx, + } + parsed2 := &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "beta"}, + map[string]any{"role": "user", "content": "alpha"}, + }, + SessionContext: ctx, + } + + h1 := svc.GenerateSessionHash(parsed1) + h2 := svc.GenerateSessionHash(parsed2) + require.NotEqual(t, h1, h2, "message order should affect the hash") +} + +// ============ 复杂内容格式测试 ============ + +func TestGenerateSessionHash_StructuredContent(t *testing.T) { + svc := &GatewayService{} + + ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "test", APIKeyID: 1} + + // 结构化 content(数组形式) + parsed := &ParsedRequest{ + Messages: []any{ + map[string]any{ + "role": "user", + "content": []any{ + map[string]any{"type": "text", "text": "Look at this"}, + map[string]any{"type": "text", "text": "And this too"}, + }, + }, + }, + SessionContext: ctx, + } + + h := svc.GenerateSessionHash(parsed) + require.NotEmpty(t, h, "structured content should produce a hash") +} + +func TestGenerateSessionHash_ArraySystemPrompt(t *testing.T) { + svc := &GatewayService{} + + ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "test", APIKeyID: 1} + + // 数组格式的 system prompt + parsed := &ParsedRequest{ + System: []any{ + map[string]any{"type": "text", "text": "You are a helpful assistant."}, + map[string]any{"type": "text", "text": "Be concise."}, + }, + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + SessionContext: ctx, + } + + h := svc.GenerateSessionHash(parsed) + require.NotEmpty(t, h, "array system prompt should produce a hash") +} + +// ============ SessionContext 与 cache_control 优先级 ============ + +func TestGenerateSessionHash_CacheControlOverridesSessionContext(t *testing.T) { + svc := &GatewayService{} + + // 当有 cache_control: ephemeral 时,使用第 2 级优先级 + // SessionContext 不应影响结果 + parsed1 := &ParsedRequest{ + System: []any{ + map[string]any{ + "type": "text", + "text": "You are a tool-specific assistant.", + "cache_control": map[string]any{"type": "ephemeral"}, + }, + }, + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + SessionContext: &SessionContext{ + ClientIP: "1.1.1.1", + UserAgent: "ua1", + APIKeyID: 100, + }, + } + parsed2 := &ParsedRequest{ + System: []any{ + map[string]any{ + "type": "text", + "text": "You are a tool-specific assistant.", + "cache_control": map[string]any{"type": "ephemeral"}, + }, + }, + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + SessionContext: &SessionContext{ + ClientIP: "2.2.2.2", + UserAgent: "ua2", + APIKeyID: 200, + }, + } + + h1 := svc.GenerateSessionHash(parsed1) + h2 := svc.GenerateSessionHash(parsed2) + require.Equal(t, h1, h2, "cache_control ephemeral has higher priority, SessionContext should not affect result") +} + +// ============ 边界情况 ============ + +func TestGenerateSessionHash_EmptyMessages(t *testing.T) { + svc := &GatewayService{} + + parsed := &ParsedRequest{ + Messages: []any{}, + SessionContext: &SessionContext{ + ClientIP: "1.1.1.1", + UserAgent: "test", + APIKeyID: 1, + }, + } + + // 空 messages + 只有 SessionContext 时,combined.Len() > 0 因为有 context 写入 + h := svc.GenerateSessionHash(parsed) + require.NotEmpty(t, h, "empty messages with SessionContext should still produce a hash from context") +} + +func TestGenerateSessionHash_EmptyMessagesNoContext(t *testing.T) { + svc := &GatewayService{} + + parsed := &ParsedRequest{ + Messages: []any{}, + } + + h := svc.GenerateSessionHash(parsed) + require.Empty(t, h, "empty messages without SessionContext should produce empty hash") +} + +func TestGenerateSessionHash_SessionContextWithEmptyFields(t *testing.T) { + svc := &GatewayService{} + + // SessionContext 字段为空字符串和零值时仍应影响 hash + withEmptyCtx := &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "test"}, + }, + SessionContext: &SessionContext{ + ClientIP: "", + UserAgent: "", + APIKeyID: 0, + }, + } + withoutCtx := &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "test"}, + }, + } + + h1 := svc.GenerateSessionHash(withEmptyCtx) + h2 := svc.GenerateSessionHash(withoutCtx) + // 有 SessionContext(即使字段为空)仍然会写入分隔符 "::" 等 + require.NotEqual(t, h1, h2, "empty-field SessionContext should still differ from nil SessionContext") +} + +// ============ 长对话历史测试 ============ + +func TestGenerateSessionHash_LongConversation(t *testing.T) { + svc := &GatewayService{} + + ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "test", APIKeyID: 1} + + // 构建 20 轮对话 + messages := make([]any, 0, 40) + for i := 0; i < 20; i++ { + messages = append(messages, map[string]any{ + "role": "user", + "content": "user message " + string(rune('A'+i)), + }) + messages = append(messages, map[string]any{ + "role": "assistant", + "content": "assistant reply " + string(rune('A'+i)), + }) + } + + parsed := &ParsedRequest{ + System: "System prompt", + HasSystem: true, + Messages: messages, + SessionContext: ctx, + } + + h := svc.GenerateSessionHash(parsed) + require.NotEmpty(t, h) + + // 再加一轮应该不同 + moreMessages := make([]any, len(messages)+2) + copy(moreMessages, messages) + moreMessages[len(messages)] = map[string]any{"role": "user", "content": "one more"} + moreMessages[len(messages)+1] = map[string]any{"role": "assistant", "content": "ok"} + + parsed2 := &ParsedRequest{ + System: "System prompt", + HasSystem: true, + Messages: moreMessages, + SessionContext: ctx, + } + + h2 := svc.GenerateSessionHash(parsed2) + require.NotEqual(t, h, h2, "adding more messages to long conversation should change hash") +} + +// ============ Gemini 原生格式 session hash 测试 ============ + +func TestGenerateSessionHash_GeminiContentsProducesHash(t *testing.T) { + svc := &GatewayService{} + + // Gemini 格式: contents[].parts[].text + parsed := &ParsedRequest{ + Messages: []any{ + map[string]any{ + "role": "user", + "parts": []any{ + map[string]any{"text": "Hello from Gemini"}, + }, + }, + }, + SessionContext: &SessionContext{ + ClientIP: "1.2.3.4", + UserAgent: "gemini-cli", + APIKeyID: 1, + }, + } + + h := svc.GenerateSessionHash(parsed) + require.NotEmpty(t, h, "Gemini contents with parts should produce a non-empty hash") +} + +func TestGenerateSessionHash_GeminiDifferentContentsDifferentHash(t *testing.T) { + svc := &GatewayService{} + + ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "gemini-cli", APIKeyID: 1} + + parsed1 := &ParsedRequest{ + Messages: []any{ + map[string]any{ + "role": "user", + "parts": []any{ + map[string]any{"text": "Hello"}, + }, + }, + }, + SessionContext: ctx, + } + parsed2 := &ParsedRequest{ + Messages: []any{ + map[string]any{ + "role": "user", + "parts": []any{ + map[string]any{"text": "Goodbye"}, + }, + }, + }, + SessionContext: ctx, + } + + h1 := svc.GenerateSessionHash(parsed1) + h2 := svc.GenerateSessionHash(parsed2) + require.NotEqual(t, h1, h2, "different Gemini contents should produce different hashes") +} + +func TestGenerateSessionHash_GeminiSameContentsSameHash(t *testing.T) { + svc := &GatewayService{} + + ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "gemini-cli", APIKeyID: 1} + + mk := func() *ParsedRequest { + return &ParsedRequest{ + Messages: []any{ + map[string]any{ + "role": "user", + "parts": []any{ + map[string]any{"text": "Hello"}, + }, + }, + map[string]any{ + "role": "model", + "parts": []any{ + map[string]any{"text": "Hi there!"}, + }, + }, + }, + SessionContext: ctx, + } + } + + h1 := svc.GenerateSessionHash(mk()) + h2 := svc.GenerateSessionHash(mk()) + require.Equal(t, h1, h2, "same Gemini contents should produce identical hash") +} + +func TestGenerateSessionHash_GeminiMultiTurnHashChanges(t *testing.T) { + svc := &GatewayService{} + + ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "gemini-cli", APIKeyID: 1} + + round1 := &ParsedRequest{ + Messages: []any{ + map[string]any{ + "role": "user", + "parts": []any{map[string]any{"text": "hello"}}, + }, + }, + SessionContext: ctx, + } + + round2 := &ParsedRequest{ + Messages: []any{ + map[string]any{ + "role": "user", + "parts": []any{map[string]any{"text": "hello"}}, + }, + map[string]any{ + "role": "model", + "parts": []any{map[string]any{"text": "Hi!"}}, + }, + map[string]any{ + "role": "user", + "parts": []any{map[string]any{"text": "How are you?"}}, + }, + }, + SessionContext: ctx, + } + + h1 := svc.GenerateSessionHash(round1) + h2 := svc.GenerateSessionHash(round2) + require.NotEmpty(t, h1) + require.NotEmpty(t, h2) + require.NotEqual(t, h1, h2, "Gemini multi-turn should produce different hashes per round") +} + +func TestGenerateSessionHash_GeminiDifferentUsersSameContentDifferentHash(t *testing.T) { + svc := &GatewayService{} + + // 核心场景:两个不同用户发送相同 Gemini 格式消息应得到不同 hash + user1 := &ParsedRequest{ + Messages: []any{ + map[string]any{ + "role": "user", + "parts": []any{map[string]any{"text": "hello"}}, + }, + }, + SessionContext: &SessionContext{ + ClientIP: "1.1.1.1", + UserAgent: "gemini-cli", + APIKeyID: 10, + }, + } + user2 := &ParsedRequest{ + Messages: []any{ + map[string]any{ + "role": "user", + "parts": []any{map[string]any{"text": "hello"}}, + }, + }, + SessionContext: &SessionContext{ + ClientIP: "2.2.2.2", + UserAgent: "gemini-cli", + APIKeyID: 20, + }, + } + + h1 := svc.GenerateSessionHash(user1) + h2 := svc.GenerateSessionHash(user2) + require.NotEqual(t, h1, h2, "CRITICAL: different Gemini users with same content must get different hashes") +} + +func TestGenerateSessionHash_GeminiSystemInstructionAffectsHash(t *testing.T) { + svc := &GatewayService{} + + ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "gemini-cli", APIKeyID: 1} + + // systemInstruction 经 ParseGatewayRequest 解析后存入 parsed.System + withSys := &ParsedRequest{ + System: []any{ + map[string]any{"text": "You are a coding assistant."}, + }, + Messages: []any{ + map[string]any{ + "role": "user", + "parts": []any{map[string]any{"text": "hello"}}, + }, + }, + SessionContext: ctx, + } + withoutSys := &ParsedRequest{ + Messages: []any{ + map[string]any{ + "role": "user", + "parts": []any{map[string]any{"text": "hello"}}, + }, + }, + SessionContext: ctx, + } + + h1 := svc.GenerateSessionHash(withSys) + h2 := svc.GenerateSessionHash(withoutSys) + require.NotEqual(t, h1, h2, "systemInstruction should affect the hash") +} + +func TestGenerateSessionHash_GeminiMultiPartMessage(t *testing.T) { + svc := &GatewayService{} + + ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "gemini-cli", APIKeyID: 1} + + // 多 parts 的消息 + parsed := &ParsedRequest{ + Messages: []any{ + map[string]any{ + "role": "user", + "parts": []any{ + map[string]any{"text": "Part 1"}, + map[string]any{"text": "Part 2"}, + map[string]any{"text": "Part 3"}, + }, + }, + }, + SessionContext: ctx, + } + + h := svc.GenerateSessionHash(parsed) + require.NotEmpty(t, h, "multi-part Gemini message should produce a hash") + + // 不同内容的多 parts + parsed2 := &ParsedRequest{ + Messages: []any{ + map[string]any{ + "role": "user", + "parts": []any{ + map[string]any{"text": "Part 1"}, + map[string]any{"text": "CHANGED"}, + map[string]any{"text": "Part 3"}, + }, + }, + }, + SessionContext: ctx, + } + + h2 := svc.GenerateSessionHash(parsed2) + require.NotEqual(t, h, h2, "changing a part should change the hash") +} + +func TestGenerateSessionHash_GeminiNonTextPartsIgnored(t *testing.T) { + svc := &GatewayService{} + + ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "gemini-cli", APIKeyID: 1} + + // 含非 text 类型 parts(如 inline_data),应被跳过但不报错 + parsed := &ParsedRequest{ + Messages: []any{ + map[string]any{ + "role": "user", + "parts": []any{ + map[string]any{"text": "Describe this image"}, + map[string]any{"inline_data": map[string]any{"mime_type": "image/png", "data": "base64..."}}, + }, + }, + }, + SessionContext: ctx, + } + + h := svc.GenerateSessionHash(parsed) + require.NotEmpty(t, h, "Gemini message with mixed parts should still produce a hash from text parts") +} + +func TestGenerateSessionHash_GeminiMultiTurnHashNotSticky(t *testing.T) { + svc := &GatewayService{} + + ctx := &SessionContext{ClientIP: "10.0.0.1", UserAgent: "gemini-cli", APIKeyID: 42} + + // 模拟同一 Gemini 会话的三轮请求,每轮 contents 累积增长。 + // 验证预期行为:每轮 hash 都不同,即 GenerateSessionHash 不具备跨轮粘性。 + // 这是 by-design 的——Gemini 的跨轮粘性由 Digest Fallback(BuildGeminiDigestChain)负责。 + round1Body := []byte(`{ + "systemInstruction": {"parts": [{"text": "You are a coding assistant."}]}, + "contents": [ + {"role": "user", "parts": [{"text": "Write a Go function"}]} + ] + }`) + round2Body := []byte(`{ + "systemInstruction": {"parts": [{"text": "You are a coding assistant."}]}, + "contents": [ + {"role": "user", "parts": [{"text": "Write a Go function"}]}, + {"role": "model", "parts": [{"text": "func hello() {}"}]}, + {"role": "user", "parts": [{"text": "Add error handling"}]} + ] + }`) + round3Body := []byte(`{ + "systemInstruction": {"parts": [{"text": "You are a coding assistant."}]}, + "contents": [ + {"role": "user", "parts": [{"text": "Write a Go function"}]}, + {"role": "model", "parts": [{"text": "func hello() {}"}]}, + {"role": "user", "parts": [{"text": "Add error handling"}]}, + {"role": "model", "parts": [{"text": "func hello() error { return nil }"}]}, + {"role": "user", "parts": [{"text": "Now add tests"}]} + ] + }`) + + hashes := make([]string, 3) + for i, body := range [][]byte{round1Body, round2Body, round3Body} { + parsed, err := ParseGatewayRequest(body, "gemini") + require.NoError(t, err) + parsed.SessionContext = ctx + hashes[i] = svc.GenerateSessionHash(parsed) + require.NotEmpty(t, hashes[i], "round %d hash should not be empty", i+1) + } + + // 每轮 hash 都不同——这是预期行为 + require.NotEqual(t, hashes[0], hashes[1], "round 1 vs 2 hash should differ (contents grow)") + require.NotEqual(t, hashes[1], hashes[2], "round 2 vs 3 hash should differ (contents grow)") + require.NotEqual(t, hashes[0], hashes[2], "round 1 vs 3 hash should differ") + + // 同一轮重试应产生相同 hash + parsed1Again, err := ParseGatewayRequest(round2Body, "gemini") + require.NoError(t, err) + parsed1Again.SessionContext = ctx + h2Again := svc.GenerateSessionHash(parsed1Again) + require.Equal(t, hashes[1], h2Again, "retry of same round should produce same hash") +} + +func TestGenerateSessionHash_GeminiEndToEnd(t *testing.T) { + svc := &GatewayService{} + + // 端到端测试:模拟 ParseGatewayRequest + GenerateSessionHash 完整流程 + body := []byte(`{ + "model": "gemini-2.5-pro", + "systemInstruction": { + "parts": [{"text": "You are a coding assistant."}] + }, + "contents": [ + {"role": "user", "parts": [{"text": "Write a Go function"}]}, + {"role": "model", "parts": [{"text": "Here is a function..."}]}, + {"role": "user", "parts": [{"text": "Now add error handling"}]} + ] + }`) + + parsed, err := ParseGatewayRequest(body, "gemini") + require.NoError(t, err) + parsed.SessionContext = &SessionContext{ + ClientIP: "10.0.0.1", + UserAgent: "gemini-cli/1.0", + APIKeyID: 42, + } + + h := svc.GenerateSessionHash(parsed) + require.NotEmpty(t, h, "end-to-end Gemini flow should produce a hash") + + // 同一请求再次解析应产生相同 hash + parsed2, err := ParseGatewayRequest(body, "gemini") + require.NoError(t, err) + parsed2.SessionContext = &SessionContext{ + ClientIP: "10.0.0.1", + UserAgent: "gemini-cli/1.0", + APIKeyID: 42, + } + + h2 := svc.GenerateSessionHash(parsed2) + require.Equal(t, h, h2, "same request should produce same hash") + + // 不同用户发送相同请求应产生不同 hash + parsed3, err := ParseGatewayRequest(body, "gemini") + require.NoError(t, err) + parsed3.SessionContext = &SessionContext{ + ClientIP: "10.0.0.2", + UserAgent: "gemini-cli/1.0", + APIKeyID: 99, + } + + h3 := svc.GenerateSessionHash(parsed3) + require.NotEqual(t, h, h3, "different user with same Gemini request should get different hash") +} diff --git a/backend/internal/service/model_rate_limit_test.go b/backend/internal/service/model_rate_limit_test.go index a51e6909..b79b9688 100644 --- a/backend/internal/service/model_rate_limit_test.go +++ b/backend/internal/service/model_rate_limit_test.go @@ -318,110 +318,6 @@ func TestGetModelRateLimitRemainingTime(t *testing.T) { } } -func TestGetQuotaScopeRateLimitRemainingTime(t *testing.T) { - now := time.Now() - future10m := now.Add(10 * time.Minute).Format(time.RFC3339) - past := now.Add(-10 * time.Minute).Format(time.RFC3339) - - tests := []struct { - name string - account *Account - requestedModel string - minExpected time.Duration - maxExpected time.Duration - }{ - { - name: "nil account", - account: nil, - requestedModel: "claude-sonnet-4-5", - minExpected: 0, - maxExpected: 0, - }, - { - name: "non-antigravity platform", - account: &Account{ - Platform: PlatformAnthropic, - Extra: map[string]any{ - antigravityQuotaScopesKey: map[string]any{ - "claude": map[string]any{ - "rate_limit_reset_at": future10m, - }, - }, - }, - }, - requestedModel: "claude-sonnet-4-5", - minExpected: 0, - maxExpected: 0, - }, - { - name: "claude scope rate limited", - account: &Account{ - Platform: PlatformAntigravity, - Extra: map[string]any{ - antigravityQuotaScopesKey: map[string]any{ - "claude": map[string]any{ - "rate_limit_reset_at": future10m, - }, - }, - }, - }, - requestedModel: "claude-sonnet-4-5", - minExpected: 9 * time.Minute, - maxExpected: 11 * time.Minute, - }, - { - name: "gemini_text scope rate limited", - account: &Account{ - Platform: PlatformAntigravity, - Extra: map[string]any{ - antigravityQuotaScopesKey: map[string]any{ - "gemini_text": map[string]any{ - "rate_limit_reset_at": future10m, - }, - }, - }, - }, - requestedModel: "gemini-3-flash", - minExpected: 9 * time.Minute, - maxExpected: 11 * time.Minute, - }, - { - name: "expired scope rate limit", - account: &Account{ - Platform: PlatformAntigravity, - Extra: map[string]any{ - antigravityQuotaScopesKey: map[string]any{ - "claude": map[string]any{ - "rate_limit_reset_at": past, - }, - }, - }, - }, - requestedModel: "claude-sonnet-4-5", - minExpected: 0, - maxExpected: 0, - }, - { - name: "unsupported model", - account: &Account{ - Platform: PlatformAntigravity, - }, - requestedModel: "gpt-4", - minExpected: 0, - maxExpected: 0, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := tt.account.GetQuotaScopeRateLimitRemainingTime(tt.requestedModel) - if result < tt.minExpected || result > tt.maxExpected { - t.Errorf("GetQuotaScopeRateLimitRemainingTime() = %v, want between %v and %v", result, tt.minExpected, tt.maxExpected) - } - }) - } -} - func TestGetRateLimitRemainingTime(t *testing.T) { now := time.Now() future15m := now.Add(15 * time.Minute).Format(time.RFC3339) @@ -442,45 +338,19 @@ func TestGetRateLimitRemainingTime(t *testing.T) { maxExpected: 0, }, { - name: "model remaining > scope remaining - returns model", + name: "model rate limited - 15 minutes", account: &Account{ Platform: PlatformAntigravity, Extra: map[string]any{ modelRateLimitsKey: map[string]any{ "claude-sonnet-4-5": map[string]any{ - "rate_limit_reset_at": future15m, // 15 分钟 - }, - }, - antigravityQuotaScopesKey: map[string]any{ - "claude": map[string]any{ - "rate_limit_reset_at": future5m, // 5 分钟 + "rate_limit_reset_at": future15m, }, }, }, }, requestedModel: "claude-sonnet-4-5", - minExpected: 14 * time.Minute, // 应返回较大的 15 分钟 - maxExpected: 16 * time.Minute, - }, - { - name: "scope remaining > model remaining - returns scope", - account: &Account{ - Platform: PlatformAntigravity, - Extra: map[string]any{ - modelRateLimitsKey: map[string]any{ - "claude-sonnet-4-5": map[string]any{ - "rate_limit_reset_at": future5m, // 5 分钟 - }, - }, - antigravityQuotaScopesKey: map[string]any{ - "claude": map[string]any{ - "rate_limit_reset_at": future15m, // 15 分钟 - }, - }, - }, - }, - requestedModel: "claude-sonnet-4-5", - minExpected: 14 * time.Minute, // 应返回较大的 15 分钟 + minExpected: 14 * time.Minute, maxExpected: 16 * time.Minute, }, { @@ -499,22 +369,6 @@ func TestGetRateLimitRemainingTime(t *testing.T) { minExpected: 4 * time.Minute, maxExpected: 6 * time.Minute, }, - { - name: "only scope rate limited", - account: &Account{ - Platform: PlatformAntigravity, - Extra: map[string]any{ - antigravityQuotaScopesKey: map[string]any{ - "claude": map[string]any{ - "rate_limit_reset_at": future5m, - }, - }, - }, - }, - requestedModel: "claude-sonnet-4-5", - minExpected: 4 * time.Minute, - maxExpected: 6 * time.Minute, - }, { name: "neither rate limited", account: &Account{ diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index fbe81cb4..6c4fe256 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -580,10 +580,6 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex } } } else { - type accountWithLoad struct { - account *Account - loadInfo *AccountLoadInfo - } var available []accountWithLoad for _, acc := range candidates { loadInfo := loadMap[acc.ID] @@ -618,6 +614,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex return a.account.LastUsedAt.Before(*b.account.LastUsedAt) } }) + shuffleWithinSortGroups(available) for _, item := range available { result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency) diff --git a/backend/internal/service/openai_gateway_service_test.go b/backend/internal/service/openai_gateway_service_test.go index 159b0afb..ae69a986 100644 --- a/backend/internal/service/openai_gateway_service_test.go +++ b/backend/internal/service/openai_gateway_service_test.go @@ -204,30 +204,6 @@ func (c *stubGatewayCache) DeleteSessionAccountID(ctx context.Context, groupID i return nil } -func (c *stubGatewayCache) IncrModelCallCount(ctx context.Context, accountID int64, model string) (int64, error) { - return 0, nil -} - -func (c *stubGatewayCache) GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*ModelLoadInfo, error) { - return nil, nil -} - -func (c *stubGatewayCache) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) { - return "", 0, false -} - -func (c *stubGatewayCache) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error { - return nil -} - -func (c *stubGatewayCache) FindAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) { - return "", 0, false -} - -func (c *stubGatewayCache) SaveAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error { - return nil -} - func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulable(t *testing.T) { now := time.Now() resetAt := now.Add(10 * time.Minute) diff --git a/backend/internal/service/ops_account_availability.go b/backend/internal/service/ops_account_availability.go index a649e7b5..da66ec4d 100644 --- a/backend/internal/service/ops_account_availability.go +++ b/backend/internal/service/ops_account_availability.go @@ -66,7 +66,6 @@ func (s *OpsService) GetAccountAvailabilityStats(ctx context.Context, platformFi } isAvailable := acc.Status == StatusActive && acc.Schedulable && !isRateLimited && !isOverloaded && !isTempUnsched - scopeRateLimits := acc.GetAntigravityScopeRateLimits() if acc.Platform != "" { if _, ok := platform[acc.Platform]; !ok { @@ -85,14 +84,6 @@ func (s *OpsService) GetAccountAvailabilityStats(ctx context.Context, platformFi if hasError { p.ErrorCount++ } - if len(scopeRateLimits) > 0 { - if p.ScopeRateLimitCount == nil { - p.ScopeRateLimitCount = make(map[string]int64) - } - for scope := range scopeRateLimits { - p.ScopeRateLimitCount[scope]++ - } - } } for _, grp := range acc.Groups { @@ -117,14 +108,6 @@ func (s *OpsService) GetAccountAvailabilityStats(ctx context.Context, platformFi if hasError { g.ErrorCount++ } - if len(scopeRateLimits) > 0 { - if g.ScopeRateLimitCount == nil { - g.ScopeRateLimitCount = make(map[string]int64) - } - for scope := range scopeRateLimits { - g.ScopeRateLimitCount[scope]++ - } - } } displayGroupID := int64(0) @@ -157,9 +140,6 @@ func (s *OpsService) GetAccountAvailabilityStats(ctx context.Context, platformFi item.RateLimitRemainingSec = &remainingSec } } - if len(scopeRateLimits) > 0 { - item.ScopeRateLimits = scopeRateLimits - } if isOverloaded && acc.OverloadUntil != nil { item.OverloadUntil = acc.OverloadUntil remainingSec := int64(time.Until(*acc.OverloadUntil).Seconds()) diff --git a/backend/internal/service/ops_realtime_models.go b/backend/internal/service/ops_realtime_models.go index 33029f59..a19ab355 100644 --- a/backend/internal/service/ops_realtime_models.go +++ b/backend/internal/service/ops_realtime_models.go @@ -50,24 +50,22 @@ type UserConcurrencyInfo struct { // PlatformAvailability aggregates account availability by platform. type PlatformAvailability struct { - Platform string `json:"platform"` - TotalAccounts int64 `json:"total_accounts"` - AvailableCount int64 `json:"available_count"` - RateLimitCount int64 `json:"rate_limit_count"` - ScopeRateLimitCount map[string]int64 `json:"scope_rate_limit_count,omitempty"` - ErrorCount int64 `json:"error_count"` + Platform string `json:"platform"` + TotalAccounts int64 `json:"total_accounts"` + AvailableCount int64 `json:"available_count"` + RateLimitCount int64 `json:"rate_limit_count"` + ErrorCount int64 `json:"error_count"` } // GroupAvailability aggregates account availability by group. type GroupAvailability struct { - GroupID int64 `json:"group_id"` - GroupName string `json:"group_name"` - Platform string `json:"platform"` - TotalAccounts int64 `json:"total_accounts"` - AvailableCount int64 `json:"available_count"` - RateLimitCount int64 `json:"rate_limit_count"` - ScopeRateLimitCount map[string]int64 `json:"scope_rate_limit_count,omitempty"` - ErrorCount int64 `json:"error_count"` + GroupID int64 `json:"group_id"` + GroupName string `json:"group_name"` + Platform string `json:"platform"` + TotalAccounts int64 `json:"total_accounts"` + AvailableCount int64 `json:"available_count"` + RateLimitCount int64 `json:"rate_limit_count"` + ErrorCount int64 `json:"error_count"` } // AccountAvailability represents current availability for a single account. @@ -85,11 +83,10 @@ type AccountAvailability struct { IsOverloaded bool `json:"is_overloaded"` HasError bool `json:"has_error"` - RateLimitResetAt *time.Time `json:"rate_limit_reset_at"` - RateLimitRemainingSec *int64 `json:"rate_limit_remaining_sec"` - ScopeRateLimits map[string]int64 `json:"scope_rate_limits,omitempty"` - OverloadUntil *time.Time `json:"overload_until"` - OverloadRemainingSec *int64 `json:"overload_remaining_sec"` - ErrorMessage string `json:"error_message"` - TempUnschedulableUntil *time.Time `json:"temp_unschedulable_until,omitempty"` + RateLimitResetAt *time.Time `json:"rate_limit_reset_at"` + RateLimitRemainingSec *int64 `json:"rate_limit_remaining_sec"` + OverloadUntil *time.Time `json:"overload_until"` + OverloadRemainingSec *int64 `json:"overload_remaining_sec"` + ErrorMessage string `json:"error_message"` + TempUnschedulableUntil *time.Time `json:"temp_unschedulable_until,omitempty"` } diff --git a/backend/internal/service/ops_retry.go b/backend/internal/service/ops_retry.go index fbc800f2..23a524ad 100644 --- a/backend/internal/service/ops_retry.go +++ b/backend/internal/service/ops_retry.go @@ -12,6 +12,7 @@ import ( "strings" "time" + "github.com/Wei-Shaw/sub2api/internal/domain" "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/gin-gonic/gin" @@ -528,7 +529,7 @@ func (s *OpsService) selectAccountForRetry(ctx context.Context, reqType opsRetry func extractRetryModelAndStream(reqType opsRetryRequestType, errorLog *OpsErrorLogDetail, body []byte) (model string, stream bool, err error) { switch reqType { case opsRetryTypeMessages: - parsed, parseErr := ParseGatewayRequest(body) + parsed, parseErr := ParseGatewayRequest(body, domain.PlatformAnthropic) if parseErr != nil { return "", false, fmt.Errorf("failed to parse messages request body: %w", parseErr) } @@ -596,7 +597,7 @@ func (s *OpsService) executeWithAccount(ctx context.Context, reqType opsRetryReq if s.gatewayService == nil { return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "gateway service not available"} } - parsedReq, parseErr := ParseGatewayRequest(body) + parsedReq, parseErr := ParseGatewayRequest(body, domain.PlatformAnthropic) if parseErr != nil { return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "failed to parse request body"} } diff --git a/backend/internal/service/ratelimit_service.go b/backend/internal/service/ratelimit_service.go index 47286deb..63732dee 100644 --- a/backend/internal/service/ratelimit_service.go +++ b/backend/internal/service/ratelimit_service.go @@ -62,6 +62,32 @@ func (s *RateLimitService) SetTokenCacheInvalidator(invalidator TokenCacheInvali s.tokenCacheInvalidator = invalidator } +// ErrorPolicyResult 表示错误策略检查的结果 +type ErrorPolicyResult int + +const ( + ErrorPolicyNone ErrorPolicyResult = iota // 未命中任何策略,继续默认逻辑 + ErrorPolicySkipped // 自定义错误码开启但未命中,跳过处理 + ErrorPolicyMatched // 自定义错误码命中,应停止调度 + ErrorPolicyTempUnscheduled // 临时不可调度规则命中 +) + +// CheckErrorPolicy 检查自定义错误码和临时不可调度规则。 +// 自定义错误码开启时覆盖后续所有逻辑(包括临时不可调度)。 +func (s *RateLimitService) CheckErrorPolicy(ctx context.Context, account *Account, statusCode int, responseBody []byte) ErrorPolicyResult { + if account.IsCustomErrorCodesEnabled() { + if account.ShouldHandleErrorCode(statusCode) { + return ErrorPolicyMatched + } + slog.Info("account_error_code_skipped", "account_id", account.ID, "status_code", statusCode) + return ErrorPolicySkipped + } + if s.tryTempUnschedulable(ctx, account, statusCode, responseBody) { + return ErrorPolicyTempUnscheduled + } + return ErrorPolicyNone +} + // HandleUpstreamError 处理上游错误响应,标记账号状态 // 返回是否应该停止该账号的调度 func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, responseBody []byte) (shouldDisable bool) { diff --git a/backend/internal/service/scheduler_shuffle_test.go b/backend/internal/service/scheduler_shuffle_test.go new file mode 100644 index 00000000..78ac5f57 --- /dev/null +++ b/backend/internal/service/scheduler_shuffle_test.go @@ -0,0 +1,318 @@ +//go:build unit + +package service + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +// ============ shuffleWithinSortGroups 测试 ============ + +func TestShuffleWithinSortGroups_Empty(t *testing.T) { + shuffleWithinSortGroups(nil) + shuffleWithinSortGroups([]accountWithLoad{}) +} + +func TestShuffleWithinSortGroups_SingleElement(t *testing.T) { + accounts := []accountWithLoad{ + {account: &Account{ID: 1, Priority: 1}, loadInfo: &AccountLoadInfo{LoadRate: 10}}, + } + shuffleWithinSortGroups(accounts) + require.Equal(t, int64(1), accounts[0].account.ID) +} + +func TestShuffleWithinSortGroups_DifferentGroups_OrderPreserved(t *testing.T) { + now := time.Now() + earlier := now.Add(-1 * time.Hour) + + accounts := []accountWithLoad{ + {account: &Account{ID: 1, Priority: 1, LastUsedAt: &earlier}, loadInfo: &AccountLoadInfo{LoadRate: 10}}, + {account: &Account{ID: 2, Priority: 1, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{LoadRate: 20}}, + {account: &Account{ID: 3, Priority: 2, LastUsedAt: &earlier}, loadInfo: &AccountLoadInfo{LoadRate: 10}}, + } + + // 每个元素都属于不同组(Priority 或 LoadRate 或 LastUsedAt 不同),顺序不变 + for i := 0; i < 20; i++ { + cpy := make([]accountWithLoad, len(accounts)) + copy(cpy, accounts) + shuffleWithinSortGroups(cpy) + require.Equal(t, int64(1), cpy[0].account.ID) + require.Equal(t, int64(2), cpy[1].account.ID) + require.Equal(t, int64(3), cpy[2].account.ID) + } +} + +func TestShuffleWithinSortGroups_SameGroup_Shuffled(t *testing.T) { + now := time.Now() + // 同一秒的时间戳视为同一组 + sameSecond := time.Unix(now.Unix(), 0) + sameSecond2 := time.Unix(now.Unix(), 500_000_000) // 同一秒但不同纳秒 + + accounts := []accountWithLoad{ + {account: &Account{ID: 1, Priority: 1, LastUsedAt: &sameSecond}, loadInfo: &AccountLoadInfo{LoadRate: 10}}, + {account: &Account{ID: 2, Priority: 1, LastUsedAt: &sameSecond2}, loadInfo: &AccountLoadInfo{LoadRate: 10}}, + {account: &Account{ID: 3, Priority: 1, LastUsedAt: &sameSecond}, loadInfo: &AccountLoadInfo{LoadRate: 10}}, + } + + // 多次执行,验证所有 ID 都出现在第一个位置(说明确实被打乱了) + seen := map[int64]bool{} + for i := 0; i < 100; i++ { + cpy := make([]accountWithLoad, len(accounts)) + copy(cpy, accounts) + shuffleWithinSortGroups(cpy) + seen[cpy[0].account.ID] = true + // 无论怎么打乱,所有 ID 都应在候选中 + ids := map[int64]bool{} + for _, a := range cpy { + ids[a.account.ID] = true + } + require.True(t, ids[1] && ids[2] && ids[3]) + } + // 至少 2 个不同的 ID 出现在首位(随机性验证) + require.GreaterOrEqual(t, len(seen), 2, "shuffle should produce different orderings") +} + +func TestShuffleWithinSortGroups_NilLastUsedAt_SameGroup(t *testing.T) { + accounts := []accountWithLoad{ + {account: &Account{ID: 1, Priority: 1, LastUsedAt: nil}, loadInfo: &AccountLoadInfo{LoadRate: 0}}, + {account: &Account{ID: 2, Priority: 1, LastUsedAt: nil}, loadInfo: &AccountLoadInfo{LoadRate: 0}}, + {account: &Account{ID: 3, Priority: 1, LastUsedAt: nil}, loadInfo: &AccountLoadInfo{LoadRate: 0}}, + } + + seen := map[int64]bool{} + for i := 0; i < 100; i++ { + cpy := make([]accountWithLoad, len(accounts)) + copy(cpy, accounts) + shuffleWithinSortGroups(cpy) + seen[cpy[0].account.ID] = true + } + require.GreaterOrEqual(t, len(seen), 2, "nil LastUsedAt accounts should be shuffled") +} + +func TestShuffleWithinSortGroups_MixedGroups(t *testing.T) { + now := time.Now() + earlier := now.Add(-1 * time.Hour) + sameAsNow := time.Unix(now.Unix(), 0) + + // 组1: Priority=1, LoadRate=10, LastUsedAt=earlier (ID 1) — 单元素组 + // 组2: Priority=1, LoadRate=20, LastUsedAt=now (ID 2, 3) — 双元素组 + // 组3: Priority=2, LoadRate=10, LastUsedAt=earlier (ID 4) — 单元素组 + accounts := []accountWithLoad{ + {account: &Account{ID: 1, Priority: 1, LastUsedAt: &earlier}, loadInfo: &AccountLoadInfo{LoadRate: 10}}, + {account: &Account{ID: 2, Priority: 1, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{LoadRate: 20}}, + {account: &Account{ID: 3, Priority: 1, LastUsedAt: &sameAsNow}, loadInfo: &AccountLoadInfo{LoadRate: 20}}, + {account: &Account{ID: 4, Priority: 2, LastUsedAt: &earlier}, loadInfo: &AccountLoadInfo{LoadRate: 10}}, + } + + for i := 0; i < 20; i++ { + cpy := make([]accountWithLoad, len(accounts)) + copy(cpy, accounts) + shuffleWithinSortGroups(cpy) + + // 组间顺序不变 + require.Equal(t, int64(1), cpy[0].account.ID, "group 1 position fixed") + require.Equal(t, int64(4), cpy[3].account.ID, "group 3 position fixed") + + // 组2 内部可以打乱,但仍在位置 1 和 2 + mid := map[int64]bool{cpy[1].account.ID: true, cpy[2].account.ID: true} + require.True(t, mid[2] && mid[3], "group 2 elements should stay in positions 1-2") + } +} + +// ============ shuffleWithinPriorityAndLastUsed 测试 ============ + +func TestShuffleWithinPriorityAndLastUsed_Empty(t *testing.T) { + shuffleWithinPriorityAndLastUsed(nil) + shuffleWithinPriorityAndLastUsed([]*Account{}) +} + +func TestShuffleWithinPriorityAndLastUsed_SingleElement(t *testing.T) { + accounts := []*Account{{ID: 1, Priority: 1}} + shuffleWithinPriorityAndLastUsed(accounts) + require.Equal(t, int64(1), accounts[0].ID) +} + +func TestShuffleWithinPriorityAndLastUsed_SameGroup_Shuffled(t *testing.T) { + accounts := []*Account{ + {ID: 1, Priority: 1, LastUsedAt: nil}, + {ID: 2, Priority: 1, LastUsedAt: nil}, + {ID: 3, Priority: 1, LastUsedAt: nil}, + } + + seen := map[int64]bool{} + for i := 0; i < 100; i++ { + cpy := make([]*Account, len(accounts)) + copy(cpy, accounts) + shuffleWithinPriorityAndLastUsed(cpy) + seen[cpy[0].ID] = true + } + require.GreaterOrEqual(t, len(seen), 2, "same group should be shuffled") +} + +func TestShuffleWithinPriorityAndLastUsed_DifferentPriority_OrderPreserved(t *testing.T) { + accounts := []*Account{ + {ID: 1, Priority: 1, LastUsedAt: nil}, + {ID: 2, Priority: 2, LastUsedAt: nil}, + {ID: 3, Priority: 3, LastUsedAt: nil}, + } + + for i := 0; i < 20; i++ { + cpy := make([]*Account, len(accounts)) + copy(cpy, accounts) + shuffleWithinPriorityAndLastUsed(cpy) + require.Equal(t, int64(1), cpy[0].ID) + require.Equal(t, int64(2), cpy[1].ID) + require.Equal(t, int64(3), cpy[2].ID) + } +} + +func TestShuffleWithinPriorityAndLastUsed_DifferentLastUsedAt_OrderPreserved(t *testing.T) { + now := time.Now() + earlier := now.Add(-1 * time.Hour) + + accounts := []*Account{ + {ID: 1, Priority: 1, LastUsedAt: nil}, + {ID: 2, Priority: 1, LastUsedAt: &earlier}, + {ID: 3, Priority: 1, LastUsedAt: &now}, + } + + for i := 0; i < 20; i++ { + cpy := make([]*Account, len(accounts)) + copy(cpy, accounts) + shuffleWithinPriorityAndLastUsed(cpy) + require.Equal(t, int64(1), cpy[0].ID) + require.Equal(t, int64(2), cpy[1].ID) + require.Equal(t, int64(3), cpy[2].ID) + } +} + +// ============ sameLastUsedAt 测试 ============ + +func TestSameLastUsedAt(t *testing.T) { + now := time.Now() + sameSecond := time.Unix(now.Unix(), 0) + sameSecondDiffNano := time.Unix(now.Unix(), 999_999_999) + differentSecond := now.Add(1 * time.Second) + + t.Run("both nil", func(t *testing.T) { + require.True(t, sameLastUsedAt(nil, nil)) + }) + + t.Run("one nil one not", func(t *testing.T) { + require.False(t, sameLastUsedAt(nil, &now)) + require.False(t, sameLastUsedAt(&now, nil)) + }) + + t.Run("same second different nanoseconds", func(t *testing.T) { + require.True(t, sameLastUsedAt(&sameSecond, &sameSecondDiffNano)) + }) + + t.Run("different seconds", func(t *testing.T) { + require.False(t, sameLastUsedAt(&now, &differentSecond)) + }) + + t.Run("exact same time", func(t *testing.T) { + require.True(t, sameLastUsedAt(&now, &now)) + }) +} + +// ============ sameAccountWithLoadGroup 测试 ============ + +func TestSameAccountWithLoadGroup(t *testing.T) { + now := time.Now() + sameSecond := time.Unix(now.Unix(), 0) + + t.Run("same group", func(t *testing.T) { + a := accountWithLoad{account: &Account{Priority: 1, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{LoadRate: 10}} + b := accountWithLoad{account: &Account{Priority: 1, LastUsedAt: &sameSecond}, loadInfo: &AccountLoadInfo{LoadRate: 10}} + require.True(t, sameAccountWithLoadGroup(a, b)) + }) + + t.Run("different priority", func(t *testing.T) { + a := accountWithLoad{account: &Account{Priority: 1, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{LoadRate: 10}} + b := accountWithLoad{account: &Account{Priority: 2, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{LoadRate: 10}} + require.False(t, sameAccountWithLoadGroup(a, b)) + }) + + t.Run("different load rate", func(t *testing.T) { + a := accountWithLoad{account: &Account{Priority: 1, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{LoadRate: 10}} + b := accountWithLoad{account: &Account{Priority: 1, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{LoadRate: 20}} + require.False(t, sameAccountWithLoadGroup(a, b)) + }) + + t.Run("different last used at", func(t *testing.T) { + later := now.Add(1 * time.Second) + a := accountWithLoad{account: &Account{Priority: 1, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{LoadRate: 10}} + b := accountWithLoad{account: &Account{Priority: 1, LastUsedAt: &later}, loadInfo: &AccountLoadInfo{LoadRate: 10}} + require.False(t, sameAccountWithLoadGroup(a, b)) + }) + + t.Run("both nil LastUsedAt", func(t *testing.T) { + a := accountWithLoad{account: &Account{Priority: 1, LastUsedAt: nil}, loadInfo: &AccountLoadInfo{LoadRate: 0}} + b := accountWithLoad{account: &Account{Priority: 1, LastUsedAt: nil}, loadInfo: &AccountLoadInfo{LoadRate: 0}} + require.True(t, sameAccountWithLoadGroup(a, b)) + }) +} + +// ============ sameAccountGroup 测试 ============ + +func TestSameAccountGroup(t *testing.T) { + now := time.Now() + + t.Run("same group", func(t *testing.T) { + a := &Account{Priority: 1, LastUsedAt: nil} + b := &Account{Priority: 1, LastUsedAt: nil} + require.True(t, sameAccountGroup(a, b)) + }) + + t.Run("different priority", func(t *testing.T) { + a := &Account{Priority: 1, LastUsedAt: nil} + b := &Account{Priority: 2, LastUsedAt: nil} + require.False(t, sameAccountGroup(a, b)) + }) + + t.Run("different LastUsedAt", func(t *testing.T) { + later := now.Add(1 * time.Second) + a := &Account{Priority: 1, LastUsedAt: &now} + b := &Account{Priority: 1, LastUsedAt: &later} + require.False(t, sameAccountGroup(a, b)) + }) +} + +// ============ sortAccountsByPriorityAndLastUsed 集成随机化测试 ============ + +func TestSortAccountsByPriorityAndLastUsed_WithShuffle(t *testing.T) { + t.Run("same priority and nil LastUsedAt are shuffled", func(t *testing.T) { + accounts := []*Account{ + {ID: 1, Priority: 1, LastUsedAt: nil}, + {ID: 2, Priority: 1, LastUsedAt: nil}, + {ID: 3, Priority: 1, LastUsedAt: nil}, + } + + seen := map[int64]bool{} + for i := 0; i < 100; i++ { + cpy := make([]*Account, len(accounts)) + copy(cpy, accounts) + sortAccountsByPriorityAndLastUsed(cpy, false) + seen[cpy[0].ID] = true + } + require.GreaterOrEqual(t, len(seen), 2, "identical sort keys should produce different orderings after shuffle") + }) + + t.Run("different priorities still sorted correctly", func(t *testing.T) { + now := time.Now() + accounts := []*Account{ + {ID: 3, Priority: 3, LastUsedAt: &now}, + {ID: 1, Priority: 1, LastUsedAt: &now}, + {ID: 2, Priority: 2, LastUsedAt: &now}, + } + + sortAccountsByPriorityAndLastUsed(accounts, false) + require.Equal(t, int64(1), accounts[0].ID) + require.Equal(t, int64(2), accounts[1].ID) + require.Equal(t, int64(3), accounts[2].ID) + }) +} diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index 05371022..87ca7897 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -275,4 +275,5 @@ var ProviderSet = wire.NewSet( NewUsageCache, NewTotpService, NewErrorPassthroughService, + NewDigestSessionStore, ) diff --git a/deploy/docker-compose.yml b/deploy/docker-compose.yml index 033731ac..f1d19f84 100644 --- a/deploy/docker-compose.yml +++ b/deploy/docker-compose.yml @@ -47,13 +47,15 @@ services: # ======================================================================= # Database Configuration (PostgreSQL) + # Default: uses local postgres container + # External DB: set DATABASE_HOST and DATABASE_SSLMODE in .env # ======================================================================= - - DATABASE_HOST=postgres - - DATABASE_PORT=5432 + - DATABASE_HOST=${DATABASE_HOST:-postgres} + - DATABASE_PORT=${DATABASE_PORT:-5432} - DATABASE_USER=${POSTGRES_USER:-sub2api} - DATABASE_PASSWORD=${POSTGRES_PASSWORD:?POSTGRES_PASSWORD is required} - DATABASE_DBNAME=${POSTGRES_DB:-sub2api} - - DATABASE_SSLMODE=disable + - DATABASE_SSLMODE=${DATABASE_SSLMODE:-disable} # ======================================================================= # Redis Configuration @@ -128,8 +130,6 @@ services: # Examples: http://host:port, socks5://host:port - UPDATE_PROXY_URL=${UPDATE_PROXY_URL:-} depends_on: - postgres: - condition: service_healthy redis: condition: service_healthy networks: @@ -141,35 +141,6 @@ services: retries: 3 start_period: 30s - # =========================================================================== - # PostgreSQL Database - # =========================================================================== - postgres: - image: postgres:18-alpine - container_name: sub2api-postgres - restart: unless-stopped - ulimits: - nofile: - soft: 100000 - hard: 100000 - volumes: - - postgres_data:/var/lib/postgresql/data - environment: - - POSTGRES_USER=${POSTGRES_USER:-sub2api} - - POSTGRES_PASSWORD=${POSTGRES_PASSWORD:?POSTGRES_PASSWORD is required} - - POSTGRES_DB=${POSTGRES_DB:-sub2api} - - TZ=${TZ:-Asia/Shanghai} - networks: - - sub2api-network - healthcheck: - test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER:-sub2api} -d ${POSTGRES_DB:-sub2api}"] - interval: 10s - timeout: 5s - retries: 5 - start_period: 10s - # 注意:不暴露端口到宿主机,应用通过内部网络连接 - # 如需调试,可临时添加:ports: ["127.0.0.1:5433:5432"] - # =========================================================================== # Redis Cache # =========================================================================== @@ -209,8 +180,6 @@ services: volumes: sub2api_data: driver: local - postgres_data: - driver: local redis_data: driver: local diff --git a/frontend/public/wechat-qr.jpg b/frontend/public/wechat-qr.jpg new file mode 100644 index 00000000..659068d8 Binary files /dev/null and b/frontend/public/wechat-qr.jpg differ diff --git a/frontend/src/api/admin/ops.ts b/frontend/src/api/admin/ops.ts index 5b96feda..9f980a12 100644 --- a/frontend/src/api/admin/ops.ts +++ b/frontend/src/api/admin/ops.ts @@ -376,7 +376,6 @@ export interface PlatformAvailability { total_accounts: number available_count: number rate_limit_count: number - scope_rate_limit_count?: Record error_count: number } @@ -387,7 +386,6 @@ export interface GroupAvailability { total_accounts: number available_count: number rate_limit_count: number - scope_rate_limit_count?: Record error_count: number } @@ -402,7 +400,6 @@ export interface AccountAvailability { is_rate_limited: boolean rate_limit_reset_at?: string rate_limit_remaining_sec?: number - scope_rate_limits?: Record is_overloaded: boolean overload_until?: string overload_remaining_sec?: number diff --git a/frontend/src/components/account/AccountStatusIndicator.vue b/frontend/src/components/account/AccountStatusIndicator.vue index 3474da44..a7e0735b 100644 --- a/frontend/src/components/account/AccountStatusIndicator.vue +++ b/frontend/src/components/account/AccountStatusIndicator.vue @@ -76,26 +76,6 @@ - - - @@ -410,6 +534,7 @@ import { useI18n } from 'vue-i18n' import { useAuthStore, useAppStore } from '@/stores' import LocaleSwitcher from '@/components/common/LocaleSwitcher.vue' import Icon from '@/components/icons/Icon.vue' +import WechatServiceButton from '@/components/common/WechatServiceButton.vue' const { t } = useI18n() @@ -419,7 +544,6 @@ const appStore = useAppStore() // Site settings - directly from appStore (already initialized from injected config) const siteName = computed(() => appStore.cachedPublicSettings?.site_name || appStore.siteName || 'Sub2API') const siteLogo = computed(() => appStore.cachedPublicSettings?.site_logo || appStore.siteLogo || '') -const siteSubtitle = computed(() => appStore.cachedPublicSettings?.site_subtitle || 'AI API Gateway Platform') const docUrl = computed(() => appStore.cachedPublicSettings?.doc_url || appStore.docUrl || '') const homeContent = computed(() => appStore.cachedPublicSettings?.home_content || '') @@ -432,9 +556,6 @@ const isHomeContentUrl = computed(() => { // Theme const isDark = ref(document.documentElement.classList.contains('dark')) -// GitHub URL -const githubUrl = 'https://github.com/Wei-Shaw/sub2api' - // Auth state const isAuthenticated = computed(() => authStore.isAuthenticated) const isAdmin = computed(() => authStore.isAdmin) diff --git a/frontend/src/views/admin/ops/components/OpsConcurrencyCard.vue b/frontend/src/views/admin/ops/components/OpsConcurrencyCard.vue index c4e3b5b9..63dd1c9a 100644 --- a/frontend/src/views/admin/ops/components/OpsConcurrencyCard.vue +++ b/frontend/src/views/admin/ops/components/OpsConcurrencyCard.vue @@ -56,7 +56,6 @@ interface SummaryRow { total_accounts: number available_accounts: number rate_limited_accounts: number - scope_rate_limit_count?: Record error_accounts: number // 并发统计 total_concurrency: number @@ -122,7 +121,6 @@ const platformRows = computed((): SummaryRow[] => { total_accounts: totalAccounts, available_accounts: availableAccounts, rate_limited_accounts: safeNumber(avail.rate_limit_count), - scope_rate_limit_count: avail.scope_rate_limit_count, error_accounts: safeNumber(avail.error_count), total_concurrency: totalConcurrency, used_concurrency: usedConcurrency, @@ -162,7 +160,6 @@ const groupRows = computed((): SummaryRow[] => { total_accounts: totalAccounts, available_accounts: availableAccounts, rate_limited_accounts: safeNumber(avail.rate_limit_count), - scope_rate_limit_count: avail.scope_rate_limit_count, error_accounts: safeNumber(avail.error_count), total_concurrency: totalConcurrency, used_concurrency: usedConcurrency, @@ -329,15 +326,6 @@ function formatDuration(seconds: number): string { return `${hours}h` } -function formatScopeName(scope: string): string { - const names: Record = { - claude: 'Claude', - gemini_text: 'Gemini', - gemini_image: 'Image' - } - return names[scope] || scope -} - watch( () => realtimeEnabled.value, async (enabled) => { @@ -505,18 +493,6 @@ watch( {{ t('admin.ops.concurrency.rateLimited', { count: row.rate_limited_accounts }) }} - - - "$output_file" 2>&1 + + echo "[Session $session_id Round $round] 完成" +} + +# 会话1:数学计算器(累加序列) +run_session_1() { + local sys_prompt="你是一个数学计算器,只返回计算结果数字,不要任何解释" + + # Round 1: 1+1=? + send_request 1 1 "$sys_prompt" '[{"role":"user","parts":[{"text":"1+1=?"}]}]' + + # Round 2: 继续 2+2=?(累加历史) + send_request 1 2 "$sys_prompt" '[{"role":"user","parts":[{"text":"1+1=?"}]},{"role":"model","parts":[{"text":"2"}]},{"role":"user","parts":[{"text":"2+2=?"}]}]' + + # Round 3: 继续 3+3=? + send_request 1 3 "$sys_prompt" '[{"role":"user","parts":[{"text":"1+1=?"}]},{"role":"model","parts":[{"text":"2"}]},{"role":"user","parts":[{"text":"2+2=?"}]},{"role":"model","parts":[{"text":"4"}]},{"role":"user","parts":[{"text":"3+3=?"}]}]' + + # Round 4: 批量计算 10+10, 20+20, 30+30 + send_request 1 4 "$sys_prompt" '[{"role":"user","parts":[{"text":"1+1=?"}]},{"role":"model","parts":[{"text":"2"}]},{"role":"user","parts":[{"text":"2+2=?"}]},{"role":"model","parts":[{"text":"4"}]},{"role":"user","parts":[{"text":"3+3=?"}]},{"role":"model","parts":[{"text":"6"}]},{"role":"user","parts":[{"text":"计算: 10+10=? 20+20=? 30+30=?"}]}]' +} + +# 会话2:英文翻译器(不同系统提示词 = 不同会话) +run_session_2() { + local sys_prompt="你是一个英文翻译器,将中文翻译成英文,只返回翻译结果" + + send_request 2 1 "$sys_prompt" '[{"role":"user","parts":[{"text":"你好"}]}]' + send_request 2 2 "$sys_prompt" '[{"role":"user","parts":[{"text":"你好"}]},{"role":"model","parts":[{"text":"Hello"}]},{"role":"user","parts":[{"text":"世界"}]}]' + send_request 2 3 "$sys_prompt" '[{"role":"user","parts":[{"text":"你好"}]},{"role":"model","parts":[{"text":"Hello"}]},{"role":"user","parts":[{"text":"世界"}]},{"role":"model","parts":[{"text":"World"}]},{"role":"user","parts":[{"text":"早上好"}]}]' +} + +# 会话3:日文翻译器 +run_session_3() { + local sys_prompt="你是一个日文翻译器,将中文翻译成日文,只返回翻译结果" + + send_request 3 1 "$sys_prompt" '[{"role":"user","parts":[{"text":"你好"}]}]' + send_request 3 2 "$sys_prompt" '[{"role":"user","parts":[{"text":"你好"}]},{"role":"model","parts":[{"text":"こんにちは"}]},{"role":"user","parts":[{"text":"谢谢"}]}]' + send_request 3 3 "$sys_prompt" '[{"role":"user","parts":[{"text":"你好"}]},{"role":"model","parts":[{"text":"こんにちは"}]},{"role":"user","parts":[{"text":"谢谢"}]},{"role":"model","parts":[{"text":"ありがとう"}]},{"role":"user","parts":[{"text":"再见"}]}]' +} + +# 会话4:乘法计算器(另一个数学会话,但系统提示词不同) +run_session_4() { + local sys_prompt="你是一个乘法专用计算器,只计算乘法,返回数字结果" + + send_request 4 1 "$sys_prompt" '[{"role":"user","parts":[{"text":"2*3=?"}]}]' + send_request 4 2 "$sys_prompt" '[{"role":"user","parts":[{"text":"2*3=?"}]},{"role":"model","parts":[{"text":"6"}]},{"role":"user","parts":[{"text":"4*5=?"}]}]' + send_request 4 3 "$sys_prompt" '[{"role":"user","parts":[{"text":"2*3=?"}]},{"role":"model","parts":[{"text":"6"}]},{"role":"user","parts":[{"text":"4*5=?"}]},{"role":"model","parts":[{"text":"20"}]},{"role":"user","parts":[{"text":"计算: 10*10=? 20*20=?"}]}]' +} + +# 会话5:诗人(完全不同的角色) +run_session_5() { + local sys_prompt="你是一位诗人,用简短的诗句回应每个话题,每次只写一句诗" + + send_request 5 1 "$sys_prompt" '[{"role":"user","parts":[{"text":"春天"}]}]' + send_request 5 2 "$sys_prompt" '[{"role":"user","parts":[{"text":"春天"}]},{"role":"model","parts":[{"text":"春风拂面花满枝"}]},{"role":"user","parts":[{"text":"夏天"}]}]' + send_request 5 3 "$sys_prompt" '[{"role":"user","parts":[{"text":"春天"}]},{"role":"model","parts":[{"text":"春风拂面花满枝"}]},{"role":"user","parts":[{"text":"夏天"}]},{"role":"model","parts":[{"text":"蝉鸣蛙声伴荷香"}]},{"role":"user","parts":[{"text":"秋天"}]}]' +} + +echo "" +echo "开始并发测试 5 个独立会话..." +echo "" + +# 并发运行所有会话 +run_session_1 & +run_session_2 & +run_session_3 & +run_session_4 & +run_session_5 & + +# 等待所有后台任务完成 +wait + +echo "" +echo "==========================================" +echo "所有请求完成,结果保存在: $RESULT_DIR" +echo "==========================================" + +# 显示结果摘要 +echo "" +echo "响应摘要:" +for f in "$RESULT_DIR"/*.json; do + filename=$(basename "$f") + response=$(cat "$f" | head -c 200) + echo "[$filename]: ${response}..." +done + +echo "" +echo "请检查服务器日志确认账号分配情况"