mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-03 06:52:13 +08:00
feat: squash merge all changes from develop-0.1.75
Squash of 124 commits from the legacy develop branch (develop-0.1.75) onto a clean v0.1.75 upstream base, to simplify future upstream merges. Key changes included: - Refactor scope-level rate limiting to model-level rate limiting - Antigravity gateway service improvements (smart retry, error policy) - Digest session store (flat cache replacing Trie-based store) - Client disconnect detection during streaming - Gemini messages compatibility service enhancements - Scheduler shuffle for thundering herd prevention - Session hash generation improvements - Frontend customizations (WeChat service, HomeView, etc.) - Ops monitoring scope cleanup
This commit is contained in:
2
.github/workflows/backend-ci.yml
vendored
2
.github/workflows/backend-ci.yml
vendored
@@ -17,6 +17,7 @@ jobs:
|
|||||||
go-version-file: backend/go.mod
|
go-version-file: backend/go.mod
|
||||||
check-latest: false
|
check-latest: false
|
||||||
cache: true
|
cache: true
|
||||||
|
cache-dependency-path: backend/go.sum
|
||||||
- name: Verify Go version
|
- name: Verify Go version
|
||||||
run: |
|
run: |
|
||||||
go version | grep -q 'go1.25.7'
|
go version | grep -q 'go1.25.7'
|
||||||
@@ -36,6 +37,7 @@ jobs:
|
|||||||
go-version-file: backend/go.mod
|
go-version-file: backend/go.mod
|
||||||
check-latest: false
|
check-latest: false
|
||||||
cache: true
|
cache: true
|
||||||
|
cache-dependency-path: backend/go.sum
|
||||||
- name: Verify Go version
|
- name: Verify Go version
|
||||||
run: |
|
run: |
|
||||||
go version | grep -q 'go1.25.7'
|
go version | grep -q 'go1.25.7'
|
||||||
|
|||||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -78,6 +78,7 @@ Desktop.ini
|
|||||||
# ===================
|
# ===================
|
||||||
tmp/
|
tmp/
|
||||||
temp/
|
temp/
|
||||||
|
logs/
|
||||||
*.tmp
|
*.tmp
|
||||||
*.temp
|
*.temp
|
||||||
*.log
|
*.log
|
||||||
|
|||||||
723
AGENTS.md
Normal file
723
AGENTS.md
Normal file
@@ -0,0 +1,723 @@
|
|||||||
|
# Sub2API 开发说明
|
||||||
|
|
||||||
|
## 版本管理策略
|
||||||
|
|
||||||
|
### 版本号规则
|
||||||
|
|
||||||
|
我们在官方版本号后面添加自己的小版本号:
|
||||||
|
|
||||||
|
- 官方版本:`v0.1.68`
|
||||||
|
- 我们的版本:`v0.1.68.1`、`v0.1.68.2`(递增)
|
||||||
|
|
||||||
|
### 分支策略
|
||||||
|
|
||||||
|
| 分支 | 说明 |
|
||||||
|
|------|------|
|
||||||
|
| `main` | 我们的主分支,包含所有定制功能 |
|
||||||
|
| `release/custom-X.Y.Z` | 基于官方 `vX.Y.Z` 的发布分支 |
|
||||||
|
| `upstream/main` | 上游官方仓库 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 发布流程(基于新官方版本)
|
||||||
|
|
||||||
|
当官方发布新版本(如 `v0.1.69`)时:
|
||||||
|
|
||||||
|
### 1. 同步上游并创建发布分支
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 获取上游最新代码
|
||||||
|
git fetch upstream --tags
|
||||||
|
|
||||||
|
# 基于官方标签创建新的发布分支
|
||||||
|
git checkout v0.1.69 -b release/custom-0.1.69
|
||||||
|
|
||||||
|
# 合并我们的 main 分支(包含所有定制功能)
|
||||||
|
git merge main --no-edit
|
||||||
|
|
||||||
|
# 解决可能的冲突后继续
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. 更新版本号并打标签
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 更新版本号文件
|
||||||
|
echo "0.1.69.1" > backend/cmd/server/VERSION
|
||||||
|
git add backend/cmd/server/VERSION
|
||||||
|
git commit -m "chore: bump version to 0.1.69.1"
|
||||||
|
|
||||||
|
# 打上我们自己的标签
|
||||||
|
git tag v0.1.69.1
|
||||||
|
|
||||||
|
# 推送分支和标签
|
||||||
|
git push origin release/custom-0.1.69
|
||||||
|
git push origin v0.1.69.1
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. 更新 main 分支
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 将发布分支合并回 main,保持 main 包含最新定制功能
|
||||||
|
git checkout main
|
||||||
|
git merge release/custom-0.1.69
|
||||||
|
git push origin main
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 热修复发布(在现有版本上修复)
|
||||||
|
|
||||||
|
当需要在当前版本上发布修复时:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 在当前发布分支上修复
|
||||||
|
git checkout release/custom-0.1.68
|
||||||
|
# ... 进行修复 ...
|
||||||
|
git commit -m "fix: 修复描述"
|
||||||
|
|
||||||
|
# 递增小版本号
|
||||||
|
echo "0.1.68.2" > backend/cmd/server/VERSION
|
||||||
|
git add backend/cmd/server/VERSION
|
||||||
|
git commit -m "chore: bump version to 0.1.68.2"
|
||||||
|
|
||||||
|
# 打标签并推送
|
||||||
|
git tag v0.1.68.2
|
||||||
|
git push origin release/custom-0.1.68
|
||||||
|
git push origin v0.1.68.2
|
||||||
|
|
||||||
|
# 同步修复到 main
|
||||||
|
git checkout main
|
||||||
|
git cherry-pick <fix-commit-hash>
|
||||||
|
git push origin main
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 服务器部署流程
|
||||||
|
|
||||||
|
### 前置条件
|
||||||
|
|
||||||
|
- 本地已配置 SSH 别名 `clicodeplus` 连接到服务器
|
||||||
|
- 服务器部署目录:`/root/sub2api`(正式)、`/root/sub2api-beta`(测试)
|
||||||
|
- 服务器使用 Docker Compose 部署
|
||||||
|
|
||||||
|
### 部署环境说明
|
||||||
|
|
||||||
|
| 环境 | 目录 | 端口 | 数据库 | 容器名 |
|
||||||
|
|------|------|------|--------|--------|
|
||||||
|
| 正式 | `/root/sub2api` | 8080 | `sub2api` | `sub2api` |
|
||||||
|
| Beta | `/root/sub2api-beta` | 8084 | `beta` | `sub2api-beta` |
|
||||||
|
|
||||||
|
### 外部数据库
|
||||||
|
|
||||||
|
正式和 Beta 环境**共用外部 PostgreSQL 数据库**(非容器内数据库),配置在 `.env` 文件中:
|
||||||
|
- `DATABASE_HOST`:外部数据库地址
|
||||||
|
- `DATABASE_SSLMODE`:SSL 模式(通常为 `require`)
|
||||||
|
- `POSTGRES_USER` / `POSTGRES_DB`:用户名和数据库名
|
||||||
|
|
||||||
|
#### 数据库操作命令
|
||||||
|
|
||||||
|
通过 SSH 在服务器上执行数据库操作:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 正式环境 - 查询迁移记录
|
||||||
|
ssh clicodeplus "source /root/sub2api/deploy/.env && PGPASSWORD=\"\$POSTGRES_PASSWORD\" psql -h \$DATABASE_HOST -U \$POSTGRES_USER -d \$POSTGRES_DB -c 'SELECT * FROM schema_migrations ORDER BY applied_at DESC LIMIT 5;'"
|
||||||
|
|
||||||
|
# Beta 环境 - 查询迁移记录
|
||||||
|
ssh clicodeplus "source /root/sub2api-beta/deploy/.env && PGPASSWORD=\"\$POSTGRES_PASSWORD\" psql -h \$DATABASE_HOST -U \$POSTGRES_USER -d \$POSTGRES_DB -c 'SELECT * FROM schema_migrations ORDER BY applied_at DESC LIMIT 5;'"
|
||||||
|
|
||||||
|
# Beta 环境 - 清除指定迁移记录(重新执行迁移)
|
||||||
|
ssh clicodeplus "source /root/sub2api-beta/deploy/.env && PGPASSWORD=\"\$POSTGRES_PASSWORD\" psql -h \$DATABASE_HOST -U \$POSTGRES_USER -d \$POSTGRES_DB -c \"DELETE FROM schema_migrations WHERE filename LIKE '%049%';\""
|
||||||
|
|
||||||
|
# Beta 环境 - 更新账号数据
|
||||||
|
ssh clicodeplus "source /root/sub2api-beta/deploy/.env && PGPASSWORD=\"\$POSTGRES_PASSWORD\" psql -h \$DATABASE_HOST -U \$POSTGRES_USER -d \$POSTGRES_DB -c \"UPDATE accounts SET credentials = credentials - 'model_mapping' WHERE platform = 'antigravity';\""
|
||||||
|
```
|
||||||
|
|
||||||
|
> **注意**:使用 `source .env` 加载环境变量,避免在命令行中暴露密码。
|
||||||
|
|
||||||
|
### 部署步骤
|
||||||
|
|
||||||
|
**重要:每次部署都必须递增版本号!**
|
||||||
|
|
||||||
|
#### 0. 递增版本号(本地操作)
|
||||||
|
|
||||||
|
每次部署前,先在本地递增小版本号:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 查看当前版本号
|
||||||
|
cat backend/cmd/server/VERSION
|
||||||
|
# 假设当前是 0.1.69.1
|
||||||
|
|
||||||
|
# 递增版本号
|
||||||
|
echo "0.1.69.2" > backend/cmd/server/VERSION
|
||||||
|
git add backend/cmd/server/VERSION
|
||||||
|
git commit -m "chore: bump version to 0.1.69.2"
|
||||||
|
git push origin release/custom-0.1.69
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 1. 服务器拉取代码
|
||||||
|
|
||||||
|
```bash
|
||||||
|
ssh clicodeplus "cd /root/sub2api && git fetch fork && git checkout -B release/custom-0.1.69 fork/release/custom-0.1.69"
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 2. 服务器构建镜像
|
||||||
|
|
||||||
|
```bash
|
||||||
|
ssh clicodeplus "cd /root/sub2api && docker build --no-cache -t sub2api:latest -f Dockerfile ."
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 3. 更新镜像标签并重启服务
|
||||||
|
|
||||||
|
```bash
|
||||||
|
ssh clicodeplus "docker tag sub2api:latest weishaw/sub2api:latest"
|
||||||
|
ssh clicodeplus "cd /root/sub2api/deploy && docker compose up -d --force-recreate sub2api"
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 4. 验证部署
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 查看启动日志
|
||||||
|
ssh clicodeplus "docker logs sub2api --tail 20"
|
||||||
|
|
||||||
|
# 确认版本号(必须与步骤 0 中设置的版本号一致)
|
||||||
|
ssh clicodeplus "cat /root/sub2api/backend/cmd/server/VERSION"
|
||||||
|
|
||||||
|
# 检查容器状态
|
||||||
|
ssh clicodeplus "docker ps | grep sub2api"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Beta 并行部署(不影响现网)
|
||||||
|
|
||||||
|
目标:在同一台服务器上并行启动一个 beta 实例(例如端口 `8084`),**严禁改动/重启**现网实例(默认目录 `/root/sub2api`)。
|
||||||
|
|
||||||
|
### 设计原则
|
||||||
|
|
||||||
|
- **新目录**:beta 使用独立目录,例如 `/root/sub2api-beta`。
|
||||||
|
- **敏感信息只放 `.env`**:beta 的数据库密码、JWT_SECRET 等只写入 `/root/sub2api-beta/deploy/.env`,不要提交到 git。
|
||||||
|
- **独立 Compose Project**:通过 `docker compose -p sub2api-beta ...` 启动,确保 network/volume 隔离。
|
||||||
|
- **独立端口**:通过 `.env` 的 `SERVER_PORT` 映射宿主机端口(例如 `8084:8080`)。
|
||||||
|
|
||||||
|
### 前置检查
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 1) 确保 8084 未被占用
|
||||||
|
ssh clicodeplus "ss -ltnp | grep :8084 || echo '8084 is free'"
|
||||||
|
|
||||||
|
# 2) 确认现网容器还在(只读检查)
|
||||||
|
ssh clicodeplus "docker ps --format 'table {{.Names}}\t{{.Image}}\t{{.Ports}}' | sed -n '1,200p'"
|
||||||
|
```
|
||||||
|
|
||||||
|
### 首次部署步骤
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 0) 进入服务器
|
||||||
|
ssh clicodeplus
|
||||||
|
|
||||||
|
# 1) 克隆代码到新目录(示例使用你的 fork)
|
||||||
|
cd /root
|
||||||
|
git clone https://github.com/touwaeriol/sub2api.git sub2api-beta
|
||||||
|
cd /root/sub2api-beta
|
||||||
|
git checkout release/custom-0.1.71
|
||||||
|
|
||||||
|
# 2) 准备 beta 的 .env(敏感信息只写这里)
|
||||||
|
cd /root/sub2api-beta/deploy
|
||||||
|
|
||||||
|
# 推荐:从现网 .env 复制,保证除 DB 名/用户/端口外完全一致
|
||||||
|
cp -f /root/sub2api/deploy/.env ./.env
|
||||||
|
|
||||||
|
# 仅修改以下三项(其他保持不变)
|
||||||
|
perl -pi -e 's/^SERVER_PORT=.*/SERVER_PORT=8084/' ./.env
|
||||||
|
perl -pi -e 's/^POSTGRES_USER=.*/POSTGRES_USER=beta/' ./.env
|
||||||
|
perl -pi -e 's/^POSTGRES_DB=.*/POSTGRES_DB=beta/' ./.env
|
||||||
|
|
||||||
|
# 3) 写 compose override(避免与现网容器名冲突,镜像使用本地构建的 sub2api:beta)
|
||||||
|
cat > docker-compose.override.yml <<'YAML'
|
||||||
|
services:
|
||||||
|
sub2api:
|
||||||
|
image: sub2api:beta
|
||||||
|
container_name: sub2api-beta
|
||||||
|
redis:
|
||||||
|
container_name: sub2api-beta-redis
|
||||||
|
YAML
|
||||||
|
|
||||||
|
# 4) 构建 beta 镜像(基于当前代码)
|
||||||
|
cd /root/sub2api-beta
|
||||||
|
docker build -t sub2api:beta -f Dockerfile .
|
||||||
|
|
||||||
|
# 5) 启动 beta(独立 project,确保不影响现网)
|
||||||
|
cd /root/sub2api-beta/deploy
|
||||||
|
docker compose -p sub2api-beta --env-file .env -f docker-compose.yml -f docker-compose.override.yml up -d
|
||||||
|
|
||||||
|
# 6) 验证 beta
|
||||||
|
curl -fsS http://127.0.0.1:8084/health
|
||||||
|
docker logs sub2api-beta --tail 50
|
||||||
|
```
|
||||||
|
|
||||||
|
### 数据库配置约定(beta)
|
||||||
|
|
||||||
|
- 数据库地址/SSL/密码:与现网一致(从现网 `.env` 复制即可)。
|
||||||
|
- 仅修改:
|
||||||
|
- `POSTGRES_USER=beta`
|
||||||
|
- `POSTGRES_DB=beta`
|
||||||
|
|
||||||
|
注意:需要数据库侧已存在 `beta` 用户与 `beta` 数据库,并授予权限;否则容器会启动失败并不断重启。
|
||||||
|
|
||||||
|
### 更新 beta(拉代码 + 仅重建 beta 容器)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
ssh clicodeplus "set -e; cd /root/sub2api-beta && git fetch --all --tags && git checkout -f release/custom-0.1.71 && git reset --hard origin/release/custom-0.1.71"
|
||||||
|
ssh clicodeplus "cd /root/sub2api-beta && docker build -t sub2api:beta -f Dockerfile ."
|
||||||
|
ssh clicodeplus "cd /root/sub2api-beta/deploy && docker compose -p sub2api-beta --env-file .env -f docker-compose.yml -f docker-compose.override.yml up -d --no-deps --force-recreate sub2api"
|
||||||
|
ssh clicodeplus "curl -fsS http://127.0.0.1:8084/health"
|
||||||
|
```
|
||||||
|
|
||||||
|
### 停止/回滚 beta(只影响 beta)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
ssh clicodeplus "cd /root/sub2api-beta/deploy && docker compose -p sub2api-beta -f docker-compose.yml -f docker-compose.override.yml down"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 服务器首次部署
|
||||||
|
|
||||||
|
### 1. 克隆代码并配置远程仓库
|
||||||
|
|
||||||
|
```bash
|
||||||
|
ssh clicodeplus
|
||||||
|
cd /root
|
||||||
|
git clone https://github.com/Wei-Shaw/sub2api.git
|
||||||
|
cd sub2api
|
||||||
|
|
||||||
|
# 添加 fork 仓库
|
||||||
|
git remote add fork https://github.com/touwaeriol/sub2api.git
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. 切换到定制分支并配置环境
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git fetch fork
|
||||||
|
git checkout -B release/custom-0.1.69 fork/release/custom-0.1.69
|
||||||
|
|
||||||
|
cd deploy
|
||||||
|
cp .env.example .env
|
||||||
|
vim .env # 配置 DATABASE_URL, REDIS_URL, JWT_SECRET 等
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. 构建并启动
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd /root/sub2api
|
||||||
|
docker build -t sub2api:latest -f Dockerfile .
|
||||||
|
docker tag sub2api:latest weishaw/sub2api:latest
|
||||||
|
cd deploy && docker compose up -d
|
||||||
|
```
|
||||||
|
|
||||||
|
### 6. 启动服务
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 进入 deploy 目录
|
||||||
|
cd deploy
|
||||||
|
|
||||||
|
# 启动所有服务(PostgreSQL、Redis、sub2api)
|
||||||
|
docker compose up -d
|
||||||
|
|
||||||
|
# 查看服务状态
|
||||||
|
docker compose ps
|
||||||
|
```
|
||||||
|
|
||||||
|
### 7. 验证部署
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 查看应用日志
|
||||||
|
docker logs sub2api --tail 50
|
||||||
|
|
||||||
|
# 检查健康状态
|
||||||
|
curl http://localhost:8080/health
|
||||||
|
|
||||||
|
# 确认版本号
|
||||||
|
cat /root/sub2api/backend/cmd/server/VERSION
|
||||||
|
```
|
||||||
|
|
||||||
|
### 8. 常用运维命令
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 查看实时日志
|
||||||
|
docker logs -f sub2api
|
||||||
|
|
||||||
|
# 重启服务
|
||||||
|
docker compose restart sub2api
|
||||||
|
|
||||||
|
# 停止所有服务
|
||||||
|
docker compose down
|
||||||
|
|
||||||
|
# 停止并删除数据卷(慎用!会删除数据库数据)
|
||||||
|
docker compose down -v
|
||||||
|
|
||||||
|
# 查看资源使用情况
|
||||||
|
docker stats sub2api
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 定制功能说明
|
||||||
|
|
||||||
|
当前定制分支包含以下功能(相对于官方版本):
|
||||||
|
|
||||||
|
### UI/UX 定制
|
||||||
|
|
||||||
|
| 功能 | 说明 |
|
||||||
|
|------|------|
|
||||||
|
| 首页优化 | 面向用户的价值主张设计 |
|
||||||
|
| 移除 GitHub 链接 | 用户菜单中不显示 GitHub 导航 |
|
||||||
|
| 微信客服按钮 | 首页悬浮微信客服入口 |
|
||||||
|
| 限流时间精确显示 | 账号限流时间显示精确到秒 |
|
||||||
|
|
||||||
|
### Antigravity 平台增强
|
||||||
|
|
||||||
|
| 功能 | 说明 |
|
||||||
|
|------|------|
|
||||||
|
| Scope 级别限流 | 按配额域(claude/gemini_text/gemini_image)独立限流,避免整个账号被锁定 |
|
||||||
|
| 模型级别限流 | 按具体模型(如 claude-opus-4-5)独立限流,更精细的限流控制 |
|
||||||
|
| 限流预检查 | 调度时预检查账号/模型限流状态,避免选中已限流账号 |
|
||||||
|
| 秒级冷却时间 | 支持 429 响应的秒级精确冷却时间 |
|
||||||
|
| 身份注入优化 | 模型身份信息注入 + 静默边界防止身份泄露 |
|
||||||
|
| thoughtSignature 修复 | Gemini 3 函数调用 400 错误修复 |
|
||||||
|
| max_tokens 自动修正 | 自动修正 max_tokens <= budget_tokens 导致的 400 错误 |
|
||||||
|
|
||||||
|
### 调度算法优化
|
||||||
|
|
||||||
|
| 功能 | 说明 |
|
||||||
|
|------|------|
|
||||||
|
| 分层过滤选择 | 调度算法从全排序改为分层过滤,提升性能 |
|
||||||
|
| LRU 随机选择 | 相同 LRU 时间时随机选择,避免账号集中 |
|
||||||
|
| 限流等待阈值配置化 | 可配置的限流等待阈值 |
|
||||||
|
|
||||||
|
### 运维增强
|
||||||
|
|
||||||
|
| 功能 | 说明 |
|
||||||
|
|------|------|
|
||||||
|
| Scope 限流统计 | 运维界面展示 Antigravity 账号 scope 级别限流统计 |
|
||||||
|
| 账号限流状态显示 | 账号列表显示 scope 和模型级别限流状态 |
|
||||||
|
| 清除限流按钮增强 | 有 scope/模型限流时也显示清除限流按钮 |
|
||||||
|
|
||||||
|
### 其他修复
|
||||||
|
|
||||||
|
| 功能 | 说明 |
|
||||||
|
|------|------|
|
||||||
|
| .gitattributes | 确保迁移文件使用 LF 换行符(解决 Windows 下 SQL 摘要不一致) |
|
||||||
|
| 部署配置优化 | DATABASE_HOST 和 DATABASE_SSLMODE 可通过 .env 配置 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 注意事项
|
||||||
|
|
||||||
|
1. **前端必须打包进镜像**:使用 `docker build` 在服务器上构建,Dockerfile 会自动编译前端并 embed 到后端二进制中
|
||||||
|
|
||||||
|
2. **镜像标签**:docker-compose.yml 使用 `weishaw/sub2api:latest`,本地构建后需要 `docker tag` 覆盖
|
||||||
|
|
||||||
|
3. **Windows 换行符问题**:已通过 `.gitattributes` 解决,确保 `*.sql` 文件始终使用 LF
|
||||||
|
|
||||||
|
4. **版本号管理**:每次发布必须更新 `backend/cmd/server/VERSION` 并打标签
|
||||||
|
|
||||||
|
5. **合并冲突**:合并上游新版本时,重点关注以下文件可能的冲突:
|
||||||
|
- `backend/internal/service/antigravity_gateway_service.go`
|
||||||
|
- `backend/internal/service/gateway_service.go`
|
||||||
|
- `backend/internal/pkg/antigravity/request_transformer.go`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Go 代码规范
|
||||||
|
|
||||||
|
### 1. 函数设计
|
||||||
|
|
||||||
|
#### 单一职责原则
|
||||||
|
- **函数行数**:单个函数常规不应超过 **30 行**,超过时应拆分为子函数。若某段逻辑确实不可拆分(如复杂的状态机、协议解析等),可以例外,但需添加注释说明原因
|
||||||
|
- **嵌套层级**:避免超过 3 层嵌套,使用 early return 减少嵌套
|
||||||
|
|
||||||
|
```go
|
||||||
|
// ❌ 不推荐:深层嵌套
|
||||||
|
func process(data []Item) {
|
||||||
|
for _, item := range data {
|
||||||
|
if item.Valid {
|
||||||
|
if item.Type == "A" {
|
||||||
|
if item.Status == "active" {
|
||||||
|
// 业务逻辑...
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ✅ 推荐:early return
|
||||||
|
func process(data []Item) {
|
||||||
|
for _, item := range data {
|
||||||
|
if !item.Valid {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if item.Type != "A" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if item.Status != "active" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// 业务逻辑...
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 复杂逻辑提取
|
||||||
|
将复杂的条件判断或处理逻辑提取为独立函数:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// ❌ 不推荐:内联复杂逻辑
|
||||||
|
if resp.StatusCode == 429 || resp.StatusCode == 503 {
|
||||||
|
// 80+ 行处理逻辑...
|
||||||
|
}
|
||||||
|
|
||||||
|
// ✅ 推荐:提取为独立函数
|
||||||
|
result := handleRateLimitResponse(resp, params)
|
||||||
|
switch result.action {
|
||||||
|
case actionRetry:
|
||||||
|
continue
|
||||||
|
case actionBreak:
|
||||||
|
return result.resp, nil
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. 重复代码消除
|
||||||
|
|
||||||
|
#### 配置获取模式
|
||||||
|
将重复的配置获取逻辑提取为方法:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// ❌ 不推荐:重复代码
|
||||||
|
logBody := s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBody
|
||||||
|
maxBytes := 2048
|
||||||
|
if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes > 0 {
|
||||||
|
maxBytes = s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||||
|
}
|
||||||
|
|
||||||
|
// ✅ 推荐:提取为方法
|
||||||
|
func (s *Service) getLogConfig() (logBody bool, maxBytes int) {
|
||||||
|
maxBytes = 2048
|
||||||
|
if s.settingService == nil || s.settingService.cfg == nil {
|
||||||
|
return false, maxBytes
|
||||||
|
}
|
||||||
|
cfg := s.settingService.cfg.Gateway
|
||||||
|
if cfg.LogUpstreamErrorBodyMaxBytes > 0 {
|
||||||
|
maxBytes = cfg.LogUpstreamErrorBodyMaxBytes
|
||||||
|
}
|
||||||
|
return cfg.LogUpstreamErrorBody, maxBytes
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. 常量管理
|
||||||
|
|
||||||
|
#### 避免魔法数字
|
||||||
|
所有硬编码的数值都应定义为常量:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// ❌ 不推荐
|
||||||
|
if retryDelay >= 10*time.Second {
|
||||||
|
resetAt := time.Now().Add(30 * time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ✅ 推荐
|
||||||
|
const (
|
||||||
|
rateLimitThreshold = 10 * time.Second
|
||||||
|
defaultRateLimitDuration = 30 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
if retryDelay >= rateLimitThreshold {
|
||||||
|
resetAt := time.Now().Add(defaultRateLimitDuration)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 注释引用常量名
|
||||||
|
在注释中引用常量名而非硬编码值:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// ❌ 不推荐
|
||||||
|
// < 10s: 等待后重试
|
||||||
|
|
||||||
|
// ✅ 推荐
|
||||||
|
// < rateLimitThreshold: 等待后重试
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. 错误处理
|
||||||
|
|
||||||
|
#### 使用结构化日志
|
||||||
|
优先使用 `slog` 进行结构化日志记录:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// ❌ 不推荐
|
||||||
|
log.Printf("%s status=%d model_rate_limit_failed model=%s error=%v", prefix, statusCode, modelName, err)
|
||||||
|
|
||||||
|
// ✅ 推荐
|
||||||
|
slog.Error("failed to set model rate limit",
|
||||||
|
"prefix", prefix,
|
||||||
|
"status_code", statusCode,
|
||||||
|
"model", modelName,
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5. 测试规范
|
||||||
|
|
||||||
|
#### Mock 函数签名同步
|
||||||
|
修改函数签名时,必须同步更新所有测试中的 mock 函数:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// 如果修改了 handleError 签名
|
||||||
|
handleError func(..., groupID int64, sessionHash string) *Result
|
||||||
|
|
||||||
|
// 必须同步更新测试中的 mock
|
||||||
|
handleError: func(..., groupID int64, sessionHash string) *Result {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 测试构建标签
|
||||||
|
统一使用测试构建标签:
|
||||||
|
|
||||||
|
```go
|
||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
```
|
||||||
|
|
||||||
|
### 6. 时间格式解析
|
||||||
|
|
||||||
|
#### 使用标准库
|
||||||
|
优先使用 `time.ParseDuration`,支持所有 Go duration 格式:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// ❌ 不推荐:手动限制格式
|
||||||
|
if !strings.HasSuffix(delay, "s") || strings.Contains(delay, "m") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// ✅ 推荐:使用标准库
|
||||||
|
dur, err := time.ParseDuration(delay) // 支持 "0.5s", "4m50s", "1h30m" 等
|
||||||
|
```
|
||||||
|
|
||||||
|
### 7. 接口设计
|
||||||
|
|
||||||
|
#### 接口隔离原则
|
||||||
|
定义最小化接口,只包含必需的方法:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// ❌ 不推荐:使用过于宽泛的接口
|
||||||
|
type AccountRepository interface {
|
||||||
|
// 20+ 个方法...
|
||||||
|
}
|
||||||
|
|
||||||
|
// ✅ 推荐:定义最小化接口
|
||||||
|
type ModelRateLimiter interface {
|
||||||
|
SetModelRateLimit(ctx context.Context, id int64, modelKey string, resetAt time.Time) error
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 8. 并发安全
|
||||||
|
|
||||||
|
#### 共享数据保护
|
||||||
|
访问可能被并发修改的数据时,确保线程安全:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// 如果 Account.Extra 可能被并发修改
|
||||||
|
// 需要使用互斥锁或原子操作保护读取
|
||||||
|
func (a *Account) GetRateLimitRemainingTime(model string) time.Duration {
|
||||||
|
a.mu.RLock()
|
||||||
|
defer a.mu.RUnlock()
|
||||||
|
// 读取 Extra 字段...
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 9. 命名规范
|
||||||
|
|
||||||
|
#### 一致的命名风格
|
||||||
|
- 常量使用 camelCase:`rateLimitThreshold`
|
||||||
|
- 类型使用 PascalCase:`AntigravityQuotaScope`
|
||||||
|
- 同一概念使用统一命名:`Threshold` 或 `Limit`,不要混用
|
||||||
|
|
||||||
|
```go
|
||||||
|
// ❌ 不推荐:命名不一致
|
||||||
|
antigravitySmartRetryMinWait // 使用 Min
|
||||||
|
antigravityRateLimitThreshold // 使用 Threshold
|
||||||
|
|
||||||
|
// ✅ 推荐:统一风格
|
||||||
|
antigravityMinRetryWait
|
||||||
|
antigravityRateLimitThreshold
|
||||||
|
```
|
||||||
|
|
||||||
|
### 10. 代码审查清单
|
||||||
|
|
||||||
|
在提交代码前,检查以下项目:
|
||||||
|
|
||||||
|
- [ ] 函数是否超过 30 行?(不可拆分的逻辑除外,需注释说明)
|
||||||
|
- [ ] 嵌套是否超过 3 层?
|
||||||
|
- [ ] 是否有重复代码可以提取?
|
||||||
|
- [ ] 是否使用了魔法数字?
|
||||||
|
- [ ] Mock 函数签名是否与实际函数一致?
|
||||||
|
- [ ] 测试是否覆盖了新增逻辑?
|
||||||
|
- [ ] 日志是否包含足够的上下文信息?
|
||||||
|
- [ ] 是否考虑了并发安全?
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## CI 检查与发布门禁
|
||||||
|
|
||||||
|
### GitHub Actions 检查项
|
||||||
|
|
||||||
|
本项目有 4 个 CI 任务,**任何代码推送或发布前都必须全部通过**:
|
||||||
|
|
||||||
|
| Workflow | Job | 说明 | 本地验证命令 |
|
||||||
|
|----------|-----|------|-------------|
|
||||||
|
| CI | `test` | 单元测试 + 集成测试 | `cd backend && make test-unit && make test-integration` |
|
||||||
|
| CI | `golangci-lint` | Go 代码静态检查(golangci-lint v2.7) | `cd backend && golangci-lint run --timeout=5m` |
|
||||||
|
| Security Scan | `backend-security` | govulncheck + gosec 安全扫描 | `cd backend && govulncheck ./... && gosec -severity high -confidence high ./...` |
|
||||||
|
| Security Scan | `frontend-security` | pnpm audit 前端依赖安全检查 | `cd frontend && pnpm audit --prod --audit-level=high` |
|
||||||
|
|
||||||
|
### 向上游提交 PR
|
||||||
|
|
||||||
|
PR 目标是上游官方仓库,**只包含通用功能改动**(bug fix、新功能、性能优化等)。
|
||||||
|
|
||||||
|
**以下文件禁止出现在 PR 中**(属于我们 fork 的定制化内容):
|
||||||
|
- `CLAUDE.md`、`AGENTS.md` — 我们的开发文档
|
||||||
|
- `backend/cmd/server/VERSION` — 我们的版本号文件
|
||||||
|
- UI 定制改动(GitHub 链接移除、微信客服按钮、首页定制等)
|
||||||
|
- 部署配置(`deploy/` 目录下的定制修改)
|
||||||
|
|
||||||
|
**PR 流程**:
|
||||||
|
1. 从 `develop` 创建功能分支,只包含要提交给上游的改动
|
||||||
|
2. 推送分支后,**等待 4 个 CI job 全部通过**
|
||||||
|
3. 确认通过后再创建 PR
|
||||||
|
4. 使用 `gh run list --repo touwaeriol/sub2api --branch <branch>` 检查状态
|
||||||
|
|
||||||
|
### 自有分支推送(develop / main)
|
||||||
|
|
||||||
|
推送到我们自己的 `develop` 或 `main` 分支时,包含所有改动(定制化 + 通用功能)。
|
||||||
|
|
||||||
|
**推送流程**:
|
||||||
|
1. 本地运行 `cd backend && make test-unit` 确保单元测试通过
|
||||||
|
2. 本地运行 `cd backend && gofmt -l ./...` 确保格式正确
|
||||||
|
3. 推送后确认 CI 和 Security Scan 两个 workflow 的 4 个 job 全部绿色 ✅
|
||||||
|
4. 任何 job 失败必须立即修复,**禁止在 CI 未通过的状态下继续后续操作**
|
||||||
|
|
||||||
|
### 发布版本
|
||||||
|
|
||||||
|
1. 确保 `main` 分支最新提交的 4 个 CI job 全部通过
|
||||||
|
2. 递增 `backend/cmd/server/VERSION`,提交并推送
|
||||||
|
3. 打 tag 推送后,确认 tag 触发的 3 个 workflow(CI、Security Scan、Release)全部通过
|
||||||
|
4. **Release workflow 失败时禁止部署** — 必须先修复问题,删除旧 tag,重新打 tag
|
||||||
|
5. 使用 `gh run list --repo touwaeriol/sub2api --limit 10` 确认状态
|
||||||
|
|
||||||
|
### 常见 CI 失败原因及修复
|
||||||
|
- **gofmt**:struct 字段对齐不一致 → 运行 `gofmt -w <file>` 修复
|
||||||
|
- **golangci-lint**:未使用的变量/导入 → 删除或使用 `_` 忽略
|
||||||
|
- **test 失败**:mock 函数签名不一致 → 同步更新 mock
|
||||||
|
- **gosec**:安全漏洞 → 根据提示修复或添加例外
|
||||||
723
CLAUDE.md
Normal file
723
CLAUDE.md
Normal file
@@ -0,0 +1,723 @@
|
|||||||
|
# Sub2API 开发说明
|
||||||
|
|
||||||
|
## 版本管理策略
|
||||||
|
|
||||||
|
### 版本号规则
|
||||||
|
|
||||||
|
我们在官方版本号后面添加自己的小版本号:
|
||||||
|
|
||||||
|
- 官方版本:`v0.1.68`
|
||||||
|
- 我们的版本:`v0.1.68.1`、`v0.1.68.2`(递增)
|
||||||
|
|
||||||
|
### 分支策略
|
||||||
|
|
||||||
|
| 分支 | 说明 |
|
||||||
|
|------|------|
|
||||||
|
| `main` | 我们的主分支,包含所有定制功能 |
|
||||||
|
| `release/custom-X.Y.Z` | 基于官方 `vX.Y.Z` 的发布分支 |
|
||||||
|
| `upstream/main` | 上游官方仓库 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 发布流程(基于新官方版本)
|
||||||
|
|
||||||
|
当官方发布新版本(如 `v0.1.69`)时:
|
||||||
|
|
||||||
|
### 1. 同步上游并创建发布分支
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 获取上游最新代码
|
||||||
|
git fetch upstream --tags
|
||||||
|
|
||||||
|
# 基于官方标签创建新的发布分支
|
||||||
|
git checkout v0.1.69 -b release/custom-0.1.69
|
||||||
|
|
||||||
|
# 合并我们的 main 分支(包含所有定制功能)
|
||||||
|
git merge main --no-edit
|
||||||
|
|
||||||
|
# 解决可能的冲突后继续
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. 更新版本号并打标签
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 更新版本号文件
|
||||||
|
echo "0.1.69.1" > backend/cmd/server/VERSION
|
||||||
|
git add backend/cmd/server/VERSION
|
||||||
|
git commit -m "chore: bump version to 0.1.69.1"
|
||||||
|
|
||||||
|
# 打上我们自己的标签
|
||||||
|
git tag v0.1.69.1
|
||||||
|
|
||||||
|
# 推送分支和标签
|
||||||
|
git push origin release/custom-0.1.69
|
||||||
|
git push origin v0.1.69.1
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. 更新 main 分支
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 将发布分支合并回 main,保持 main 包含最新定制功能
|
||||||
|
git checkout main
|
||||||
|
git merge release/custom-0.1.69
|
||||||
|
git push origin main
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 热修复发布(在现有版本上修复)
|
||||||
|
|
||||||
|
当需要在当前版本上发布修复时:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 在当前发布分支上修复
|
||||||
|
git checkout release/custom-0.1.68
|
||||||
|
# ... 进行修复 ...
|
||||||
|
git commit -m "fix: 修复描述"
|
||||||
|
|
||||||
|
# 递增小版本号
|
||||||
|
echo "0.1.68.2" > backend/cmd/server/VERSION
|
||||||
|
git add backend/cmd/server/VERSION
|
||||||
|
git commit -m "chore: bump version to 0.1.68.2"
|
||||||
|
|
||||||
|
# 打标签并推送
|
||||||
|
git tag v0.1.68.2
|
||||||
|
git push origin release/custom-0.1.68
|
||||||
|
git push origin v0.1.68.2
|
||||||
|
|
||||||
|
# 同步修复到 main
|
||||||
|
git checkout main
|
||||||
|
git cherry-pick <fix-commit-hash>
|
||||||
|
git push origin main
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 服务器部署流程
|
||||||
|
|
||||||
|
### 前置条件
|
||||||
|
|
||||||
|
- 本地已配置 SSH 别名 `clicodeplus` 连接到服务器
|
||||||
|
- 服务器部署目录:`/root/sub2api`(正式)、`/root/sub2api-beta`(测试)
|
||||||
|
- 服务器使用 Docker Compose 部署
|
||||||
|
|
||||||
|
### 部署环境说明
|
||||||
|
|
||||||
|
| 环境 | 目录 | 端口 | 数据库 | 容器名 |
|
||||||
|
|------|------|------|--------|--------|
|
||||||
|
| 正式 | `/root/sub2api` | 8080 | `sub2api` | `sub2api` |
|
||||||
|
| Beta | `/root/sub2api-beta` | 8084 | `beta` | `sub2api-beta` |
|
||||||
|
|
||||||
|
### 外部数据库
|
||||||
|
|
||||||
|
正式和 Beta 环境**共用外部 PostgreSQL 数据库**(非容器内数据库),配置在 `.env` 文件中:
|
||||||
|
- `DATABASE_HOST`:外部数据库地址
|
||||||
|
- `DATABASE_SSLMODE`:SSL 模式(通常为 `require`)
|
||||||
|
- `POSTGRES_USER` / `POSTGRES_DB`:用户名和数据库名
|
||||||
|
|
||||||
|
#### 数据库操作命令
|
||||||
|
|
||||||
|
通过 SSH 在服务器上执行数据库操作:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 正式环境 - 查询迁移记录
|
||||||
|
ssh clicodeplus "source /root/sub2api/deploy/.env && PGPASSWORD=\"\$POSTGRES_PASSWORD\" psql -h \$DATABASE_HOST -U \$POSTGRES_USER -d \$POSTGRES_DB -c 'SELECT * FROM schema_migrations ORDER BY applied_at DESC LIMIT 5;'"
|
||||||
|
|
||||||
|
# Beta 环境 - 查询迁移记录
|
||||||
|
ssh clicodeplus "source /root/sub2api-beta/deploy/.env && PGPASSWORD=\"\$POSTGRES_PASSWORD\" psql -h \$DATABASE_HOST -U \$POSTGRES_USER -d \$POSTGRES_DB -c 'SELECT * FROM schema_migrations ORDER BY applied_at DESC LIMIT 5;'"
|
||||||
|
|
||||||
|
# Beta 环境 - 清除指定迁移记录(重新执行迁移)
|
||||||
|
ssh clicodeplus "source /root/sub2api-beta/deploy/.env && PGPASSWORD=\"\$POSTGRES_PASSWORD\" psql -h \$DATABASE_HOST -U \$POSTGRES_USER -d \$POSTGRES_DB -c \"DELETE FROM schema_migrations WHERE filename LIKE '%049%';\""
|
||||||
|
|
||||||
|
# Beta 环境 - 更新账号数据
|
||||||
|
ssh clicodeplus "source /root/sub2api-beta/deploy/.env && PGPASSWORD=\"\$POSTGRES_PASSWORD\" psql -h \$DATABASE_HOST -U \$POSTGRES_USER -d \$POSTGRES_DB -c \"UPDATE accounts SET credentials = credentials - 'model_mapping' WHERE platform = 'antigravity';\""
|
||||||
|
```
|
||||||
|
|
||||||
|
> **注意**:使用 `source .env` 加载环境变量,避免在命令行中暴露密码。
|
||||||
|
|
||||||
|
### 部署步骤
|
||||||
|
|
||||||
|
**重要:每次部署都必须递增版本号!**
|
||||||
|
|
||||||
|
#### 0. 递增版本号(本地操作)
|
||||||
|
|
||||||
|
每次部署前,先在本地递增小版本号:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 查看当前版本号
|
||||||
|
cat backend/cmd/server/VERSION
|
||||||
|
# 假设当前是 0.1.69.1
|
||||||
|
|
||||||
|
# 递增版本号
|
||||||
|
echo "0.1.69.2" > backend/cmd/server/VERSION
|
||||||
|
git add backend/cmd/server/VERSION
|
||||||
|
git commit -m "chore: bump version to 0.1.69.2"
|
||||||
|
git push origin release/custom-0.1.69
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 1. 服务器拉取代码
|
||||||
|
|
||||||
|
```bash
|
||||||
|
ssh clicodeplus "cd /root/sub2api && git fetch fork && git checkout -B release/custom-0.1.69 fork/release/custom-0.1.69"
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 2. 服务器构建镜像
|
||||||
|
|
||||||
|
```bash
|
||||||
|
ssh clicodeplus "cd /root/sub2api && docker build --no-cache -t sub2api:latest -f Dockerfile ."
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 3. 更新镜像标签并重启服务
|
||||||
|
|
||||||
|
```bash
|
||||||
|
ssh clicodeplus "docker tag sub2api:latest weishaw/sub2api:latest"
|
||||||
|
ssh clicodeplus "cd /root/sub2api/deploy && docker compose up -d --force-recreate sub2api"
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 4. 验证部署
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 查看启动日志
|
||||||
|
ssh clicodeplus "docker logs sub2api --tail 20"
|
||||||
|
|
||||||
|
# 确认版本号(必须与步骤 0 中设置的版本号一致)
|
||||||
|
ssh clicodeplus "cat /root/sub2api/backend/cmd/server/VERSION"
|
||||||
|
|
||||||
|
# 检查容器状态
|
||||||
|
ssh clicodeplus "docker ps | grep sub2api"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Beta 并行部署(不影响现网)
|
||||||
|
|
||||||
|
目标:在同一台服务器上并行启动一个 beta 实例(例如端口 `8084`),**严禁改动/重启**现网实例(默认目录 `/root/sub2api`)。
|
||||||
|
|
||||||
|
### 设计原则
|
||||||
|
|
||||||
|
- **新目录**:beta 使用独立目录,例如 `/root/sub2api-beta`。
|
||||||
|
- **敏感信息只放 `.env`**:beta 的数据库密码、JWT_SECRET 等只写入 `/root/sub2api-beta/deploy/.env`,不要提交到 git。
|
||||||
|
- **独立 Compose Project**:通过 `docker compose -p sub2api-beta ...` 启动,确保 network/volume 隔离。
|
||||||
|
- **独立端口**:通过 `.env` 的 `SERVER_PORT` 映射宿主机端口(例如 `8084:8080`)。
|
||||||
|
|
||||||
|
### 前置检查
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 1) 确保 8084 未被占用
|
||||||
|
ssh clicodeplus "ss -ltnp | grep :8084 || echo '8084 is free'"
|
||||||
|
|
||||||
|
# 2) 确认现网容器还在(只读检查)
|
||||||
|
ssh clicodeplus "docker ps --format 'table {{.Names}}\t{{.Image}}\t{{.Ports}}' | sed -n '1,200p'"
|
||||||
|
```
|
||||||
|
|
||||||
|
### 首次部署步骤
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 0) 进入服务器
|
||||||
|
ssh clicodeplus
|
||||||
|
|
||||||
|
# 1) 克隆代码到新目录(示例使用你的 fork)
|
||||||
|
cd /root
|
||||||
|
git clone https://github.com/touwaeriol/sub2api.git sub2api-beta
|
||||||
|
cd /root/sub2api-beta
|
||||||
|
git checkout release/custom-0.1.71
|
||||||
|
|
||||||
|
# 2) 准备 beta 的 .env(敏感信息只写这里)
|
||||||
|
cd /root/sub2api-beta/deploy
|
||||||
|
|
||||||
|
# 推荐:从现网 .env 复制,保证除 DB 名/用户/端口外完全一致
|
||||||
|
cp -f /root/sub2api/deploy/.env ./.env
|
||||||
|
|
||||||
|
# 仅修改以下三项(其他保持不变)
|
||||||
|
perl -pi -e 's/^SERVER_PORT=.*/SERVER_PORT=8084/' ./.env
|
||||||
|
perl -pi -e 's/^POSTGRES_USER=.*/POSTGRES_USER=beta/' ./.env
|
||||||
|
perl -pi -e 's/^POSTGRES_DB=.*/POSTGRES_DB=beta/' ./.env
|
||||||
|
|
||||||
|
# 3) 写 compose override(避免与现网容器名冲突,镜像使用本地构建的 sub2api:beta)
|
||||||
|
cat > docker-compose.override.yml <<'YAML'
|
||||||
|
services:
|
||||||
|
sub2api:
|
||||||
|
image: sub2api:beta
|
||||||
|
container_name: sub2api-beta
|
||||||
|
redis:
|
||||||
|
container_name: sub2api-beta-redis
|
||||||
|
YAML
|
||||||
|
|
||||||
|
# 4) 构建 beta 镜像(基于当前代码)
|
||||||
|
cd /root/sub2api-beta
|
||||||
|
docker build -t sub2api:beta -f Dockerfile .
|
||||||
|
|
||||||
|
# 5) 启动 beta(独立 project,确保不影响现网)
|
||||||
|
cd /root/sub2api-beta/deploy
|
||||||
|
docker compose -p sub2api-beta --env-file .env -f docker-compose.yml -f docker-compose.override.yml up -d
|
||||||
|
|
||||||
|
# 6) 验证 beta
|
||||||
|
curl -fsS http://127.0.0.1:8084/health
|
||||||
|
docker logs sub2api-beta --tail 50
|
||||||
|
```
|
||||||
|
|
||||||
|
### 数据库配置约定(beta)
|
||||||
|
|
||||||
|
- 数据库地址/SSL/密码:与现网一致(从现网 `.env` 复制即可)。
|
||||||
|
- 仅修改:
|
||||||
|
- `POSTGRES_USER=beta`
|
||||||
|
- `POSTGRES_DB=beta`
|
||||||
|
|
||||||
|
注意:需要数据库侧已存在 `beta` 用户与 `beta` 数据库,并授予权限;否则容器会启动失败并不断重启。
|
||||||
|
|
||||||
|
### 更新 beta(拉代码 + 仅重建 beta 容器)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
ssh clicodeplus "set -e; cd /root/sub2api-beta && git fetch --all --tags && git checkout -f release/custom-0.1.71 && git reset --hard origin/release/custom-0.1.71"
|
||||||
|
ssh clicodeplus "cd /root/sub2api-beta && docker build -t sub2api:beta -f Dockerfile ."
|
||||||
|
ssh clicodeplus "cd /root/sub2api-beta/deploy && docker compose -p sub2api-beta --env-file .env -f docker-compose.yml -f docker-compose.override.yml up -d --no-deps --force-recreate sub2api"
|
||||||
|
ssh clicodeplus "curl -fsS http://127.0.0.1:8084/health"
|
||||||
|
```
|
||||||
|
|
||||||
|
### 停止/回滚 beta(只影响 beta)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
ssh clicodeplus "cd /root/sub2api-beta/deploy && docker compose -p sub2api-beta -f docker-compose.yml -f docker-compose.override.yml down"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 服务器首次部署
|
||||||
|
|
||||||
|
### 1. 克隆代码并配置远程仓库
|
||||||
|
|
||||||
|
```bash
|
||||||
|
ssh clicodeplus
|
||||||
|
cd /root
|
||||||
|
git clone https://github.com/Wei-Shaw/sub2api.git
|
||||||
|
cd sub2api
|
||||||
|
|
||||||
|
# 添加 fork 仓库
|
||||||
|
git remote add fork https://github.com/touwaeriol/sub2api.git
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. 切换到定制分支并配置环境
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git fetch fork
|
||||||
|
git checkout -B release/custom-0.1.69 fork/release/custom-0.1.69
|
||||||
|
|
||||||
|
cd deploy
|
||||||
|
cp .env.example .env
|
||||||
|
vim .env # 配置 DATABASE_URL, REDIS_URL, JWT_SECRET 等
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. 构建并启动
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd /root/sub2api
|
||||||
|
docker build -t sub2api:latest -f Dockerfile .
|
||||||
|
docker tag sub2api:latest weishaw/sub2api:latest
|
||||||
|
cd deploy && docker compose up -d
|
||||||
|
```
|
||||||
|
|
||||||
|
### 6. 启动服务
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 进入 deploy 目录
|
||||||
|
cd deploy
|
||||||
|
|
||||||
|
# 启动所有服务(PostgreSQL、Redis、sub2api)
|
||||||
|
docker compose up -d
|
||||||
|
|
||||||
|
# 查看服务状态
|
||||||
|
docker compose ps
|
||||||
|
```
|
||||||
|
|
||||||
|
### 7. 验证部署
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 查看应用日志
|
||||||
|
docker logs sub2api --tail 50
|
||||||
|
|
||||||
|
# 检查健康状态
|
||||||
|
curl http://localhost:8080/health
|
||||||
|
|
||||||
|
# 确认版本号
|
||||||
|
cat /root/sub2api/backend/cmd/server/VERSION
|
||||||
|
```
|
||||||
|
|
||||||
|
### 8. 常用运维命令
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 查看实时日志
|
||||||
|
docker logs -f sub2api
|
||||||
|
|
||||||
|
# 重启服务
|
||||||
|
docker compose restart sub2api
|
||||||
|
|
||||||
|
# 停止所有服务
|
||||||
|
docker compose down
|
||||||
|
|
||||||
|
# 停止并删除数据卷(慎用!会删除数据库数据)
|
||||||
|
docker compose down -v
|
||||||
|
|
||||||
|
# 查看资源使用情况
|
||||||
|
docker stats sub2api
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 定制功能说明
|
||||||
|
|
||||||
|
当前定制分支包含以下功能(相对于官方版本):
|
||||||
|
|
||||||
|
### UI/UX 定制
|
||||||
|
|
||||||
|
| 功能 | 说明 |
|
||||||
|
|------|------|
|
||||||
|
| 首页优化 | 面向用户的价值主张设计 |
|
||||||
|
| 移除 GitHub 链接 | 用户菜单中不显示 GitHub 导航 |
|
||||||
|
| 微信客服按钮 | 首页悬浮微信客服入口 |
|
||||||
|
| 限流时间精确显示 | 账号限流时间显示精确到秒 |
|
||||||
|
|
||||||
|
### Antigravity 平台增强
|
||||||
|
|
||||||
|
| 功能 | 说明 |
|
||||||
|
|------|------|
|
||||||
|
| Scope 级别限流 | 按配额域(claude/gemini_text/gemini_image)独立限流,避免整个账号被锁定 |
|
||||||
|
| 模型级别限流 | 按具体模型(如 claude-opus-4-5)独立限流,更精细的限流控制 |
|
||||||
|
| 限流预检查 | 调度时预检查账号/模型限流状态,避免选中已限流账号 |
|
||||||
|
| 秒级冷却时间 | 支持 429 响应的秒级精确冷却时间 |
|
||||||
|
| 身份注入优化 | 模型身份信息注入 + 静默边界防止身份泄露 |
|
||||||
|
| thoughtSignature 修复 | Gemini 3 函数调用 400 错误修复 |
|
||||||
|
| max_tokens 自动修正 | 自动修正 max_tokens <= budget_tokens 导致的 400 错误 |
|
||||||
|
|
||||||
|
### 调度算法优化
|
||||||
|
|
||||||
|
| 功能 | 说明 |
|
||||||
|
|------|------|
|
||||||
|
| 分层过滤选择 | 调度算法从全排序改为分层过滤,提升性能 |
|
||||||
|
| LRU 随机选择 | 相同 LRU 时间时随机选择,避免账号集中 |
|
||||||
|
| 限流等待阈值配置化 | 可配置的限流等待阈值 |
|
||||||
|
|
||||||
|
### 运维增强
|
||||||
|
|
||||||
|
| 功能 | 说明 |
|
||||||
|
|------|------|
|
||||||
|
| Scope 限流统计 | 运维界面展示 Antigravity 账号 scope 级别限流统计 |
|
||||||
|
| 账号限流状态显示 | 账号列表显示 scope 和模型级别限流状态 |
|
||||||
|
| 清除限流按钮增强 | 有 scope/模型限流时也显示清除限流按钮 |
|
||||||
|
|
||||||
|
### 其他修复
|
||||||
|
|
||||||
|
| 功能 | 说明 |
|
||||||
|
|------|------|
|
||||||
|
| .gitattributes | 确保迁移文件使用 LF 换行符(解决 Windows 下 SQL 摘要不一致) |
|
||||||
|
| 部署配置优化 | DATABASE_HOST 和 DATABASE_SSLMODE 可通过 .env 配置 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 注意事项
|
||||||
|
|
||||||
|
1. **前端必须打包进镜像**:使用 `docker build` 在服务器上构建,Dockerfile 会自动编译前端并 embed 到后端二进制中
|
||||||
|
|
||||||
|
2. **镜像标签**:docker-compose.yml 使用 `weishaw/sub2api:latest`,本地构建后需要 `docker tag` 覆盖
|
||||||
|
|
||||||
|
3. **Windows 换行符问题**:已通过 `.gitattributes` 解决,确保 `*.sql` 文件始终使用 LF
|
||||||
|
|
||||||
|
4. **版本号管理**:每次发布必须更新 `backend/cmd/server/VERSION` 并打标签
|
||||||
|
|
||||||
|
5. **合并冲突**:合并上游新版本时,重点关注以下文件可能的冲突:
|
||||||
|
- `backend/internal/service/antigravity_gateway_service.go`
|
||||||
|
- `backend/internal/service/gateway_service.go`
|
||||||
|
- `backend/internal/pkg/antigravity/request_transformer.go`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Go 代码规范
|
||||||
|
|
||||||
|
### 1. 函数设计
|
||||||
|
|
||||||
|
#### 单一职责原则
|
||||||
|
- **函数行数**:单个函数常规不应超过 **30 行**,超过时应拆分为子函数。若某段逻辑确实不可拆分(如复杂的状态机、协议解析等),可以例外,但需添加注释说明原因
|
||||||
|
- **嵌套层级**:避免超过 3 层嵌套,使用 early return 减少嵌套
|
||||||
|
|
||||||
|
```go
|
||||||
|
// ❌ 不推荐:深层嵌套
|
||||||
|
func process(data []Item) {
|
||||||
|
for _, item := range data {
|
||||||
|
if item.Valid {
|
||||||
|
if item.Type == "A" {
|
||||||
|
if item.Status == "active" {
|
||||||
|
// 业务逻辑...
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ✅ 推荐:early return
|
||||||
|
func process(data []Item) {
|
||||||
|
for _, item := range data {
|
||||||
|
if !item.Valid {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if item.Type != "A" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if item.Status != "active" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// 业务逻辑...
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 复杂逻辑提取
|
||||||
|
将复杂的条件判断或处理逻辑提取为独立函数:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// ❌ 不推荐:内联复杂逻辑
|
||||||
|
if resp.StatusCode == 429 || resp.StatusCode == 503 {
|
||||||
|
// 80+ 行处理逻辑...
|
||||||
|
}
|
||||||
|
|
||||||
|
// ✅ 推荐:提取为独立函数
|
||||||
|
result := handleRateLimitResponse(resp, params)
|
||||||
|
switch result.action {
|
||||||
|
case actionRetry:
|
||||||
|
continue
|
||||||
|
case actionBreak:
|
||||||
|
return result.resp, nil
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. 重复代码消除
|
||||||
|
|
||||||
|
#### 配置获取模式
|
||||||
|
将重复的配置获取逻辑提取为方法:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// ❌ 不推荐:重复代码
|
||||||
|
logBody := s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBody
|
||||||
|
maxBytes := 2048
|
||||||
|
if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes > 0 {
|
||||||
|
maxBytes = s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||||
|
}
|
||||||
|
|
||||||
|
// ✅ 推荐:提取为方法
|
||||||
|
func (s *Service) getLogConfig() (logBody bool, maxBytes int) {
|
||||||
|
maxBytes = 2048
|
||||||
|
if s.settingService == nil || s.settingService.cfg == nil {
|
||||||
|
return false, maxBytes
|
||||||
|
}
|
||||||
|
cfg := s.settingService.cfg.Gateway
|
||||||
|
if cfg.LogUpstreamErrorBodyMaxBytes > 0 {
|
||||||
|
maxBytes = cfg.LogUpstreamErrorBodyMaxBytes
|
||||||
|
}
|
||||||
|
return cfg.LogUpstreamErrorBody, maxBytes
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. 常量管理
|
||||||
|
|
||||||
|
#### 避免魔法数字
|
||||||
|
所有硬编码的数值都应定义为常量:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// ❌ 不推荐
|
||||||
|
if retryDelay >= 10*time.Second {
|
||||||
|
resetAt := time.Now().Add(30 * time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ✅ 推荐
|
||||||
|
const (
|
||||||
|
rateLimitThreshold = 10 * time.Second
|
||||||
|
defaultRateLimitDuration = 30 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
if retryDelay >= rateLimitThreshold {
|
||||||
|
resetAt := time.Now().Add(defaultRateLimitDuration)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 注释引用常量名
|
||||||
|
在注释中引用常量名而非硬编码值:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// ❌ 不推荐
|
||||||
|
// < 10s: 等待后重试
|
||||||
|
|
||||||
|
// ✅ 推荐
|
||||||
|
// < rateLimitThreshold: 等待后重试
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. 错误处理
|
||||||
|
|
||||||
|
#### 使用结构化日志
|
||||||
|
优先使用 `slog` 进行结构化日志记录:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// ❌ 不推荐
|
||||||
|
log.Printf("%s status=%d model_rate_limit_failed model=%s error=%v", prefix, statusCode, modelName, err)
|
||||||
|
|
||||||
|
// ✅ 推荐
|
||||||
|
slog.Error("failed to set model rate limit",
|
||||||
|
"prefix", prefix,
|
||||||
|
"status_code", statusCode,
|
||||||
|
"model", modelName,
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5. 测试规范
|
||||||
|
|
||||||
|
#### Mock 函数签名同步
|
||||||
|
修改函数签名时,必须同步更新所有测试中的 mock 函数:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// 如果修改了 handleError 签名
|
||||||
|
handleError func(..., groupID int64, sessionHash string) *Result
|
||||||
|
|
||||||
|
// 必须同步更新测试中的 mock
|
||||||
|
handleError: func(..., groupID int64, sessionHash string) *Result {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 测试构建标签
|
||||||
|
统一使用测试构建标签:
|
||||||
|
|
||||||
|
```go
|
||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
```
|
||||||
|
|
||||||
|
### 6. 时间格式解析
|
||||||
|
|
||||||
|
#### 使用标准库
|
||||||
|
优先使用 `time.ParseDuration`,支持所有 Go duration 格式:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// ❌ 不推荐:手动限制格式
|
||||||
|
if !strings.HasSuffix(delay, "s") || strings.Contains(delay, "m") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// ✅ 推荐:使用标准库
|
||||||
|
dur, err := time.ParseDuration(delay) // 支持 "0.5s", "4m50s", "1h30m" 等
|
||||||
|
```
|
||||||
|
|
||||||
|
### 7. 接口设计
|
||||||
|
|
||||||
|
#### 接口隔离原则
|
||||||
|
定义最小化接口,只包含必需的方法:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// ❌ 不推荐:使用过于宽泛的接口
|
||||||
|
type AccountRepository interface {
|
||||||
|
// 20+ 个方法...
|
||||||
|
}
|
||||||
|
|
||||||
|
// ✅ 推荐:定义最小化接口
|
||||||
|
type ModelRateLimiter interface {
|
||||||
|
SetModelRateLimit(ctx context.Context, id int64, modelKey string, resetAt time.Time) error
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 8. 并发安全
|
||||||
|
|
||||||
|
#### 共享数据保护
|
||||||
|
访问可能被并发修改的数据时,确保线程安全:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// 如果 Account.Extra 可能被并发修改
|
||||||
|
// 需要使用互斥锁或原子操作保护读取
|
||||||
|
func (a *Account) GetRateLimitRemainingTime(model string) time.Duration {
|
||||||
|
a.mu.RLock()
|
||||||
|
defer a.mu.RUnlock()
|
||||||
|
// 读取 Extra 字段...
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 9. 命名规范
|
||||||
|
|
||||||
|
#### 一致的命名风格
|
||||||
|
- 常量使用 camelCase:`rateLimitThreshold`
|
||||||
|
- 类型使用 PascalCase:`AntigravityQuotaScope`
|
||||||
|
- 同一概念使用统一命名:`Threshold` 或 `Limit`,不要混用
|
||||||
|
|
||||||
|
```go
|
||||||
|
// ❌ 不推荐:命名不一致
|
||||||
|
antigravitySmartRetryMinWait // 使用 Min
|
||||||
|
antigravityRateLimitThreshold // 使用 Threshold
|
||||||
|
|
||||||
|
// ✅ 推荐:统一风格
|
||||||
|
antigravityMinRetryWait
|
||||||
|
antigravityRateLimitThreshold
|
||||||
|
```
|
||||||
|
|
||||||
|
### 10. 代码审查清单
|
||||||
|
|
||||||
|
在提交代码前,检查以下项目:
|
||||||
|
|
||||||
|
- [ ] 函数是否超过 30 行?(不可拆分的逻辑除外,需注释说明)
|
||||||
|
- [ ] 嵌套是否超过 3 层?
|
||||||
|
- [ ] 是否有重复代码可以提取?
|
||||||
|
- [ ] 是否使用了魔法数字?
|
||||||
|
- [ ] Mock 函数签名是否与实际函数一致?
|
||||||
|
- [ ] 测试是否覆盖了新增逻辑?
|
||||||
|
- [ ] 日志是否包含足够的上下文信息?
|
||||||
|
- [ ] 是否考虑了并发安全?
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## CI 检查与发布门禁
|
||||||
|
|
||||||
|
### GitHub Actions 检查项
|
||||||
|
|
||||||
|
本项目有 4 个 CI 任务,**任何代码推送或发布前都必须全部通过**:
|
||||||
|
|
||||||
|
| Workflow | Job | 说明 | 本地验证命令 |
|
||||||
|
|----------|-----|------|-------------|
|
||||||
|
| CI | `test` | 单元测试 + 集成测试 | `cd backend && make test-unit && make test-integration` |
|
||||||
|
| CI | `golangci-lint` | Go 代码静态检查(golangci-lint v2.7) | `cd backend && golangci-lint run --timeout=5m` |
|
||||||
|
| Security Scan | `backend-security` | govulncheck + gosec 安全扫描 | `cd backend && govulncheck ./... && gosec -severity high -confidence high ./...` |
|
||||||
|
| Security Scan | `frontend-security` | pnpm audit 前端依赖安全检查 | `cd frontend && pnpm audit --prod --audit-level=high` |
|
||||||
|
|
||||||
|
### 向上游提交 PR
|
||||||
|
|
||||||
|
PR 目标是上游官方仓库,**只包含通用功能改动**(bug fix、新功能、性能优化等)。
|
||||||
|
|
||||||
|
**以下文件禁止出现在 PR 中**(属于我们 fork 的定制化内容):
|
||||||
|
- `CLAUDE.md`、`AGENTS.md` — 我们的开发文档
|
||||||
|
- `backend/cmd/server/VERSION` — 我们的版本号文件
|
||||||
|
- UI 定制改动(GitHub 链接移除、微信客服按钮、首页定制等)
|
||||||
|
- 部署配置(`deploy/` 目录下的定制修改)
|
||||||
|
|
||||||
|
**PR 流程**:
|
||||||
|
1. 从 `develop` 创建功能分支,只包含要提交给上游的改动
|
||||||
|
2. 推送分支后,**等待 4 个 CI job 全部通过**
|
||||||
|
3. 确认通过后再创建 PR
|
||||||
|
4. 使用 `gh run list --repo touwaeriol/sub2api --branch <branch>` 检查状态
|
||||||
|
|
||||||
|
### 自有分支推送(develop / main)
|
||||||
|
|
||||||
|
推送到我们自己的 `develop` 或 `main` 分支时,包含所有改动(定制化 + 通用功能)。
|
||||||
|
|
||||||
|
**推送流程**:
|
||||||
|
1. 本地运行 `cd backend && make test-unit` 确保单元测试通过
|
||||||
|
2. 本地运行 `cd backend && gofmt -l ./...` 确保格式正确
|
||||||
|
3. 推送后确认 CI 和 Security Scan 两个 workflow 的 4 个 job 全部绿色 ✅
|
||||||
|
4. 任何 job 失败必须立即修复,**禁止在 CI 未通过的状态下继续后续操作**
|
||||||
|
|
||||||
|
### 发布版本
|
||||||
|
|
||||||
|
1. 确保 `main` 分支最新提交的 4 个 CI job 全部通过
|
||||||
|
2. 递增 `backend/cmd/server/VERSION`,提交并推送
|
||||||
|
3. 打 tag 推送后,确认 tag 触发的 3 个 workflow(CI、Security Scan、Release)全部通过
|
||||||
|
4. **Release workflow 失败时禁止部署** — 必须先修复问题,删除旧 tag,重新打 tag
|
||||||
|
5. 使用 `gh run list --repo touwaeriol/sub2api --limit 10` 确认状态
|
||||||
|
|
||||||
|
### 常见 CI 失败原因及修复
|
||||||
|
- **gofmt**:struct 字段对齐不一致 → 运行 `gofmt -w <file>` 修复
|
||||||
|
- **golangci-lint**:未使用的变量/导入 → 删除或使用 `_` 忽略
|
||||||
|
- **test 失败**:mock 函数签名不一致 → 同步更新 mock
|
||||||
|
- **gosec**:安全漏洞 → 根据提示修复或添加例外
|
||||||
@@ -1 +1 @@
|
|||||||
0.1.74.7
|
0.1.75.7
|
||||||
@@ -154,7 +154,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
identityService := service.NewIdentityService(identityCache)
|
identityService := service.NewIdentityService(identityCache)
|
||||||
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
|
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
|
||||||
claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService)
|
claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService)
|
||||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache)
|
digestSessionStore := service.NewDigestSessionStore()
|
||||||
|
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, digestSessionStore)
|
||||||
openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService)
|
openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService)
|
||||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
|
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
|
||||||
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
||||||
|
|||||||
@@ -103,6 +103,7 @@ require (
|
|||||||
github.com/ncruces/go-strftime v1.0.0 // indirect
|
github.com/ncruces/go-strftime v1.0.0 // indirect
|
||||||
github.com/opencontainers/go-digest v1.0.0 // indirect
|
github.com/opencontainers/go-digest v1.0.0 // indirect
|
||||||
github.com/opencontainers/image-spec v1.1.1 // indirect
|
github.com/opencontainers/image-spec v1.1.1 // indirect
|
||||||
|
github.com/patrickmn/go-cache v2.1.0+incompatible // indirect
|
||||||
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
|
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
|
||||||
github.com/pkg/errors v0.9.1 // indirect
|
github.com/pkg/errors v0.9.1 // indirect
|
||||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
||||||
|
|||||||
@@ -207,6 +207,8 @@ github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8
|
|||||||
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
|
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
|
||||||
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
|
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
|
||||||
github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M=
|
github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M=
|
||||||
|
github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc=
|
||||||
|
github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ=
|
||||||
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
|
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
|
||||||
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
|
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
|
||||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||||
|
|||||||
@@ -2,11 +2,6 @@ package dto
|
|||||||
|
|
||||||
import "time"
|
import "time"
|
||||||
|
|
||||||
type ScopeRateLimitInfo struct {
|
|
||||||
ResetAt time.Time `json:"reset_at"`
|
|
||||||
RemainingSec int64 `json:"remaining_sec"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type User struct {
|
type User struct {
|
||||||
ID int64 `json:"id"`
|
ID int64 `json:"id"`
|
||||||
Email string `json:"email"`
|
Email string `json:"email"`
|
||||||
@@ -126,9 +121,6 @@ type Account struct {
|
|||||||
RateLimitResetAt *time.Time `json:"rate_limit_reset_at"`
|
RateLimitResetAt *time.Time `json:"rate_limit_reset_at"`
|
||||||
OverloadUntil *time.Time `json:"overload_until"`
|
OverloadUntil *time.Time `json:"overload_until"`
|
||||||
|
|
||||||
// Antigravity scope 级限流状态(从 extra 提取)
|
|
||||||
ScopeRateLimits map[string]ScopeRateLimitInfo `json:"scope_rate_limits,omitempty"`
|
|
||||||
|
|
||||||
TempUnschedulableUntil *time.Time `json:"temp_unschedulable_until"`
|
TempUnschedulableUntil *time.Time `json:"temp_unschedulable_until"`
|
||||||
TempUnschedulableReason string `json:"temp_unschedulable_reason"`
|
TempUnschedulableReason string `json:"temp_unschedulable_reason"`
|
||||||
|
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/domain"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||||
@@ -114,7 +115,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
|
|
||||||
setOpsRequestContext(c, "", false, body)
|
setOpsRequestContext(c, "", false, body)
|
||||||
|
|
||||||
parsedReq, err := service.ParseGatewayRequest(body)
|
parsedReq, err := service.ParseGatewayRequest(body, domain.PlatformAnthropic)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
||||||
return
|
return
|
||||||
@@ -203,6 +204,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 计算粘性会话hash
|
// 计算粘性会话hash
|
||||||
|
parsedReq.SessionContext = &service.SessionContext{
|
||||||
|
ClientIP: ip.GetClientIP(c),
|
||||||
|
UserAgent: c.GetHeader("User-Agent"),
|
||||||
|
APIKeyID: apiKey.ID,
|
||||||
|
}
|
||||||
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
|
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
|
||||||
|
|
||||||
// 获取平台:优先使用强制平台(/antigravity 路由,中间件已设置 request.Context),否则使用分组平台
|
// 获取平台:优先使用强制平台(/antigravity 路由,中间件已设置 request.Context),否则使用分组平台
|
||||||
@@ -335,7 +341,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
if errors.As(err, &failoverErr) {
|
if errors.As(err, &failoverErr) {
|
||||||
failedAccountIDs[account.ID] = struct{}{}
|
failedAccountIDs[account.ID] = struct{}{}
|
||||||
lastFailoverErr = failoverErr
|
lastFailoverErr = failoverErr
|
||||||
if failoverErr.ForceCacheBilling {
|
if needForceCacheBilling(hasBoundSession, failoverErr) {
|
||||||
forceCacheBilling = true
|
forceCacheBilling = true
|
||||||
}
|
}
|
||||||
if switchCount >= maxAccountSwitches {
|
if switchCount >= maxAccountSwitches {
|
||||||
@@ -344,6 +350,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
switchCount++
|
switchCount++
|
||||||
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
|
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
|
||||||
|
if account.Platform == service.PlatformAntigravity {
|
||||||
|
if !sleepFailoverDelay(c.Request.Context(), switchCount) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
// 错误响应已在Forward中处理,这里只记录日志
|
// 错误响应已在Forward中处理,这里只记录日志
|
||||||
@@ -530,7 +541,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
if errors.As(err, &failoverErr) {
|
if errors.As(err, &failoverErr) {
|
||||||
failedAccountIDs[account.ID] = struct{}{}
|
failedAccountIDs[account.ID] = struct{}{}
|
||||||
lastFailoverErr = failoverErr
|
lastFailoverErr = failoverErr
|
||||||
if failoverErr.ForceCacheBilling {
|
if needForceCacheBilling(hasBoundSession, failoverErr) {
|
||||||
forceCacheBilling = true
|
forceCacheBilling = true
|
||||||
}
|
}
|
||||||
if switchCount >= maxAccountSwitches {
|
if switchCount >= maxAccountSwitches {
|
||||||
@@ -539,6 +550,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
switchCount++
|
switchCount++
|
||||||
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
|
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
|
||||||
|
if account.Platform == service.PlatformAntigravity {
|
||||||
|
if !sleepFailoverDelay(c.Request.Context(), switchCount) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
// 错误响应已在Forward中处理,这里只记录日志
|
// 错误响应已在Forward中处理,这里只记录日志
|
||||||
@@ -801,6 +817,27 @@ func (h *GatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotT
|
|||||||
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
|
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// needForceCacheBilling 判断 failover 时是否需要强制缓存计费
|
||||||
|
// 粘性会话切换账号、或上游明确标记时,将 input_tokens 转为 cache_read 计费
|
||||||
|
func needForceCacheBilling(hasBoundSession bool, failoverErr *service.UpstreamFailoverError) bool {
|
||||||
|
return hasBoundSession || (failoverErr != nil && failoverErr.ForceCacheBilling)
|
||||||
|
}
|
||||||
|
|
||||||
|
// sleepFailoverDelay 账号切换线性递增延时:第1次0s、第2次1s、第3次2s…
|
||||||
|
// 返回 false 表示 context 已取消。
|
||||||
|
func sleepFailoverDelay(ctx context.Context, switchCount int) bool {
|
||||||
|
delay := time.Duration(switchCount-1) * time.Second
|
||||||
|
if delay <= 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return false
|
||||||
|
case <-time.After(delay):
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, platform string, streamStarted bool) {
|
func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, platform string, streamStarted bool) {
|
||||||
statusCode := failoverErr.StatusCode
|
statusCode := failoverErr.StatusCode
|
||||||
responseBody := failoverErr.ResponseBody
|
responseBody := failoverErr.ResponseBody
|
||||||
@@ -934,7 +971,7 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
|||||||
|
|
||||||
setOpsRequestContext(c, "", false, body)
|
setOpsRequestContext(c, "", false, body)
|
||||||
|
|
||||||
parsedReq, err := service.ParseGatewayRequest(body)
|
parsedReq, err := service.ParseGatewayRequest(body, domain.PlatformAnthropic)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
||||||
return
|
return
|
||||||
@@ -962,6 +999,11 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 计算粘性会话 hash
|
// 计算粘性会话 hash
|
||||||
|
parsedReq.SessionContext = &service.SessionContext{
|
||||||
|
ClientIP: ip.GetClientIP(c),
|
||||||
|
UserAgent: c.GetHeader("User-Agent"),
|
||||||
|
APIKeyID: apiKey.ID,
|
||||||
|
}
|
||||||
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
|
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
|
||||||
|
|
||||||
// 选择支持该模型的账号
|
// 选择支持该模型的账号
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/domain"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/gemini"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/gemini"
|
||||||
@@ -30,13 +31,6 @@ import (
|
|||||||
// 匹配格式: /Users/xxx/.gemini/tmp/[64位十六进制哈希]
|
// 匹配格式: /Users/xxx/.gemini/tmp/[64位十六进制哈希]
|
||||||
var geminiCLITmpDirRegex = regexp.MustCompile(`/\.gemini/tmp/([A-Fa-f0-9]{64})`)
|
var geminiCLITmpDirRegex = regexp.MustCompile(`/\.gemini/tmp/([A-Fa-f0-9]{64})`)
|
||||||
|
|
||||||
func isGeminiCLIRequest(c *gin.Context, body []byte) bool {
|
|
||||||
if strings.TrimSpace(c.GetHeader("x-gemini-api-privileged-user-id")) != "" {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return geminiCLITmpDirRegex.Match(body)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GeminiV1BetaListModels proxies:
|
// GeminiV1BetaListModels proxies:
|
||||||
// GET /v1beta/models
|
// GET /v1beta/models
|
||||||
func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) {
|
func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) {
|
||||||
@@ -239,7 +233,14 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
sessionHash := extractGeminiCLISessionHash(c, body)
|
sessionHash := extractGeminiCLISessionHash(c, body)
|
||||||
if sessionHash == "" {
|
if sessionHash == "" {
|
||||||
// Fallback: 使用通用的会话哈希生成逻辑(适用于其他客户端)
|
// Fallback: 使用通用的会话哈希生成逻辑(适用于其他客户端)
|
||||||
parsedReq, _ := service.ParseGatewayRequest(body)
|
parsedReq, _ := service.ParseGatewayRequest(body, domain.PlatformGemini)
|
||||||
|
if parsedReq != nil {
|
||||||
|
parsedReq.SessionContext = &service.SessionContext{
|
||||||
|
ClientIP: ip.GetClientIP(c),
|
||||||
|
UserAgent: c.GetHeader("User-Agent"),
|
||||||
|
APIKeyID: apiKey.ID,
|
||||||
|
}
|
||||||
|
}
|
||||||
sessionHash = h.gatewayService.GenerateSessionHash(parsedReq)
|
sessionHash = h.gatewayService.GenerateSessionHash(parsedReq)
|
||||||
}
|
}
|
||||||
sessionKey := sessionHash
|
sessionKey := sessionHash
|
||||||
@@ -258,6 +259,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
var geminiDigestChain string
|
var geminiDigestChain string
|
||||||
var geminiPrefixHash string
|
var geminiPrefixHash string
|
||||||
var geminiSessionUUID string
|
var geminiSessionUUID string
|
||||||
|
var matchedDigestChain string
|
||||||
useDigestFallback := sessionBoundAccountID == 0
|
useDigestFallback := sessionBoundAccountID == 0
|
||||||
|
|
||||||
if useDigestFallback {
|
if useDigestFallback {
|
||||||
@@ -284,13 +286,14 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
)
|
)
|
||||||
|
|
||||||
// 查找会话
|
// 查找会话
|
||||||
foundUUID, foundAccountID, found := h.gatewayService.FindGeminiSession(
|
foundUUID, foundAccountID, foundMatchedChain, found := h.gatewayService.FindGeminiSession(
|
||||||
c.Request.Context(),
|
c.Request.Context(),
|
||||||
derefGroupID(apiKey.GroupID),
|
derefGroupID(apiKey.GroupID),
|
||||||
geminiPrefixHash,
|
geminiPrefixHash,
|
||||||
geminiDigestChain,
|
geminiDigestChain,
|
||||||
)
|
)
|
||||||
if found {
|
if found {
|
||||||
|
matchedDigestChain = foundMatchedChain
|
||||||
sessionBoundAccountID = foundAccountID
|
sessionBoundAccountID = foundAccountID
|
||||||
geminiSessionUUID = foundUUID
|
geminiSessionUUID = foundUUID
|
||||||
log.Printf("[Gemini] Digest fallback matched: uuid=%s, accountID=%d, chain=%s",
|
log.Printf("[Gemini] Digest fallback matched: uuid=%s, accountID=%d, chain=%s",
|
||||||
@@ -316,7 +319,6 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
|
|
||||||
// 判断是否真的绑定了粘性会话:有 sessionKey 且已经绑定到某个账号
|
// 判断是否真的绑定了粘性会话:有 sessionKey 且已经绑定到某个账号
|
||||||
hasBoundSession := sessionKey != "" && sessionBoundAccountID > 0
|
hasBoundSession := sessionKey != "" && sessionBoundAccountID > 0
|
||||||
isCLI := isGeminiCLIRequest(c, body)
|
|
||||||
cleanedForUnknownBinding := false
|
cleanedForUnknownBinding := false
|
||||||
|
|
||||||
maxAccountSwitches := h.maxAccountSwitchesGemini
|
maxAccountSwitches := h.maxAccountSwitchesGemini
|
||||||
@@ -344,10 +346,10 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
log.Printf("[Gemini] Sticky session account switched: %d -> %d, cleaning thoughtSignature", sessionBoundAccountID, account.ID)
|
log.Printf("[Gemini] Sticky session account switched: %d -> %d, cleaning thoughtSignature", sessionBoundAccountID, account.ID)
|
||||||
body = service.CleanGeminiNativeThoughtSignatures(body)
|
body = service.CleanGeminiNativeThoughtSignatures(body)
|
||||||
sessionBoundAccountID = account.ID
|
sessionBoundAccountID = account.ID
|
||||||
} else if sessionKey != "" && sessionBoundAccountID == 0 && isCLI && !cleanedForUnknownBinding && bytes.Contains(body, []byte(`"thoughtSignature"`)) {
|
} else if sessionKey != "" && sessionBoundAccountID == 0 && !cleanedForUnknownBinding && bytes.Contains(body, []byte(`"thoughtSignature"`)) {
|
||||||
// 无缓存绑定但请求里已有 thoughtSignature:常见于缓存丢失/TTL 过期后,CLI 继续携带旧签名。
|
// 无缓存绑定但请求里已有 thoughtSignature:常见于缓存丢失/TTL 过期后,客户端继续携带旧签名。
|
||||||
// 为避免第一次转发就 400,这里做一次确定性清理,让新账号重新生成签名链路。
|
// 为避免第一次转发就 400,这里做一次确定性清理,让新账号重新生成签名链路。
|
||||||
log.Printf("[Gemini] Sticky session binding missing for CLI request, cleaning thoughtSignature proactively")
|
log.Printf("[Gemini] Sticky session binding missing, cleaning thoughtSignature proactively")
|
||||||
body = service.CleanGeminiNativeThoughtSignatures(body)
|
body = service.CleanGeminiNativeThoughtSignatures(body)
|
||||||
cleanedForUnknownBinding = true
|
cleanedForUnknownBinding = true
|
||||||
sessionBoundAccountID = account.ID
|
sessionBoundAccountID = account.ID
|
||||||
@@ -422,7 +424,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
var failoverErr *service.UpstreamFailoverError
|
var failoverErr *service.UpstreamFailoverError
|
||||||
if errors.As(err, &failoverErr) {
|
if errors.As(err, &failoverErr) {
|
||||||
failedAccountIDs[account.ID] = struct{}{}
|
failedAccountIDs[account.ID] = struct{}{}
|
||||||
if failoverErr.ForceCacheBilling {
|
if needForceCacheBilling(hasBoundSession, failoverErr) {
|
||||||
forceCacheBilling = true
|
forceCacheBilling = true
|
||||||
}
|
}
|
||||||
if switchCount >= maxAccountSwitches {
|
if switchCount >= maxAccountSwitches {
|
||||||
@@ -433,6 +435,11 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
lastFailoverErr = failoverErr
|
lastFailoverErr = failoverErr
|
||||||
switchCount++
|
switchCount++
|
||||||
log.Printf("Gemini account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
|
log.Printf("Gemini account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
|
||||||
|
if account.Platform == service.PlatformAntigravity {
|
||||||
|
if !sleepFailoverDelay(c.Request.Context(), switchCount) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
// ForwardNative already wrote the response
|
// ForwardNative already wrote the response
|
||||||
@@ -453,6 +460,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
geminiDigestChain,
|
geminiDigestChain,
|
||||||
geminiSessionUUID,
|
geminiSessionUUID,
|
||||||
account.ID,
|
account.ID,
|
||||||
|
matchedDigestChain,
|
||||||
); err != nil {
|
); err != nil {
|
||||||
log.Printf("[Gemini] Failed to save digest session: %v", err)
|
log.Printf("[Gemini] Failed to save digest session: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -798,53 +798,6 @@ func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetA
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *accountRepository) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope service.AntigravityQuotaScope, resetAt time.Time) error {
|
|
||||||
now := time.Now().UTC()
|
|
||||||
payload := map[string]string{
|
|
||||||
"rate_limited_at": now.Format(time.RFC3339),
|
|
||||||
"rate_limit_reset_at": resetAt.UTC().Format(time.RFC3339),
|
|
||||||
}
|
|
||||||
raw, err := json.Marshal(payload)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
scopeKey := string(scope)
|
|
||||||
client := clientFromContext(ctx, r.client)
|
|
||||||
result, err := client.ExecContext(
|
|
||||||
ctx,
|
|
||||||
`UPDATE accounts SET
|
|
||||||
extra = jsonb_set(
|
|
||||||
jsonb_set(COALESCE(extra, '{}'::jsonb), '{antigravity_quota_scopes}'::text[], COALESCE(extra->'antigravity_quota_scopes', '{}'::jsonb), true),
|
|
||||||
ARRAY['antigravity_quota_scopes', $1]::text[],
|
|
||||||
$2::jsonb,
|
|
||||||
true
|
|
||||||
),
|
|
||||||
updated_at = NOW(),
|
|
||||||
last_used_at = NOW()
|
|
||||||
WHERE id = $3 AND deleted_at IS NULL`,
|
|
||||||
scopeKey,
|
|
||||||
raw,
|
|
||||||
id,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
affected, err := result.RowsAffected()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if affected == 0 {
|
|
||||||
return service.ErrAccountNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
|
||||||
log.Printf("[SchedulerOutbox] enqueue quota scope failed: account=%d err=%v", id, err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *accountRepository) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
|
func (r *accountRepository) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
|
||||||
if scope == "" {
|
if scope == "" {
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -11,63 +11,6 @@ import (
|
|||||||
|
|
||||||
const stickySessionPrefix = "sticky_session:"
|
const stickySessionPrefix = "sticky_session:"
|
||||||
|
|
||||||
// Gemini Trie Lua 脚本
|
|
||||||
const (
|
|
||||||
// geminiTrieFindScript 查找最长前缀匹配的 Lua 脚本
|
|
||||||
// KEYS[1] = trie key
|
|
||||||
// ARGV[1] = digestChain (如 "u:a-m:b-u:c-m:d")
|
|
||||||
// ARGV[2] = TTL seconds (用于刷新)
|
|
||||||
// 返回: 最长匹配的 value (uuid:accountID) 或 nil
|
|
||||||
// 查找成功时自动刷新 TTL,防止活跃会话意外过期
|
|
||||||
geminiTrieFindScript = `
|
|
||||||
local chain = ARGV[1]
|
|
||||||
local ttl = tonumber(ARGV[2])
|
|
||||||
local lastMatch = nil
|
|
||||||
local path = ""
|
|
||||||
|
|
||||||
for part in string.gmatch(chain, "[^-]+") do
|
|
||||||
path = path == "" and part or path .. "-" .. part
|
|
||||||
local val = redis.call('HGET', KEYS[1], path)
|
|
||||||
if val and val ~= "" then
|
|
||||||
lastMatch = val
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
if lastMatch then
|
|
||||||
redis.call('EXPIRE', KEYS[1], ttl)
|
|
||||||
end
|
|
||||||
|
|
||||||
return lastMatch
|
|
||||||
`
|
|
||||||
|
|
||||||
// geminiTrieSaveScript 保存会话到 Trie 的 Lua 脚本
|
|
||||||
// KEYS[1] = trie key
|
|
||||||
// ARGV[1] = digestChain
|
|
||||||
// ARGV[2] = value (uuid:accountID)
|
|
||||||
// ARGV[3] = TTL seconds
|
|
||||||
geminiTrieSaveScript = `
|
|
||||||
local chain = ARGV[1]
|
|
||||||
local value = ARGV[2]
|
|
||||||
local ttl = tonumber(ARGV[3])
|
|
||||||
local path = ""
|
|
||||||
|
|
||||||
for part in string.gmatch(chain, "[^-]+") do
|
|
||||||
path = path == "" and part or path .. "-" .. part
|
|
||||||
end
|
|
||||||
redis.call('HSET', KEYS[1], path, value)
|
|
||||||
redis.call('EXPIRE', KEYS[1], ttl)
|
|
||||||
return "OK"
|
|
||||||
`
|
|
||||||
)
|
|
||||||
|
|
||||||
// 模型负载统计相关常量
|
|
||||||
const (
|
|
||||||
modelLoadKeyPrefix = "ag:model_load:" // 模型调用次数 key 前缀
|
|
||||||
modelLastUsedKeyPrefix = "ag:model_last_used:" // 模型最后调度时间 key 前缀
|
|
||||||
modelLoadTTL = 24 * time.Hour // 调用次数 TTL(24 小时无调用后清零)
|
|
||||||
modelLastUsedTTL = 24 * time.Hour // 最后调度时间 TTL
|
|
||||||
)
|
|
||||||
|
|
||||||
type gatewayCache struct {
|
type gatewayCache struct {
|
||||||
rdb *redis.Client
|
rdb *redis.Client
|
||||||
}
|
}
|
||||||
@@ -108,171 +51,3 @@ func (c *gatewayCache) DeleteSessionAccountID(ctx context.Context, groupID int64
|
|||||||
key := buildSessionKey(groupID, sessionHash)
|
key := buildSessionKey(groupID, sessionHash)
|
||||||
return c.rdb.Del(ctx, key).Err()
|
return c.rdb.Del(ctx, key).Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============ Antigravity 模型负载统计方法 ============
|
|
||||||
|
|
||||||
// modelLoadKey 构建模型调用次数 key
|
|
||||||
// 格式: ag:model_load:{accountID}:{model}
|
|
||||||
func modelLoadKey(accountID int64, model string) string {
|
|
||||||
return fmt.Sprintf("%s%d:%s", modelLoadKeyPrefix, accountID, model)
|
|
||||||
}
|
|
||||||
|
|
||||||
// modelLastUsedKey 构建模型最后调度时间 key
|
|
||||||
// 格式: ag:model_last_used:{accountID}:{model}
|
|
||||||
func modelLastUsedKey(accountID int64, model string) string {
|
|
||||||
return fmt.Sprintf("%s%d:%s", modelLastUsedKeyPrefix, accountID, model)
|
|
||||||
}
|
|
||||||
|
|
||||||
// IncrModelCallCount 增加模型调用次数并更新最后调度时间
|
|
||||||
// 返回更新后的调用次数
|
|
||||||
func (c *gatewayCache) IncrModelCallCount(ctx context.Context, accountID int64, model string) (int64, error) {
|
|
||||||
loadKey := modelLoadKey(accountID, model)
|
|
||||||
lastUsedKey := modelLastUsedKey(accountID, model)
|
|
||||||
|
|
||||||
pipe := c.rdb.Pipeline()
|
|
||||||
incrCmd := pipe.Incr(ctx, loadKey)
|
|
||||||
pipe.Expire(ctx, loadKey, modelLoadTTL) // 每次调用刷新 TTL
|
|
||||||
pipe.Set(ctx, lastUsedKey, time.Now().Unix(), modelLastUsedTTL)
|
|
||||||
if _, err := pipe.Exec(ctx); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
return incrCmd.Val(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetModelLoadBatch 批量获取账号的模型负载信息
|
|
||||||
func (c *gatewayCache) GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*service.ModelLoadInfo, error) {
|
|
||||||
if len(accountIDs) == 0 {
|
|
||||||
return make(map[int64]*service.ModelLoadInfo), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
loadCmds, lastUsedCmds := c.pipelineModelLoadGet(ctx, accountIDs, model)
|
|
||||||
return c.parseModelLoadResults(accountIDs, loadCmds, lastUsedCmds), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// pipelineModelLoadGet 批量获取模型负载的 Pipeline 操作
|
|
||||||
func (c *gatewayCache) pipelineModelLoadGet(
|
|
||||||
ctx context.Context,
|
|
||||||
accountIDs []int64,
|
|
||||||
model string,
|
|
||||||
) (map[int64]*redis.StringCmd, map[int64]*redis.StringCmd) {
|
|
||||||
pipe := c.rdb.Pipeline()
|
|
||||||
loadCmds := make(map[int64]*redis.StringCmd, len(accountIDs))
|
|
||||||
lastUsedCmds := make(map[int64]*redis.StringCmd, len(accountIDs))
|
|
||||||
|
|
||||||
for _, id := range accountIDs {
|
|
||||||
loadCmds[id] = pipe.Get(ctx, modelLoadKey(id, model))
|
|
||||||
lastUsedCmds[id] = pipe.Get(ctx, modelLastUsedKey(id, model))
|
|
||||||
}
|
|
||||||
_, _ = pipe.Exec(ctx) // 忽略错误,key 不存在是正常的
|
|
||||||
return loadCmds, lastUsedCmds
|
|
||||||
}
|
|
||||||
|
|
||||||
// parseModelLoadResults 解析 Pipeline 结果
|
|
||||||
func (c *gatewayCache) parseModelLoadResults(
|
|
||||||
accountIDs []int64,
|
|
||||||
loadCmds map[int64]*redis.StringCmd,
|
|
||||||
lastUsedCmds map[int64]*redis.StringCmd,
|
|
||||||
) map[int64]*service.ModelLoadInfo {
|
|
||||||
result := make(map[int64]*service.ModelLoadInfo, len(accountIDs))
|
|
||||||
for _, id := range accountIDs {
|
|
||||||
result[id] = &service.ModelLoadInfo{
|
|
||||||
CallCount: getInt64OrZero(loadCmds[id]),
|
|
||||||
LastUsedAt: getTimeOrZero(lastUsedCmds[id]),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
// getInt64OrZero 从 StringCmd 获取 int64 值,失败返回 0
|
|
||||||
func getInt64OrZero(cmd *redis.StringCmd) int64 {
|
|
||||||
val, _ := cmd.Int64()
|
|
||||||
return val
|
|
||||||
}
|
|
||||||
|
|
||||||
// getTimeOrZero 从 StringCmd 获取 time.Time,失败返回零值
|
|
||||||
func getTimeOrZero(cmd *redis.StringCmd) time.Time {
|
|
||||||
val, err := cmd.Int64()
|
|
||||||
if err != nil {
|
|
||||||
return time.Time{}
|
|
||||||
}
|
|
||||||
return time.Unix(val, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============ Gemini 会话 Fallback 方法 (Trie 实现) ============
|
|
||||||
|
|
||||||
// FindGeminiSession 查找 Gemini 会话(使用 Trie + Lua 脚本实现 O(L) 查询)
|
|
||||||
// 返回最长匹配的会话信息,匹配成功时自动刷新 TTL
|
|
||||||
func (c *gatewayCache) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
|
|
||||||
if digestChain == "" {
|
|
||||||
return "", 0, false
|
|
||||||
}
|
|
||||||
|
|
||||||
trieKey := service.BuildGeminiTrieKey(groupID, prefixHash)
|
|
||||||
ttlSeconds := int(service.GeminiSessionTTL().Seconds())
|
|
||||||
|
|
||||||
// 使用 Lua 脚本在 Redis 端执行 Trie 查找,O(L) 次 HGET,1 次网络往返
|
|
||||||
// 查找成功时自动刷新 TTL,防止活跃会话意外过期
|
|
||||||
result, err := c.rdb.Eval(ctx, geminiTrieFindScript, []string{trieKey}, digestChain, ttlSeconds).Result()
|
|
||||||
if err != nil || result == nil {
|
|
||||||
return "", 0, false
|
|
||||||
}
|
|
||||||
|
|
||||||
value, ok := result.(string)
|
|
||||||
if !ok || value == "" {
|
|
||||||
return "", 0, false
|
|
||||||
}
|
|
||||||
|
|
||||||
uuid, accountID, ok = service.ParseGeminiSessionValue(value)
|
|
||||||
return uuid, accountID, ok
|
|
||||||
}
|
|
||||||
|
|
||||||
// SaveGeminiSession 保存 Gemini 会话(使用 Trie + Lua 脚本)
|
|
||||||
func (c *gatewayCache) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
|
|
||||||
if digestChain == "" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
trieKey := service.BuildGeminiTrieKey(groupID, prefixHash)
|
|
||||||
value := service.FormatGeminiSessionValue(uuid, accountID)
|
|
||||||
ttlSeconds := int(service.GeminiSessionTTL().Seconds())
|
|
||||||
|
|
||||||
return c.rdb.Eval(ctx, geminiTrieSaveScript, []string{trieKey}, digestChain, value, ttlSeconds).Err()
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============ Anthropic 会话 Fallback 方法 (复用 Trie 实现) ============
|
|
||||||
|
|
||||||
// FindAnthropicSession 查找 Anthropic 会话(复用 Gemini Trie Lua 脚本)
|
|
||||||
func (c *gatewayCache) FindAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
|
|
||||||
if digestChain == "" {
|
|
||||||
return "", 0, false
|
|
||||||
}
|
|
||||||
|
|
||||||
trieKey := service.BuildAnthropicTrieKey(groupID, prefixHash)
|
|
||||||
ttlSeconds := int(service.AnthropicSessionTTL().Seconds())
|
|
||||||
|
|
||||||
result, err := c.rdb.Eval(ctx, geminiTrieFindScript, []string{trieKey}, digestChain, ttlSeconds).Result()
|
|
||||||
if err != nil || result == nil {
|
|
||||||
return "", 0, false
|
|
||||||
}
|
|
||||||
|
|
||||||
value, ok := result.(string)
|
|
||||||
if !ok || value == "" {
|
|
||||||
return "", 0, false
|
|
||||||
}
|
|
||||||
|
|
||||||
uuid, accountID, ok = service.ParseGeminiSessionValue(value)
|
|
||||||
return uuid, accountID, ok
|
|
||||||
}
|
|
||||||
|
|
||||||
// SaveAnthropicSession 保存 Anthropic 会话(复用 Gemini Trie Lua 脚本)
|
|
||||||
func (c *gatewayCache) SaveAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
|
|
||||||
if digestChain == "" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
trieKey := service.BuildAnthropicTrieKey(groupID, prefixHash)
|
|
||||||
value := service.FormatGeminiSessionValue(uuid, accountID)
|
|
||||||
ttlSeconds := int(service.AnthropicSessionTTL().Seconds())
|
|
||||||
|
|
||||||
return c.rdb.Eval(ctx, geminiTrieSaveScript, []string{trieKey}, digestChain, value, ttlSeconds).Err()
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -104,157 +104,6 @@ func (s *GatewayCacheSuite) TestGetSessionAccountID_CorruptedValue() {
|
|||||||
require.False(s.T(), errors.Is(err, redis.Nil), "expected parsing error, not redis.Nil")
|
require.False(s.T(), errors.Is(err, redis.Nil), "expected parsing error, not redis.Nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============ Gemini Trie 会话测试 ============
|
|
||||||
|
|
||||||
func (s *GatewayCacheSuite) TestGeminiSessionTrie_SaveAndFind() {
|
|
||||||
groupID := int64(1)
|
|
||||||
prefixHash := "testprefix"
|
|
||||||
digestChain := "u:hash1-m:hash2-u:hash3"
|
|
||||||
uuid := "test-uuid-123"
|
|
||||||
accountID := int64(42)
|
|
||||||
|
|
||||||
// 保存会话
|
|
||||||
err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, digestChain, uuid, accountID)
|
|
||||||
require.NoError(s.T(), err, "SaveGeminiSession")
|
|
||||||
|
|
||||||
// 精确匹配查找
|
|
||||||
foundUUID, foundAccountID, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, digestChain)
|
|
||||||
require.True(s.T(), found, "should find exact match")
|
|
||||||
require.Equal(s.T(), uuid, foundUUID)
|
|
||||||
require.Equal(s.T(), accountID, foundAccountID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *GatewayCacheSuite) TestGeminiSessionTrie_PrefixMatch() {
|
|
||||||
groupID := int64(1)
|
|
||||||
prefixHash := "prefixmatch"
|
|
||||||
shortChain := "u:a-m:b"
|
|
||||||
longChain := "u:a-m:b-u:c-m:d"
|
|
||||||
uuid := "uuid-prefix"
|
|
||||||
accountID := int64(100)
|
|
||||||
|
|
||||||
// 保存短链
|
|
||||||
err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, shortChain, uuid, accountID)
|
|
||||||
require.NoError(s.T(), err)
|
|
||||||
|
|
||||||
// 用长链查找,应该匹配到短链(前缀匹配)
|
|
||||||
foundUUID, foundAccountID, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, longChain)
|
|
||||||
require.True(s.T(), found, "should find prefix match")
|
|
||||||
require.Equal(s.T(), uuid, foundUUID)
|
|
||||||
require.Equal(s.T(), accountID, foundAccountID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *GatewayCacheSuite) TestGeminiSessionTrie_LongestPrefixMatch() {
|
|
||||||
groupID := int64(1)
|
|
||||||
prefixHash := "longestmatch"
|
|
||||||
|
|
||||||
// 保存多个不同长度的链
|
|
||||||
err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, "u:a", "uuid-short", 1)
|
|
||||||
require.NoError(s.T(), err)
|
|
||||||
err = s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, "u:a-m:b", "uuid-medium", 2)
|
|
||||||
require.NoError(s.T(), err)
|
|
||||||
err = s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, "u:a-m:b-u:c", "uuid-long", 3)
|
|
||||||
require.NoError(s.T(), err)
|
|
||||||
|
|
||||||
// 查找更长的链,应该匹配到最长的前缀
|
|
||||||
foundUUID, foundAccountID, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, "u:a-m:b-u:c-m:d-u:e")
|
|
||||||
require.True(s.T(), found, "should find longest prefix match")
|
|
||||||
require.Equal(s.T(), "uuid-long", foundUUID)
|
|
||||||
require.Equal(s.T(), int64(3), foundAccountID)
|
|
||||||
|
|
||||||
// 查找中等长度的链
|
|
||||||
foundUUID, foundAccountID, found = s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, "u:a-m:b-u:x")
|
|
||||||
require.True(s.T(), found)
|
|
||||||
require.Equal(s.T(), "uuid-medium", foundUUID)
|
|
||||||
require.Equal(s.T(), int64(2), foundAccountID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *GatewayCacheSuite) TestGeminiSessionTrie_NoMatch() {
|
|
||||||
groupID := int64(1)
|
|
||||||
prefixHash := "nomatch"
|
|
||||||
digestChain := "u:a-m:b"
|
|
||||||
|
|
||||||
// 保存一个会话
|
|
||||||
err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, digestChain, "uuid", 1)
|
|
||||||
require.NoError(s.T(), err)
|
|
||||||
|
|
||||||
// 用不同的链查找,应该找不到
|
|
||||||
_, _, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, "u:x-m:y")
|
|
||||||
require.False(s.T(), found, "should not find non-matching chain")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *GatewayCacheSuite) TestGeminiSessionTrie_DifferentPrefixHash() {
|
|
||||||
groupID := int64(1)
|
|
||||||
digestChain := "u:a-m:b"
|
|
||||||
|
|
||||||
// 保存到 prefixHash1
|
|
||||||
err := s.cache.SaveGeminiSession(s.ctx, groupID, "prefix1", digestChain, "uuid1", 1)
|
|
||||||
require.NoError(s.T(), err)
|
|
||||||
|
|
||||||
// 用 prefixHash2 查找,应该找不到(不同用户/客户端隔离)
|
|
||||||
_, _, found := s.cache.FindGeminiSession(s.ctx, groupID, "prefix2", digestChain)
|
|
||||||
require.False(s.T(), found, "different prefixHash should be isolated")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *GatewayCacheSuite) TestGeminiSessionTrie_DifferentGroupID() {
|
|
||||||
prefixHash := "sameprefix"
|
|
||||||
digestChain := "u:a-m:b"
|
|
||||||
|
|
||||||
// 保存到 groupID 1
|
|
||||||
err := s.cache.SaveGeminiSession(s.ctx, 1, prefixHash, digestChain, "uuid1", 1)
|
|
||||||
require.NoError(s.T(), err)
|
|
||||||
|
|
||||||
// 用 groupID 2 查找,应该找不到(分组隔离)
|
|
||||||
_, _, found := s.cache.FindGeminiSession(s.ctx, 2, prefixHash, digestChain)
|
|
||||||
require.False(s.T(), found, "different groupID should be isolated")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *GatewayCacheSuite) TestGeminiSessionTrie_EmptyDigestChain() {
|
|
||||||
groupID := int64(1)
|
|
||||||
prefixHash := "emptytest"
|
|
||||||
|
|
||||||
// 空链不应该保存
|
|
||||||
err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, "", "uuid", 1)
|
|
||||||
require.NoError(s.T(), err, "empty chain should not error")
|
|
||||||
|
|
||||||
// 空链查找应该返回 false
|
|
||||||
_, _, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, "")
|
|
||||||
require.False(s.T(), found, "empty chain should not match")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *GatewayCacheSuite) TestGeminiSessionTrie_MultipleSessions() {
|
|
||||||
groupID := int64(1)
|
|
||||||
prefixHash := "multisession"
|
|
||||||
|
|
||||||
// 保存多个不同会话(模拟 1000 个并发会话的场景)
|
|
||||||
sessions := []struct {
|
|
||||||
chain string
|
|
||||||
uuid string
|
|
||||||
accountID int64
|
|
||||||
}{
|
|
||||||
{"u:session1", "uuid-1", 1},
|
|
||||||
{"u:session2-m:reply2", "uuid-2", 2},
|
|
||||||
{"u:session3-m:reply3-u:msg3", "uuid-3", 3},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, sess := range sessions {
|
|
||||||
err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, sess.chain, sess.uuid, sess.accountID)
|
|
||||||
require.NoError(s.T(), err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 验证每个会话都能正确查找
|
|
||||||
for _, sess := range sessions {
|
|
||||||
foundUUID, foundAccountID, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, sess.chain)
|
|
||||||
require.True(s.T(), found, "should find session: %s", sess.chain)
|
|
||||||
require.Equal(s.T(), sess.uuid, foundUUID)
|
|
||||||
require.Equal(s.T(), sess.accountID, foundAccountID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 验证继续对话的场景
|
|
||||||
foundUUID, foundAccountID, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, "u:session2-m:reply2-u:newmsg")
|
|
||||||
require.True(s.T(), found)
|
|
||||||
require.Equal(s.T(), "uuid-2", foundUUID)
|
|
||||||
require.Equal(s.T(), int64(2), foundAccountID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGatewayCacheSuite(t *testing.T) {
|
func TestGatewayCacheSuite(t *testing.T) {
|
||||||
suite.Run(t, new(GatewayCacheSuite))
|
suite.Run(t, new(GatewayCacheSuite))
|
||||||
|
|||||||
@@ -1,234 +0,0 @@
|
|||||||
//go:build integration
|
|
||||||
|
|
||||||
package repository
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
"github.com/stretchr/testify/suite"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ============ Gateway Cache 模型负载统计集成测试 ============
|
|
||||||
|
|
||||||
type GatewayCacheModelLoadSuite struct {
|
|
||||||
suite.Suite
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGatewayCacheModelLoadSuite(t *testing.T) {
|
|
||||||
suite.Run(t, new(GatewayCacheModelLoadSuite))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *GatewayCacheModelLoadSuite) TestIncrModelCallCount_Basic() {
|
|
||||||
t := s.T()
|
|
||||||
rdb := testRedis(t)
|
|
||||||
cache := &gatewayCache{rdb: rdb}
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
accountID := int64(123)
|
|
||||||
model := "claude-sonnet-4-20250514"
|
|
||||||
|
|
||||||
// 首次调用应返回 1
|
|
||||||
count1, err := cache.IncrModelCallCount(ctx, accountID, model)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, int64(1), count1)
|
|
||||||
|
|
||||||
// 第二次调用应返回 2
|
|
||||||
count2, err := cache.IncrModelCallCount(ctx, accountID, model)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, int64(2), count2)
|
|
||||||
|
|
||||||
// 第三次调用应返回 3
|
|
||||||
count3, err := cache.IncrModelCallCount(ctx, accountID, model)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, int64(3), count3)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *GatewayCacheModelLoadSuite) TestIncrModelCallCount_DifferentModels() {
|
|
||||||
t := s.T()
|
|
||||||
rdb := testRedis(t)
|
|
||||||
cache := &gatewayCache{rdb: rdb}
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
accountID := int64(456)
|
|
||||||
model1 := "claude-sonnet-4-20250514"
|
|
||||||
model2 := "claude-opus-4-5-20251101"
|
|
||||||
|
|
||||||
// 不同模型应该独立计数
|
|
||||||
count1, err := cache.IncrModelCallCount(ctx, accountID, model1)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, int64(1), count1)
|
|
||||||
|
|
||||||
count2, err := cache.IncrModelCallCount(ctx, accountID, model2)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, int64(1), count2)
|
|
||||||
|
|
||||||
count1Again, err := cache.IncrModelCallCount(ctx, accountID, model1)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, int64(2), count1Again)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *GatewayCacheModelLoadSuite) TestIncrModelCallCount_DifferentAccounts() {
|
|
||||||
t := s.T()
|
|
||||||
rdb := testRedis(t)
|
|
||||||
cache := &gatewayCache{rdb: rdb}
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
account1 := int64(111)
|
|
||||||
account2 := int64(222)
|
|
||||||
model := "gemini-2.5-pro"
|
|
||||||
|
|
||||||
// 不同账号应该独立计数
|
|
||||||
count1, err := cache.IncrModelCallCount(ctx, account1, model)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, int64(1), count1)
|
|
||||||
|
|
||||||
count2, err := cache.IncrModelCallCount(ctx, account2, model)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, int64(1), count2)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *GatewayCacheModelLoadSuite) TestGetModelLoadBatch_Empty() {
|
|
||||||
t := s.T()
|
|
||||||
rdb := testRedis(t)
|
|
||||||
cache := &gatewayCache{rdb: rdb}
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
result, err := cache.GetModelLoadBatch(ctx, []int64{}, "any-model")
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotNil(t, result)
|
|
||||||
require.Empty(t, result)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *GatewayCacheModelLoadSuite) TestGetModelLoadBatch_NonExistent() {
|
|
||||||
t := s.T()
|
|
||||||
rdb := testRedis(t)
|
|
||||||
cache := &gatewayCache{rdb: rdb}
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
// 查询不存在的账号应返回零值
|
|
||||||
result, err := cache.GetModelLoadBatch(ctx, []int64{9999, 9998}, "claude-sonnet-4-20250514")
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Len(t, result, 2)
|
|
||||||
|
|
||||||
require.Equal(t, int64(0), result[9999].CallCount)
|
|
||||||
require.True(t, result[9999].LastUsedAt.IsZero())
|
|
||||||
require.Equal(t, int64(0), result[9998].CallCount)
|
|
||||||
require.True(t, result[9998].LastUsedAt.IsZero())
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *GatewayCacheModelLoadSuite) TestGetModelLoadBatch_AfterIncrement() {
|
|
||||||
t := s.T()
|
|
||||||
rdb := testRedis(t)
|
|
||||||
cache := &gatewayCache{rdb: rdb}
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
accountID := int64(789)
|
|
||||||
model := "claude-sonnet-4-20250514"
|
|
||||||
|
|
||||||
// 先增加调用次数
|
|
||||||
beforeIncr := time.Now()
|
|
||||||
_, err := cache.IncrModelCallCount(ctx, accountID, model)
|
|
||||||
require.NoError(t, err)
|
|
||||||
_, err = cache.IncrModelCallCount(ctx, accountID, model)
|
|
||||||
require.NoError(t, err)
|
|
||||||
_, err = cache.IncrModelCallCount(ctx, accountID, model)
|
|
||||||
require.NoError(t, err)
|
|
||||||
afterIncr := time.Now()
|
|
||||||
|
|
||||||
// 获取负载信息
|
|
||||||
result, err := cache.GetModelLoadBatch(ctx, []int64{accountID}, model)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Len(t, result, 1)
|
|
||||||
|
|
||||||
loadInfo := result[accountID]
|
|
||||||
require.NotNil(t, loadInfo)
|
|
||||||
require.Equal(t, int64(3), loadInfo.CallCount)
|
|
||||||
require.False(t, loadInfo.LastUsedAt.IsZero())
|
|
||||||
// LastUsedAt 应该在 beforeIncr 和 afterIncr 之间
|
|
||||||
require.True(t, loadInfo.LastUsedAt.After(beforeIncr.Add(-time.Second)) || loadInfo.LastUsedAt.Equal(beforeIncr))
|
|
||||||
require.True(t, loadInfo.LastUsedAt.Before(afterIncr.Add(time.Second)) || loadInfo.LastUsedAt.Equal(afterIncr))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *GatewayCacheModelLoadSuite) TestGetModelLoadBatch_MultipleAccounts() {
|
|
||||||
t := s.T()
|
|
||||||
rdb := testRedis(t)
|
|
||||||
cache := &gatewayCache{rdb: rdb}
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
model := "claude-opus-4-5-20251101"
|
|
||||||
account1 := int64(1001)
|
|
||||||
account2 := int64(1002)
|
|
||||||
account3 := int64(1003) // 不调用
|
|
||||||
|
|
||||||
// account1 调用 2 次
|
|
||||||
_, err := cache.IncrModelCallCount(ctx, account1, model)
|
|
||||||
require.NoError(t, err)
|
|
||||||
_, err = cache.IncrModelCallCount(ctx, account1, model)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// account2 调用 5 次
|
|
||||||
for i := 0; i < 5; i++ {
|
|
||||||
_, err = cache.IncrModelCallCount(ctx, account2, model)
|
|
||||||
require.NoError(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 批量获取
|
|
||||||
result, err := cache.GetModelLoadBatch(ctx, []int64{account1, account2, account3}, model)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Len(t, result, 3)
|
|
||||||
|
|
||||||
require.Equal(t, int64(2), result[account1].CallCount)
|
|
||||||
require.False(t, result[account1].LastUsedAt.IsZero())
|
|
||||||
|
|
||||||
require.Equal(t, int64(5), result[account2].CallCount)
|
|
||||||
require.False(t, result[account2].LastUsedAt.IsZero())
|
|
||||||
|
|
||||||
require.Equal(t, int64(0), result[account3].CallCount)
|
|
||||||
require.True(t, result[account3].LastUsedAt.IsZero())
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *GatewayCacheModelLoadSuite) TestGetModelLoadBatch_ModelIsolation() {
|
|
||||||
t := s.T()
|
|
||||||
rdb := testRedis(t)
|
|
||||||
cache := &gatewayCache{rdb: rdb}
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
accountID := int64(2001)
|
|
||||||
model1 := "claude-sonnet-4-20250514"
|
|
||||||
model2 := "gemini-2.5-pro"
|
|
||||||
|
|
||||||
// 对 model1 调用 3 次
|
|
||||||
for i := 0; i < 3; i++ {
|
|
||||||
_, err := cache.IncrModelCallCount(ctx, accountID, model1)
|
|
||||||
require.NoError(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 获取 model1 的负载
|
|
||||||
result1, err := cache.GetModelLoadBatch(ctx, []int64{accountID}, model1)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, int64(3), result1[accountID].CallCount)
|
|
||||||
|
|
||||||
// 获取 model2 的负载(应该为 0)
|
|
||||||
result2, err := cache.GetModelLoadBatch(ctx, []int64{accountID}, model2)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, int64(0), result2[accountID].CallCount)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============ 辅助函数测试 ============
|
|
||||||
|
|
||||||
func (s *GatewayCacheModelLoadSuite) TestModelLoadKey_Format() {
|
|
||||||
t := s.T()
|
|
||||||
|
|
||||||
key := modelLoadKey(123, "claude-sonnet-4")
|
|
||||||
require.Equal(t, "ag:model_load:123:claude-sonnet-4", key)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *GatewayCacheModelLoadSuite) TestModelLastUsedKey_Format() {
|
|
||||||
t := s.T()
|
|
||||||
|
|
||||||
key := modelLastUsedKey(456, "gemini-2.5-pro")
|
|
||||||
require.Equal(t, "ag:model_last_used:456:gemini-2.5-pro", key)
|
|
||||||
}
|
|
||||||
@@ -1004,10 +1004,6 @@ func (s *stubAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt
|
|||||||
return errors.New("not implemented")
|
return errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *stubAccountRepo) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope service.AntigravityQuotaScope, resetAt time.Time) error {
|
|
||||||
return errors.New("not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *stubAccountRepo) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
|
func (s *stubAccountRepo) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
|
||||||
return errors.New("not implemented")
|
return errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -50,7 +50,6 @@ type AccountRepository interface {
|
|||||||
ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error)
|
ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error)
|
||||||
|
|
||||||
SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error
|
SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error
|
||||||
SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error
|
|
||||||
SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error
|
SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error
|
||||||
SetOverloaded(ctx context.Context, id int64, until time.Time) error
|
SetOverloaded(ctx context.Context, id int64, until time.Time) error
|
||||||
SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error
|
SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error
|
||||||
|
|||||||
@@ -143,10 +143,6 @@ func (s *accountRepoStub) SetRateLimited(ctx context.Context, id int64, resetAt
|
|||||||
panic("unexpected SetRateLimited call")
|
panic("unexpected SetRateLimited call")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountRepoStub) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error {
|
|
||||||
panic("unexpected SetAntigravityQuotaScopeLimit call")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *accountRepoStub) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
|
func (s *accountRepoStub) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
|
||||||
panic("unexpected SetModelRateLimit call")
|
panic("unexpected SetModelRateLimit call")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package service
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"strconv"
|
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@@ -12,9 +11,6 @@ const (
|
|||||||
// anthropicSessionTTLSeconds Anthropic 会话缓存 TTL(5 分钟)
|
// anthropicSessionTTLSeconds Anthropic 会话缓存 TTL(5 分钟)
|
||||||
anthropicSessionTTLSeconds = 300
|
anthropicSessionTTLSeconds = 300
|
||||||
|
|
||||||
// anthropicTrieKeyPrefix Anthropic Trie 会话 key 前缀
|
|
||||||
anthropicTrieKeyPrefix = "anthropic:trie:"
|
|
||||||
|
|
||||||
// anthropicDigestSessionKeyPrefix Anthropic 摘要 fallback 会话 key 前缀
|
// anthropicDigestSessionKeyPrefix Anthropic 摘要 fallback 会话 key 前缀
|
||||||
anthropicDigestSessionKeyPrefix = "anthropic:digest:"
|
anthropicDigestSessionKeyPrefix = "anthropic:digest:"
|
||||||
)
|
)
|
||||||
@@ -68,12 +64,6 @@ func rolePrefix(role string) string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// BuildAnthropicTrieKey 构建 Anthropic Trie Redis key
|
|
||||||
// 格式: anthropic:trie:{groupID}:{prefixHash}
|
|
||||||
func BuildAnthropicTrieKey(groupID int64, prefixHash string) string {
|
|
||||||
return anthropicTrieKeyPrefix + strconv.FormatInt(groupID, 10) + ":" + prefixHash
|
|
||||||
}
|
|
||||||
|
|
||||||
// GenerateAnthropicDigestSessionKey 生成 Anthropic 摘要 fallback 的 sessionKey
|
// GenerateAnthropicDigestSessionKey 生成 Anthropic 摘要 fallback 的 sessionKey
|
||||||
// 组合 prefixHash 前 8 位 + uuid 前 8 位,确保不同会话产生不同的 sessionKey
|
// 组合 prefixHash 前 8 位 + uuid 前 8 位,确保不同会话产生不同的 sessionKey
|
||||||
func GenerateAnthropicDigestSessionKey(prefixHash, uuid string) string {
|
func GenerateAnthropicDigestSessionKey(prefixHash, uuid string) string {
|
||||||
|
|||||||
@@ -236,43 +236,6 @@ func TestBuildAnthropicDigestChain_Deterministic(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBuildAnthropicTrieKey(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
groupID int64
|
|
||||||
prefixHash string
|
|
||||||
want string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "normal",
|
|
||||||
groupID: 123,
|
|
||||||
prefixHash: "abcdef12",
|
|
||||||
want: "anthropic:trie:123:abcdef12",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "zero group",
|
|
||||||
groupID: 0,
|
|
||||||
prefixHash: "xyz",
|
|
||||||
want: "anthropic:trie:0:xyz",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "empty prefix",
|
|
||||||
groupID: 1,
|
|
||||||
prefixHash: "",
|
|
||||||
want: "anthropic:trie:1:",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
got := BuildAnthropicTrieKey(tt.groupID, tt.prefixHash)
|
|
||||||
if got != tt.want {
|
|
||||||
t.Errorf("BuildAnthropicTrieKey(%d, %q) = %q, want %q", tt.groupID, tt.prefixHash, got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGenerateAnthropicDigestSessionKey(t *testing.T) {
|
func TestGenerateAnthropicDigestSessionKey(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -4,17 +4,42 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// antigravityFailingWriter 模拟客户端断开连接的 gin.ResponseWriter
|
||||||
|
type antigravityFailingWriter struct {
|
||||||
|
gin.ResponseWriter
|
||||||
|
failAfter int // 允许成功写入的次数,之后所有写入返回错误
|
||||||
|
writes int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *antigravityFailingWriter) Write(p []byte) (int, error) {
|
||||||
|
if w.writes >= w.failAfter {
|
||||||
|
return 0, errors.New("write failed: client disconnected")
|
||||||
|
}
|
||||||
|
w.writes++
|
||||||
|
return w.ResponseWriter.Write(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
// newAntigravityTestService 创建用于流式测试的 AntigravityGatewayService
|
||||||
|
func newAntigravityTestService(cfg *config.Config) *AntigravityGatewayService {
|
||||||
|
return &AntigravityGatewayService{
|
||||||
|
settingService: &SettingService{cfg: cfg},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestStripSignatureSensitiveBlocksFromClaudeRequest(t *testing.T) {
|
func TestStripSignatureSensitiveBlocksFromClaudeRequest(t *testing.T) {
|
||||||
req := &antigravity.ClaudeRequest{
|
req := &antigravity.ClaudeRequest{
|
||||||
Model: "claude-sonnet-4-5",
|
Model: "claude-sonnet-4-5",
|
||||||
@@ -337,8 +362,8 @@ func TestAntigravityGatewayService_Forward_StickySessionForceCacheBilling(t *tes
|
|||||||
require.True(t, failoverErr.ForceCacheBilling, "ForceCacheBilling should be true for sticky session switch")
|
require.True(t, failoverErr.ForceCacheBilling, "ForceCacheBilling should be true for sticky session switch")
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling
|
// TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling verifies
|
||||||
// 验证:ForwardGemini 粘性会话切换时,UpstreamFailoverError.ForceCacheBilling 应为 true
|
// that ForwardGemini sets ForceCacheBilling=true for sticky session switch.
|
||||||
func TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling(t *testing.T) {
|
func TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
writer := httptest.NewRecorder()
|
writer := httptest.NewRecorder()
|
||||||
@@ -391,3 +416,438 @@ func TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling(
|
|||||||
require.Equal(t, http.StatusServiceUnavailable, failoverErr.StatusCode)
|
require.Equal(t, http.StatusServiceUnavailable, failoverErr.StatusCode)
|
||||||
require.True(t, failoverErr.ForceCacheBilling, "ForceCacheBilling should be true for sticky session switch")
|
require.True(t, failoverErr.ForceCacheBilling, "ForceCacheBilling should be true for sticky session switch")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// --- 流式 happy path 测试 ---
|
||||||
|
|
||||||
|
// TestStreamUpstreamResponse_NormalComplete
|
||||||
|
// 验证:正常流式转发完成时,数据正确透传、usage 正确收集、clientDisconnect=false
|
||||||
|
func TestStreamUpstreamResponse_NormalComplete(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
svc := newAntigravityTestService(&config.Config{
|
||||||
|
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
|
||||||
|
})
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||||
|
|
||||||
|
pr, pw := io.Pipe()
|
||||||
|
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer func() { _ = pw.Close() }()
|
||||||
|
fmt.Fprintln(pw, `event: message_start`)
|
||||||
|
fmt.Fprintln(pw, `data: {"type":"message_start","message":{"usage":{"input_tokens":10}}}`)
|
||||||
|
fmt.Fprintln(pw, "")
|
||||||
|
fmt.Fprintln(pw, `event: content_block_delta`)
|
||||||
|
fmt.Fprintln(pw, `data: {"type":"content_block_delta","delta":{"text":"hello"}}`)
|
||||||
|
fmt.Fprintln(pw, "")
|
||||||
|
fmt.Fprintln(pw, `event: message_delta`)
|
||||||
|
fmt.Fprintln(pw, `data: {"type":"message_delta","usage":{"output_tokens":5}}`)
|
||||||
|
fmt.Fprintln(pw, "")
|
||||||
|
}()
|
||||||
|
|
||||||
|
result := svc.streamUpstreamResponse(c, resp, time.Now())
|
||||||
|
_ = pr.Close()
|
||||||
|
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.False(t, result.clientDisconnect, "normal completion should not set clientDisconnect")
|
||||||
|
require.NotNil(t, result.usage)
|
||||||
|
require.Equal(t, 5, result.usage.OutputTokens, "should collect output_tokens from message_delta")
|
||||||
|
require.NotNil(t, result.firstTokenMs, "should record first token time")
|
||||||
|
|
||||||
|
// 验证数据被透传到客户端
|
||||||
|
body := rec.Body.String()
|
||||||
|
require.Contains(t, body, "event: message_start")
|
||||||
|
require.Contains(t, body, "content_block_delta")
|
||||||
|
require.Contains(t, body, "message_delta")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHandleGeminiStreamingResponse_NormalComplete
|
||||||
|
// 验证:正常 Gemini 流式转发,数据正确透传、usage 正确收集
|
||||||
|
func TestHandleGeminiStreamingResponse_NormalComplete(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
svc := newAntigravityTestService(&config.Config{
|
||||||
|
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
|
||||||
|
})
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||||
|
|
||||||
|
pr, pw := io.Pipe()
|
||||||
|
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer func() { _ = pw.Close() }()
|
||||||
|
// 第一个 chunk(部分内容)
|
||||||
|
fmt.Fprintln(pw, `data: {"candidates":[{"content":{"parts":[{"text":"Hello"}]}}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":3}}`)
|
||||||
|
fmt.Fprintln(pw, "")
|
||||||
|
// 第二个 chunk(最终内容+完整 usage)
|
||||||
|
fmt.Fprintln(pw, `data: {"candidates":[{"content":{"parts":[{"text":" world"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":8,"cachedContentTokenCount":2}}`)
|
||||||
|
fmt.Fprintln(pw, "")
|
||||||
|
}()
|
||||||
|
|
||||||
|
result, err := svc.handleGeminiStreamingResponse(c, resp, time.Now())
|
||||||
|
_ = pr.Close()
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.False(t, result.clientDisconnect, "normal completion should not set clientDisconnect")
|
||||||
|
require.NotNil(t, result.usage)
|
||||||
|
// Gemini usage: promptTokenCount=10, candidatesTokenCount=8, cachedContentTokenCount=2
|
||||||
|
// → InputTokens=10-2=8, OutputTokens=8, CacheReadInputTokens=2
|
||||||
|
require.Equal(t, 8, result.usage.InputTokens)
|
||||||
|
require.Equal(t, 8, result.usage.OutputTokens)
|
||||||
|
require.Equal(t, 2, result.usage.CacheReadInputTokens)
|
||||||
|
require.NotNil(t, result.firstTokenMs, "should record first token time")
|
||||||
|
|
||||||
|
// 验证数据被透传到客户端
|
||||||
|
body := rec.Body.String()
|
||||||
|
require.Contains(t, body, "Hello")
|
||||||
|
require.Contains(t, body, "world")
|
||||||
|
// 不应包含错误事件
|
||||||
|
require.NotContains(t, body, "event: error")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHandleClaudeStreamingResponse_NormalComplete
|
||||||
|
// 验证:正常 Claude 流式转发(Gemini→Claude 转换),数据正确转换并输出
|
||||||
|
func TestHandleClaudeStreamingResponse_NormalComplete(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
svc := newAntigravityTestService(&config.Config{
|
||||||
|
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
|
||||||
|
})
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||||
|
|
||||||
|
pr, pw := io.Pipe()
|
||||||
|
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer func() { _ = pw.Close() }()
|
||||||
|
// v1internal 包装格式:Gemini 数据嵌套在 "response" 字段下
|
||||||
|
// ProcessLine 先尝试反序列化为 V1InternalResponse,裸格式会导致 Response.UsageMetadata 为空
|
||||||
|
fmt.Fprintln(pw, `data: {"response":{"candidates":[{"content":{"parts":[{"text":"Hi there"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":5,"candidatesTokenCount":3}}}`)
|
||||||
|
fmt.Fprintln(pw, "")
|
||||||
|
}()
|
||||||
|
|
||||||
|
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5")
|
||||||
|
_ = pr.Close()
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.False(t, result.clientDisconnect, "normal completion should not set clientDisconnect")
|
||||||
|
require.NotNil(t, result.usage)
|
||||||
|
// Gemini→Claude 转换的 usage:promptTokenCount=5→InputTokens=5, candidatesTokenCount=3→OutputTokens=3
|
||||||
|
require.Equal(t, 5, result.usage.InputTokens)
|
||||||
|
require.Equal(t, 3, result.usage.OutputTokens)
|
||||||
|
require.NotNil(t, result.firstTokenMs, "should record first token time")
|
||||||
|
|
||||||
|
// 验证输出是 Claude SSE 格式(processor 会转换)
|
||||||
|
body := rec.Body.String()
|
||||||
|
require.Contains(t, body, "event: message_start", "should contain Claude message_start event")
|
||||||
|
require.Contains(t, body, "event: message_stop", "should contain Claude message_stop event")
|
||||||
|
// 不应包含错误事件
|
||||||
|
require.NotContains(t, body, "event: error")
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- 流式客户端断开检测测试 ---
|
||||||
|
|
||||||
|
// TestStreamUpstreamResponse_ClientDisconnectDrainsUsage
|
||||||
|
// 验证:客户端写入失败后,streamUpstreamResponse 继续读取上游以收集 usage
|
||||||
|
func TestStreamUpstreamResponse_ClientDisconnectDrainsUsage(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
svc := newAntigravityTestService(&config.Config{
|
||||||
|
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
|
||||||
|
})
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||||
|
c.Writer = &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
|
||||||
|
|
||||||
|
pr, pw := io.Pipe()
|
||||||
|
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer func() { _ = pw.Close() }()
|
||||||
|
fmt.Fprintln(pw, `event: message_start`)
|
||||||
|
fmt.Fprintln(pw, `data: {"type":"message_start","message":{"usage":{"input_tokens":10}}}`)
|
||||||
|
fmt.Fprintln(pw, "")
|
||||||
|
fmt.Fprintln(pw, `event: message_delta`)
|
||||||
|
fmt.Fprintln(pw, `data: {"type":"message_delta","usage":{"output_tokens":20}}`)
|
||||||
|
fmt.Fprintln(pw, "")
|
||||||
|
}()
|
||||||
|
|
||||||
|
result := svc.streamUpstreamResponse(c, resp, time.Now())
|
||||||
|
_ = pr.Close()
|
||||||
|
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.True(t, result.clientDisconnect)
|
||||||
|
require.NotNil(t, result.usage)
|
||||||
|
require.Equal(t, 20, result.usage.OutputTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestStreamUpstreamResponse_ContextCanceled
|
||||||
|
// 验证:context 取消时返回 usage 且标记 clientDisconnect
|
||||||
|
func TestStreamUpstreamResponse_ContextCanceled(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
svc := newAntigravityTestService(&config.Config{
|
||||||
|
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
|
||||||
|
})
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/", nil).WithContext(ctx)
|
||||||
|
|
||||||
|
resp := &http.Response{StatusCode: http.StatusOK, Body: cancelReadCloser{}, Header: http.Header{}}
|
||||||
|
|
||||||
|
result := svc.streamUpstreamResponse(c, resp, time.Now())
|
||||||
|
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.True(t, result.clientDisconnect)
|
||||||
|
require.NotContains(t, rec.Body.String(), "event: error")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestStreamUpstreamResponse_Timeout
|
||||||
|
// 验证:上游超时时返回已收集的 usage
|
||||||
|
func TestStreamUpstreamResponse_Timeout(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
svc := newAntigravityTestService(&config.Config{
|
||||||
|
Gateway: config.GatewayConfig{StreamDataIntervalTimeout: 1, MaxLineSize: defaultMaxLineSize},
|
||||||
|
})
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||||
|
|
||||||
|
pr, pw := io.Pipe()
|
||||||
|
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
|
||||||
|
|
||||||
|
result := svc.streamUpstreamResponse(c, resp, time.Now())
|
||||||
|
_ = pw.Close()
|
||||||
|
_ = pr.Close()
|
||||||
|
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.False(t, result.clientDisconnect)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestStreamUpstreamResponse_TimeoutAfterClientDisconnect
|
||||||
|
// 验证:客户端断开后上游超时,返回 usage 并标记 clientDisconnect
|
||||||
|
func TestStreamUpstreamResponse_TimeoutAfterClientDisconnect(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
svc := newAntigravityTestService(&config.Config{
|
||||||
|
Gateway: config.GatewayConfig{StreamDataIntervalTimeout: 1, MaxLineSize: defaultMaxLineSize},
|
||||||
|
})
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||||
|
c.Writer = &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
|
||||||
|
|
||||||
|
pr, pw := io.Pipe()
|
||||||
|
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
fmt.Fprintln(pw, `data: {"type":"message_start","message":{"usage":{"input_tokens":5}}}`)
|
||||||
|
fmt.Fprintln(pw, "")
|
||||||
|
// 不关闭 pw → 等待超时
|
||||||
|
}()
|
||||||
|
|
||||||
|
result := svc.streamUpstreamResponse(c, resp, time.Now())
|
||||||
|
_ = pw.Close()
|
||||||
|
_ = pr.Close()
|
||||||
|
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.True(t, result.clientDisconnect)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHandleGeminiStreamingResponse_ClientDisconnect
|
||||||
|
// 验证:Gemini 流式转发中客户端断开后继续 drain 上游
|
||||||
|
func TestHandleGeminiStreamingResponse_ClientDisconnect(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
svc := newAntigravityTestService(&config.Config{
|
||||||
|
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
|
||||||
|
})
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||||
|
c.Writer = &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
|
||||||
|
|
||||||
|
pr, pw := io.Pipe()
|
||||||
|
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer func() { _ = pw.Close() }()
|
||||||
|
fmt.Fprintln(pw, `data: {"candidates":[{"content":{"parts":[{"text":"hi"}]}}],"usageMetadata":{"promptTokenCount":5,"candidatesTokenCount":10}}`)
|
||||||
|
fmt.Fprintln(pw, "")
|
||||||
|
}()
|
||||||
|
|
||||||
|
result, err := svc.handleGeminiStreamingResponse(c, resp, time.Now())
|
||||||
|
_ = pr.Close()
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.True(t, result.clientDisconnect)
|
||||||
|
require.NotContains(t, rec.Body.String(), "write_failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHandleGeminiStreamingResponse_ContextCanceled
|
||||||
|
// 验证:context 取消时不注入错误事件
|
||||||
|
func TestHandleGeminiStreamingResponse_ContextCanceled(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
svc := newAntigravityTestService(&config.Config{
|
||||||
|
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
|
||||||
|
})
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/", nil).WithContext(ctx)
|
||||||
|
|
||||||
|
resp := &http.Response{StatusCode: http.StatusOK, Body: cancelReadCloser{}, Header: http.Header{}}
|
||||||
|
|
||||||
|
result, err := svc.handleGeminiStreamingResponse(c, resp, time.Now())
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.True(t, result.clientDisconnect)
|
||||||
|
require.NotContains(t, rec.Body.String(), "event: error")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHandleClaudeStreamingResponse_ClientDisconnect
|
||||||
|
// 验证:Claude 流式转发中客户端断开后继续 drain 上游
|
||||||
|
func TestHandleClaudeStreamingResponse_ClientDisconnect(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
svc := newAntigravityTestService(&config.Config{
|
||||||
|
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
|
||||||
|
})
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||||
|
c.Writer = &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
|
||||||
|
|
||||||
|
pr, pw := io.Pipe()
|
||||||
|
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer func() { _ = pw.Close() }()
|
||||||
|
// v1internal 包装格式
|
||||||
|
fmt.Fprintln(pw, `data: {"response":{"candidates":[{"content":{"parts":[{"text":"hello"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":8,"candidatesTokenCount":15}}}`)
|
||||||
|
fmt.Fprintln(pw, "")
|
||||||
|
}()
|
||||||
|
|
||||||
|
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5")
|
||||||
|
_ = pr.Close()
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.True(t, result.clientDisconnect)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHandleClaudeStreamingResponse_ContextCanceled
|
||||||
|
// 验证:context 取消时不注入错误事件
|
||||||
|
func TestHandleClaudeStreamingResponse_ContextCanceled(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
svc := newAntigravityTestService(&config.Config{
|
||||||
|
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
|
||||||
|
})
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/", nil).WithContext(ctx)
|
||||||
|
|
||||||
|
resp := &http.Response{StatusCode: http.StatusOK, Body: cancelReadCloser{}, Header: http.Header{}}
|
||||||
|
|
||||||
|
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5")
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.True(t, result.clientDisconnect)
|
||||||
|
require.NotContains(t, rec.Body.String(), "event: error")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestExtractSSEUsage 验证 extractSSEUsage 从 SSE data 行正确提取 usage
|
||||||
|
func TestExtractSSEUsage(t *testing.T) {
|
||||||
|
svc := &AntigravityGatewayService{}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
line string
|
||||||
|
expected ClaudeUsage
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "message_delta with output_tokens",
|
||||||
|
line: `data: {"type":"message_delta","usage":{"output_tokens":42}}`,
|
||||||
|
expected: ClaudeUsage{OutputTokens: 42},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non-data line ignored",
|
||||||
|
line: `event: message_start`,
|
||||||
|
expected: ClaudeUsage{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "top-level usage with all fields",
|
||||||
|
line: `data: {"usage":{"input_tokens":10,"output_tokens":20,"cache_read_input_tokens":5,"cache_creation_input_tokens":3}}`,
|
||||||
|
expected: ClaudeUsage{InputTokens: 10, OutputTokens: 20, CacheReadInputTokens: 5, CacheCreationInputTokens: 3},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
usage := &ClaudeUsage{}
|
||||||
|
svc.extractSSEUsage(tt.line, usage)
|
||||||
|
require.Equal(t, tt.expected, *usage)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestAntigravityClientWriter 验证 antigravityClientWriter 的断开检测
|
||||||
|
func TestAntigravityClientWriter(t *testing.T) {
|
||||||
|
t.Run("normal write succeeds", func(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
flusher, _ := c.Writer.(http.Flusher)
|
||||||
|
cw := newAntigravityClientWriter(c.Writer, flusher, "test")
|
||||||
|
|
||||||
|
ok := cw.Write([]byte("hello"))
|
||||||
|
require.True(t, ok)
|
||||||
|
require.False(t, cw.Disconnected())
|
||||||
|
require.Contains(t, rec.Body.String(), "hello")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("write failure marks disconnected", func(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
fw := &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
|
||||||
|
flusher, _ := c.Writer.(http.Flusher)
|
||||||
|
cw := newAntigravityClientWriter(fw, flusher, "test")
|
||||||
|
|
||||||
|
ok := cw.Write([]byte("hello"))
|
||||||
|
require.False(t, ok)
|
||||||
|
require.True(t, cw.Disconnected())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("subsequent writes are no-op", func(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
fw := &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
|
||||||
|
flusher, _ := c.Writer.(http.Flusher)
|
||||||
|
cw := newAntigravityClientWriter(fw, flusher, "test")
|
||||||
|
|
||||||
|
cw.Write([]byte("first"))
|
||||||
|
ok := cw.Fprintf("second %d", 2)
|
||||||
|
require.False(t, ok)
|
||||||
|
require.True(t, cw.Disconnected())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,63 +2,23 @@ package service
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"slices"
|
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
const antigravityQuotaScopesKey = "antigravity_quota_scopes"
|
|
||||||
|
|
||||||
// AntigravityQuotaScope 表示 Antigravity 的配额域
|
|
||||||
type AntigravityQuotaScope string
|
|
||||||
|
|
||||||
const (
|
|
||||||
AntigravityQuotaScopeClaude AntigravityQuotaScope = "claude"
|
|
||||||
AntigravityQuotaScopeGeminiText AntigravityQuotaScope = "gemini_text"
|
|
||||||
AntigravityQuotaScopeGeminiImage AntigravityQuotaScope = "gemini_image"
|
|
||||||
)
|
|
||||||
|
|
||||||
// IsScopeSupported 检查给定的 scope 是否在分组支持的 scope 列表中
|
|
||||||
func IsScopeSupported(supportedScopes []string, scope AntigravityQuotaScope) bool {
|
|
||||||
if len(supportedScopes) == 0 {
|
|
||||||
// 未配置时默认全部支持
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
supported := slices.Contains(supportedScopes, string(scope))
|
|
||||||
return supported
|
|
||||||
}
|
|
||||||
|
|
||||||
// ResolveAntigravityQuotaScope 根据模型名称解析配额域(导出版本)
|
|
||||||
func ResolveAntigravityQuotaScope(requestedModel string) (AntigravityQuotaScope, bool) {
|
|
||||||
return resolveAntigravityQuotaScope(requestedModel)
|
|
||||||
}
|
|
||||||
|
|
||||||
// resolveAntigravityQuotaScope 根据模型名称解析配额域
|
|
||||||
func resolveAntigravityQuotaScope(requestedModel string) (AntigravityQuotaScope, bool) {
|
|
||||||
model := normalizeAntigravityModelName(requestedModel)
|
|
||||||
if model == "" {
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
switch {
|
|
||||||
case strings.HasPrefix(model, "claude-"):
|
|
||||||
return AntigravityQuotaScopeClaude, true
|
|
||||||
case strings.HasPrefix(model, "gemini-"):
|
|
||||||
if isImageGenerationModel(model) {
|
|
||||||
return AntigravityQuotaScopeGeminiImage, true
|
|
||||||
}
|
|
||||||
return AntigravityQuotaScopeGeminiText, true
|
|
||||||
default:
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func normalizeAntigravityModelName(model string) string {
|
func normalizeAntigravityModelName(model string) string {
|
||||||
normalized := strings.ToLower(strings.TrimSpace(model))
|
normalized := strings.ToLower(strings.TrimSpace(model))
|
||||||
normalized = strings.TrimPrefix(normalized, "models/")
|
normalized = strings.TrimPrefix(normalized, "models/")
|
||||||
return normalized
|
return normalized
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsSchedulableForModel 结合 Antigravity 配额域限流判断是否可调度。
|
// resolveAntigravityModelKey 根据请求的模型名解析限流 key
|
||||||
|
// 返回空字符串表示无法解析
|
||||||
|
func resolveAntigravityModelKey(requestedModel string) string {
|
||||||
|
return normalizeAntigravityModelName(requestedModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsSchedulableForModel 结合模型级限流判断是否可调度。
|
||||||
// 保持旧签名以兼容既有调用方;默认使用 context.Background()。
|
// 保持旧签名以兼容既有调用方;默认使用 context.Background()。
|
||||||
func (a *Account) IsSchedulableForModel(requestedModel string) bool {
|
func (a *Account) IsSchedulableForModel(requestedModel string) bool {
|
||||||
return a.IsSchedulableForModelWithContext(context.Background(), requestedModel)
|
return a.IsSchedulableForModelWithContext(context.Background(), requestedModel)
|
||||||
@@ -74,107 +34,20 @@ func (a *Account) IsSchedulableForModelWithContext(ctx context.Context, requeste
|
|||||||
if a.isModelRateLimitedWithContext(ctx, requestedModel) {
|
if a.isModelRateLimitedWithContext(ctx, requestedModel) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if a.Platform != PlatformAntigravity {
|
return true
|
||||||
return true
|
|
||||||
}
|
|
||||||
scope, ok := resolveAntigravityQuotaScope(requestedModel)
|
|
||||||
if !ok {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
resetAt := a.antigravityQuotaScopeResetAt(scope)
|
|
||||||
if resetAt == nil {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
now := time.Now()
|
|
||||||
return !now.Before(*resetAt)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Account) antigravityQuotaScopeResetAt(scope AntigravityQuotaScope) *time.Time {
|
// GetRateLimitRemainingTime 获取限流剩余时间(模型级限流)
|
||||||
if a == nil || a.Extra == nil || scope == "" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
rawScopes, ok := a.Extra[antigravityQuotaScopesKey].(map[string]any)
|
|
||||||
if !ok {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
rawScope, ok := rawScopes[string(scope)].(map[string]any)
|
|
||||||
if !ok {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
resetAtRaw, ok := rawScope["rate_limit_reset_at"].(string)
|
|
||||||
if !ok || strings.TrimSpace(resetAtRaw) == "" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
resetAt, err := time.Parse(time.RFC3339, resetAtRaw)
|
|
||||||
if err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return &resetAt
|
|
||||||
}
|
|
||||||
|
|
||||||
var antigravityAllScopes = []AntigravityQuotaScope{
|
|
||||||
AntigravityQuotaScopeClaude,
|
|
||||||
AntigravityQuotaScopeGeminiText,
|
|
||||||
AntigravityQuotaScopeGeminiImage,
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Account) GetAntigravityScopeRateLimits() map[string]int64 {
|
|
||||||
if a == nil || a.Platform != PlatformAntigravity {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
now := time.Now()
|
|
||||||
result := make(map[string]int64)
|
|
||||||
for _, scope := range antigravityAllScopes {
|
|
||||||
resetAt := a.antigravityQuotaScopeResetAt(scope)
|
|
||||||
if resetAt != nil && now.Before(*resetAt) {
|
|
||||||
remainingSec := int64(time.Until(*resetAt).Seconds())
|
|
||||||
if remainingSec > 0 {
|
|
||||||
result[string(scope)] = remainingSec
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(result) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetQuotaScopeRateLimitRemainingTime 获取模型域限流剩余时间
|
|
||||||
// 返回 0 表示未限流或已过期
|
|
||||||
func (a *Account) GetQuotaScopeRateLimitRemainingTime(requestedModel string) time.Duration {
|
|
||||||
if a == nil || a.Platform != PlatformAntigravity {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
scope, ok := resolveAntigravityQuotaScope(requestedModel)
|
|
||||||
if !ok {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
resetAt := a.antigravityQuotaScopeResetAt(scope)
|
|
||||||
if resetAt == nil {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
if remaining := time.Until(*resetAt); remaining > 0 {
|
|
||||||
return remaining
|
|
||||||
}
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetRateLimitRemainingTime 获取限流剩余时间(模型限流和模型域限流取最大值)
|
|
||||||
// 返回 0 表示未限流或已过期
|
// 返回 0 表示未限流或已过期
|
||||||
func (a *Account) GetRateLimitRemainingTime(requestedModel string) time.Duration {
|
func (a *Account) GetRateLimitRemainingTime(requestedModel string) time.Duration {
|
||||||
return a.GetRateLimitRemainingTimeWithContext(context.Background(), requestedModel)
|
return a.GetRateLimitRemainingTimeWithContext(context.Background(), requestedModel)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetRateLimitRemainingTimeWithContext 获取限流剩余时间(模型限流和模型域限流取最大值)
|
// GetRateLimitRemainingTimeWithContext 获取限流剩余时间(模型级限流)
|
||||||
// 返回 0 表示未限流或已过期
|
// 返回 0 表示未限流或已过期
|
||||||
func (a *Account) GetRateLimitRemainingTimeWithContext(ctx context.Context, requestedModel string) time.Duration {
|
func (a *Account) GetRateLimitRemainingTimeWithContext(ctx context.Context, requestedModel string) time.Duration {
|
||||||
if a == nil {
|
if a == nil {
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
modelRemaining := a.GetModelRateLimitRemainingTimeWithContext(ctx, requestedModel)
|
return a.GetModelRateLimitRemainingTimeWithContext(ctx, requestedModel)
|
||||||
scopeRemaining := a.GetQuotaScopeRateLimitRemainingTime(requestedModel)
|
|
||||||
if modelRemaining > scopeRemaining {
|
|
||||||
return modelRemaining
|
|
||||||
}
|
|
||||||
return scopeRemaining
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -59,12 +59,6 @@ func (s *stubAntigravityUpstream) DoWithTLS(req *http.Request, proxyURL string,
|
|||||||
return s.Do(req, proxyURL, accountID, accountConcurrency)
|
return s.Do(req, proxyURL, accountID, accountConcurrency)
|
||||||
}
|
}
|
||||||
|
|
||||||
type scopeLimitCall struct {
|
|
||||||
accountID int64
|
|
||||||
scope AntigravityQuotaScope
|
|
||||||
resetAt time.Time
|
|
||||||
}
|
|
||||||
|
|
||||||
type rateLimitCall struct {
|
type rateLimitCall struct {
|
||||||
accountID int64
|
accountID int64
|
||||||
resetAt time.Time
|
resetAt time.Time
|
||||||
@@ -78,16 +72,10 @@ type modelRateLimitCall struct {
|
|||||||
|
|
||||||
type stubAntigravityAccountRepo struct {
|
type stubAntigravityAccountRepo struct {
|
||||||
AccountRepository
|
AccountRepository
|
||||||
scopeCalls []scopeLimitCall
|
|
||||||
rateCalls []rateLimitCall
|
rateCalls []rateLimitCall
|
||||||
modelRateLimitCalls []modelRateLimitCall
|
modelRateLimitCalls []modelRateLimitCall
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *stubAntigravityAccountRepo) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error {
|
|
||||||
s.scopeCalls = append(s.scopeCalls, scopeLimitCall{accountID: id, scope: scope, resetAt: resetAt})
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *stubAntigravityAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
func (s *stubAntigravityAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
||||||
s.rateCalls = append(s.rateCalls, rateLimitCall{accountID: id, resetAt: resetAt})
|
s.rateCalls = append(s.rateCalls, rateLimitCall{accountID: id, resetAt: resetAt})
|
||||||
return nil
|
return nil
|
||||||
@@ -131,10 +119,9 @@ func TestAntigravityRetryLoop_URLFallback_UsesLatestSuccess(t *testing.T) {
|
|||||||
accessToken: "token",
|
accessToken: "token",
|
||||||
action: "generateContent",
|
action: "generateContent",
|
||||||
body: []byte(`{"input":"test"}`),
|
body: []byte(`{"input":"test"}`),
|
||||||
quotaScope: AntigravityQuotaScopeClaude,
|
|
||||||
httpUpstream: upstream,
|
httpUpstream: upstream,
|
||||||
requestedModel: "claude-sonnet-4-5",
|
requestedModel: "claude-sonnet-4-5",
|
||||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||||
handleErrorCalled = true
|
handleErrorCalled = true
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
@@ -155,23 +142,6 @@ func TestAntigravityRetryLoop_URLFallback_UsesLatestSuccess(t *testing.T) {
|
|||||||
require.Equal(t, base2, available[0])
|
require.Equal(t, base2, available[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAntigravityHandleUpstreamError_UsesScopeLimit(t *testing.T) {
|
|
||||||
// 分区限流始终开启,不再支持通过环境变量关闭
|
|
||||||
repo := &stubAntigravityAccountRepo{}
|
|
||||||
svc := &AntigravityGatewayService{accountRepo: repo}
|
|
||||||
account := &Account{ID: 9, Name: "acc-9", Platform: PlatformAntigravity}
|
|
||||||
|
|
||||||
body := buildGeminiRateLimitBody("3s")
|
|
||||||
svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, AntigravityQuotaScopeClaude, 0, "", false)
|
|
||||||
|
|
||||||
require.Len(t, repo.scopeCalls, 1)
|
|
||||||
require.Empty(t, repo.rateCalls)
|
|
||||||
call := repo.scopeCalls[0]
|
|
||||||
require.Equal(t, account.ID, call.accountID)
|
|
||||||
require.Equal(t, AntigravityQuotaScopeClaude, call.scope)
|
|
||||||
require.WithinDuration(t, time.Now().Add(3*time.Second), call.resetAt, 2*time.Second)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestHandleUpstreamError_429_ModelRateLimit 测试 429 模型限流场景
|
// TestHandleUpstreamError_429_ModelRateLimit 测试 429 模型限流场景
|
||||||
func TestHandleUpstreamError_429_ModelRateLimit(t *testing.T) {
|
func TestHandleUpstreamError_429_ModelRateLimit(t *testing.T) {
|
||||||
repo := &stubAntigravityAccountRepo{}
|
repo := &stubAntigravityAccountRepo{}
|
||||||
@@ -189,7 +159,7 @@ func TestHandleUpstreamError_429_ModelRateLimit(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}`)
|
}`)
|
||||||
|
|
||||||
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, AntigravityQuotaScopeClaude, 0, "", false)
|
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, "claude-sonnet-4-5", 0, "", false)
|
||||||
|
|
||||||
// 应该触发模型限流
|
// 应该触发模型限流
|
||||||
require.NotNil(t, result)
|
require.NotNil(t, result)
|
||||||
@@ -200,22 +170,22 @@ func TestHandleUpstreamError_429_ModelRateLimit(t *testing.T) {
|
|||||||
require.Equal(t, "claude-sonnet-4-5", repo.modelRateLimitCalls[0].modelKey)
|
require.Equal(t, "claude-sonnet-4-5", repo.modelRateLimitCalls[0].modelKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestHandleUpstreamError_429_NonModelRateLimit 测试 429 非模型限流场景(走 scope 限流)
|
// TestHandleUpstreamError_429_NonModelRateLimit 测试 429 非模型限流场景(走模型级限流兜底)
|
||||||
func TestHandleUpstreamError_429_NonModelRateLimit(t *testing.T) {
|
func TestHandleUpstreamError_429_NonModelRateLimit(t *testing.T) {
|
||||||
repo := &stubAntigravityAccountRepo{}
|
repo := &stubAntigravityAccountRepo{}
|
||||||
svc := &AntigravityGatewayService{accountRepo: repo}
|
svc := &AntigravityGatewayService{accountRepo: repo}
|
||||||
account := &Account{ID: 2, Name: "acc-2", Platform: PlatformAntigravity}
|
account := &Account{ID: 2, Name: "acc-2", Platform: PlatformAntigravity}
|
||||||
|
|
||||||
// 429 + 普通限流响应(无 RATE_LIMIT_EXCEEDED reason)→ scope 限流
|
// 429 + 普通限流响应(无 RATE_LIMIT_EXCEEDED reason)→ 走模型级限流兜底
|
||||||
body := buildGeminiRateLimitBody("5s")
|
body := buildGeminiRateLimitBody("5s")
|
||||||
|
|
||||||
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, AntigravityQuotaScopeClaude, 0, "", false)
|
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, "claude-sonnet-4-5", 0, "", false)
|
||||||
|
|
||||||
// 不应该触发模型限流,应该走 scope 限流
|
// handleModelRateLimit 不会处理(因为没有 RATE_LIMIT_EXCEEDED),
|
||||||
|
// 但 429 兜底逻辑会使用 requestedModel 设置模型级限流
|
||||||
require.Nil(t, result)
|
require.Nil(t, result)
|
||||||
require.Empty(t, repo.modelRateLimitCalls)
|
require.Len(t, repo.modelRateLimitCalls, 1)
|
||||||
require.Len(t, repo.scopeCalls, 1)
|
require.Equal(t, "claude-sonnet-4-5", repo.modelRateLimitCalls[0].modelKey)
|
||||||
require.Equal(t, AntigravityQuotaScopeClaude, repo.scopeCalls[0].scope)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestHandleUpstreamError_503_ModelRateLimit 测试 503 模型限流场景
|
// TestHandleUpstreamError_503_ModelRateLimit 测试 503 模型限流场景
|
||||||
@@ -235,7 +205,7 @@ func TestHandleUpstreamError_503_ModelRateLimit(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}`)
|
}`)
|
||||||
|
|
||||||
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, AntigravityQuotaScopeGeminiText, 0, "", false)
|
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, "gemini-3-pro-high", 0, "", false)
|
||||||
|
|
||||||
// 应该触发模型限流
|
// 应该触发模型限流
|
||||||
require.NotNil(t, result)
|
require.NotNil(t, result)
|
||||||
@@ -263,12 +233,11 @@ func TestHandleUpstreamError_503_NonModelRateLimit(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}`)
|
}`)
|
||||||
|
|
||||||
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, AntigravityQuotaScopeGeminiText, 0, "", false)
|
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, "gemini-3-pro-high", 0, "", false)
|
||||||
|
|
||||||
// 503 非模型限流不应该做任何处理
|
// 503 非模型限流不应该做任何处理
|
||||||
require.Nil(t, result)
|
require.Nil(t, result)
|
||||||
require.Empty(t, repo.modelRateLimitCalls, "503 non-model rate limit should not trigger model rate limit")
|
require.Empty(t, repo.modelRateLimitCalls, "503 non-model rate limit should not trigger model rate limit")
|
||||||
require.Empty(t, repo.scopeCalls, "503 non-model rate limit should not trigger scope rate limit")
|
|
||||||
require.Empty(t, repo.rateCalls, "503 non-model rate limit should not trigger account rate limit")
|
require.Empty(t, repo.rateCalls, "503 non-model rate limit should not trigger account rate limit")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -281,12 +250,11 @@ func TestHandleUpstreamError_503_EmptyBody(t *testing.T) {
|
|||||||
// 503 + 空响应体 → 不做任何处理
|
// 503 + 空响应体 → 不做任何处理
|
||||||
body := []byte(`{}`)
|
body := []byte(`{}`)
|
||||||
|
|
||||||
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, AntigravityQuotaScopeGeminiText, 0, "", false)
|
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, "gemini-3-pro-high", 0, "", false)
|
||||||
|
|
||||||
// 503 空响应不应该做任何处理
|
// 503 空响应不应该做任何处理
|
||||||
require.Nil(t, result)
|
require.Nil(t, result)
|
||||||
require.Empty(t, repo.modelRateLimitCalls)
|
require.Empty(t, repo.modelRateLimitCalls)
|
||||||
require.Empty(t, repo.scopeCalls)
|
|
||||||
require.Empty(t, repo.rateCalls)
|
require.Empty(t, repo.rateCalls)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -307,15 +275,7 @@ func TestAccountIsSchedulableForModel_AntigravityRateLimits(t *testing.T) {
|
|||||||
require.False(t, account.IsSchedulableForModel("gemini-3-flash"))
|
require.False(t, account.IsSchedulableForModel("gemini-3-flash"))
|
||||||
|
|
||||||
account.RateLimitResetAt = nil
|
account.RateLimitResetAt = nil
|
||||||
account.Extra = map[string]any{
|
require.True(t, account.IsSchedulableForModel("claude-sonnet-4-5"))
|
||||||
antigravityQuotaScopesKey: map[string]any{
|
|
||||||
"claude": map[string]any{
|
|
||||||
"rate_limit_reset_at": future.Format(time.RFC3339),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
require.False(t, account.IsSchedulableForModel("claude-sonnet-4-5"))
|
|
||||||
require.True(t, account.IsSchedulableForModel("gemini-3-flash"))
|
require.True(t, account.IsSchedulableForModel("gemini-3-flash"))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -635,6 +595,7 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) {
|
|||||||
}`,
|
}`,
|
||||||
expectedShouldRetry: false,
|
expectedShouldRetry: false,
|
||||||
expectedShouldRateLimit: true,
|
expectedShouldRateLimit: true,
|
||||||
|
minWait: 7 * time.Second,
|
||||||
modelName: "gemini-pro",
|
modelName: "gemini-pro",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -652,6 +613,7 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) {
|
|||||||
}`,
|
}`,
|
||||||
expectedShouldRetry: false,
|
expectedShouldRetry: false,
|
||||||
expectedShouldRateLimit: true,
|
expectedShouldRateLimit: true,
|
||||||
|
minWait: 39 * time.Second,
|
||||||
modelName: "gemini-3-pro-high",
|
modelName: "gemini-3-pro-high",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -669,6 +631,7 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) {
|
|||||||
}`,
|
}`,
|
||||||
expectedShouldRetry: false,
|
expectedShouldRetry: false,
|
||||||
expectedShouldRateLimit: true,
|
expectedShouldRateLimit: true,
|
||||||
|
minWait: 30 * time.Second,
|
||||||
modelName: "gemini-2.5-flash",
|
modelName: "gemini-2.5-flash",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -686,6 +649,7 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) {
|
|||||||
}`,
|
}`,
|
||||||
expectedShouldRetry: false,
|
expectedShouldRetry: false,
|
||||||
expectedShouldRateLimit: true,
|
expectedShouldRateLimit: true,
|
||||||
|
minWait: 30 * time.Second,
|
||||||
modelName: "claude-sonnet-4-5",
|
modelName: "claude-sonnet-4-5",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -704,6 +668,11 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) {
|
|||||||
t.Errorf("wait = %v, want >= %v", wait, tt.minWait)
|
t.Errorf("wait = %v, want >= %v", wait, tt.minWait)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if shouldRateLimit && tt.minWait > 0 {
|
||||||
|
if wait < tt.minWait {
|
||||||
|
t.Errorf("rate limit wait = %v, want >= %v", wait, tt.minWait)
|
||||||
|
}
|
||||||
|
}
|
||||||
if (shouldRetry || shouldRateLimit) && model != tt.modelName {
|
if (shouldRetry || shouldRateLimit) && model != tt.modelName {
|
||||||
t.Errorf("modelName = %q, want %q", model, tt.modelName)
|
t.Errorf("modelName = %q, want %q", model, tt.modelName)
|
||||||
}
|
}
|
||||||
@@ -832,7 +801,7 @@ func TestAntigravityRetryLoop_PreCheck_SwitchesWhenRateLimited(t *testing.T) {
|
|||||||
requestedModel: "claude-sonnet-4-5",
|
requestedModel: "claude-sonnet-4-5",
|
||||||
httpUpstream: upstream,
|
httpUpstream: upstream,
|
||||||
isStickySession: true,
|
isStickySession: true,
|
||||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
@@ -875,7 +844,7 @@ func TestAntigravityRetryLoop_PreCheck_SwitchesWhenRemainingLong(t *testing.T) {
|
|||||||
requestedModel: "claude-sonnet-4-5",
|
requestedModel: "claude-sonnet-4-5",
|
||||||
httpUpstream: upstream,
|
httpUpstream: upstream,
|
||||||
isStickySession: true,
|
isStickySession: true,
|
||||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -75,7 +75,7 @@ func TestHandleSmartRetry_URLLevelRateLimit(t *testing.T) {
|
|||||||
accessToken: "token",
|
accessToken: "token",
|
||||||
action: "generateContent",
|
action: "generateContent",
|
||||||
body: []byte(`{"input":"test"}`),
|
body: []byte(`{"input":"test"}`),
|
||||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -127,7 +127,7 @@ func TestHandleSmartRetry_LongDelay_ReturnsSwitchError(t *testing.T) {
|
|||||||
body: []byte(`{"input":"test"}`),
|
body: []byte(`{"input":"test"}`),
|
||||||
accountRepo: repo,
|
accountRepo: repo,
|
||||||
isStickySession: true,
|
isStickySession: true,
|
||||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -194,7 +194,7 @@ func TestHandleSmartRetry_ShortDelay_SmartRetrySuccess(t *testing.T) {
|
|||||||
action: "generateContent",
|
action: "generateContent",
|
||||||
body: []byte(`{"input":"test"}`),
|
body: []byte(`{"input":"test"}`),
|
||||||
httpUpstream: upstream,
|
httpUpstream: upstream,
|
||||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -269,7 +269,7 @@ func TestHandleSmartRetry_ShortDelay_SmartRetryFailed_ReturnsSwitchError(t *test
|
|||||||
httpUpstream: upstream,
|
httpUpstream: upstream,
|
||||||
accountRepo: repo,
|
accountRepo: repo,
|
||||||
isStickySession: false,
|
isStickySession: false,
|
||||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -331,7 +331,7 @@ func TestHandleSmartRetry_503_ModelCapacityExhausted_ReturnsSwitchError(t *testi
|
|||||||
body: []byte(`{"input":"test"}`),
|
body: []byte(`{"input":"test"}`),
|
||||||
accountRepo: repo,
|
accountRepo: repo,
|
||||||
isStickySession: true,
|
isStickySession: true,
|
||||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -387,7 +387,7 @@ func TestHandleSmartRetry_NonAntigravityAccount_ContinuesDefaultLogic(t *testing
|
|||||||
accessToken: "token",
|
accessToken: "token",
|
||||||
action: "generateContent",
|
action: "generateContent",
|
||||||
body: []byte(`{"input":"test"}`),
|
body: []byte(`{"input":"test"}`),
|
||||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -436,7 +436,7 @@ func TestHandleSmartRetry_NonModelRateLimit_ContinuesDefaultLogic(t *testing.T)
|
|||||||
accessToken: "token",
|
accessToken: "token",
|
||||||
action: "generateContent",
|
action: "generateContent",
|
||||||
body: []byte(`{"input":"test"}`),
|
body: []byte(`{"input":"test"}`),
|
||||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -487,7 +487,7 @@ func TestHandleSmartRetry_ExactlyAtThreshold_ReturnsSwitchError(t *testing.T) {
|
|||||||
action: "generateContent",
|
action: "generateContent",
|
||||||
body: []byte(`{"input":"test"}`),
|
body: []byte(`{"input":"test"}`),
|
||||||
accountRepo: repo,
|
accountRepo: repo,
|
||||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -548,7 +548,7 @@ func TestAntigravityRetryLoop_HandleSmartRetry_SwitchError_Propagates(t *testing
|
|||||||
httpUpstream: upstream,
|
httpUpstream: upstream,
|
||||||
accountRepo: repo,
|
accountRepo: repo,
|
||||||
isStickySession: true,
|
isStickySession: true,
|
||||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
@@ -604,7 +604,7 @@ func TestHandleSmartRetry_NetworkError_ExhaustsRetry(t *testing.T) {
|
|||||||
body: []byte(`{"input":"test"}`),
|
body: []byte(`{"input":"test"}`),
|
||||||
httpUpstream: upstream,
|
httpUpstream: upstream,
|
||||||
accountRepo: repo,
|
accountRepo: repo,
|
||||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -662,7 +662,7 @@ func TestHandleSmartRetry_NoRetryDelay_UsesDefaultRateLimit(t *testing.T) {
|
|||||||
body: []byte(`{"input":"test"}`),
|
body: []byte(`{"input":"test"}`),
|
||||||
accountRepo: repo,
|
accountRepo: repo,
|
||||||
isStickySession: true,
|
isStickySession: true,
|
||||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -754,7 +754,7 @@ func TestHandleSmartRetry_ShortDelay_StickySession_FailedRetry_ClearsSession(t *
|
|||||||
isStickySession: true,
|
isStickySession: true,
|
||||||
groupID: 42,
|
groupID: 42,
|
||||||
sessionHash: "sticky-hash-abc",
|
sessionHash: "sticky-hash-abc",
|
||||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -842,7 +842,7 @@ func TestHandleSmartRetry_ShortDelay_NonStickySession_FailedRetry_NoDeleteSessio
|
|||||||
isStickySession: false,
|
isStickySession: false,
|
||||||
groupID: 42,
|
groupID: 42,
|
||||||
sessionHash: "", // 非粘性会话,sessionHash 为空
|
sessionHash: "", // 非粘性会话,sessionHash 为空
|
||||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -918,7 +918,7 @@ func TestHandleSmartRetry_ShortDelay_StickySession_FailedRetry_NilCache_NoPanic(
|
|||||||
isStickySession: true,
|
isStickySession: true,
|
||||||
groupID: 42,
|
groupID: 42,
|
||||||
sessionHash: "sticky-hash-nil-cache",
|
sessionHash: "sticky-hash-nil-cache",
|
||||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -983,7 +983,7 @@ func TestHandleSmartRetry_ShortDelay_StickySession_SuccessRetry_NoDeleteSession(
|
|||||||
isStickySession: true,
|
isStickySession: true,
|
||||||
groupID: 42,
|
groupID: 42,
|
||||||
sessionHash: "sticky-hash-success",
|
sessionHash: "sticky-hash-success",
|
||||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -1043,7 +1043,7 @@ func TestHandleSmartRetry_LongDelay_StickySession_NoDeleteInHandleSmartRetry(t *
|
|||||||
isStickySession: true,
|
isStickySession: true,
|
||||||
groupID: 42,
|
groupID: 42,
|
||||||
sessionHash: "sticky-hash-long-delay",
|
sessionHash: "sticky-hash-long-delay",
|
||||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -1108,7 +1108,7 @@ func TestHandleSmartRetry_ShortDelay_NetworkError_StickySession_ClearsSession(t
|
|||||||
isStickySession: true,
|
isStickySession: true,
|
||||||
groupID: 99,
|
groupID: 99,
|
||||||
sessionHash: "sticky-net-error",
|
sessionHash: "sticky-net-error",
|
||||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -1188,7 +1188,7 @@ func TestHandleSmartRetry_ShortDelay_503_StickySession_FailedRetry_ClearsSession
|
|||||||
isStickySession: true,
|
isStickySession: true,
|
||||||
groupID: 77,
|
groupID: 77,
|
||||||
sessionHash: "sticky-503-short",
|
sessionHash: "sticky-503-short",
|
||||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -1278,7 +1278,7 @@ func TestAntigravityRetryLoop_SmartRetryFailed_StickySession_SwitchErrorPropagat
|
|||||||
isStickySession: true,
|
isStickySession: true,
|
||||||
groupID: 55,
|
groupID: 55,
|
||||||
sessionHash: "sticky-loop-test",
|
sessionHash: "sticky-loop-test",
|
||||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
@@ -1296,4 +1296,4 @@ func TestAntigravityRetryLoop_SmartRetryFailed_StickySession_SwitchErrorPropagat
|
|||||||
require.Len(t, cache.deleteCalls, 1, "should clear sticky session in handleSmartRetry")
|
require.Len(t, cache.deleteCalls, 1, "should clear sticky session in handleSmartRetry")
|
||||||
require.Equal(t, int64(55), cache.deleteCalls[0].groupID)
|
require.Equal(t, int64(55), cache.deleteCalls[0].groupID)
|
||||||
require.Equal(t, "sticky-loop-test", cache.deleteCalls[0].sessionHash)
|
require.Equal(t, "sticky-loop-test", cache.deleteCalls[0].sessionHash)
|
||||||
}
|
}
|
||||||
|
|||||||
69
backend/internal/service/digest_session_store.go
Normal file
69
backend/internal/service/digest_session_store.go
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
gocache "github.com/patrickmn/go-cache"
|
||||||
|
)
|
||||||
|
|
||||||
|
// digestSessionTTL 摘要会话默认 TTL
|
||||||
|
const digestSessionTTL = 5 * time.Minute
|
||||||
|
|
||||||
|
// sessionEntry flat cache 条目
|
||||||
|
type sessionEntry struct {
|
||||||
|
uuid string
|
||||||
|
accountID int64
|
||||||
|
}
|
||||||
|
|
||||||
|
// DigestSessionStore 内存摘要会话存储(flat cache 实现)
|
||||||
|
// key: "{groupID}:{prefixHash}|{digestChain}" → *sessionEntry
|
||||||
|
type DigestSessionStore struct {
|
||||||
|
cache *gocache.Cache
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDigestSessionStore 创建内存摘要会话存储
|
||||||
|
func NewDigestSessionStore() *DigestSessionStore {
|
||||||
|
return &DigestSessionStore{
|
||||||
|
cache: gocache.New(digestSessionTTL, time.Minute),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save 保存摘要会话。oldDigestChain 为 Find 返回的 matchedChain,用于删旧 key。
|
||||||
|
func (s *DigestSessionStore) Save(groupID int64, prefixHash, digestChain, uuid string, accountID int64, oldDigestChain string) {
|
||||||
|
if digestChain == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ns := buildNS(groupID, prefixHash)
|
||||||
|
s.cache.Set(ns+digestChain, &sessionEntry{uuid: uuid, accountID: accountID}, gocache.DefaultExpiration)
|
||||||
|
if oldDigestChain != "" && oldDigestChain != digestChain {
|
||||||
|
s.cache.Delete(ns + oldDigestChain)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find 查找摘要会话,从完整 chain 逐段截断,返回最长匹配及对应 matchedChain。
|
||||||
|
func (s *DigestSessionStore) Find(groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, matchedChain string, found bool) {
|
||||||
|
if digestChain == "" {
|
||||||
|
return "", 0, "", false
|
||||||
|
}
|
||||||
|
ns := buildNS(groupID, prefixHash)
|
||||||
|
chain := digestChain
|
||||||
|
for {
|
||||||
|
if val, ok := s.cache.Get(ns + chain); ok {
|
||||||
|
if e, ok := val.(*sessionEntry); ok {
|
||||||
|
return e.uuid, e.accountID, chain, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
i := strings.LastIndex(chain, "-")
|
||||||
|
if i < 0 {
|
||||||
|
return "", 0, "", false
|
||||||
|
}
|
||||||
|
chain = chain[:i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildNS 构建 namespace 前缀
|
||||||
|
func buildNS(groupID int64, prefixHash string) string {
|
||||||
|
return strconv.FormatInt(groupID, 10) + ":" + prefixHash + "|"
|
||||||
|
}
|
||||||
312
backend/internal/service/digest_session_store_test.go
Normal file
312
backend/internal/service/digest_session_store_test.go
Normal file
@@ -0,0 +1,312 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
gocache "github.com/patrickmn/go-cache"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDigestSessionStore_SaveAndFind(t *testing.T) {
|
||||||
|
store := NewDigestSessionStore()
|
||||||
|
|
||||||
|
store.Save(1, "prefix", "s:a1-u:b2-m:c3", "uuid-1", 100, "")
|
||||||
|
|
||||||
|
uuid, accountID, _, found := store.Find(1, "prefix", "s:a1-u:b2-m:c3")
|
||||||
|
require.True(t, found)
|
||||||
|
assert.Equal(t, "uuid-1", uuid)
|
||||||
|
assert.Equal(t, int64(100), accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDigestSessionStore_PrefixMatch(t *testing.T) {
|
||||||
|
store := NewDigestSessionStore()
|
||||||
|
|
||||||
|
// 保存短链
|
||||||
|
store.Save(1, "prefix", "u:a-m:b", "uuid-short", 10, "")
|
||||||
|
|
||||||
|
// 用长链查找,应前缀匹配到短链
|
||||||
|
uuid, accountID, matchedChain, found := store.Find(1, "prefix", "u:a-m:b-u:c-m:d")
|
||||||
|
require.True(t, found)
|
||||||
|
assert.Equal(t, "uuid-short", uuid)
|
||||||
|
assert.Equal(t, int64(10), accountID)
|
||||||
|
assert.Equal(t, "u:a-m:b", matchedChain)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDigestSessionStore_LongestPrefixMatch(t *testing.T) {
|
||||||
|
store := NewDigestSessionStore()
|
||||||
|
|
||||||
|
store.Save(1, "prefix", "u:a", "uuid-1", 1, "")
|
||||||
|
store.Save(1, "prefix", "u:a-m:b", "uuid-2", 2, "")
|
||||||
|
store.Save(1, "prefix", "u:a-m:b-u:c", "uuid-3", 3, "")
|
||||||
|
|
||||||
|
// 应匹配最深的 "u:a-m:b-u:c"(从完整 chain 逐段截断,先命中最长的)
|
||||||
|
uuid, accountID, _, found := store.Find(1, "prefix", "u:a-m:b-u:c-m:d-u:e")
|
||||||
|
require.True(t, found)
|
||||||
|
assert.Equal(t, "uuid-3", uuid)
|
||||||
|
assert.Equal(t, int64(3), accountID)
|
||||||
|
|
||||||
|
// 查找中等长度,应匹配到 "u:a-m:b"
|
||||||
|
uuid, accountID, _, found = store.Find(1, "prefix", "u:a-m:b-u:x")
|
||||||
|
require.True(t, found)
|
||||||
|
assert.Equal(t, "uuid-2", uuid)
|
||||||
|
assert.Equal(t, int64(2), accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDigestSessionStore_SaveDeletesOldChain(t *testing.T) {
|
||||||
|
store := NewDigestSessionStore()
|
||||||
|
|
||||||
|
// 第一轮:保存 "u:a-m:b"
|
||||||
|
store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "")
|
||||||
|
|
||||||
|
// 第二轮:同一 uuid 保存更长的链,传入旧 chain
|
||||||
|
store.Save(1, "prefix", "u:a-m:b-u:c-m:d", "uuid-1", 100, "u:a-m:b")
|
||||||
|
|
||||||
|
// 旧链 "u:a-m:b" 应已被删除
|
||||||
|
_, _, _, found := store.Find(1, "prefix", "u:a-m:b")
|
||||||
|
assert.False(t, found, "old chain should be deleted")
|
||||||
|
|
||||||
|
// 新链应能找到
|
||||||
|
uuid, accountID, _, found := store.Find(1, "prefix", "u:a-m:b-u:c-m:d")
|
||||||
|
require.True(t, found)
|
||||||
|
assert.Equal(t, "uuid-1", uuid)
|
||||||
|
assert.Equal(t, int64(100), accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDigestSessionStore_DifferentSessionsNoInterference(t *testing.T) {
|
||||||
|
store := NewDigestSessionStore()
|
||||||
|
|
||||||
|
// 相同系统提示词,不同用户提示词
|
||||||
|
store.Save(1, "prefix", "s:sys-u:user1", "uuid-1", 100, "")
|
||||||
|
store.Save(1, "prefix", "s:sys-u:user2", "uuid-2", 200, "")
|
||||||
|
|
||||||
|
uuid, accountID, _, found := store.Find(1, "prefix", "s:sys-u:user1-m:reply1")
|
||||||
|
require.True(t, found)
|
||||||
|
assert.Equal(t, "uuid-1", uuid)
|
||||||
|
assert.Equal(t, int64(100), accountID)
|
||||||
|
|
||||||
|
uuid, accountID, _, found = store.Find(1, "prefix", "s:sys-u:user2-m:reply2")
|
||||||
|
require.True(t, found)
|
||||||
|
assert.Equal(t, "uuid-2", uuid)
|
||||||
|
assert.Equal(t, int64(200), accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDigestSessionStore_NoMatch(t *testing.T) {
|
||||||
|
store := NewDigestSessionStore()
|
||||||
|
|
||||||
|
store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "")
|
||||||
|
|
||||||
|
// 完全不同的 chain
|
||||||
|
_, _, _, found := store.Find(1, "prefix", "u:x-m:y")
|
||||||
|
assert.False(t, found)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDigestSessionStore_DifferentPrefixHash(t *testing.T) {
|
||||||
|
store := NewDigestSessionStore()
|
||||||
|
|
||||||
|
store.Save(1, "prefix1", "u:a-m:b", "uuid-1", 100, "")
|
||||||
|
|
||||||
|
// 不同 prefixHash 应隔离
|
||||||
|
_, _, _, found := store.Find(1, "prefix2", "u:a-m:b")
|
||||||
|
assert.False(t, found)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDigestSessionStore_DifferentGroupID(t *testing.T) {
|
||||||
|
store := NewDigestSessionStore()
|
||||||
|
|
||||||
|
store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "")
|
||||||
|
|
||||||
|
// 不同 groupID 应隔离
|
||||||
|
_, _, _, found := store.Find(2, "prefix", "u:a-m:b")
|
||||||
|
assert.False(t, found)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDigestSessionStore_EmptyDigestChain(t *testing.T) {
|
||||||
|
store := NewDigestSessionStore()
|
||||||
|
|
||||||
|
// 空链不应保存
|
||||||
|
store.Save(1, "prefix", "", "uuid-1", 100, "")
|
||||||
|
_, _, _, found := store.Find(1, "prefix", "")
|
||||||
|
assert.False(t, found)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDigestSessionStore_TTLExpiration(t *testing.T) {
|
||||||
|
store := &DigestSessionStore{
|
||||||
|
cache: gocache.New(100*time.Millisecond, 50*time.Millisecond),
|
||||||
|
}
|
||||||
|
|
||||||
|
store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "")
|
||||||
|
|
||||||
|
// 立即应该能找到
|
||||||
|
_, _, _, found := store.Find(1, "prefix", "u:a-m:b")
|
||||||
|
require.True(t, found)
|
||||||
|
|
||||||
|
// 等待过期 + 清理周期
|
||||||
|
time.Sleep(300 * time.Millisecond)
|
||||||
|
|
||||||
|
// 过期后应找不到
|
||||||
|
_, _, _, found = store.Find(1, "prefix", "u:a-m:b")
|
||||||
|
assert.False(t, found)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDigestSessionStore_ConcurrentSafety(t *testing.T) {
|
||||||
|
store := NewDigestSessionStore()
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
const goroutines = 50
|
||||||
|
const operations = 100
|
||||||
|
|
||||||
|
wg.Add(goroutines)
|
||||||
|
for g := 0; g < goroutines; g++ {
|
||||||
|
go func(id int) {
|
||||||
|
defer wg.Done()
|
||||||
|
prefix := fmt.Sprintf("prefix-%d", id%5)
|
||||||
|
for i := 0; i < operations; i++ {
|
||||||
|
chain := fmt.Sprintf("u:%d-m:%d", id, i)
|
||||||
|
uuid := fmt.Sprintf("uuid-%d-%d", id, i)
|
||||||
|
store.Save(1, prefix, chain, uuid, int64(id), "")
|
||||||
|
store.Find(1, prefix, chain)
|
||||||
|
}
|
||||||
|
}(g)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDigestSessionStore_MultipleSessions(t *testing.T) {
|
||||||
|
store := NewDigestSessionStore()
|
||||||
|
|
||||||
|
sessions := []struct {
|
||||||
|
chain string
|
||||||
|
uuid string
|
||||||
|
accountID int64
|
||||||
|
}{
|
||||||
|
{"u:session1", "uuid-1", 1},
|
||||||
|
{"u:session2-m:reply2", "uuid-2", 2},
|
||||||
|
{"u:session3-m:reply3-u:msg3", "uuid-3", 3},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, sess := range sessions {
|
||||||
|
store.Save(1, "prefix", sess.chain, sess.uuid, sess.accountID, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证每个会话都能正确查找
|
||||||
|
for _, sess := range sessions {
|
||||||
|
uuid, accountID, _, found := store.Find(1, "prefix", sess.chain)
|
||||||
|
require.True(t, found, "should find session: %s", sess.chain)
|
||||||
|
assert.Equal(t, sess.uuid, uuid)
|
||||||
|
assert.Equal(t, sess.accountID, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证继续对话的场景
|
||||||
|
uuid, accountID, _, found := store.Find(1, "prefix", "u:session2-m:reply2-u:newmsg")
|
||||||
|
require.True(t, found)
|
||||||
|
assert.Equal(t, "uuid-2", uuid)
|
||||||
|
assert.Equal(t, int64(2), accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDigestSessionStore_Performance1000Sessions(t *testing.T) {
|
||||||
|
store := NewDigestSessionStore()
|
||||||
|
|
||||||
|
// 插入 1000 个会话
|
||||||
|
for i := 0; i < 1000; i++ {
|
||||||
|
chain := fmt.Sprintf("s:sys-u:user%d-m:reply%d", i, i)
|
||||||
|
store.Save(1, "prefix", chain, fmt.Sprintf("uuid-%d", i), int64(i), "")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 查找性能测试
|
||||||
|
start := time.Now()
|
||||||
|
const lookups = 10000
|
||||||
|
for i := 0; i < lookups; i++ {
|
||||||
|
idx := i % 1000
|
||||||
|
chain := fmt.Sprintf("s:sys-u:user%d-m:reply%d-u:newmsg", idx, idx)
|
||||||
|
_, _, _, found := store.Find(1, "prefix", chain)
|
||||||
|
assert.True(t, found)
|
||||||
|
}
|
||||||
|
elapsed := time.Since(start)
|
||||||
|
t.Logf("%d lookups in %v (%.0f ns/op)", lookups, elapsed, float64(elapsed.Nanoseconds())/lookups)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDigestSessionStore_FindReturnsMatchedChain(t *testing.T) {
|
||||||
|
store := NewDigestSessionStore()
|
||||||
|
|
||||||
|
store.Save(1, "prefix", "u:a-m:b-u:c", "uuid-1", 100, "")
|
||||||
|
|
||||||
|
// 精确匹配
|
||||||
|
_, _, matchedChain, found := store.Find(1, "prefix", "u:a-m:b-u:c")
|
||||||
|
require.True(t, found)
|
||||||
|
assert.Equal(t, "u:a-m:b-u:c", matchedChain)
|
||||||
|
|
||||||
|
// 前缀匹配(截断后命中)
|
||||||
|
_, _, matchedChain, found = store.Find(1, "prefix", "u:a-m:b-u:c-m:d-u:e")
|
||||||
|
require.True(t, found)
|
||||||
|
assert.Equal(t, "u:a-m:b-u:c", matchedChain)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDigestSessionStore_CacheItemCountStable(t *testing.T) {
|
||||||
|
store := NewDigestSessionStore()
|
||||||
|
|
||||||
|
// 模拟 100 个独立会话,每个进行 10 轮对话
|
||||||
|
// 正确传递 oldDigestChain 时,每个会话始终只保留 1 个 key
|
||||||
|
for conv := 0; conv < 100; conv++ {
|
||||||
|
var prevMatchedChain string
|
||||||
|
for round := 0; round < 10; round++ {
|
||||||
|
chain := fmt.Sprintf("s:sys-u:user%d", conv)
|
||||||
|
for r := 0; r < round; r++ {
|
||||||
|
chain += fmt.Sprintf("-m:a%d-u:q%d", r, r+1)
|
||||||
|
}
|
||||||
|
uuid := fmt.Sprintf("uuid-conv%d", conv)
|
||||||
|
|
||||||
|
_, _, matched, _ := store.Find(1, "prefix", chain)
|
||||||
|
store.Save(1, "prefix", chain, uuid, int64(conv), matched)
|
||||||
|
prevMatchedChain = matched
|
||||||
|
_ = prevMatchedChain
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 100 个会话 × 1 key/会话 = 应该 ≤ 100 个 key
|
||||||
|
// 允许少量并发残留,但绝不能接近 100×10=1000
|
||||||
|
itemCount := store.cache.ItemCount()
|
||||||
|
assert.LessOrEqual(t, itemCount, 100, "cache should have at most 100 items (1 per conversation), got %d", itemCount)
|
||||||
|
t.Logf("Cache item count after 100 conversations × 10 rounds: %d", itemCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDigestSessionStore_TTLPreventsUnboundedGrowth(t *testing.T) {
|
||||||
|
// 使用极短 TTL 验证大量写入后 cache 能被清理
|
||||||
|
store := &DigestSessionStore{
|
||||||
|
cache: gocache.New(100*time.Millisecond, 50*time.Millisecond),
|
||||||
|
}
|
||||||
|
|
||||||
|
// 插入 500 个不同的 key(无 oldDigestChain,模拟最坏场景:全是新会话首轮)
|
||||||
|
for i := 0; i < 500; i++ {
|
||||||
|
chain := fmt.Sprintf("u:user%d", i)
|
||||||
|
store.Save(1, "prefix", chain, fmt.Sprintf("uuid-%d", i), int64(i), "")
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, 500, store.cache.ItemCount())
|
||||||
|
|
||||||
|
// 等待 TTL + 清理周期
|
||||||
|
time.Sleep(300 * time.Millisecond)
|
||||||
|
|
||||||
|
assert.Equal(t, 0, store.cache.ItemCount(), "all items should be expired and cleaned up")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDigestSessionStore_SaveSameChainNoDelete(t *testing.T) {
|
||||||
|
store := NewDigestSessionStore()
|
||||||
|
|
||||||
|
// 保存 chain
|
||||||
|
store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "")
|
||||||
|
|
||||||
|
// 用户重发相同消息:oldDigestChain == digestChain,不应删掉刚设置的 key
|
||||||
|
store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "u:a-m:b")
|
||||||
|
|
||||||
|
// 仍然能找到
|
||||||
|
uuid, accountID, _, found := store.Find(1, "prefix", "u:a-m:b")
|
||||||
|
require.True(t, found)
|
||||||
|
assert.Equal(t, "uuid-1", uuid)
|
||||||
|
assert.Equal(t, int64(100), accountID)
|
||||||
|
}
|
||||||
366
backend/internal/service/error_policy_integration_test.go
Normal file
366
backend/internal/service/error_policy_integration_test.go
Normal file
@@ -0,0 +1,366 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Mocks (scoped to this file by naming convention)
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// epFixedUpstream returns a fixed response for every request.
|
||||||
|
type epFixedUpstream struct {
|
||||||
|
statusCode int
|
||||||
|
body string
|
||||||
|
calls int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *epFixedUpstream) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
|
||||||
|
u.calls++
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: u.statusCode,
|
||||||
|
Header: http.Header{},
|
||||||
|
Body: io.NopCloser(strings.NewReader(u.body)),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *epFixedUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) {
|
||||||
|
return u.Do(req, proxyURL, accountID, accountConcurrency)
|
||||||
|
}
|
||||||
|
|
||||||
|
// epAccountRepo records SetTempUnschedulable / SetError calls.
|
||||||
|
type epAccountRepo struct {
|
||||||
|
mockAccountRepoForGemini
|
||||||
|
tempCalls int
|
||||||
|
setErrCalls int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *epAccountRepo) SetTempUnschedulable(_ context.Context, _ int64, _ time.Time, _ string) error {
|
||||||
|
r.tempCalls++
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *epAccountRepo) SetError(_ context.Context, _ int64, _ string) error {
|
||||||
|
r.setErrCalls++
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Helpers
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func saveAndSetBaseURLs(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
oldBaseURLs := append([]string(nil), antigravity.BaseURLs...)
|
||||||
|
oldAvail := antigravity.DefaultURLAvailability
|
||||||
|
antigravity.BaseURLs = []string{"https://ep-test.example"}
|
||||||
|
antigravity.DefaultURLAvailability = antigravity.NewURLAvailability(time.Minute)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
antigravity.BaseURLs = oldBaseURLs
|
||||||
|
antigravity.DefaultURLAvailability = oldAvail
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRetryParams(account *Account, upstream HTTPUpstream, handleError func(context.Context, string, *Account, int, http.Header, []byte, string, int64, string, bool) *handleModelRateLimitResult) antigravityRetryLoopParams {
|
||||||
|
return antigravityRetryLoopParams{
|
||||||
|
ctx: context.Background(),
|
||||||
|
prefix: "[ep-test]",
|
||||||
|
account: account,
|
||||||
|
accessToken: "token",
|
||||||
|
action: "generateContent",
|
||||||
|
body: []byte(`{"input":"test"}`),
|
||||||
|
httpUpstream: upstream,
|
||||||
|
requestedModel: "claude-sonnet-4-5",
|
||||||
|
handleError: handleError,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// TestRetryLoop_ErrorPolicy_CustomErrorCodes
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestRetryLoop_ErrorPolicy_CustomErrorCodes(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
upstreamStatus int
|
||||||
|
upstreamBody string
|
||||||
|
customCodes []any
|
||||||
|
expectHandleError int
|
||||||
|
expectUpstream int
|
||||||
|
expectStatusCode int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "429_in_custom_codes_matched",
|
||||||
|
upstreamStatus: 429,
|
||||||
|
upstreamBody: `{"error":"rate limited"}`,
|
||||||
|
customCodes: []any{float64(429)},
|
||||||
|
expectHandleError: 1,
|
||||||
|
expectUpstream: 1,
|
||||||
|
expectStatusCode: 429,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "429_not_in_custom_codes_skipped",
|
||||||
|
upstreamStatus: 429,
|
||||||
|
upstreamBody: `{"error":"rate limited"}`,
|
||||||
|
customCodes: []any{float64(500)},
|
||||||
|
expectHandleError: 0,
|
||||||
|
expectUpstream: 1,
|
||||||
|
expectStatusCode: 429,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "500_in_custom_codes_matched",
|
||||||
|
upstreamStatus: 500,
|
||||||
|
upstreamBody: `{"error":"internal"}`,
|
||||||
|
customCodes: []any{float64(500)},
|
||||||
|
expectHandleError: 1,
|
||||||
|
expectUpstream: 1,
|
||||||
|
expectStatusCode: 500,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "500_not_in_custom_codes_skipped",
|
||||||
|
upstreamStatus: 500,
|
||||||
|
upstreamBody: `{"error":"internal"}`,
|
||||||
|
customCodes: []any{float64(429)},
|
||||||
|
expectHandleError: 0,
|
||||||
|
expectUpstream: 1,
|
||||||
|
expectStatusCode: 500,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
saveAndSetBaseURLs(t)
|
||||||
|
|
||||||
|
upstream := &epFixedUpstream{statusCode: tt.upstreamStatus, body: tt.upstreamBody}
|
||||||
|
repo := &epAccountRepo{}
|
||||||
|
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||||
|
|
||||||
|
account := &Account{
|
||||||
|
ID: 100,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
Schedulable: true,
|
||||||
|
Status: StatusActive,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"custom_error_codes_enabled": true,
|
||||||
|
"custom_error_codes": tt.customCodes,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := &AntigravityGatewayService{rateLimitService: rlSvc}
|
||||||
|
|
||||||
|
var handleErrorCount int
|
||||||
|
p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult {
|
||||||
|
handleErrorCount++
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
result, err := svc.antigravityRetryLoop(p)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.NotNil(t, result.resp)
|
||||||
|
defer func() { _ = result.resp.Body.Close() }()
|
||||||
|
|
||||||
|
require.Equal(t, tt.expectStatusCode, result.resp.StatusCode)
|
||||||
|
require.Equal(t, tt.expectHandleError, handleErrorCount, "handleError call count")
|
||||||
|
require.Equal(t, tt.expectUpstream, upstream.calls, "upstream call count")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// TestRetryLoop_ErrorPolicy_TempUnschedulable
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestRetryLoop_ErrorPolicy_TempUnschedulable(t *testing.T) {
|
||||||
|
tempRulesAccount := func(rules []any) *Account {
|
||||||
|
return &Account{
|
||||||
|
ID: 200,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
Schedulable: true,
|
||||||
|
Status: StatusActive,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"temp_unschedulable_enabled": true,
|
||||||
|
"temp_unschedulable_rules": rules,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
overloadedRule := map[string]any{
|
||||||
|
"error_code": float64(503),
|
||||||
|
"keywords": []any{"overloaded"},
|
||||||
|
"duration_minutes": float64(10),
|
||||||
|
}
|
||||||
|
|
||||||
|
rateLimitRule := map[string]any{
|
||||||
|
"error_code": float64(429),
|
||||||
|
"keywords": []any{"rate limited keyword"},
|
||||||
|
"duration_minutes": float64(5),
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("503_overloaded_matches_rule", func(t *testing.T) {
|
||||||
|
saveAndSetBaseURLs(t)
|
||||||
|
|
||||||
|
upstream := &epFixedUpstream{statusCode: 503, body: `overloaded`}
|
||||||
|
repo := &epAccountRepo{}
|
||||||
|
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||||
|
svc := &AntigravityGatewayService{rateLimitService: rlSvc}
|
||||||
|
|
||||||
|
account := tempRulesAccount([]any{overloadedRule})
|
||||||
|
p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult {
|
||||||
|
t.Error("handleError should not be called for temp unschedulable")
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
result, err := svc.antigravityRetryLoop(p)
|
||||||
|
|
||||||
|
require.Nil(t, result)
|
||||||
|
var switchErr *AntigravityAccountSwitchError
|
||||||
|
require.ErrorAs(t, err, &switchErr)
|
||||||
|
require.Equal(t, account.ID, switchErr.OriginalAccountID)
|
||||||
|
require.Equal(t, 1, upstream.calls, "should not retry")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("429_rate_limited_keyword_matches_rule", func(t *testing.T) {
|
||||||
|
saveAndSetBaseURLs(t)
|
||||||
|
|
||||||
|
upstream := &epFixedUpstream{statusCode: 429, body: `rate limited keyword`}
|
||||||
|
repo := &epAccountRepo{}
|
||||||
|
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||||
|
svc := &AntigravityGatewayService{rateLimitService: rlSvc}
|
||||||
|
|
||||||
|
account := tempRulesAccount([]any{rateLimitRule})
|
||||||
|
p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult {
|
||||||
|
t.Error("handleError should not be called for temp unschedulable")
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
result, err := svc.antigravityRetryLoop(p)
|
||||||
|
|
||||||
|
require.Nil(t, result)
|
||||||
|
var switchErr *AntigravityAccountSwitchError
|
||||||
|
require.ErrorAs(t, err, &switchErr)
|
||||||
|
require.Equal(t, account.ID, switchErr.OriginalAccountID)
|
||||||
|
require.Equal(t, 1, upstream.calls, "should not retry")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("503_body_no_match_continues_default_retry", func(t *testing.T) {
|
||||||
|
saveAndSetBaseURLs(t)
|
||||||
|
|
||||||
|
upstream := &epFixedUpstream{statusCode: 503, body: `random`}
|
||||||
|
repo := &epAccountRepo{}
|
||||||
|
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||||
|
svc := &AntigravityGatewayService{rateLimitService: rlSvc}
|
||||||
|
|
||||||
|
account := tempRulesAccount([]any{overloadedRule})
|
||||||
|
|
||||||
|
// Use a short-lived context: the backoff sleep (~1s) will be
|
||||||
|
// interrupted, proving the code entered the default retry path
|
||||||
|
// instead of breaking early via error policy.
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult {
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
p.ctx = ctx
|
||||||
|
|
||||||
|
result, err := svc.antigravityRetryLoop(p)
|
||||||
|
|
||||||
|
// Context cancellation during backoff proves default retry was entered
|
||||||
|
require.Nil(t, result)
|
||||||
|
require.ErrorIs(t, err, context.DeadlineExceeded)
|
||||||
|
require.GreaterOrEqual(t, upstream.calls, 1, "should have called upstream at least once")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// TestRetryLoop_ErrorPolicy_NilRateLimitService
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestRetryLoop_ErrorPolicy_NilRateLimitService(t *testing.T) {
|
||||||
|
saveAndSetBaseURLs(t)
|
||||||
|
|
||||||
|
upstream := &epFixedUpstream{statusCode: 429, body: `{"error":"rate limited"}`}
|
||||||
|
// rateLimitService is nil — must not panic
|
||||||
|
svc := &AntigravityGatewayService{rateLimitService: nil}
|
||||||
|
|
||||||
|
account := &Account{
|
||||||
|
ID: 300,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
Schedulable: true,
|
||||||
|
Status: StatusActive,
|
||||||
|
Concurrency: 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult {
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
p.ctx = ctx
|
||||||
|
|
||||||
|
// Should not panic; enters the default retry path (eventually times out)
|
||||||
|
result, err := svc.antigravityRetryLoop(p)
|
||||||
|
|
||||||
|
require.Nil(t, result)
|
||||||
|
require.ErrorIs(t, err, context.DeadlineExceeded)
|
||||||
|
require.GreaterOrEqual(t, upstream.calls, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// TestRetryLoop_ErrorPolicy_NoPolicy_OriginalBehavior
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestRetryLoop_ErrorPolicy_NoPolicy_OriginalBehavior(t *testing.T) {
|
||||||
|
saveAndSetBaseURLs(t)
|
||||||
|
|
||||||
|
upstream := &epFixedUpstream{statusCode: 429, body: `{"error":"rate limited"}`}
|
||||||
|
repo := &epAccountRepo{}
|
||||||
|
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||||
|
svc := &AntigravityGatewayService{rateLimitService: rlSvc}
|
||||||
|
|
||||||
|
// Plain OAuth account with no error policy configured
|
||||||
|
account := &Account{
|
||||||
|
ID: 400,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
Schedulable: true,
|
||||||
|
Status: StatusActive,
|
||||||
|
Concurrency: 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
var handleErrorCount int
|
||||||
|
p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult {
|
||||||
|
handleErrorCount++
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
result, err := svc.antigravityRetryLoop(p)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.NotNil(t, result.resp)
|
||||||
|
defer func() { _ = result.resp.Body.Close() }()
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusTooManyRequests, result.resp.StatusCode)
|
||||||
|
require.Equal(t, antigravityMaxRetries, upstream.calls, "should exhaust all retries")
|
||||||
|
require.Equal(t, 1, handleErrorCount, "handleError should be called once after retries exhausted")
|
||||||
|
}
|
||||||
289
backend/internal/service/error_policy_test.go
Normal file
289
backend/internal/service/error_policy_test.go
Normal file
@@ -0,0 +1,289 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// TestCheckErrorPolicy — 6 table-driven cases for the pure logic function
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestCheckErrorPolicy(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
account *Account
|
||||||
|
statusCode int
|
||||||
|
body []byte
|
||||||
|
expected ErrorPolicyResult
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no_policy_oauth_returns_none",
|
||||||
|
account: &Account{
|
||||||
|
ID: 1,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
// no custom error codes, no temp rules
|
||||||
|
},
|
||||||
|
statusCode: 500,
|
||||||
|
body: []byte(`"error"`),
|
||||||
|
expected: ErrorPolicyNone,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "custom_error_codes_hit_returns_matched",
|
||||||
|
account: &Account{
|
||||||
|
ID: 2,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"custom_error_codes_enabled": true,
|
||||||
|
"custom_error_codes": []any{float64(429), float64(500)},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
statusCode: 500,
|
||||||
|
body: []byte(`"error"`),
|
||||||
|
expected: ErrorPolicyMatched,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "custom_error_codes_miss_returns_skipped",
|
||||||
|
account: &Account{
|
||||||
|
ID: 3,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"custom_error_codes_enabled": true,
|
||||||
|
"custom_error_codes": []any{float64(429), float64(500)},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
statusCode: 503,
|
||||||
|
body: []byte(`"error"`),
|
||||||
|
expected: ErrorPolicySkipped,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "temp_unschedulable_hit_returns_temp_unscheduled",
|
||||||
|
account: &Account{
|
||||||
|
ID: 4,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"temp_unschedulable_enabled": true,
|
||||||
|
"temp_unschedulable_rules": []any{
|
||||||
|
map[string]any{
|
||||||
|
"error_code": float64(503),
|
||||||
|
"keywords": []any{"overloaded"},
|
||||||
|
"duration_minutes": float64(10),
|
||||||
|
"description": "overloaded rule",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
statusCode: 503,
|
||||||
|
body: []byte(`overloaded service`),
|
||||||
|
expected: ErrorPolicyTempUnscheduled,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "temp_unschedulable_body_miss_returns_none",
|
||||||
|
account: &Account{
|
||||||
|
ID: 5,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"temp_unschedulable_enabled": true,
|
||||||
|
"temp_unschedulable_rules": []any{
|
||||||
|
map[string]any{
|
||||||
|
"error_code": float64(503),
|
||||||
|
"keywords": []any{"overloaded"},
|
||||||
|
"duration_minutes": float64(10),
|
||||||
|
"description": "overloaded rule",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
statusCode: 503,
|
||||||
|
body: []byte(`random msg`),
|
||||||
|
expected: ErrorPolicyNone,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "custom_error_codes_override_temp_unschedulable",
|
||||||
|
account: &Account{
|
||||||
|
ID: 6,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"custom_error_codes_enabled": true,
|
||||||
|
"custom_error_codes": []any{float64(503)},
|
||||||
|
"temp_unschedulable_enabled": true,
|
||||||
|
"temp_unschedulable_rules": []any{
|
||||||
|
map[string]any{
|
||||||
|
"error_code": float64(503),
|
||||||
|
"keywords": []any{"overloaded"},
|
||||||
|
"duration_minutes": float64(10),
|
||||||
|
"description": "overloaded rule",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
statusCode: 503,
|
||||||
|
body: []byte(`overloaded`),
|
||||||
|
expected: ErrorPolicyMatched, // custom codes take precedence
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
repo := &errorPolicyRepoStub{}
|
||||||
|
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||||
|
|
||||||
|
result := svc.CheckErrorPolicy(context.Background(), tt.account, tt.statusCode, tt.body)
|
||||||
|
require.Equal(t, tt.expected, result, "unexpected ErrorPolicyResult")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// TestApplyErrorPolicy — 4 table-driven cases for the wrapper method
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestApplyErrorPolicy(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
account *Account
|
||||||
|
statusCode int
|
||||||
|
body []byte
|
||||||
|
expectedHandled bool
|
||||||
|
expectedSwitchErr bool // expect *AntigravityAccountSwitchError
|
||||||
|
handleErrorCalls int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "none_not_handled",
|
||||||
|
account: &Account{
|
||||||
|
ID: 10,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
},
|
||||||
|
statusCode: 500,
|
||||||
|
body: []byte(`"error"`),
|
||||||
|
expectedHandled: false,
|
||||||
|
handleErrorCalls: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "skipped_handled_no_handleError",
|
||||||
|
account: &Account{
|
||||||
|
ID: 11,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"custom_error_codes_enabled": true,
|
||||||
|
"custom_error_codes": []any{float64(429)},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
statusCode: 500, // not in custom codes
|
||||||
|
body: []byte(`"error"`),
|
||||||
|
expectedHandled: true,
|
||||||
|
handleErrorCalls: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "matched_handled_calls_handleError",
|
||||||
|
account: &Account{
|
||||||
|
ID: 12,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"custom_error_codes_enabled": true,
|
||||||
|
"custom_error_codes": []any{float64(500)},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
statusCode: 500,
|
||||||
|
body: []byte(`"error"`),
|
||||||
|
expectedHandled: true,
|
||||||
|
handleErrorCalls: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "temp_unscheduled_returns_switch_error",
|
||||||
|
account: &Account{
|
||||||
|
ID: 13,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"temp_unschedulable_enabled": true,
|
||||||
|
"temp_unschedulable_rules": []any{
|
||||||
|
map[string]any{
|
||||||
|
"error_code": float64(503),
|
||||||
|
"keywords": []any{"overloaded"},
|
||||||
|
"duration_minutes": float64(10),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
statusCode: 503,
|
||||||
|
body: []byte(`overloaded`),
|
||||||
|
expectedHandled: true,
|
||||||
|
expectedSwitchErr: true,
|
||||||
|
handleErrorCalls: 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
repo := &errorPolicyRepoStub{}
|
||||||
|
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||||
|
svc := &AntigravityGatewayService{
|
||||||
|
rateLimitService: rlSvc,
|
||||||
|
}
|
||||||
|
|
||||||
|
var handleErrorCount int
|
||||||
|
p := antigravityRetryLoopParams{
|
||||||
|
ctx: context.Background(),
|
||||||
|
prefix: "[test]",
|
||||||
|
account: tt.account,
|
||||||
|
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||||
|
handleErrorCount++
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
isStickySession: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
handled, retErr := svc.applyErrorPolicy(p, tt.statusCode, http.Header{}, tt.body)
|
||||||
|
|
||||||
|
require.Equal(t, tt.expectedHandled, handled, "handled mismatch")
|
||||||
|
require.Equal(t, tt.handleErrorCalls, handleErrorCount, "handleError call count mismatch")
|
||||||
|
|
||||||
|
if tt.expectedSwitchErr {
|
||||||
|
var switchErr *AntigravityAccountSwitchError
|
||||||
|
require.ErrorAs(t, retErr, &switchErr)
|
||||||
|
require.Equal(t, tt.account.ID, switchErr.OriginalAccountID)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, retErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// errorPolicyRepoStub — minimal AccountRepository stub for error policy tests
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
type errorPolicyRepoStub struct {
|
||||||
|
mockAccountRepoForGemini
|
||||||
|
tempCalls int
|
||||||
|
setErrCalls int
|
||||||
|
lastErrorMsg string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *errorPolicyRepoStub) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
|
||||||
|
r.tempCalls++
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *errorPolicyRepoStub) SetError(ctx context.Context, id int64, errorMsg string) error {
|
||||||
|
r.setErrCalls++
|
||||||
|
r.lastErrorMsg = errorMsg
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -142,9 +142,6 @@ func (m *mockAccountRepoForPlatform) ListSchedulableByGroupIDAndPlatforms(ctx co
|
|||||||
func (m *mockAccountRepoForPlatform) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
func (m *mockAccountRepoForPlatform) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
func (m *mockAccountRepoForPlatform) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
func (m *mockAccountRepoForPlatform) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
|
func (m *mockAccountRepoForPlatform) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -216,29 +213,6 @@ func (m *mockGatewayCacheForPlatform) DeleteSessionAccountID(ctx context.Context
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockGatewayCacheForPlatform) IncrModelCallCount(ctx context.Context, accountID int64, model string) (int64, error) {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockGatewayCacheForPlatform) GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*ModelLoadInfo, error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockGatewayCacheForPlatform) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
|
|
||||||
return "", 0, false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockGatewayCacheForPlatform) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockGatewayCacheForPlatform) FindAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
|
|
||||||
return "", 0, false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockGatewayCacheForPlatform) SaveAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type mockGroupRepoForGateway struct {
|
type mockGroupRepoForGateway struct {
|
||||||
groups map[int64]*Group
|
groups map[int64]*Group
|
||||||
|
|||||||
@@ -6,9 +6,19 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/domain"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// SessionContext 粘性会话上下文,用于区分不同来源的请求。
|
||||||
|
// 仅在 GenerateSessionHash 第 3 级 fallback(消息内容 hash)时混入,
|
||||||
|
// 避免不同用户发送相同消息产生相同 hash 导致账号集中。
|
||||||
|
type SessionContext struct {
|
||||||
|
ClientIP string
|
||||||
|
UserAgent string
|
||||||
|
APIKeyID int64
|
||||||
|
}
|
||||||
|
|
||||||
// ParsedRequest 保存网关请求的预解析结果
|
// ParsedRequest 保存网关请求的预解析结果
|
||||||
//
|
//
|
||||||
// 性能优化说明:
|
// 性能优化说明:
|
||||||
@@ -22,20 +32,22 @@ import (
|
|||||||
// 2. 将解析结果 ParsedRequest 传递给 Service 层
|
// 2. 将解析结果 ParsedRequest 传递给 Service 层
|
||||||
// 3. 避免重复 json.Unmarshal,减少 CPU 和内存开销
|
// 3. 避免重复 json.Unmarshal,减少 CPU 和内存开销
|
||||||
type ParsedRequest struct {
|
type ParsedRequest struct {
|
||||||
Body []byte // 原始请求体(保留用于转发)
|
Body []byte // 原始请求体(保留用于转发)
|
||||||
Model string // 请求的模型名称
|
Model string // 请求的模型名称
|
||||||
Stream bool // 是否为流式请求
|
Stream bool // 是否为流式请求
|
||||||
MetadataUserID string // metadata.user_id(用于会话亲和)
|
MetadataUserID string // metadata.user_id(用于会话亲和)
|
||||||
System any // system 字段内容
|
System any // system 字段内容
|
||||||
Messages []any // messages 数组
|
Messages []any // messages 数组
|
||||||
HasSystem bool // 是否包含 system 字段(包含 null 也视为显式传入)
|
HasSystem bool // 是否包含 system 字段(包含 null 也视为显式传入)
|
||||||
ThinkingEnabled bool // 是否开启 thinking(部分平台会影响最终模型名)
|
ThinkingEnabled bool // 是否开启 thinking(部分平台会影响最终模型名)
|
||||||
MaxTokens int // max_tokens 值(用于探测请求拦截)
|
MaxTokens int // max_tokens 值(用于探测请求拦截)
|
||||||
|
SessionContext *SessionContext // 可选:请求上下文区分因子(nil 时行为不变)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ParseGatewayRequest 解析网关请求体并返回结构化结果
|
// ParseGatewayRequest 解析网关请求体并返回结构化结果。
|
||||||
// 性能优化:一次解析提取所有需要的字段,避免重复 Unmarshal
|
// protocol 指定请求协议格式(domain.PlatformAnthropic / domain.PlatformGemini),
|
||||||
func ParseGatewayRequest(body []byte) (*ParsedRequest, error) {
|
// 不同协议使用不同的 system/messages 字段名。
|
||||||
|
func ParseGatewayRequest(body []byte, protocol string) (*ParsedRequest, error) {
|
||||||
var req map[string]any
|
var req map[string]any
|
||||||
if err := json.Unmarshal(body, &req); err != nil {
|
if err := json.Unmarshal(body, &req); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -64,14 +76,29 @@ func ParseGatewayRequest(body []byte) (*ParsedRequest, error) {
|
|||||||
parsed.MetadataUserID = userID
|
parsed.MetadataUserID = userID
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// system 字段只要存在就视为显式提供(即使为 null),
|
|
||||||
// 以避免客户端传 null 时被默认 system 误注入。
|
switch protocol {
|
||||||
if system, ok := req["system"]; ok {
|
case domain.PlatformGemini:
|
||||||
parsed.HasSystem = true
|
// Gemini 原生格式: systemInstruction.parts / contents
|
||||||
parsed.System = system
|
if sysInst, ok := req["systemInstruction"].(map[string]any); ok {
|
||||||
}
|
if parts, ok := sysInst["parts"].([]any); ok {
|
||||||
if messages, ok := req["messages"].([]any); ok {
|
parsed.System = parts
|
||||||
parsed.Messages = messages
|
}
|
||||||
|
}
|
||||||
|
if contents, ok := req["contents"].([]any); ok {
|
||||||
|
parsed.Messages = contents
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
// Anthropic / OpenAI 格式: system / messages
|
||||||
|
// system 字段只要存在就视为显式提供(即使为 null),
|
||||||
|
// 以避免客户端传 null 时被默认 system 误注入。
|
||||||
|
if system, ok := req["system"]; ok {
|
||||||
|
parsed.HasSystem = true
|
||||||
|
parsed.System = system
|
||||||
|
}
|
||||||
|
if messages, ok := req["messages"].([]any); ok {
|
||||||
|
parsed.Messages = messages
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// thinking: {type: "enabled"}
|
// thinking: {type: "enabled"}
|
||||||
|
|||||||
@@ -4,12 +4,13 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/domain"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestParseGatewayRequest(t *testing.T) {
|
func TestParseGatewayRequest(t *testing.T) {
|
||||||
body := []byte(`{"model":"claude-3-7-sonnet","stream":true,"metadata":{"user_id":"session_123e4567-e89b-12d3-a456-426614174000"},"system":[{"type":"text","text":"hello","cache_control":{"type":"ephemeral"}}],"messages":[{"content":"hi"}]}`)
|
body := []byte(`{"model":"claude-3-7-sonnet","stream":true,"metadata":{"user_id":"session_123e4567-e89b-12d3-a456-426614174000"},"system":[{"type":"text","text":"hello","cache_control":{"type":"ephemeral"}}],"messages":[{"content":"hi"}]}`)
|
||||||
parsed, err := ParseGatewayRequest(body)
|
parsed, err := ParseGatewayRequest(body, "")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, "claude-3-7-sonnet", parsed.Model)
|
require.Equal(t, "claude-3-7-sonnet", parsed.Model)
|
||||||
require.True(t, parsed.Stream)
|
require.True(t, parsed.Stream)
|
||||||
@@ -22,7 +23,7 @@ func TestParseGatewayRequest(t *testing.T) {
|
|||||||
|
|
||||||
func TestParseGatewayRequest_ThinkingEnabled(t *testing.T) {
|
func TestParseGatewayRequest_ThinkingEnabled(t *testing.T) {
|
||||||
body := []byte(`{"model":"claude-sonnet-4-5","thinking":{"type":"enabled"},"messages":[{"content":"hi"}]}`)
|
body := []byte(`{"model":"claude-sonnet-4-5","thinking":{"type":"enabled"},"messages":[{"content":"hi"}]}`)
|
||||||
parsed, err := ParseGatewayRequest(body)
|
parsed, err := ParseGatewayRequest(body, "")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, "claude-sonnet-4-5", parsed.Model)
|
require.Equal(t, "claude-sonnet-4-5", parsed.Model)
|
||||||
require.True(t, parsed.ThinkingEnabled)
|
require.True(t, parsed.ThinkingEnabled)
|
||||||
@@ -30,21 +31,21 @@ func TestParseGatewayRequest_ThinkingEnabled(t *testing.T) {
|
|||||||
|
|
||||||
func TestParseGatewayRequest_MaxTokens(t *testing.T) {
|
func TestParseGatewayRequest_MaxTokens(t *testing.T) {
|
||||||
body := []byte(`{"model":"claude-haiku-4-5","max_tokens":1}`)
|
body := []byte(`{"model":"claude-haiku-4-5","max_tokens":1}`)
|
||||||
parsed, err := ParseGatewayRequest(body)
|
parsed, err := ParseGatewayRequest(body, "")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, 1, parsed.MaxTokens)
|
require.Equal(t, 1, parsed.MaxTokens)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseGatewayRequest_MaxTokensNonIntegralIgnored(t *testing.T) {
|
func TestParseGatewayRequest_MaxTokensNonIntegralIgnored(t *testing.T) {
|
||||||
body := []byte(`{"model":"claude-haiku-4-5","max_tokens":1.5}`)
|
body := []byte(`{"model":"claude-haiku-4-5","max_tokens":1.5}`)
|
||||||
parsed, err := ParseGatewayRequest(body)
|
parsed, err := ParseGatewayRequest(body, "")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, 0, parsed.MaxTokens)
|
require.Equal(t, 0, parsed.MaxTokens)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseGatewayRequest_SystemNull(t *testing.T) {
|
func TestParseGatewayRequest_SystemNull(t *testing.T) {
|
||||||
body := []byte(`{"model":"claude-3","system":null}`)
|
body := []byte(`{"model":"claude-3","system":null}`)
|
||||||
parsed, err := ParseGatewayRequest(body)
|
parsed, err := ParseGatewayRequest(body, "")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
// 显式传入 system:null 也应视为“字段已存在”,避免默认 system 被注入。
|
// 显式传入 system:null 也应视为“字段已存在”,避免默认 system 被注入。
|
||||||
require.True(t, parsed.HasSystem)
|
require.True(t, parsed.HasSystem)
|
||||||
@@ -53,16 +54,112 @@ func TestParseGatewayRequest_SystemNull(t *testing.T) {
|
|||||||
|
|
||||||
func TestParseGatewayRequest_InvalidModelType(t *testing.T) {
|
func TestParseGatewayRequest_InvalidModelType(t *testing.T) {
|
||||||
body := []byte(`{"model":123}`)
|
body := []byte(`{"model":123}`)
|
||||||
_, err := ParseGatewayRequest(body)
|
_, err := ParseGatewayRequest(body, "")
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseGatewayRequest_InvalidStreamType(t *testing.T) {
|
func TestParseGatewayRequest_InvalidStreamType(t *testing.T) {
|
||||||
body := []byte(`{"stream":"true"}`)
|
body := []byte(`{"stream":"true"}`)
|
||||||
_, err := ParseGatewayRequest(body)
|
_, err := ParseGatewayRequest(body, "")
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ============ Gemini 原生格式解析测试 ============
|
||||||
|
|
||||||
|
func TestParseGatewayRequest_GeminiContents(t *testing.T) {
|
||||||
|
body := []byte(`{
|
||||||
|
"contents": [
|
||||||
|
{"role": "user", "parts": [{"text": "Hello"}]},
|
||||||
|
{"role": "model", "parts": [{"text": "Hi there"}]},
|
||||||
|
{"role": "user", "parts": [{"text": "How are you?"}]}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
parsed, err := ParseGatewayRequest(body, domain.PlatformGemini)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, parsed.Messages, 3, "should parse contents as Messages")
|
||||||
|
require.False(t, parsed.HasSystem, "Gemini format should not set HasSystem")
|
||||||
|
require.Nil(t, parsed.System, "no systemInstruction means nil System")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseGatewayRequest_GeminiSystemInstruction(t *testing.T) {
|
||||||
|
body := []byte(`{
|
||||||
|
"systemInstruction": {
|
||||||
|
"parts": [{"text": "You are a helpful assistant."}]
|
||||||
|
},
|
||||||
|
"contents": [
|
||||||
|
{"role": "user", "parts": [{"text": "Hello"}]}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
parsed, err := ParseGatewayRequest(body, domain.PlatformGemini)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, parsed.System, "should parse systemInstruction.parts as System")
|
||||||
|
parts, ok := parsed.System.([]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Len(t, parts, 1)
|
||||||
|
partMap, ok := parts[0].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, "You are a helpful assistant.", partMap["text"])
|
||||||
|
require.Len(t, parsed.Messages, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseGatewayRequest_GeminiWithModel(t *testing.T) {
|
||||||
|
body := []byte(`{
|
||||||
|
"model": "gemini-2.5-pro",
|
||||||
|
"contents": [{"role": "user", "parts": [{"text": "test"}]}]
|
||||||
|
}`)
|
||||||
|
parsed, err := ParseGatewayRequest(body, domain.PlatformGemini)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "gemini-2.5-pro", parsed.Model)
|
||||||
|
require.Len(t, parsed.Messages, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseGatewayRequest_GeminiIgnoresAnthropicFields(t *testing.T) {
|
||||||
|
// Gemini 格式下 system/messages 字段应被忽略
|
||||||
|
body := []byte(`{
|
||||||
|
"system": "should be ignored",
|
||||||
|
"messages": [{"role": "user", "content": "ignored"}],
|
||||||
|
"contents": [{"role": "user", "parts": [{"text": "real content"}]}]
|
||||||
|
}`)
|
||||||
|
parsed, err := ParseGatewayRequest(body, domain.PlatformGemini)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.False(t, parsed.HasSystem, "Gemini protocol should not parse Anthropic system field")
|
||||||
|
require.Nil(t, parsed.System, "no systemInstruction = nil System")
|
||||||
|
require.Len(t, parsed.Messages, 1, "should use contents, not messages")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseGatewayRequest_GeminiEmptyContents(t *testing.T) {
|
||||||
|
body := []byte(`{"contents": []}`)
|
||||||
|
parsed, err := ParseGatewayRequest(body, domain.PlatformGemini)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Empty(t, parsed.Messages)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseGatewayRequest_GeminiNoContents(t *testing.T) {
|
||||||
|
body := []byte(`{"model": "gemini-2.5-flash"}`)
|
||||||
|
parsed, err := ParseGatewayRequest(body, domain.PlatformGemini)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Nil(t, parsed.Messages)
|
||||||
|
require.Equal(t, "gemini-2.5-flash", parsed.Model)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseGatewayRequest_AnthropicIgnoresGeminiFields(t *testing.T) {
|
||||||
|
// Anthropic 格式下 contents/systemInstruction 字段应被忽略
|
||||||
|
body := []byte(`{
|
||||||
|
"system": "real system",
|
||||||
|
"messages": [{"role": "user", "content": "real content"}],
|
||||||
|
"contents": [{"role": "user", "parts": [{"text": "ignored"}]}],
|
||||||
|
"systemInstruction": {"parts": [{"text": "ignored"}]}
|
||||||
|
}`)
|
||||||
|
parsed, err := ParseGatewayRequest(body, domain.PlatformAnthropic)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.True(t, parsed.HasSystem)
|
||||||
|
require.Equal(t, "real system", parsed.System)
|
||||||
|
require.Len(t, parsed.Messages, 1)
|
||||||
|
msg, ok := parsed.Messages[0].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, "real content", msg["content"])
|
||||||
|
}
|
||||||
|
|
||||||
func TestFilterThinkingBlocks(t *testing.T) {
|
func TestFilterThinkingBlocks(t *testing.T) {
|
||||||
containsThinkingBlock := func(body []byte) bool {
|
containsThinkingBlock := func(body []byte) bool {
|
||||||
var req map[string]any
|
var req map[string]any
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/hex"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -17,6 +16,7 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"regexp"
|
"regexp"
|
||||||
"sort"
|
"sort"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
@@ -26,6 +26,7 @@ import (
|
|||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
|
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
||||||
|
"github.com/cespare/xxhash/v2"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
"github.com/tidwall/sjson"
|
"github.com/tidwall/sjson"
|
||||||
@@ -245,9 +246,6 @@ var (
|
|||||||
// ErrClaudeCodeOnly 表示分组仅允许 Claude Code 客户端访问
|
// ErrClaudeCodeOnly 表示分组仅允许 Claude Code 客户端访问
|
||||||
var ErrClaudeCodeOnly = errors.New("this group only allows Claude Code clients")
|
var ErrClaudeCodeOnly = errors.New("this group only allows Claude Code clients")
|
||||||
|
|
||||||
// ErrModelScopeNotSupported 表示请求的模型系列不在分组支持的范围内
|
|
||||||
var ErrModelScopeNotSupported = errors.New("model scope not supported by this group")
|
|
||||||
|
|
||||||
// allowedHeaders 白名单headers(参考CRS项目)
|
// allowedHeaders 白名单headers(参考CRS项目)
|
||||||
var allowedHeaders = map[string]bool{
|
var allowedHeaders = map[string]bool{
|
||||||
"accept": true,
|
"accept": true,
|
||||||
@@ -273,13 +271,6 @@ var allowedHeaders = map[string]bool{
|
|||||||
// GatewayCache 定义网关服务的缓存操作接口。
|
// GatewayCache 定义网关服务的缓存操作接口。
|
||||||
// 提供粘性会话(Sticky Session)的存储、查询、刷新和删除功能。
|
// 提供粘性会话(Sticky Session)的存储、查询、刷新和删除功能。
|
||||||
//
|
//
|
||||||
// ModelLoadInfo 模型负载信息(用于 Antigravity 调度)
|
|
||||||
// Model load info for Antigravity scheduling
|
|
||||||
type ModelLoadInfo struct {
|
|
||||||
CallCount int64 // 当前分钟调用次数 / Call count in current minute
|
|
||||||
LastUsedAt time.Time // 最后调度时间(零值表示未调度过)/ Last scheduling time (zero means never scheduled)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GatewayCache defines cache operations for gateway service.
|
// GatewayCache defines cache operations for gateway service.
|
||||||
// Provides sticky session storage, retrieval, refresh and deletion capabilities.
|
// Provides sticky session storage, retrieval, refresh and deletion capabilities.
|
||||||
type GatewayCache interface {
|
type GatewayCache interface {
|
||||||
@@ -295,32 +286,6 @@ type GatewayCache interface {
|
|||||||
// DeleteSessionAccountID 删除粘性会话绑定,用于账号不可用时主动清理
|
// DeleteSessionAccountID 删除粘性会话绑定,用于账号不可用时主动清理
|
||||||
// Delete sticky session binding, used to proactively clean up when account becomes unavailable
|
// Delete sticky session binding, used to proactively clean up when account becomes unavailable
|
||||||
DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error
|
DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error
|
||||||
|
|
||||||
// IncrModelCallCount 增加模型调用次数并更新最后调度时间(Antigravity 专用)
|
|
||||||
// Increment model call count and update last scheduling time (Antigravity only)
|
|
||||||
// 返回更新后的调用次数
|
|
||||||
IncrModelCallCount(ctx context.Context, accountID int64, model string) (int64, error)
|
|
||||||
|
|
||||||
// GetModelLoadBatch 批量获取账号的模型负载信息(Antigravity 专用)
|
|
||||||
// Batch get model load info for accounts (Antigravity only)
|
|
||||||
GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*ModelLoadInfo, error)
|
|
||||||
|
|
||||||
// FindGeminiSession 查找 Gemini 会话(MGET 倒序匹配)
|
|
||||||
// Find Gemini session using MGET reverse order matching
|
|
||||||
// 返回最长匹配的会话信息(uuid, accountID)
|
|
||||||
FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool)
|
|
||||||
|
|
||||||
// SaveGeminiSession 保存 Gemini 会话
|
|
||||||
// Save Gemini session binding
|
|
||||||
SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error
|
|
||||||
|
|
||||||
// FindAnthropicSession 查找 Anthropic 会话(Trie 匹配)
|
|
||||||
// Find Anthropic session using Trie matching
|
|
||||||
FindAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool)
|
|
||||||
|
|
||||||
// SaveAnthropicSession 保存 Anthropic 会话
|
|
||||||
// Save Anthropic session binding
|
|
||||||
SaveAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// derefGroupID safely dereferences *int64 to int64, returning 0 if nil
|
// derefGroupID safely dereferences *int64 to int64, returning 0 if nil
|
||||||
@@ -415,6 +380,7 @@ type GatewayService struct {
|
|||||||
userSubRepo UserSubscriptionRepository
|
userSubRepo UserSubscriptionRepository
|
||||||
userGroupRateRepo UserGroupRateRepository
|
userGroupRateRepo UserGroupRateRepository
|
||||||
cache GatewayCache
|
cache GatewayCache
|
||||||
|
digestStore *DigestSessionStore
|
||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
schedulerSnapshot *SchedulerSnapshotService
|
schedulerSnapshot *SchedulerSnapshotService
|
||||||
billingService *BillingService
|
billingService *BillingService
|
||||||
@@ -448,6 +414,7 @@ func NewGatewayService(
|
|||||||
deferredService *DeferredService,
|
deferredService *DeferredService,
|
||||||
claudeTokenProvider *ClaudeTokenProvider,
|
claudeTokenProvider *ClaudeTokenProvider,
|
||||||
sessionLimitCache SessionLimitCache,
|
sessionLimitCache SessionLimitCache,
|
||||||
|
digestStore *DigestSessionStore,
|
||||||
) *GatewayService {
|
) *GatewayService {
|
||||||
return &GatewayService{
|
return &GatewayService{
|
||||||
accountRepo: accountRepo,
|
accountRepo: accountRepo,
|
||||||
@@ -457,6 +424,7 @@ func NewGatewayService(
|
|||||||
userSubRepo: userSubRepo,
|
userSubRepo: userSubRepo,
|
||||||
userGroupRateRepo: userGroupRateRepo,
|
userGroupRateRepo: userGroupRateRepo,
|
||||||
cache: cache,
|
cache: cache,
|
||||||
|
digestStore: digestStore,
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
schedulerSnapshot: schedulerSnapshot,
|
schedulerSnapshot: schedulerSnapshot,
|
||||||
concurrencyService: concurrencyService,
|
concurrencyService: concurrencyService,
|
||||||
@@ -490,8 +458,17 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string {
|
|||||||
return s.hashContent(cacheableContent)
|
return s.hashContent(cacheableContent)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. 最后 fallback: 使用 system + 所有消息的完整摘要串
|
// 3. 最后 fallback: 使用 session上下文 + system + 所有消息的完整摘要串
|
||||||
var combined strings.Builder
|
var combined strings.Builder
|
||||||
|
// 混入请求上下文区分因子,避免不同用户相同消息产生相同 hash
|
||||||
|
if parsed.SessionContext != nil {
|
||||||
|
_, _ = combined.WriteString(parsed.SessionContext.ClientIP)
|
||||||
|
_, _ = combined.WriteString(":")
|
||||||
|
_, _ = combined.WriteString(parsed.SessionContext.UserAgent)
|
||||||
|
_, _ = combined.WriteString(":")
|
||||||
|
_, _ = combined.WriteString(strconv.FormatInt(parsed.SessionContext.APIKeyID, 10))
|
||||||
|
_, _ = combined.WriteString("|")
|
||||||
|
}
|
||||||
if parsed.System != nil {
|
if parsed.System != nil {
|
||||||
systemText := s.extractTextFromSystem(parsed.System)
|
systemText := s.extractTextFromSystem(parsed.System)
|
||||||
if systemText != "" {
|
if systemText != "" {
|
||||||
@@ -500,9 +477,20 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string {
|
|||||||
}
|
}
|
||||||
for _, msg := range parsed.Messages {
|
for _, msg := range parsed.Messages {
|
||||||
if m, ok := msg.(map[string]any); ok {
|
if m, ok := msg.(map[string]any); ok {
|
||||||
msgText := s.extractTextFromContent(m["content"])
|
if content, exists := m["content"]; exists {
|
||||||
if msgText != "" {
|
// Anthropic: messages[].content
|
||||||
_, _ = combined.WriteString(msgText)
|
if msgText := s.extractTextFromContent(content); msgText != "" {
|
||||||
|
_, _ = combined.WriteString(msgText)
|
||||||
|
}
|
||||||
|
} else if parts, ok := m["parts"].([]any); ok {
|
||||||
|
// Gemini: contents[].parts[].text
|
||||||
|
for _, part := range parts {
|
||||||
|
if partMap, ok := part.(map[string]any); ok {
|
||||||
|
if text, ok := partMap["text"].(string); ok {
|
||||||
|
_, _ = combined.WriteString(text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -536,35 +524,37 @@ func (s *GatewayService) GetCachedSessionAccountID(ctx context.Context, groupID
|
|||||||
|
|
||||||
// FindGeminiSession 查找 Gemini 会话(基于内容摘要链的 Fallback 匹配)
|
// FindGeminiSession 查找 Gemini 会话(基于内容摘要链的 Fallback 匹配)
|
||||||
// 返回最长匹配的会话信息(uuid, accountID)
|
// 返回最长匹配的会话信息(uuid, accountID)
|
||||||
func (s *GatewayService) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
|
func (s *GatewayService) FindGeminiSession(_ context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, matchedChain string, found bool) {
|
||||||
if digestChain == "" || s.cache == nil {
|
if digestChain == "" || s.digestStore == nil {
|
||||||
return "", 0, false
|
return "", 0, "", false
|
||||||
}
|
}
|
||||||
return s.cache.FindGeminiSession(ctx, groupID, prefixHash, digestChain)
|
return s.digestStore.Find(groupID, prefixHash, digestChain)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveGeminiSession 保存 Gemini 会话
|
// SaveGeminiSession 保存 Gemini 会话。oldDigestChain 为 Find 返回的 matchedChain,用于删旧 key。
|
||||||
func (s *GatewayService) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
|
func (s *GatewayService) SaveGeminiSession(_ context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64, oldDigestChain string) error {
|
||||||
if digestChain == "" || s.cache == nil {
|
if digestChain == "" || s.digestStore == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return s.cache.SaveGeminiSession(ctx, groupID, prefixHash, digestChain, uuid, accountID)
|
s.digestStore.Save(groupID, prefixHash, digestChain, uuid, accountID, oldDigestChain)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// FindAnthropicSession 查找 Anthropic 会话(基于内容摘要链的 Fallback 匹配)
|
// FindAnthropicSession 查找 Anthropic 会话(基于内容摘要链的 Fallback 匹配)
|
||||||
func (s *GatewayService) FindAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
|
func (s *GatewayService) FindAnthropicSession(_ context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, matchedChain string, found bool) {
|
||||||
if digestChain == "" || s.cache == nil {
|
if digestChain == "" || s.digestStore == nil {
|
||||||
return "", 0, false
|
return "", 0, "", false
|
||||||
}
|
}
|
||||||
return s.cache.FindAnthropicSession(ctx, groupID, prefixHash, digestChain)
|
return s.digestStore.Find(groupID, prefixHash, digestChain)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveAnthropicSession 保存 Anthropic 会话
|
// SaveAnthropicSession 保存 Anthropic 会话
|
||||||
func (s *GatewayService) SaveAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
|
func (s *GatewayService) SaveAnthropicSession(_ context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64, oldDigestChain string) error {
|
||||||
if digestChain == "" || s.cache == nil {
|
if digestChain == "" || s.digestStore == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return s.cache.SaveAnthropicSession(ctx, groupID, prefixHash, digestChain, uuid, accountID)
|
s.digestStore.Save(groupID, prefixHash, digestChain, uuid, accountID, oldDigestChain)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string {
|
func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string {
|
||||||
@@ -649,8 +639,8 @@ func (s *GatewayService) extractTextFromContent(content any) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *GatewayService) hashContent(content string) string {
|
func (s *GatewayService) hashContent(content string) string {
|
||||||
hash := sha256.Sum256([]byte(content))
|
h := xxhash.Sum64String(content)
|
||||||
return hex.EncodeToString(hash[:16]) // 32字符
|
return strconv.FormatUint(h, 36)
|
||||||
}
|
}
|
||||||
|
|
||||||
// replaceModelInBody 替换请求体中的model字段
|
// replaceModelInBody 替换请求体中的model字段
|
||||||
@@ -1009,13 +999,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
|||||||
log.Printf("[ModelRoutingDebug] load-aware enabled: group_id=%v model=%s session=%s platform=%s", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), platform)
|
log.Printf("[ModelRoutingDebug] load-aware enabled: group_id=%v model=%s session=%s platform=%s", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), platform)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Antigravity 模型系列检查(在账号选择前检查,确保所有代码路径都经过此检查)
|
|
||||||
if platform == PlatformAntigravity && groupID != nil && requestedModel != "" {
|
|
||||||
if err := s.checkAntigravityModelScope(ctx, *groupID, requestedModel); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
accounts, useMixed, err := s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
|
accounts, useMixed, err := s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -1209,6 +1192,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
|||||||
return a.account.LastUsedAt.Before(*b.account.LastUsedAt)
|
return a.account.LastUsedAt.Before(*b.account.LastUsedAt)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
shuffleWithinSortGroups(routingAvailable)
|
||||||
|
|
||||||
// 4. 尝试获取槽位
|
// 4. 尝试获取槽位
|
||||||
for _, item := range routingAvailable {
|
for _, item := range routingAvailable {
|
||||||
@@ -1362,10 +1346,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
|||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Antigravity 平台:获取模型负载信息
|
|
||||||
var modelLoadMap map[int64]*ModelLoadInfo
|
|
||||||
isAntigravity := platform == PlatformAntigravity
|
|
||||||
|
|
||||||
var available []accountWithLoad
|
var available []accountWithLoad
|
||||||
for _, acc := range candidates {
|
for _, acc := range candidates {
|
||||||
loadInfo := loadMap[acc.ID]
|
loadInfo := loadMap[acc.ID]
|
||||||
@@ -1380,109 +1360,44 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Antigravity 平台:按账号实际映射后的模型名获取模型负载(与 Forward 的统计保持一致)
|
// 分层过滤选择:优先级 → 负载率 → LRU
|
||||||
if isAntigravity && requestedModel != "" && s.cache != nil && len(available) > 0 {
|
for len(available) > 0 {
|
||||||
modelLoadMap = make(map[int64]*ModelLoadInfo, len(available))
|
// 1. 取优先级最小的集合
|
||||||
modelToAccountIDs := make(map[string][]int64)
|
candidates := filterByMinPriority(available)
|
||||||
for _, item := range available {
|
// 2. 取负载率最低的集合
|
||||||
mappedModel := mapAntigravityModel(item.account, requestedModel)
|
candidates = filterByMinLoadRate(candidates)
|
||||||
if mappedModel == "" {
|
// 3. LRU 选择最久未用的账号
|
||||||
continue
|
selected := selectByLRU(candidates, preferOAuth)
|
||||||
}
|
if selected == nil {
|
||||||
modelToAccountIDs[mappedModel] = append(modelToAccountIDs[mappedModel], item.account.ID)
|
break
|
||||||
}
|
}
|
||||||
for model, ids := range modelToAccountIDs {
|
|
||||||
batch, err := s.cache.GetModelLoadBatch(ctx, ids, model)
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
for id, info := range batch {
|
|
||||||
modelLoadMap[id] = info
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(modelLoadMap) == 0 {
|
|
||||||
modelLoadMap = nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Antigravity 平台:优先级硬过滤 →(同优先级内)按调用次数选择(最少优先,新账号用平均值)
|
result, err := s.tryAcquireAccountSlot(ctx, selected.account.ID, selected.account.Concurrency)
|
||||||
// 其他平台:分层过滤选择:优先级 → 负载率 → LRU
|
if err == nil && result.Acquired {
|
||||||
if isAntigravity {
|
// 会话数量限制检查
|
||||||
for len(available) > 0 {
|
if !s.checkAndRegisterSession(ctx, selected.account, sessionHash) {
|
||||||
// 1. 取优先级最小的集合(硬过滤)
|
result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
|
||||||
candidates := filterByMinPriority(available)
|
} else {
|
||||||
// 2. 同优先级内按调用次数选择(调用次数最少优先,新账号使用平均值)
|
if sessionHash != "" && s.cache != nil {
|
||||||
selected := selectByCallCount(candidates, modelLoadMap, preferOAuth)
|
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.account.ID, stickySessionTTL)
|
||||||
if selected == nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
result, err := s.tryAcquireAccountSlot(ctx, selected.account.ID, selected.account.Concurrency)
|
|
||||||
if err == nil && result.Acquired {
|
|
||||||
// 会话数量限制检查
|
|
||||||
if !s.checkAndRegisterSession(ctx, selected.account, sessionHash) {
|
|
||||||
result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
|
|
||||||
} else {
|
|
||||||
if sessionHash != "" && s.cache != nil {
|
|
||||||
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.account.ID, stickySessionTTL)
|
|
||||||
}
|
|
||||||
return &AccountSelectionResult{
|
|
||||||
Account: selected.account,
|
|
||||||
Acquired: true,
|
|
||||||
ReleaseFunc: result.ReleaseFunc,
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
return &AccountSelectionResult{
|
||||||
|
Account: selected.account,
|
||||||
|
Acquired: true,
|
||||||
|
ReleaseFunc: result.ReleaseFunc,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// 移除已尝试的账号,重新选择
|
|
||||||
selectedID := selected.account.ID
|
|
||||||
newAvailable := make([]accountWithLoad, 0, len(available)-1)
|
|
||||||
for _, acc := range available {
|
|
||||||
if acc.account.ID != selectedID {
|
|
||||||
newAvailable = append(newAvailable, acc)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
available = newAvailable
|
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
for len(available) > 0 {
|
|
||||||
// 1. 取优先级最小的集合
|
|
||||||
candidates := filterByMinPriority(available)
|
|
||||||
// 2. 取负载率最低的集合
|
|
||||||
candidates = filterByMinLoadRate(candidates)
|
|
||||||
// 3. LRU 选择最久未用的账号
|
|
||||||
selected := selectByLRU(candidates, preferOAuth)
|
|
||||||
if selected == nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
result, err := s.tryAcquireAccountSlot(ctx, selected.account.ID, selected.account.Concurrency)
|
// 移除已尝试的账号,重新进行分层过滤
|
||||||
if err == nil && result.Acquired {
|
selectedID := selected.account.ID
|
||||||
// 会话数量限制检查
|
newAvailable := make([]accountWithLoad, 0, len(available)-1)
|
||||||
if !s.checkAndRegisterSession(ctx, selected.account, sessionHash) {
|
for _, acc := range available {
|
||||||
result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
|
if acc.account.ID != selectedID {
|
||||||
} else {
|
newAvailable = append(newAvailable, acc)
|
||||||
if sessionHash != "" && s.cache != nil {
|
|
||||||
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.account.ID, stickySessionTTL)
|
|
||||||
}
|
|
||||||
return &AccountSelectionResult{
|
|
||||||
Account: selected.account,
|
|
||||||
Acquired: true,
|
|
||||||
ReleaseFunc: result.ReleaseFunc,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 移除已尝试的账号,重新进行分层过滤
|
|
||||||
selectedID := selected.account.ID
|
|
||||||
newAvailable := make([]accountWithLoad, 0, len(available)-1)
|
|
||||||
for _, acc := range available {
|
|
||||||
if acc.account.ID != selectedID {
|
|
||||||
newAvailable = append(newAvailable, acc)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
available = newAvailable
|
|
||||||
}
|
}
|
||||||
|
available = newAvailable
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2018,87 +1933,79 @@ func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) {
|
|||||||
return a.LastUsedAt.Before(*b.LastUsedAt)
|
return a.LastUsedAt.Before(*b.LastUsedAt)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
shuffleWithinPriorityAndLastUsed(accounts)
|
||||||
}
|
}
|
||||||
|
|
||||||
// selectByCallCount 从候选账号中选择调用次数最少的账号(Antigravity 专用)
|
// shuffleWithinSortGroups 对排序后的 accountWithLoad 切片,按 (Priority, LoadRate, LastUsedAt) 分组后组内随机打乱。
|
||||||
// 新账号(CallCount=0)使用平均调用次数作为虚拟值,避免冷启动被猛调
|
// 防止并发请求读取同一快照时,确定性排序导致所有请求命中相同账号。
|
||||||
// 如果有多个账号具有相同的最小调用次数,则随机选择一个
|
func shuffleWithinSortGroups(accounts []accountWithLoad) {
|
||||||
func selectByCallCount(accounts []accountWithLoad, modelLoadMap map[int64]*ModelLoadInfo, preferOAuth bool) *accountWithLoad {
|
if len(accounts) <= 1 {
|
||||||
if len(accounts) == 0 {
|
return
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
if len(accounts) == 1 {
|
i := 0
|
||||||
return &accounts[0]
|
for i < len(accounts) {
|
||||||
}
|
j := i + 1
|
||||||
|
for j < len(accounts) && sameAccountWithLoadGroup(accounts[i], accounts[j]) {
|
||||||
// 如果没有负载信息,回退到 LRU
|
j++
|
||||||
if modelLoadMap == nil {
|
|
||||||
return selectByLRU(accounts, preferOAuth)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 1. 计算平均调用次数(用于新账号冷启动)
|
|
||||||
var totalCallCount int64
|
|
||||||
var countWithCalls int
|
|
||||||
for _, acc := range accounts {
|
|
||||||
if info := modelLoadMap[acc.account.ID]; info != nil && info.CallCount > 0 {
|
|
||||||
totalCallCount += info.CallCount
|
|
||||||
countWithCalls++
|
|
||||||
}
|
}
|
||||||
}
|
if j-i > 1 {
|
||||||
|
mathrand.Shuffle(j-i, func(a, b int) {
|
||||||
var avgCallCount int64
|
accounts[i+a], accounts[i+b] = accounts[i+b], accounts[i+a]
|
||||||
if countWithCalls > 0 {
|
})
|
||||||
avgCallCount = totalCallCount / int64(countWithCalls)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 2. 获取每个账号的有效调用次数
|
|
||||||
getEffectiveCallCount := func(acc accountWithLoad) int64 {
|
|
||||||
if acc.account == nil {
|
|
||||||
return 0
|
|
||||||
}
|
}
|
||||||
info := modelLoadMap[acc.account.ID]
|
i = j
|
||||||
if info == nil || info.CallCount == 0 {
|
|
||||||
return avgCallCount // 新账号使用平均值
|
|
||||||
}
|
|
||||||
return info.CallCount
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 3. 找到最小调用次数
|
// sameAccountWithLoadGroup 判断两个 accountWithLoad 是否属于同一排序组
|
||||||
minCount := getEffectiveCallCount(accounts[0])
|
func sameAccountWithLoadGroup(a, b accountWithLoad) bool {
|
||||||
for _, acc := range accounts[1:] {
|
if a.account.Priority != b.account.Priority {
|
||||||
if c := getEffectiveCallCount(acc); c < minCount {
|
return false
|
||||||
minCount = c
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
if a.loadInfo.LoadRate != b.loadInfo.LoadRate {
|
||||||
// 4. 收集所有具有最小调用次数的账号
|
return false
|
||||||
var candidateIdxs []int
|
|
||||||
for i, acc := range accounts {
|
|
||||||
if getEffectiveCallCount(acc) == minCount {
|
|
||||||
candidateIdxs = append(candidateIdxs, i)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
return sameLastUsedAt(a.account.LastUsedAt, b.account.LastUsedAt)
|
||||||
|
}
|
||||||
|
|
||||||
// 5. 如果只有一个候选,直接返回
|
// shuffleWithinPriorityAndLastUsed 对排序后的 []*Account 切片,按 (Priority, LastUsedAt) 分组后组内随机打乱。
|
||||||
if len(candidateIdxs) == 1 {
|
func shuffleWithinPriorityAndLastUsed(accounts []*Account) {
|
||||||
return &accounts[candidateIdxs[0]]
|
if len(accounts) <= 1 {
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
i := 0
|
||||||
// 6. preferOAuth 处理
|
for i < len(accounts) {
|
||||||
if preferOAuth {
|
j := i + 1
|
||||||
var oauthIdxs []int
|
for j < len(accounts) && sameAccountGroup(accounts[i], accounts[j]) {
|
||||||
for _, idx := range candidateIdxs {
|
j++
|
||||||
if accounts[idx].account.Type == AccountTypeOAuth {
|
|
||||||
oauthIdxs = append(oauthIdxs, idx)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if len(oauthIdxs) > 0 {
|
if j-i > 1 {
|
||||||
candidateIdxs = oauthIdxs
|
mathrand.Shuffle(j-i, func(a, b int) {
|
||||||
|
accounts[i+a], accounts[i+b] = accounts[i+b], accounts[i+a]
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
i = j
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 7. 随机选择
|
// sameAccountGroup 判断两个 Account 是否属于同一排序组(Priority + LastUsedAt)
|
||||||
return &accounts[candidateIdxs[mathrand.Intn(len(candidateIdxs))]]
|
func sameAccountGroup(a, b *Account) bool {
|
||||||
|
if a.Priority != b.Priority {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return sameLastUsedAt(a.LastUsedAt, b.LastUsedAt)
|
||||||
|
}
|
||||||
|
|
||||||
|
// sameLastUsedAt 判断两个 LastUsedAt 是否相同(精度到秒)
|
||||||
|
func sameLastUsedAt(a, b *time.Time) bool {
|
||||||
|
switch {
|
||||||
|
case a == nil && b == nil:
|
||||||
|
return true
|
||||||
|
case a == nil || b == nil:
|
||||||
|
return false
|
||||||
|
default:
|
||||||
|
return a.Unix() == b.Unix()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// sortCandidatesForFallback 根据配置选择排序策略
|
// sortCandidatesForFallback 根据配置选择排序策略
|
||||||
@@ -2153,13 +2060,6 @@ func shuffleWithinPriority(accounts []*Account) {
|
|||||||
|
|
||||||
// selectAccountForModelWithPlatform 选择单平台账户(完全隔离)
|
// selectAccountForModelWithPlatform 选择单平台账户(完全隔离)
|
||||||
func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) {
|
func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) {
|
||||||
// 对 Antigravity 平台,检查请求的模型系列是否在分组支持范围内
|
|
||||||
if platform == PlatformAntigravity && groupID != nil && requestedModel != "" {
|
|
||||||
if err := s.checkAntigravityModelScope(ctx, *groupID, requestedModel); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
preferOAuth := platform == PlatformGemini
|
preferOAuth := platform == PlatformGemini
|
||||||
routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, platform)
|
routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, platform)
|
||||||
|
|
||||||
@@ -5171,27 +5071,6 @@ func (s *GatewayService) validateUpstreamBaseURL(raw string) (string, error) {
|
|||||||
return normalized, nil
|
return normalized, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// checkAntigravityModelScope 检查 Antigravity 平台的模型系列是否在分组支持范围内
|
|
||||||
func (s *GatewayService) checkAntigravityModelScope(ctx context.Context, groupID int64, requestedModel string) error {
|
|
||||||
scope, ok := ResolveAntigravityQuotaScope(requestedModel)
|
|
||||||
if !ok {
|
|
||||||
return nil // 无法解析 scope,跳过检查
|
|
||||||
}
|
|
||||||
|
|
||||||
group, err := s.resolveGroupByID(ctx, groupID)
|
|
||||||
if err != nil {
|
|
||||||
return nil // 查询失败时放行
|
|
||||||
}
|
|
||||||
if group == nil {
|
|
||||||
return nil // 分组不存在时放行
|
|
||||||
}
|
|
||||||
|
|
||||||
if !IsScopeSupported(group.SupportedModelScopes, scope) {
|
|
||||||
return ErrModelScopeNotSupported
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAvailableModels returns the list of models available for a group
|
// GetAvailableModels returns the list of models available for a group
|
||||||
// It aggregates model_mapping keys from all schedulable accounts in the group
|
// It aggregates model_mapping keys from all schedulable accounts in the group
|
||||||
func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64, platform string) []string {
|
func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64, platform string) []string {
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ func BenchmarkGenerateSessionHash_Metadata(b *testing.B) {
|
|||||||
|
|
||||||
b.ReportAllocs()
|
b.ReportAllocs()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
parsed, err := ParseGatewayRequest(body)
|
parsed, err := ParseGatewayRequest(body, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.Fatalf("解析请求失败: %v", err)
|
b.Fatalf("解析请求失败: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
384
backend/internal/service/gemini_error_policy_test.go
Normal file
384
backend/internal/service/gemini_error_policy_test.go
Normal file
@@ -0,0 +1,384 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// TestShouldFailoverGeminiUpstreamError — verifies the failover decision
|
||||||
|
// for the ErrorPolicyNone path (original logic preserved).
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestShouldFailoverGeminiUpstreamError(t *testing.T) {
|
||||||
|
svc := &GeminiMessagesCompatService{}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
statusCode int
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{"401_failover", 401, true},
|
||||||
|
{"403_failover", 403, true},
|
||||||
|
{"429_failover", 429, true},
|
||||||
|
{"529_failover", 529, true},
|
||||||
|
{"500_failover", 500, true},
|
||||||
|
{"502_failover", 502, true},
|
||||||
|
{"503_failover", 503, true},
|
||||||
|
{"400_no_failover", 400, false},
|
||||||
|
{"404_no_failover", 404, false},
|
||||||
|
{"422_no_failover", 422, false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := svc.shouldFailoverGeminiUpstreamError(tt.statusCode)
|
||||||
|
require.Equal(t, tt.expected, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// TestCheckErrorPolicy_GeminiAccounts — verifies CheckErrorPolicy works
|
||||||
|
// correctly for Gemini platform accounts (API Key type).
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestCheckErrorPolicy_GeminiAccounts(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
account *Account
|
||||||
|
statusCode int
|
||||||
|
body []byte
|
||||||
|
expected ErrorPolicyResult
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "gemini_apikey_custom_codes_hit",
|
||||||
|
account: &Account{
|
||||||
|
ID: 100,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Platform: PlatformGemini,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"custom_error_codes_enabled": true,
|
||||||
|
"custom_error_codes": []any{float64(429), float64(500)},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
statusCode: 429,
|
||||||
|
body: []byte(`{"error":"rate limited"}`),
|
||||||
|
expected: ErrorPolicyMatched,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "gemini_apikey_custom_codes_miss",
|
||||||
|
account: &Account{
|
||||||
|
ID: 101,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Platform: PlatformGemini,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"custom_error_codes_enabled": true,
|
||||||
|
"custom_error_codes": []any{float64(429)},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
statusCode: 500,
|
||||||
|
body: []byte(`{"error":"internal"}`),
|
||||||
|
expected: ErrorPolicySkipped,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "gemini_apikey_no_custom_codes_returns_none",
|
||||||
|
account: &Account{
|
||||||
|
ID: 102,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Platform: PlatformGemini,
|
||||||
|
},
|
||||||
|
statusCode: 500,
|
||||||
|
body: []byte(`{"error":"internal"}`),
|
||||||
|
expected: ErrorPolicyNone,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "gemini_apikey_temp_unschedulable_hit",
|
||||||
|
account: &Account{
|
||||||
|
ID: 103,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Platform: PlatformGemini,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"temp_unschedulable_enabled": true,
|
||||||
|
"temp_unschedulable_rules": []any{
|
||||||
|
map[string]any{
|
||||||
|
"error_code": float64(503),
|
||||||
|
"keywords": []any{"overloaded"},
|
||||||
|
"duration_minutes": float64(10),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
statusCode: 503,
|
||||||
|
body: []byte(`overloaded service`),
|
||||||
|
expected: ErrorPolicyTempUnscheduled,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "gemini_custom_codes_override_temp_unschedulable",
|
||||||
|
account: &Account{
|
||||||
|
ID: 104,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Platform: PlatformGemini,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"custom_error_codes_enabled": true,
|
||||||
|
"custom_error_codes": []any{float64(503)},
|
||||||
|
"temp_unschedulable_enabled": true,
|
||||||
|
"temp_unschedulable_rules": []any{
|
||||||
|
map[string]any{
|
||||||
|
"error_code": float64(503),
|
||||||
|
"keywords": []any{"overloaded"},
|
||||||
|
"duration_minutes": float64(10),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
statusCode: 503,
|
||||||
|
body: []byte(`overloaded`),
|
||||||
|
expected: ErrorPolicyMatched, // custom codes take precedence
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
repo := &errorPolicyRepoStub{}
|
||||||
|
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||||
|
|
||||||
|
result := svc.CheckErrorPolicy(context.Background(), tt.account, tt.statusCode, tt.body)
|
||||||
|
require.Equal(t, tt.expected, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// TestGeminiErrorPolicyIntegration — verifies the Gemini error handling
|
||||||
|
// paths produce the correct behavior for each ErrorPolicyResult.
|
||||||
|
//
|
||||||
|
// These tests simulate the inline error policy switch in handleClaudeCompat
|
||||||
|
// and forwardNativeGemini by calling the same methods in the same order.
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestGeminiErrorPolicyIntegration(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
account *Account
|
||||||
|
statusCode int
|
||||||
|
respBody []byte
|
||||||
|
expectFailover bool // expect UpstreamFailoverError
|
||||||
|
expectHandleError bool // expect handleGeminiUpstreamError to be called
|
||||||
|
expectShouldFailover bool // for None path, whether shouldFailover triggers
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "custom_codes_matched_429_failover",
|
||||||
|
account: &Account{
|
||||||
|
ID: 200,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Platform: PlatformGemini,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"custom_error_codes_enabled": true,
|
||||||
|
"custom_error_codes": []any{float64(429)},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
statusCode: 429,
|
||||||
|
respBody: []byte(`{"error":"rate limited"}`),
|
||||||
|
expectFailover: true,
|
||||||
|
expectHandleError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "custom_codes_skipped_500_no_failover",
|
||||||
|
account: &Account{
|
||||||
|
ID: 201,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Platform: PlatformGemini,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"custom_error_codes_enabled": true,
|
||||||
|
"custom_error_codes": []any{float64(429)},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
statusCode: 500,
|
||||||
|
respBody: []byte(`{"error":"internal"}`),
|
||||||
|
expectFailover: false,
|
||||||
|
expectHandleError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "temp_unschedulable_matched_failover",
|
||||||
|
account: &Account{
|
||||||
|
ID: 202,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Platform: PlatformGemini,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"temp_unschedulable_enabled": true,
|
||||||
|
"temp_unschedulable_rules": []any{
|
||||||
|
map[string]any{
|
||||||
|
"error_code": float64(503),
|
||||||
|
"keywords": []any{"overloaded"},
|
||||||
|
"duration_minutes": float64(10),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
statusCode: 503,
|
||||||
|
respBody: []byte(`overloaded`),
|
||||||
|
expectFailover: true,
|
||||||
|
expectHandleError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no_policy_429_failover_via_shouldFailover",
|
||||||
|
account: &Account{
|
||||||
|
ID: 203,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Platform: PlatformGemini,
|
||||||
|
},
|
||||||
|
statusCode: 429,
|
||||||
|
respBody: []byte(`{"error":"rate limited"}`),
|
||||||
|
expectFailover: true,
|
||||||
|
expectHandleError: true,
|
||||||
|
expectShouldFailover: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no_policy_400_no_failover",
|
||||||
|
account: &Account{
|
||||||
|
ID: 204,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Platform: PlatformGemini,
|
||||||
|
},
|
||||||
|
statusCode: 400,
|
||||||
|
respBody: []byte(`{"error":"bad request"}`),
|
||||||
|
expectFailover: false,
|
||||||
|
expectHandleError: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
repo := &geminiErrorPolicyRepo{}
|
||||||
|
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||||
|
svc := &GeminiMessagesCompatService{
|
||||||
|
accountRepo: repo,
|
||||||
|
rateLimitService: rlSvc,
|
||||||
|
}
|
||||||
|
|
||||||
|
writer := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(writer)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||||
|
|
||||||
|
// Simulate the Claude compat error handling path (same logic as native).
|
||||||
|
// This mirrors the inline switch in handleClaudeCompat.
|
||||||
|
var handleErrorCalled bool
|
||||||
|
var gotFailover bool
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
statusCode := tt.statusCode
|
||||||
|
respBody := tt.respBody
|
||||||
|
account := tt.account
|
||||||
|
headers := http.Header{}
|
||||||
|
|
||||||
|
if svc.rateLimitService != nil {
|
||||||
|
switch svc.rateLimitService.CheckErrorPolicy(ctx, account, statusCode, respBody) {
|
||||||
|
case ErrorPolicySkipped:
|
||||||
|
// Skipped → return error directly (no handleGeminiUpstreamError, no failover)
|
||||||
|
gotFailover = false
|
||||||
|
handleErrorCalled = false
|
||||||
|
goto verify
|
||||||
|
case ErrorPolicyMatched, ErrorPolicyTempUnscheduled:
|
||||||
|
svc.handleGeminiUpstreamError(ctx, account, statusCode, headers, respBody)
|
||||||
|
handleErrorCalled = true
|
||||||
|
gotFailover = true
|
||||||
|
goto verify
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrorPolicyNone → original logic
|
||||||
|
svc.handleGeminiUpstreamError(ctx, account, statusCode, headers, respBody)
|
||||||
|
handleErrorCalled = true
|
||||||
|
if svc.shouldFailoverGeminiUpstreamError(statusCode) {
|
||||||
|
gotFailover = true
|
||||||
|
}
|
||||||
|
|
||||||
|
verify:
|
||||||
|
require.Equal(t, tt.expectFailover, gotFailover, "failover mismatch")
|
||||||
|
require.Equal(t, tt.expectHandleError, handleErrorCalled, "handleGeminiUpstreamError call mismatch")
|
||||||
|
|
||||||
|
if tt.expectShouldFailover {
|
||||||
|
require.True(t, svc.shouldFailoverGeminiUpstreamError(statusCode),
|
||||||
|
"shouldFailoverGeminiUpstreamError should return true for status %d", statusCode)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// TestGeminiErrorPolicy_NilRateLimitService — verifies nil safety
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestGeminiErrorPolicy_NilRateLimitService(t *testing.T) {
|
||||||
|
svc := &GeminiMessagesCompatService{
|
||||||
|
rateLimitService: nil,
|
||||||
|
}
|
||||||
|
|
||||||
|
// When rateLimitService is nil, error policy is skipped → falls through to
|
||||||
|
// shouldFailoverGeminiUpstreamError (original logic).
|
||||||
|
// Verify this doesn't panic and follows expected behavior.
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
account := &Account{
|
||||||
|
ID: 300,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Platform: PlatformGemini,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"custom_error_codes_enabled": true,
|
||||||
|
"custom_error_codes": []any{float64(429)},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// The nil check should prevent CheckErrorPolicy from being called
|
||||||
|
if svc.rateLimitService != nil {
|
||||||
|
t.Fatal("rateLimitService should be nil for this test")
|
||||||
|
}
|
||||||
|
|
||||||
|
// shouldFailoverGeminiUpstreamError still works
|
||||||
|
require.True(t, svc.shouldFailoverGeminiUpstreamError(429))
|
||||||
|
require.False(t, svc.shouldFailoverGeminiUpstreamError(400))
|
||||||
|
|
||||||
|
// handleGeminiUpstreamError should not panic with nil rateLimitService
|
||||||
|
require.NotPanics(t, func() {
|
||||||
|
svc.handleGeminiUpstreamError(ctx, account, 500, http.Header{}, []byte(`error`))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// geminiErrorPolicyRepo — minimal AccountRepository stub for Gemini error
|
||||||
|
// policy tests. Embeds mockAccountRepoForGemini and adds tracking.
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
type geminiErrorPolicyRepo struct {
|
||||||
|
mockAccountRepoForGemini
|
||||||
|
setErrorCalls int
|
||||||
|
setRateLimitedCalls int
|
||||||
|
setTempCalls int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *geminiErrorPolicyRepo) SetError(_ context.Context, _ int64, _ string) error {
|
||||||
|
r.setErrorCalls++
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *geminiErrorPolicyRepo) SetRateLimited(_ context.Context, _ int64, _ time.Time) error {
|
||||||
|
r.setRateLimitedCalls++
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *geminiErrorPolicyRepo) SetTempUnschedulable(_ context.Context, _ int64, _ time.Time, _ string) error {
|
||||||
|
r.setTempCalls++
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -831,38 +831,47 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
|||||||
|
|
||||||
if resp.StatusCode >= 400 {
|
if resp.StatusCode >= 400 {
|
||||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||||
tempMatched := false
|
// 统一错误策略:自定义错误码 + 临时不可调度
|
||||||
if s.rateLimitService != nil {
|
if s.rateLimitService != nil {
|
||||||
tempMatched = s.rateLimitService.HandleTempUnschedulable(ctx, account, resp.StatusCode, respBody)
|
switch s.rateLimitService.CheckErrorPolicy(ctx, account, resp.StatusCode, respBody) {
|
||||||
}
|
case ErrorPolicySkipped:
|
||||||
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
upstreamReqID := resp.Header.Get(requestIDHeader)
|
||||||
if tempMatched {
|
if upstreamReqID == "" {
|
||||||
upstreamReqID := resp.Header.Get(requestIDHeader)
|
upstreamReqID = resp.Header.Get("x-goog-request-id")
|
||||||
if upstreamReqID == "" {
|
|
||||||
upstreamReqID = resp.Header.Get("x-goog-request-id")
|
|
||||||
}
|
|
||||||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
|
|
||||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
|
||||||
upstreamDetail := ""
|
|
||||||
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
|
||||||
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
|
||||||
if maxBytes <= 0 {
|
|
||||||
maxBytes = 2048
|
|
||||||
}
|
}
|
||||||
upstreamDetail = truncateString(string(respBody), maxBytes)
|
return nil, s.writeGeminiMappedError(c, account, resp.StatusCode, upstreamReqID, respBody)
|
||||||
|
case ErrorPolicyMatched, ErrorPolicyTempUnscheduled:
|
||||||
|
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||||
|
upstreamReqID := resp.Header.Get(requestIDHeader)
|
||||||
|
if upstreamReqID == "" {
|
||||||
|
upstreamReqID = resp.Header.Get("x-goog-request-id")
|
||||||
|
}
|
||||||
|
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
|
||||||
|
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||||
|
upstreamDetail := ""
|
||||||
|
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||||
|
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||||
|
if maxBytes <= 0 {
|
||||||
|
maxBytes = 2048
|
||||||
|
}
|
||||||
|
upstreamDetail = truncateString(string(respBody), maxBytes)
|
||||||
|
}
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
|
UpstreamRequestID: upstreamReqID,
|
||||||
|
Kind: "failover",
|
||||||
|
Message: upstreamMsg,
|
||||||
|
Detail: upstreamDetail,
|
||||||
|
})
|
||||||
|
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
|
||||||
}
|
}
|
||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
|
||||||
Platform: account.Platform,
|
|
||||||
AccountID: account.ID,
|
|
||||||
AccountName: account.Name,
|
|
||||||
UpstreamStatusCode: resp.StatusCode,
|
|
||||||
UpstreamRequestID: upstreamReqID,
|
|
||||||
Kind: "failover",
|
|
||||||
Message: upstreamMsg,
|
|
||||||
Detail: upstreamDetail,
|
|
||||||
})
|
|
||||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ErrorPolicyNone → 原有逻辑
|
||||||
|
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||||
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
|
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
|
||||||
upstreamReqID := resp.Header.Get(requestIDHeader)
|
upstreamReqID := resp.Header.Get(requestIDHeader)
|
||||||
if upstreamReqID == "" {
|
if upstreamReqID == "" {
|
||||||
@@ -1249,14 +1258,9 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
|||||||
|
|
||||||
if resp.StatusCode >= 400 {
|
if resp.StatusCode >= 400 {
|
||||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||||
tempMatched := false
|
|
||||||
if s.rateLimitService != nil {
|
|
||||||
tempMatched = s.rateLimitService.HandleTempUnschedulable(ctx, account, resp.StatusCode, respBody)
|
|
||||||
}
|
|
||||||
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
|
||||||
|
|
||||||
// Best-effort fallback for OAuth tokens missing AI Studio scopes when calling countTokens.
|
// Best-effort fallback for OAuth tokens missing AI Studio scopes when calling countTokens.
|
||||||
// This avoids Gemini SDKs failing hard during preflight token counting.
|
// This avoids Gemini SDKs failing hard during preflight token counting.
|
||||||
|
// Checked before error policy so it always works regardless of custom error codes.
|
||||||
if action == "countTokens" && isOAuth && isGeminiInsufficientScope(resp.Header, respBody) {
|
if action == "countTokens" && isOAuth && isGeminiInsufficientScope(resp.Header, respBody) {
|
||||||
estimated := estimateGeminiCountTokens(body)
|
estimated := estimateGeminiCountTokens(body)
|
||||||
c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated})
|
c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated})
|
||||||
@@ -1270,30 +1274,46 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if tempMatched {
|
// 统一错误策略:自定义错误码 + 临时不可调度
|
||||||
evBody := unwrapIfNeeded(isOAuth, respBody)
|
if s.rateLimitService != nil {
|
||||||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(evBody))
|
switch s.rateLimitService.CheckErrorPolicy(ctx, account, resp.StatusCode, respBody) {
|
||||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
case ErrorPolicySkipped:
|
||||||
upstreamDetail := ""
|
respBody = unwrapIfNeeded(isOAuth, respBody)
|
||||||
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
contentType := resp.Header.Get("Content-Type")
|
||||||
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
if contentType == "" {
|
||||||
if maxBytes <= 0 {
|
contentType = "application/json"
|
||||||
maxBytes = 2048
|
|
||||||
}
|
}
|
||||||
upstreamDetail = truncateString(string(evBody), maxBytes)
|
c.Data(resp.StatusCode, contentType, respBody)
|
||||||
|
return nil, fmt.Errorf("gemini upstream error: %d (skipped by error policy)", resp.StatusCode)
|
||||||
|
case ErrorPolicyMatched, ErrorPolicyTempUnscheduled:
|
||||||
|
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||||
|
evBody := unwrapIfNeeded(isOAuth, respBody)
|
||||||
|
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(evBody))
|
||||||
|
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||||
|
upstreamDetail := ""
|
||||||
|
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||||
|
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||||
|
if maxBytes <= 0 {
|
||||||
|
maxBytes = 2048
|
||||||
|
}
|
||||||
|
upstreamDetail = truncateString(string(evBody), maxBytes)
|
||||||
|
}
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
|
UpstreamRequestID: requestID,
|
||||||
|
Kind: "failover",
|
||||||
|
Message: upstreamMsg,
|
||||||
|
Detail: upstreamDetail,
|
||||||
|
})
|
||||||
|
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
|
||||||
}
|
}
|
||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
|
||||||
Platform: account.Platform,
|
|
||||||
AccountID: account.ID,
|
|
||||||
AccountName: account.Name,
|
|
||||||
UpstreamStatusCode: resp.StatusCode,
|
|
||||||
UpstreamRequestID: requestID,
|
|
||||||
Kind: "failover",
|
|
||||||
Message: upstreamMsg,
|
|
||||||
Detail: upstreamDetail,
|
|
||||||
})
|
|
||||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ErrorPolicyNone → 原有逻辑
|
||||||
|
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||||
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
|
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
|
||||||
evBody := unwrapIfNeeded(isOAuth, respBody)
|
evBody := unwrapIfNeeded(isOAuth, respBody)
|
||||||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(evBody))
|
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(evBody))
|
||||||
|
|||||||
@@ -133,9 +133,6 @@ func (m *mockAccountRepoForGemini) ListSchedulableByGroupIDAndPlatforms(ctx cont
|
|||||||
func (m *mockAccountRepoForGemini) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
func (m *mockAccountRepoForGemini) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
func (m *mockAccountRepoForGemini) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
func (m *mockAccountRepoForGemini) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
|
func (m *mockAccountRepoForGemini) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -265,29 +262,6 @@ func (m *mockGatewayCacheForGemini) DeleteSessionAccountID(ctx context.Context,
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockGatewayCacheForGemini) IncrModelCallCount(ctx context.Context, accountID int64, model string) (int64, error) {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockGatewayCacheForGemini) GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*ModelLoadInfo, error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockGatewayCacheForGemini) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
|
|
||||||
return "", 0, false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockGatewayCacheForGemini) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockGatewayCacheForGemini) FindAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
|
|
||||||
return "", 0, false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockGatewayCacheForGemini) SaveAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform 测试 Gemini 单平台选择
|
// TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform 测试 Gemini 单平台选择
|
||||||
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform(t *testing.T) {
|
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform(t *testing.T) {
|
||||||
|
|||||||
@@ -6,26 +6,11 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||||
"github.com/cespare/xxhash/v2"
|
"github.com/cespare/xxhash/v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Gemini 会话 ID Fallback 相关常量
|
|
||||||
const (
|
|
||||||
// geminiSessionTTLSeconds Gemini 会话缓存 TTL(5 分钟)
|
|
||||||
geminiSessionTTLSeconds = 300
|
|
||||||
|
|
||||||
// geminiSessionKeyPrefix Gemini 会话 Redis key 前缀
|
|
||||||
geminiSessionKeyPrefix = "gemini:sess:"
|
|
||||||
)
|
|
||||||
|
|
||||||
// GeminiSessionTTL 返回 Gemini 会话缓存 TTL
|
|
||||||
func GeminiSessionTTL() time.Duration {
|
|
||||||
return geminiSessionTTLSeconds * time.Second
|
|
||||||
}
|
|
||||||
|
|
||||||
// shortHash 使用 XXHash64 + Base36 生成短 hash(16 字符)
|
// shortHash 使用 XXHash64 + Base36 生成短 hash(16 字符)
|
||||||
// XXHash64 比 SHA256 快约 10 倍,Base36 比 Hex 短约 20%
|
// XXHash64 比 SHA256 快约 10 倍,Base36 比 Hex 短约 20%
|
||||||
func shortHash(data []byte) string {
|
func shortHash(data []byte) string {
|
||||||
@@ -79,35 +64,6 @@ func GenerateGeminiPrefixHash(userID, apiKeyID int64, ip, userAgent, platform, m
|
|||||||
return base64.RawURLEncoding.EncodeToString(hash[:12])
|
return base64.RawURLEncoding.EncodeToString(hash[:12])
|
||||||
}
|
}
|
||||||
|
|
||||||
// BuildGeminiSessionKey 构建 Gemini 会话 Redis key
|
|
||||||
// 格式: gemini:sess:{groupID}:{prefixHash}:{digestChain}
|
|
||||||
func BuildGeminiSessionKey(groupID int64, prefixHash, digestChain string) string {
|
|
||||||
return geminiSessionKeyPrefix + strconv.FormatInt(groupID, 10) + ":" + prefixHash + ":" + digestChain
|
|
||||||
}
|
|
||||||
|
|
||||||
// GenerateDigestChainPrefixes 生成摘要链的所有前缀(从长到短)
|
|
||||||
// 用于 MGET 批量查询最长匹配
|
|
||||||
func GenerateDigestChainPrefixes(chain string) []string {
|
|
||||||
if chain == "" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var prefixes []string
|
|
||||||
c := chain
|
|
||||||
|
|
||||||
for c != "" {
|
|
||||||
prefixes = append(prefixes, c)
|
|
||||||
// 找到最后一个 "-" 的位置
|
|
||||||
if i := strings.LastIndex(c, "-"); i > 0 {
|
|
||||||
c = c[:i]
|
|
||||||
} else {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return prefixes
|
|
||||||
}
|
|
||||||
|
|
||||||
// ParseGeminiSessionValue 解析 Gemini 会话缓存值
|
// ParseGeminiSessionValue 解析 Gemini 会话缓存值
|
||||||
// 格式: {uuid}:{accountID}
|
// 格式: {uuid}:{accountID}
|
||||||
func ParseGeminiSessionValue(value string) (uuid string, accountID int64, ok bool) {
|
func ParseGeminiSessionValue(value string) (uuid string, accountID int64, ok bool) {
|
||||||
@@ -139,15 +95,6 @@ func FormatGeminiSessionValue(uuid string, accountID int64) string {
|
|||||||
// geminiDigestSessionKeyPrefix Gemini 摘要 fallback 会话 key 前缀
|
// geminiDigestSessionKeyPrefix Gemini 摘要 fallback 会话 key 前缀
|
||||||
const geminiDigestSessionKeyPrefix = "gemini:digest:"
|
const geminiDigestSessionKeyPrefix = "gemini:digest:"
|
||||||
|
|
||||||
// geminiTrieKeyPrefix Gemini Trie 会话 key 前缀
|
|
||||||
const geminiTrieKeyPrefix = "gemini:trie:"
|
|
||||||
|
|
||||||
// BuildGeminiTrieKey 构建 Gemini Trie Redis key
|
|
||||||
// 格式: gemini:trie:{groupID}:{prefixHash}
|
|
||||||
func BuildGeminiTrieKey(groupID int64, prefixHash string) string {
|
|
||||||
return geminiTrieKeyPrefix + strconv.FormatInt(groupID, 10) + ":" + prefixHash
|
|
||||||
}
|
|
||||||
|
|
||||||
// GenerateGeminiDigestSessionKey 生成 Gemini 摘要 fallback 的 sessionKey
|
// GenerateGeminiDigestSessionKey 生成 Gemini 摘要 fallback 的 sessionKey
|
||||||
// 组合 prefixHash 前 8 位 + uuid 前 8 位,确保不同会话产生不同的 sessionKey
|
// 组合 prefixHash 前 8 位 + uuid 前 8 位,确保不同会话产生不同的 sessionKey
|
||||||
// 用于在 SelectAccountWithLoadAwareness 中保持粘性会话
|
// 用于在 SelectAccountWithLoadAwareness 中保持粘性会话
|
||||||
|
|||||||
@@ -1,41 +1,14 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||||
)
|
)
|
||||||
|
|
||||||
// mockGeminiSessionCache 模拟 Redis 缓存
|
|
||||||
type mockGeminiSessionCache struct {
|
|
||||||
sessions map[string]string // key -> value
|
|
||||||
}
|
|
||||||
|
|
||||||
func newMockGeminiSessionCache() *mockGeminiSessionCache {
|
|
||||||
return &mockGeminiSessionCache{sessions: make(map[string]string)}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockGeminiSessionCache) Save(groupID int64, prefixHash, digestChain, uuid string, accountID int64) {
|
|
||||||
key := BuildGeminiSessionKey(groupID, prefixHash, digestChain)
|
|
||||||
value := FormatGeminiSessionValue(uuid, accountID)
|
|
||||||
m.sessions[key] = value
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockGeminiSessionCache) Find(groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
|
|
||||||
prefixes := GenerateDigestChainPrefixes(digestChain)
|
|
||||||
for _, p := range prefixes {
|
|
||||||
key := BuildGeminiSessionKey(groupID, prefixHash, p)
|
|
||||||
if val, ok := m.sessions[key]; ok {
|
|
||||||
return ParseGeminiSessionValue(val)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return "", 0, false
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestGeminiSessionContinuousConversation 测试连续会话的摘要链匹配
|
// TestGeminiSessionContinuousConversation 测试连续会话的摘要链匹配
|
||||||
func TestGeminiSessionContinuousConversation(t *testing.T) {
|
func TestGeminiSessionContinuousConversation(t *testing.T) {
|
||||||
cache := newMockGeminiSessionCache()
|
store := NewDigestSessionStore()
|
||||||
groupID := int64(1)
|
groupID := int64(1)
|
||||||
prefixHash := "test_prefix_hash"
|
prefixHash := "test_prefix_hash"
|
||||||
sessionUUID := "session-uuid-12345"
|
sessionUUID := "session-uuid-12345"
|
||||||
@@ -54,13 +27,13 @@ func TestGeminiSessionContinuousConversation(t *testing.T) {
|
|||||||
t.Logf("Round 1 chain: %s", chain1)
|
t.Logf("Round 1 chain: %s", chain1)
|
||||||
|
|
||||||
// 第一轮:没有找到会话,创建新会话
|
// 第一轮:没有找到会话,创建新会话
|
||||||
_, _, found := cache.Find(groupID, prefixHash, chain1)
|
_, _, _, found := store.Find(groupID, prefixHash, chain1)
|
||||||
if found {
|
if found {
|
||||||
t.Error("Round 1: should not find existing session")
|
t.Error("Round 1: should not find existing session")
|
||||||
}
|
}
|
||||||
|
|
||||||
// 保存第一轮会话
|
// 保存第一轮会话(首轮无旧 chain)
|
||||||
cache.Save(groupID, prefixHash, chain1, sessionUUID, accountID)
|
store.Save(groupID, prefixHash, chain1, sessionUUID, accountID, "")
|
||||||
|
|
||||||
// 模拟第二轮对话(用户继续对话)
|
// 模拟第二轮对话(用户继续对话)
|
||||||
req2 := &antigravity.GeminiRequest{
|
req2 := &antigravity.GeminiRequest{
|
||||||
@@ -77,7 +50,7 @@ func TestGeminiSessionContinuousConversation(t *testing.T) {
|
|||||||
t.Logf("Round 2 chain: %s", chain2)
|
t.Logf("Round 2 chain: %s", chain2)
|
||||||
|
|
||||||
// 第二轮:应该能找到会话(通过前缀匹配)
|
// 第二轮:应该能找到会话(通过前缀匹配)
|
||||||
foundUUID, foundAccID, found := cache.Find(groupID, prefixHash, chain2)
|
foundUUID, foundAccID, matchedChain, found := store.Find(groupID, prefixHash, chain2)
|
||||||
if !found {
|
if !found {
|
||||||
t.Error("Round 2: should find session via prefix matching")
|
t.Error("Round 2: should find session via prefix matching")
|
||||||
}
|
}
|
||||||
@@ -88,8 +61,8 @@ func TestGeminiSessionContinuousConversation(t *testing.T) {
|
|||||||
t.Errorf("Round 2: expected accountID %d, got %d", accountID, foundAccID)
|
t.Errorf("Round 2: expected accountID %d, got %d", accountID, foundAccID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 保存第二轮会话
|
// 保存第二轮会话,传入 Find 返回的 matchedChain 以删旧 key
|
||||||
cache.Save(groupID, prefixHash, chain2, sessionUUID, accountID)
|
store.Save(groupID, prefixHash, chain2, sessionUUID, accountID, matchedChain)
|
||||||
|
|
||||||
// 模拟第三轮对话
|
// 模拟第三轮对话
|
||||||
req3 := &antigravity.GeminiRequest{
|
req3 := &antigravity.GeminiRequest{
|
||||||
@@ -108,7 +81,7 @@ func TestGeminiSessionContinuousConversation(t *testing.T) {
|
|||||||
t.Logf("Round 3 chain: %s", chain3)
|
t.Logf("Round 3 chain: %s", chain3)
|
||||||
|
|
||||||
// 第三轮:应该能找到会话(通过第二轮的前缀匹配)
|
// 第三轮:应该能找到会话(通过第二轮的前缀匹配)
|
||||||
foundUUID, foundAccID, found = cache.Find(groupID, prefixHash, chain3)
|
foundUUID, foundAccID, _, found = store.Find(groupID, prefixHash, chain3)
|
||||||
if !found {
|
if !found {
|
||||||
t.Error("Round 3: should find session via prefix matching")
|
t.Error("Round 3: should find session via prefix matching")
|
||||||
}
|
}
|
||||||
@@ -118,13 +91,11 @@ func TestGeminiSessionContinuousConversation(t *testing.T) {
|
|||||||
if foundAccID != accountID {
|
if foundAccID != accountID {
|
||||||
t.Errorf("Round 3: expected accountID %d, got %d", accountID, foundAccID)
|
t.Errorf("Round 3: expected accountID %d, got %d", accountID, foundAccID)
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Log("✓ Continuous conversation session matching works correctly!")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestGeminiSessionDifferentConversations 测试不同会话不会错误匹配
|
// TestGeminiSessionDifferentConversations 测试不同会话不会错误匹配
|
||||||
func TestGeminiSessionDifferentConversations(t *testing.T) {
|
func TestGeminiSessionDifferentConversations(t *testing.T) {
|
||||||
cache := newMockGeminiSessionCache()
|
store := NewDigestSessionStore()
|
||||||
groupID := int64(1)
|
groupID := int64(1)
|
||||||
prefixHash := "test_prefix_hash"
|
prefixHash := "test_prefix_hash"
|
||||||
|
|
||||||
@@ -135,7 +106,7 @@ func TestGeminiSessionDifferentConversations(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
chain1 := BuildGeminiDigestChain(req1)
|
chain1 := BuildGeminiDigestChain(req1)
|
||||||
cache.Save(groupID, prefixHash, chain1, "session-1", 100)
|
store.Save(groupID, prefixHash, chain1, "session-1", 100, "")
|
||||||
|
|
||||||
// 第二个完全不同的会话
|
// 第二个完全不同的会话
|
||||||
req2 := &antigravity.GeminiRequest{
|
req2 := &antigravity.GeminiRequest{
|
||||||
@@ -146,61 +117,29 @@ func TestGeminiSessionDifferentConversations(t *testing.T) {
|
|||||||
chain2 := BuildGeminiDigestChain(req2)
|
chain2 := BuildGeminiDigestChain(req2)
|
||||||
|
|
||||||
// 不同会话不应该匹配
|
// 不同会话不应该匹配
|
||||||
_, _, found := cache.Find(groupID, prefixHash, chain2)
|
_, _, _, found := store.Find(groupID, prefixHash, chain2)
|
||||||
if found {
|
if found {
|
||||||
t.Error("Different conversations should not match")
|
t.Error("Different conversations should not match")
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Log("✓ Different conversations are correctly isolated!")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestGeminiSessionPrefixMatchingOrder 测试前缀匹配的优先级(最长匹配优先)
|
// TestGeminiSessionPrefixMatchingOrder 测试前缀匹配的优先级(最长匹配优先)
|
||||||
func TestGeminiSessionPrefixMatchingOrder(t *testing.T) {
|
func TestGeminiSessionPrefixMatchingOrder(t *testing.T) {
|
||||||
cache := newMockGeminiSessionCache()
|
store := NewDigestSessionStore()
|
||||||
groupID := int64(1)
|
groupID := int64(1)
|
||||||
prefixHash := "test_prefix_hash"
|
prefixHash := "test_prefix_hash"
|
||||||
|
|
||||||
// 创建一个三轮对话
|
|
||||||
req := &antigravity.GeminiRequest{
|
|
||||||
SystemInstruction: &antigravity.GeminiContent{
|
|
||||||
Parts: []antigravity.GeminiPart{{Text: "System prompt"}},
|
|
||||||
},
|
|
||||||
Contents: []antigravity.GeminiContent{
|
|
||||||
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "Q1"}}},
|
|
||||||
{Role: "model", Parts: []antigravity.GeminiPart{{Text: "A1"}}},
|
|
||||||
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "Q2"}}},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
fullChain := BuildGeminiDigestChain(req)
|
|
||||||
prefixes := GenerateDigestChainPrefixes(fullChain)
|
|
||||||
|
|
||||||
t.Logf("Full chain: %s", fullChain)
|
|
||||||
t.Logf("Prefixes (longest first): %v", prefixes)
|
|
||||||
|
|
||||||
// 验证前缀生成顺序(从长到短)
|
|
||||||
if len(prefixes) != 4 {
|
|
||||||
t.Errorf("Expected 4 prefixes, got %d", len(prefixes))
|
|
||||||
}
|
|
||||||
|
|
||||||
// 保存不同轮次的会话到不同账号
|
// 保存不同轮次的会话到不同账号
|
||||||
// 第一轮(最短前缀)-> 账号 1
|
store.Save(groupID, prefixHash, "s:sys-u:q1", "session-round1", 1, "")
|
||||||
cache.Save(groupID, prefixHash, prefixes[3], "session-round1", 1)
|
store.Save(groupID, prefixHash, "s:sys-u:q1-m:a1", "session-round2", 2, "")
|
||||||
// 第二轮 -> 账号 2
|
store.Save(groupID, prefixHash, "s:sys-u:q1-m:a1-u:q2", "session-round3", 3, "")
|
||||||
cache.Save(groupID, prefixHash, prefixes[2], "session-round2", 2)
|
|
||||||
// 第三轮(最长前缀,完整链)-> 账号 3
|
|
||||||
cache.Save(groupID, prefixHash, prefixes[0], "session-round3", 3)
|
|
||||||
|
|
||||||
// 查找应该返回最长匹配(账号 3)
|
// 查找更长的链,应该返回最长匹配(账号 3)
|
||||||
_, accID, found := cache.Find(groupID, prefixHash, fullChain)
|
_, accID, _, found := store.Find(groupID, prefixHash, "s:sys-u:q1-m:a1-u:q2-m:a2")
|
||||||
if !found {
|
if !found {
|
||||||
t.Error("Should find session")
|
t.Error("Should find session")
|
||||||
}
|
}
|
||||||
if accID != 3 {
|
if accID != 3 {
|
||||||
t.Errorf("Should match longest prefix (account 3), got account %d", accID)
|
t.Errorf("Should match longest prefix (account 3), got account %d", accID)
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Log("✓ Longest prefix matching works correctly!")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 确保 context 包被使用(避免未使用的导入警告)
|
|
||||||
var _ = context.Background
|
|
||||||
|
|||||||
@@ -152,61 +152,6 @@ func TestGenerateGeminiPrefixHash(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGenerateDigestChainPrefixes(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
chain string
|
|
||||||
want []string
|
|
||||||
wantLen int
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "empty",
|
|
||||||
chain: "",
|
|
||||||
wantLen: 0,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "single part",
|
|
||||||
chain: "u:abc123",
|
|
||||||
want: []string{"u:abc123"},
|
|
||||||
wantLen: 1,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "two parts",
|
|
||||||
chain: "s:xyz-u:abc",
|
|
||||||
want: []string{"s:xyz-u:abc", "s:xyz"},
|
|
||||||
wantLen: 2,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "four parts",
|
|
||||||
chain: "s:a-u:b-m:c-u:d",
|
|
||||||
want: []string{"s:a-u:b-m:c-u:d", "s:a-u:b-m:c", "s:a-u:b", "s:a"},
|
|
||||||
wantLen: 4,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
result := GenerateDigestChainPrefixes(tt.chain)
|
|
||||||
|
|
||||||
if len(result) != tt.wantLen {
|
|
||||||
t.Errorf("expected %d prefixes, got %d: %v", tt.wantLen, len(result), result)
|
|
||||||
}
|
|
||||||
|
|
||||||
if tt.want != nil {
|
|
||||||
for i, want := range tt.want {
|
|
||||||
if i >= len(result) {
|
|
||||||
t.Errorf("missing prefix at index %d", i)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if result[i] != want {
|
|
||||||
t.Errorf("prefix[%d]: expected %s, got %s", i, want, result[i])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParseGeminiSessionValue(t *testing.T) {
|
func TestParseGeminiSessionValue(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -442,40 +387,3 @@ func TestGenerateGeminiDigestSessionKey(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBuildGeminiTrieKey(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
groupID int64
|
|
||||||
prefixHash string
|
|
||||||
want string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "normal",
|
|
||||||
groupID: 123,
|
|
||||||
prefixHash: "abcdef12",
|
|
||||||
want: "gemini:trie:123:abcdef12",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "zero group",
|
|
||||||
groupID: 0,
|
|
||||||
prefixHash: "xyz",
|
|
||||||
want: "gemini:trie:0:xyz",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "empty prefix",
|
|
||||||
groupID: 1,
|
|
||||||
prefixHash: "",
|
|
||||||
want: "gemini:trie:1:",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
got := BuildGeminiTrieKey(tt.groupID, tt.prefixHash)
|
|
||||||
if got != tt.want {
|
|
||||||
t.Errorf("BuildGeminiTrieKey(%d, %q) = %q, want %q", tt.groupID, tt.prefixHash, got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
1213
backend/internal/service/generate_session_hash_test.go
Normal file
1213
backend/internal/service/generate_session_hash_test.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -318,110 +318,6 @@ func TestGetModelRateLimitRemainingTime(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetQuotaScopeRateLimitRemainingTime(t *testing.T) {
|
|
||||||
now := time.Now()
|
|
||||||
future10m := now.Add(10 * time.Minute).Format(time.RFC3339)
|
|
||||||
past := now.Add(-10 * time.Minute).Format(time.RFC3339)
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
account *Account
|
|
||||||
requestedModel string
|
|
||||||
minExpected time.Duration
|
|
||||||
maxExpected time.Duration
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "nil account",
|
|
||||||
account: nil,
|
|
||||||
requestedModel: "claude-sonnet-4-5",
|
|
||||||
minExpected: 0,
|
|
||||||
maxExpected: 0,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "non-antigravity platform",
|
|
||||||
account: &Account{
|
|
||||||
Platform: PlatformAnthropic,
|
|
||||||
Extra: map[string]any{
|
|
||||||
antigravityQuotaScopesKey: map[string]any{
|
|
||||||
"claude": map[string]any{
|
|
||||||
"rate_limit_reset_at": future10m,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
requestedModel: "claude-sonnet-4-5",
|
|
||||||
minExpected: 0,
|
|
||||||
maxExpected: 0,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "claude scope rate limited",
|
|
||||||
account: &Account{
|
|
||||||
Platform: PlatformAntigravity,
|
|
||||||
Extra: map[string]any{
|
|
||||||
antigravityQuotaScopesKey: map[string]any{
|
|
||||||
"claude": map[string]any{
|
|
||||||
"rate_limit_reset_at": future10m,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
requestedModel: "claude-sonnet-4-5",
|
|
||||||
minExpected: 9 * time.Minute,
|
|
||||||
maxExpected: 11 * time.Minute,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "gemini_text scope rate limited",
|
|
||||||
account: &Account{
|
|
||||||
Platform: PlatformAntigravity,
|
|
||||||
Extra: map[string]any{
|
|
||||||
antigravityQuotaScopesKey: map[string]any{
|
|
||||||
"gemini_text": map[string]any{
|
|
||||||
"rate_limit_reset_at": future10m,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
requestedModel: "gemini-3-flash",
|
|
||||||
minExpected: 9 * time.Minute,
|
|
||||||
maxExpected: 11 * time.Minute,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "expired scope rate limit",
|
|
||||||
account: &Account{
|
|
||||||
Platform: PlatformAntigravity,
|
|
||||||
Extra: map[string]any{
|
|
||||||
antigravityQuotaScopesKey: map[string]any{
|
|
||||||
"claude": map[string]any{
|
|
||||||
"rate_limit_reset_at": past,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
requestedModel: "claude-sonnet-4-5",
|
|
||||||
minExpected: 0,
|
|
||||||
maxExpected: 0,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "unsupported model",
|
|
||||||
account: &Account{
|
|
||||||
Platform: PlatformAntigravity,
|
|
||||||
},
|
|
||||||
requestedModel: "gpt-4",
|
|
||||||
minExpected: 0,
|
|
||||||
maxExpected: 0,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
result := tt.account.GetQuotaScopeRateLimitRemainingTime(tt.requestedModel)
|
|
||||||
if result < tt.minExpected || result > tt.maxExpected {
|
|
||||||
t.Errorf("GetQuotaScopeRateLimitRemainingTime() = %v, want between %v and %v", result, tt.minExpected, tt.maxExpected)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetRateLimitRemainingTime(t *testing.T) {
|
func TestGetRateLimitRemainingTime(t *testing.T) {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
future15m := now.Add(15 * time.Minute).Format(time.RFC3339)
|
future15m := now.Add(15 * time.Minute).Format(time.RFC3339)
|
||||||
@@ -442,45 +338,19 @@ func TestGetRateLimitRemainingTime(t *testing.T) {
|
|||||||
maxExpected: 0,
|
maxExpected: 0,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "model remaining > scope remaining - returns model",
|
name: "model rate limited - 15 minutes",
|
||||||
account: &Account{
|
account: &Account{
|
||||||
Platform: PlatformAntigravity,
|
Platform: PlatformAntigravity,
|
||||||
Extra: map[string]any{
|
Extra: map[string]any{
|
||||||
modelRateLimitsKey: map[string]any{
|
modelRateLimitsKey: map[string]any{
|
||||||
"claude-sonnet-4-5": map[string]any{
|
"claude-sonnet-4-5": map[string]any{
|
||||||
"rate_limit_reset_at": future15m, // 15 分钟
|
"rate_limit_reset_at": future15m,
|
||||||
},
|
|
||||||
},
|
|
||||||
antigravityQuotaScopesKey: map[string]any{
|
|
||||||
"claude": map[string]any{
|
|
||||||
"rate_limit_reset_at": future5m, // 5 分钟
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
requestedModel: "claude-sonnet-4-5",
|
requestedModel: "claude-sonnet-4-5",
|
||||||
minExpected: 14 * time.Minute, // 应返回较大的 15 分钟
|
minExpected: 14 * time.Minute,
|
||||||
maxExpected: 16 * time.Minute,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "scope remaining > model remaining - returns scope",
|
|
||||||
account: &Account{
|
|
||||||
Platform: PlatformAntigravity,
|
|
||||||
Extra: map[string]any{
|
|
||||||
modelRateLimitsKey: map[string]any{
|
|
||||||
"claude-sonnet-4-5": map[string]any{
|
|
||||||
"rate_limit_reset_at": future5m, // 5 分钟
|
|
||||||
},
|
|
||||||
},
|
|
||||||
antigravityQuotaScopesKey: map[string]any{
|
|
||||||
"claude": map[string]any{
|
|
||||||
"rate_limit_reset_at": future15m, // 15 分钟
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
requestedModel: "claude-sonnet-4-5",
|
|
||||||
minExpected: 14 * time.Minute, // 应返回较大的 15 分钟
|
|
||||||
maxExpected: 16 * time.Minute,
|
maxExpected: 16 * time.Minute,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -499,22 +369,6 @@ func TestGetRateLimitRemainingTime(t *testing.T) {
|
|||||||
minExpected: 4 * time.Minute,
|
minExpected: 4 * time.Minute,
|
||||||
maxExpected: 6 * time.Minute,
|
maxExpected: 6 * time.Minute,
|
||||||
},
|
},
|
||||||
{
|
|
||||||
name: "only scope rate limited",
|
|
||||||
account: &Account{
|
|
||||||
Platform: PlatformAntigravity,
|
|
||||||
Extra: map[string]any{
|
|
||||||
antigravityQuotaScopesKey: map[string]any{
|
|
||||||
"claude": map[string]any{
|
|
||||||
"rate_limit_reset_at": future5m,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
requestedModel: "claude-sonnet-4-5",
|
|
||||||
minExpected: 4 * time.Minute,
|
|
||||||
maxExpected: 6 * time.Minute,
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
name: "neither rate limited",
|
name: "neither rate limited",
|
||||||
account: &Account{
|
account: &Account{
|
||||||
|
|||||||
@@ -580,10 +580,6 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
type accountWithLoad struct {
|
|
||||||
account *Account
|
|
||||||
loadInfo *AccountLoadInfo
|
|
||||||
}
|
|
||||||
var available []accountWithLoad
|
var available []accountWithLoad
|
||||||
for _, acc := range candidates {
|
for _, acc := range candidates {
|
||||||
loadInfo := loadMap[acc.ID]
|
loadInfo := loadMap[acc.ID]
|
||||||
@@ -618,6 +614,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
|
|||||||
return a.account.LastUsedAt.Before(*b.account.LastUsedAt)
|
return a.account.LastUsedAt.Before(*b.account.LastUsedAt)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
shuffleWithinSortGroups(available)
|
||||||
|
|
||||||
for _, item := range available {
|
for _, item := range available {
|
||||||
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
|
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
|
||||||
|
|||||||
@@ -204,30 +204,6 @@ func (c *stubGatewayCache) DeleteSessionAccountID(ctx context.Context, groupID i
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *stubGatewayCache) IncrModelCallCount(ctx context.Context, accountID int64, model string) (int64, error) {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *stubGatewayCache) GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*ModelLoadInfo, error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *stubGatewayCache) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
|
|
||||||
return "", 0, false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *stubGatewayCache) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *stubGatewayCache) FindAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
|
|
||||||
return "", 0, false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *stubGatewayCache) SaveAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulable(t *testing.T) {
|
func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulable(t *testing.T) {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
resetAt := now.Add(10 * time.Minute)
|
resetAt := now.Add(10 * time.Minute)
|
||||||
|
|||||||
@@ -66,7 +66,6 @@ func (s *OpsService) GetAccountAvailabilityStats(ctx context.Context, platformFi
|
|||||||
}
|
}
|
||||||
|
|
||||||
isAvailable := acc.Status == StatusActive && acc.Schedulable && !isRateLimited && !isOverloaded && !isTempUnsched
|
isAvailable := acc.Status == StatusActive && acc.Schedulable && !isRateLimited && !isOverloaded && !isTempUnsched
|
||||||
scopeRateLimits := acc.GetAntigravityScopeRateLimits()
|
|
||||||
|
|
||||||
if acc.Platform != "" {
|
if acc.Platform != "" {
|
||||||
if _, ok := platform[acc.Platform]; !ok {
|
if _, ok := platform[acc.Platform]; !ok {
|
||||||
@@ -85,14 +84,6 @@ func (s *OpsService) GetAccountAvailabilityStats(ctx context.Context, platformFi
|
|||||||
if hasError {
|
if hasError {
|
||||||
p.ErrorCount++
|
p.ErrorCount++
|
||||||
}
|
}
|
||||||
if len(scopeRateLimits) > 0 {
|
|
||||||
if p.ScopeRateLimitCount == nil {
|
|
||||||
p.ScopeRateLimitCount = make(map[string]int64)
|
|
||||||
}
|
|
||||||
for scope := range scopeRateLimits {
|
|
||||||
p.ScopeRateLimitCount[scope]++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, grp := range acc.Groups {
|
for _, grp := range acc.Groups {
|
||||||
@@ -117,14 +108,6 @@ func (s *OpsService) GetAccountAvailabilityStats(ctx context.Context, platformFi
|
|||||||
if hasError {
|
if hasError {
|
||||||
g.ErrorCount++
|
g.ErrorCount++
|
||||||
}
|
}
|
||||||
if len(scopeRateLimits) > 0 {
|
|
||||||
if g.ScopeRateLimitCount == nil {
|
|
||||||
g.ScopeRateLimitCount = make(map[string]int64)
|
|
||||||
}
|
|
||||||
for scope := range scopeRateLimits {
|
|
||||||
g.ScopeRateLimitCount[scope]++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
displayGroupID := int64(0)
|
displayGroupID := int64(0)
|
||||||
@@ -157,9 +140,6 @@ func (s *OpsService) GetAccountAvailabilityStats(ctx context.Context, platformFi
|
|||||||
item.RateLimitRemainingSec = &remainingSec
|
item.RateLimitRemainingSec = &remainingSec
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(scopeRateLimits) > 0 {
|
|
||||||
item.ScopeRateLimits = scopeRateLimits
|
|
||||||
}
|
|
||||||
if isOverloaded && acc.OverloadUntil != nil {
|
if isOverloaded && acc.OverloadUntil != nil {
|
||||||
item.OverloadUntil = acc.OverloadUntil
|
item.OverloadUntil = acc.OverloadUntil
|
||||||
remainingSec := int64(time.Until(*acc.OverloadUntil).Seconds())
|
remainingSec := int64(time.Until(*acc.OverloadUntil).Seconds())
|
||||||
|
|||||||
@@ -50,24 +50,22 @@ type UserConcurrencyInfo struct {
|
|||||||
|
|
||||||
// PlatformAvailability aggregates account availability by platform.
|
// PlatformAvailability aggregates account availability by platform.
|
||||||
type PlatformAvailability struct {
|
type PlatformAvailability struct {
|
||||||
Platform string `json:"platform"`
|
Platform string `json:"platform"`
|
||||||
TotalAccounts int64 `json:"total_accounts"`
|
TotalAccounts int64 `json:"total_accounts"`
|
||||||
AvailableCount int64 `json:"available_count"`
|
AvailableCount int64 `json:"available_count"`
|
||||||
RateLimitCount int64 `json:"rate_limit_count"`
|
RateLimitCount int64 `json:"rate_limit_count"`
|
||||||
ScopeRateLimitCount map[string]int64 `json:"scope_rate_limit_count,omitempty"`
|
ErrorCount int64 `json:"error_count"`
|
||||||
ErrorCount int64 `json:"error_count"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GroupAvailability aggregates account availability by group.
|
// GroupAvailability aggregates account availability by group.
|
||||||
type GroupAvailability struct {
|
type GroupAvailability struct {
|
||||||
GroupID int64 `json:"group_id"`
|
GroupID int64 `json:"group_id"`
|
||||||
GroupName string `json:"group_name"`
|
GroupName string `json:"group_name"`
|
||||||
Platform string `json:"platform"`
|
Platform string `json:"platform"`
|
||||||
TotalAccounts int64 `json:"total_accounts"`
|
TotalAccounts int64 `json:"total_accounts"`
|
||||||
AvailableCount int64 `json:"available_count"`
|
AvailableCount int64 `json:"available_count"`
|
||||||
RateLimitCount int64 `json:"rate_limit_count"`
|
RateLimitCount int64 `json:"rate_limit_count"`
|
||||||
ScopeRateLimitCount map[string]int64 `json:"scope_rate_limit_count,omitempty"`
|
ErrorCount int64 `json:"error_count"`
|
||||||
ErrorCount int64 `json:"error_count"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// AccountAvailability represents current availability for a single account.
|
// AccountAvailability represents current availability for a single account.
|
||||||
@@ -85,11 +83,10 @@ type AccountAvailability struct {
|
|||||||
IsOverloaded bool `json:"is_overloaded"`
|
IsOverloaded bool `json:"is_overloaded"`
|
||||||
HasError bool `json:"has_error"`
|
HasError bool `json:"has_error"`
|
||||||
|
|
||||||
RateLimitResetAt *time.Time `json:"rate_limit_reset_at"`
|
RateLimitResetAt *time.Time `json:"rate_limit_reset_at"`
|
||||||
RateLimitRemainingSec *int64 `json:"rate_limit_remaining_sec"`
|
RateLimitRemainingSec *int64 `json:"rate_limit_remaining_sec"`
|
||||||
ScopeRateLimits map[string]int64 `json:"scope_rate_limits,omitempty"`
|
OverloadUntil *time.Time `json:"overload_until"`
|
||||||
OverloadUntil *time.Time `json:"overload_until"`
|
OverloadRemainingSec *int64 `json:"overload_remaining_sec"`
|
||||||
OverloadRemainingSec *int64 `json:"overload_remaining_sec"`
|
ErrorMessage string `json:"error_message"`
|
||||||
ErrorMessage string `json:"error_message"`
|
TempUnschedulableUntil *time.Time `json:"temp_unschedulable_until,omitempty"`
|
||||||
TempUnschedulableUntil *time.Time `json:"temp_unschedulable_until,omitempty"`
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/domain"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -528,7 +529,7 @@ func (s *OpsService) selectAccountForRetry(ctx context.Context, reqType opsRetry
|
|||||||
func extractRetryModelAndStream(reqType opsRetryRequestType, errorLog *OpsErrorLogDetail, body []byte) (model string, stream bool, err error) {
|
func extractRetryModelAndStream(reqType opsRetryRequestType, errorLog *OpsErrorLogDetail, body []byte) (model string, stream bool, err error) {
|
||||||
switch reqType {
|
switch reqType {
|
||||||
case opsRetryTypeMessages:
|
case opsRetryTypeMessages:
|
||||||
parsed, parseErr := ParseGatewayRequest(body)
|
parsed, parseErr := ParseGatewayRequest(body, domain.PlatformAnthropic)
|
||||||
if parseErr != nil {
|
if parseErr != nil {
|
||||||
return "", false, fmt.Errorf("failed to parse messages request body: %w", parseErr)
|
return "", false, fmt.Errorf("failed to parse messages request body: %w", parseErr)
|
||||||
}
|
}
|
||||||
@@ -596,7 +597,7 @@ func (s *OpsService) executeWithAccount(ctx context.Context, reqType opsRetryReq
|
|||||||
if s.gatewayService == nil {
|
if s.gatewayService == nil {
|
||||||
return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "gateway service not available"}
|
return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "gateway service not available"}
|
||||||
}
|
}
|
||||||
parsedReq, parseErr := ParseGatewayRequest(body)
|
parsedReq, parseErr := ParseGatewayRequest(body, domain.PlatformAnthropic)
|
||||||
if parseErr != nil {
|
if parseErr != nil {
|
||||||
return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "failed to parse request body"}
|
return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "failed to parse request body"}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -62,6 +62,32 @@ func (s *RateLimitService) SetTokenCacheInvalidator(invalidator TokenCacheInvali
|
|||||||
s.tokenCacheInvalidator = invalidator
|
s.tokenCacheInvalidator = invalidator
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ErrorPolicyResult 表示错误策略检查的结果
|
||||||
|
type ErrorPolicyResult int
|
||||||
|
|
||||||
|
const (
|
||||||
|
ErrorPolicyNone ErrorPolicyResult = iota // 未命中任何策略,继续默认逻辑
|
||||||
|
ErrorPolicySkipped // 自定义错误码开启但未命中,跳过处理
|
||||||
|
ErrorPolicyMatched // 自定义错误码命中,应停止调度
|
||||||
|
ErrorPolicyTempUnscheduled // 临时不可调度规则命中
|
||||||
|
)
|
||||||
|
|
||||||
|
// CheckErrorPolicy 检查自定义错误码和临时不可调度规则。
|
||||||
|
// 自定义错误码开启时覆盖后续所有逻辑(包括临时不可调度)。
|
||||||
|
func (s *RateLimitService) CheckErrorPolicy(ctx context.Context, account *Account, statusCode int, responseBody []byte) ErrorPolicyResult {
|
||||||
|
if account.IsCustomErrorCodesEnabled() {
|
||||||
|
if account.ShouldHandleErrorCode(statusCode) {
|
||||||
|
return ErrorPolicyMatched
|
||||||
|
}
|
||||||
|
slog.Info("account_error_code_skipped", "account_id", account.ID, "status_code", statusCode)
|
||||||
|
return ErrorPolicySkipped
|
||||||
|
}
|
||||||
|
if s.tryTempUnschedulable(ctx, account, statusCode, responseBody) {
|
||||||
|
return ErrorPolicyTempUnscheduled
|
||||||
|
}
|
||||||
|
return ErrorPolicyNone
|
||||||
|
}
|
||||||
|
|
||||||
// HandleUpstreamError 处理上游错误响应,标记账号状态
|
// HandleUpstreamError 处理上游错误响应,标记账号状态
|
||||||
// 返回是否应该停止该账号的调度
|
// 返回是否应该停止该账号的调度
|
||||||
func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, responseBody []byte) (shouldDisable bool) {
|
func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, responseBody []byte) (shouldDisable bool) {
|
||||||
|
|||||||
318
backend/internal/service/scheduler_shuffle_test.go
Normal file
318
backend/internal/service/scheduler_shuffle_test.go
Normal file
@@ -0,0 +1,318 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ============ shuffleWithinSortGroups 测试 ============
|
||||||
|
|
||||||
|
func TestShuffleWithinSortGroups_Empty(t *testing.T) {
|
||||||
|
shuffleWithinSortGroups(nil)
|
||||||
|
shuffleWithinSortGroups([]accountWithLoad{})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestShuffleWithinSortGroups_SingleElement(t *testing.T) {
|
||||||
|
accounts := []accountWithLoad{
|
||||||
|
{account: &Account{ID: 1, Priority: 1}, loadInfo: &AccountLoadInfo{LoadRate: 10}},
|
||||||
|
}
|
||||||
|
shuffleWithinSortGroups(accounts)
|
||||||
|
require.Equal(t, int64(1), accounts[0].account.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestShuffleWithinSortGroups_DifferentGroups_OrderPreserved(t *testing.T) {
|
||||||
|
now := time.Now()
|
||||||
|
earlier := now.Add(-1 * time.Hour)
|
||||||
|
|
||||||
|
accounts := []accountWithLoad{
|
||||||
|
{account: &Account{ID: 1, Priority: 1, LastUsedAt: &earlier}, loadInfo: &AccountLoadInfo{LoadRate: 10}},
|
||||||
|
{account: &Account{ID: 2, Priority: 1, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{LoadRate: 20}},
|
||||||
|
{account: &Account{ID: 3, Priority: 2, LastUsedAt: &earlier}, loadInfo: &AccountLoadInfo{LoadRate: 10}},
|
||||||
|
}
|
||||||
|
|
||||||
|
// 每个元素都属于不同组(Priority 或 LoadRate 或 LastUsedAt 不同),顺序不变
|
||||||
|
for i := 0; i < 20; i++ {
|
||||||
|
cpy := make([]accountWithLoad, len(accounts))
|
||||||
|
copy(cpy, accounts)
|
||||||
|
shuffleWithinSortGroups(cpy)
|
||||||
|
require.Equal(t, int64(1), cpy[0].account.ID)
|
||||||
|
require.Equal(t, int64(2), cpy[1].account.ID)
|
||||||
|
require.Equal(t, int64(3), cpy[2].account.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestShuffleWithinSortGroups_SameGroup_Shuffled(t *testing.T) {
|
||||||
|
now := time.Now()
|
||||||
|
// 同一秒的时间戳视为同一组
|
||||||
|
sameSecond := time.Unix(now.Unix(), 0)
|
||||||
|
sameSecond2 := time.Unix(now.Unix(), 500_000_000) // 同一秒但不同纳秒
|
||||||
|
|
||||||
|
accounts := []accountWithLoad{
|
||||||
|
{account: &Account{ID: 1, Priority: 1, LastUsedAt: &sameSecond}, loadInfo: &AccountLoadInfo{LoadRate: 10}},
|
||||||
|
{account: &Account{ID: 2, Priority: 1, LastUsedAt: &sameSecond2}, loadInfo: &AccountLoadInfo{LoadRate: 10}},
|
||||||
|
{account: &Account{ID: 3, Priority: 1, LastUsedAt: &sameSecond}, loadInfo: &AccountLoadInfo{LoadRate: 10}},
|
||||||
|
}
|
||||||
|
|
||||||
|
// 多次执行,验证所有 ID 都出现在第一个位置(说明确实被打乱了)
|
||||||
|
seen := map[int64]bool{}
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
cpy := make([]accountWithLoad, len(accounts))
|
||||||
|
copy(cpy, accounts)
|
||||||
|
shuffleWithinSortGroups(cpy)
|
||||||
|
seen[cpy[0].account.ID] = true
|
||||||
|
// 无论怎么打乱,所有 ID 都应在候选中
|
||||||
|
ids := map[int64]bool{}
|
||||||
|
for _, a := range cpy {
|
||||||
|
ids[a.account.ID] = true
|
||||||
|
}
|
||||||
|
require.True(t, ids[1] && ids[2] && ids[3])
|
||||||
|
}
|
||||||
|
// 至少 2 个不同的 ID 出现在首位(随机性验证)
|
||||||
|
require.GreaterOrEqual(t, len(seen), 2, "shuffle should produce different orderings")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestShuffleWithinSortGroups_NilLastUsedAt_SameGroup(t *testing.T) {
|
||||||
|
accounts := []accountWithLoad{
|
||||||
|
{account: &Account{ID: 1, Priority: 1, LastUsedAt: nil}, loadInfo: &AccountLoadInfo{LoadRate: 0}},
|
||||||
|
{account: &Account{ID: 2, Priority: 1, LastUsedAt: nil}, loadInfo: &AccountLoadInfo{LoadRate: 0}},
|
||||||
|
{account: &Account{ID: 3, Priority: 1, LastUsedAt: nil}, loadInfo: &AccountLoadInfo{LoadRate: 0}},
|
||||||
|
}
|
||||||
|
|
||||||
|
seen := map[int64]bool{}
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
cpy := make([]accountWithLoad, len(accounts))
|
||||||
|
copy(cpy, accounts)
|
||||||
|
shuffleWithinSortGroups(cpy)
|
||||||
|
seen[cpy[0].account.ID] = true
|
||||||
|
}
|
||||||
|
require.GreaterOrEqual(t, len(seen), 2, "nil LastUsedAt accounts should be shuffled")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestShuffleWithinSortGroups_MixedGroups(t *testing.T) {
|
||||||
|
now := time.Now()
|
||||||
|
earlier := now.Add(-1 * time.Hour)
|
||||||
|
sameAsNow := time.Unix(now.Unix(), 0)
|
||||||
|
|
||||||
|
// 组1: Priority=1, LoadRate=10, LastUsedAt=earlier (ID 1) — 单元素组
|
||||||
|
// 组2: Priority=1, LoadRate=20, LastUsedAt=now (ID 2, 3) — 双元素组
|
||||||
|
// 组3: Priority=2, LoadRate=10, LastUsedAt=earlier (ID 4) — 单元素组
|
||||||
|
accounts := []accountWithLoad{
|
||||||
|
{account: &Account{ID: 1, Priority: 1, LastUsedAt: &earlier}, loadInfo: &AccountLoadInfo{LoadRate: 10}},
|
||||||
|
{account: &Account{ID: 2, Priority: 1, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{LoadRate: 20}},
|
||||||
|
{account: &Account{ID: 3, Priority: 1, LastUsedAt: &sameAsNow}, loadInfo: &AccountLoadInfo{LoadRate: 20}},
|
||||||
|
{account: &Account{ID: 4, Priority: 2, LastUsedAt: &earlier}, loadInfo: &AccountLoadInfo{LoadRate: 10}},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < 20; i++ {
|
||||||
|
cpy := make([]accountWithLoad, len(accounts))
|
||||||
|
copy(cpy, accounts)
|
||||||
|
shuffleWithinSortGroups(cpy)
|
||||||
|
|
||||||
|
// 组间顺序不变
|
||||||
|
require.Equal(t, int64(1), cpy[0].account.ID, "group 1 position fixed")
|
||||||
|
require.Equal(t, int64(4), cpy[3].account.ID, "group 3 position fixed")
|
||||||
|
|
||||||
|
// 组2 内部可以打乱,但仍在位置 1 和 2
|
||||||
|
mid := map[int64]bool{cpy[1].account.ID: true, cpy[2].account.ID: true}
|
||||||
|
require.True(t, mid[2] && mid[3], "group 2 elements should stay in positions 1-2")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============ shuffleWithinPriorityAndLastUsed 测试 ============
|
||||||
|
|
||||||
|
func TestShuffleWithinPriorityAndLastUsed_Empty(t *testing.T) {
|
||||||
|
shuffleWithinPriorityAndLastUsed(nil)
|
||||||
|
shuffleWithinPriorityAndLastUsed([]*Account{})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestShuffleWithinPriorityAndLastUsed_SingleElement(t *testing.T) {
|
||||||
|
accounts := []*Account{{ID: 1, Priority: 1}}
|
||||||
|
shuffleWithinPriorityAndLastUsed(accounts)
|
||||||
|
require.Equal(t, int64(1), accounts[0].ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestShuffleWithinPriorityAndLastUsed_SameGroup_Shuffled(t *testing.T) {
|
||||||
|
accounts := []*Account{
|
||||||
|
{ID: 1, Priority: 1, LastUsedAt: nil},
|
||||||
|
{ID: 2, Priority: 1, LastUsedAt: nil},
|
||||||
|
{ID: 3, Priority: 1, LastUsedAt: nil},
|
||||||
|
}
|
||||||
|
|
||||||
|
seen := map[int64]bool{}
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
cpy := make([]*Account, len(accounts))
|
||||||
|
copy(cpy, accounts)
|
||||||
|
shuffleWithinPriorityAndLastUsed(cpy)
|
||||||
|
seen[cpy[0].ID] = true
|
||||||
|
}
|
||||||
|
require.GreaterOrEqual(t, len(seen), 2, "same group should be shuffled")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestShuffleWithinPriorityAndLastUsed_DifferentPriority_OrderPreserved(t *testing.T) {
|
||||||
|
accounts := []*Account{
|
||||||
|
{ID: 1, Priority: 1, LastUsedAt: nil},
|
||||||
|
{ID: 2, Priority: 2, LastUsedAt: nil},
|
||||||
|
{ID: 3, Priority: 3, LastUsedAt: nil},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < 20; i++ {
|
||||||
|
cpy := make([]*Account, len(accounts))
|
||||||
|
copy(cpy, accounts)
|
||||||
|
shuffleWithinPriorityAndLastUsed(cpy)
|
||||||
|
require.Equal(t, int64(1), cpy[0].ID)
|
||||||
|
require.Equal(t, int64(2), cpy[1].ID)
|
||||||
|
require.Equal(t, int64(3), cpy[2].ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestShuffleWithinPriorityAndLastUsed_DifferentLastUsedAt_OrderPreserved(t *testing.T) {
|
||||||
|
now := time.Now()
|
||||||
|
earlier := now.Add(-1 * time.Hour)
|
||||||
|
|
||||||
|
accounts := []*Account{
|
||||||
|
{ID: 1, Priority: 1, LastUsedAt: nil},
|
||||||
|
{ID: 2, Priority: 1, LastUsedAt: &earlier},
|
||||||
|
{ID: 3, Priority: 1, LastUsedAt: &now},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < 20; i++ {
|
||||||
|
cpy := make([]*Account, len(accounts))
|
||||||
|
copy(cpy, accounts)
|
||||||
|
shuffleWithinPriorityAndLastUsed(cpy)
|
||||||
|
require.Equal(t, int64(1), cpy[0].ID)
|
||||||
|
require.Equal(t, int64(2), cpy[1].ID)
|
||||||
|
require.Equal(t, int64(3), cpy[2].ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============ sameLastUsedAt 测试 ============
|
||||||
|
|
||||||
|
func TestSameLastUsedAt(t *testing.T) {
|
||||||
|
now := time.Now()
|
||||||
|
sameSecond := time.Unix(now.Unix(), 0)
|
||||||
|
sameSecondDiffNano := time.Unix(now.Unix(), 999_999_999)
|
||||||
|
differentSecond := now.Add(1 * time.Second)
|
||||||
|
|
||||||
|
t.Run("both nil", func(t *testing.T) {
|
||||||
|
require.True(t, sameLastUsedAt(nil, nil))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("one nil one not", func(t *testing.T) {
|
||||||
|
require.False(t, sameLastUsedAt(nil, &now))
|
||||||
|
require.False(t, sameLastUsedAt(&now, nil))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("same second different nanoseconds", func(t *testing.T) {
|
||||||
|
require.True(t, sameLastUsedAt(&sameSecond, &sameSecondDiffNano))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("different seconds", func(t *testing.T) {
|
||||||
|
require.False(t, sameLastUsedAt(&now, &differentSecond))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("exact same time", func(t *testing.T) {
|
||||||
|
require.True(t, sameLastUsedAt(&now, &now))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============ sameAccountWithLoadGroup 测试 ============
|
||||||
|
|
||||||
|
func TestSameAccountWithLoadGroup(t *testing.T) {
|
||||||
|
now := time.Now()
|
||||||
|
sameSecond := time.Unix(now.Unix(), 0)
|
||||||
|
|
||||||
|
t.Run("same group", func(t *testing.T) {
|
||||||
|
a := accountWithLoad{account: &Account{Priority: 1, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{LoadRate: 10}}
|
||||||
|
b := accountWithLoad{account: &Account{Priority: 1, LastUsedAt: &sameSecond}, loadInfo: &AccountLoadInfo{LoadRate: 10}}
|
||||||
|
require.True(t, sameAccountWithLoadGroup(a, b))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("different priority", func(t *testing.T) {
|
||||||
|
a := accountWithLoad{account: &Account{Priority: 1, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{LoadRate: 10}}
|
||||||
|
b := accountWithLoad{account: &Account{Priority: 2, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{LoadRate: 10}}
|
||||||
|
require.False(t, sameAccountWithLoadGroup(a, b))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("different load rate", func(t *testing.T) {
|
||||||
|
a := accountWithLoad{account: &Account{Priority: 1, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{LoadRate: 10}}
|
||||||
|
b := accountWithLoad{account: &Account{Priority: 1, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{LoadRate: 20}}
|
||||||
|
require.False(t, sameAccountWithLoadGroup(a, b))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("different last used at", func(t *testing.T) {
|
||||||
|
later := now.Add(1 * time.Second)
|
||||||
|
a := accountWithLoad{account: &Account{Priority: 1, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{LoadRate: 10}}
|
||||||
|
b := accountWithLoad{account: &Account{Priority: 1, LastUsedAt: &later}, loadInfo: &AccountLoadInfo{LoadRate: 10}}
|
||||||
|
require.False(t, sameAccountWithLoadGroup(a, b))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("both nil LastUsedAt", func(t *testing.T) {
|
||||||
|
a := accountWithLoad{account: &Account{Priority: 1, LastUsedAt: nil}, loadInfo: &AccountLoadInfo{LoadRate: 0}}
|
||||||
|
b := accountWithLoad{account: &Account{Priority: 1, LastUsedAt: nil}, loadInfo: &AccountLoadInfo{LoadRate: 0}}
|
||||||
|
require.True(t, sameAccountWithLoadGroup(a, b))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============ sameAccountGroup 测试 ============
|
||||||
|
|
||||||
|
func TestSameAccountGroup(t *testing.T) {
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
t.Run("same group", func(t *testing.T) {
|
||||||
|
a := &Account{Priority: 1, LastUsedAt: nil}
|
||||||
|
b := &Account{Priority: 1, LastUsedAt: nil}
|
||||||
|
require.True(t, sameAccountGroup(a, b))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("different priority", func(t *testing.T) {
|
||||||
|
a := &Account{Priority: 1, LastUsedAt: nil}
|
||||||
|
b := &Account{Priority: 2, LastUsedAt: nil}
|
||||||
|
require.False(t, sameAccountGroup(a, b))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("different LastUsedAt", func(t *testing.T) {
|
||||||
|
later := now.Add(1 * time.Second)
|
||||||
|
a := &Account{Priority: 1, LastUsedAt: &now}
|
||||||
|
b := &Account{Priority: 1, LastUsedAt: &later}
|
||||||
|
require.False(t, sameAccountGroup(a, b))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============ sortAccountsByPriorityAndLastUsed 集成随机化测试 ============
|
||||||
|
|
||||||
|
func TestSortAccountsByPriorityAndLastUsed_WithShuffle(t *testing.T) {
|
||||||
|
t.Run("same priority and nil LastUsedAt are shuffled", func(t *testing.T) {
|
||||||
|
accounts := []*Account{
|
||||||
|
{ID: 1, Priority: 1, LastUsedAt: nil},
|
||||||
|
{ID: 2, Priority: 1, LastUsedAt: nil},
|
||||||
|
{ID: 3, Priority: 1, LastUsedAt: nil},
|
||||||
|
}
|
||||||
|
|
||||||
|
seen := map[int64]bool{}
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
cpy := make([]*Account, len(accounts))
|
||||||
|
copy(cpy, accounts)
|
||||||
|
sortAccountsByPriorityAndLastUsed(cpy, false)
|
||||||
|
seen[cpy[0].ID] = true
|
||||||
|
}
|
||||||
|
require.GreaterOrEqual(t, len(seen), 2, "identical sort keys should produce different orderings after shuffle")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("different priorities still sorted correctly", func(t *testing.T) {
|
||||||
|
now := time.Now()
|
||||||
|
accounts := []*Account{
|
||||||
|
{ID: 3, Priority: 3, LastUsedAt: &now},
|
||||||
|
{ID: 1, Priority: 1, LastUsedAt: &now},
|
||||||
|
{ID: 2, Priority: 2, LastUsedAt: &now},
|
||||||
|
}
|
||||||
|
|
||||||
|
sortAccountsByPriorityAndLastUsed(accounts, false)
|
||||||
|
require.Equal(t, int64(1), accounts[0].ID)
|
||||||
|
require.Equal(t, int64(2), accounts[1].ID)
|
||||||
|
require.Equal(t, int64(3), accounts[2].ID)
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -275,4 +275,5 @@ var ProviderSet = wire.NewSet(
|
|||||||
NewUsageCache,
|
NewUsageCache,
|
||||||
NewTotpService,
|
NewTotpService,
|
||||||
NewErrorPassthroughService,
|
NewErrorPassthroughService,
|
||||||
|
NewDigestSessionStore,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -47,13 +47,15 @@ services:
|
|||||||
|
|
||||||
# =======================================================================
|
# =======================================================================
|
||||||
# Database Configuration (PostgreSQL)
|
# Database Configuration (PostgreSQL)
|
||||||
|
# Default: uses local postgres container
|
||||||
|
# External DB: set DATABASE_HOST and DATABASE_SSLMODE in .env
|
||||||
# =======================================================================
|
# =======================================================================
|
||||||
- DATABASE_HOST=postgres
|
- DATABASE_HOST=${DATABASE_HOST:-postgres}
|
||||||
- DATABASE_PORT=5432
|
- DATABASE_PORT=${DATABASE_PORT:-5432}
|
||||||
- DATABASE_USER=${POSTGRES_USER:-sub2api}
|
- DATABASE_USER=${POSTGRES_USER:-sub2api}
|
||||||
- DATABASE_PASSWORD=${POSTGRES_PASSWORD:?POSTGRES_PASSWORD is required}
|
- DATABASE_PASSWORD=${POSTGRES_PASSWORD:?POSTGRES_PASSWORD is required}
|
||||||
- DATABASE_DBNAME=${POSTGRES_DB:-sub2api}
|
- DATABASE_DBNAME=${POSTGRES_DB:-sub2api}
|
||||||
- DATABASE_SSLMODE=disable
|
- DATABASE_SSLMODE=${DATABASE_SSLMODE:-disable}
|
||||||
|
|
||||||
# =======================================================================
|
# =======================================================================
|
||||||
# Redis Configuration
|
# Redis Configuration
|
||||||
@@ -128,8 +130,6 @@ services:
|
|||||||
# Examples: http://host:port, socks5://host:port
|
# Examples: http://host:port, socks5://host:port
|
||||||
- UPDATE_PROXY_URL=${UPDATE_PROXY_URL:-}
|
- UPDATE_PROXY_URL=${UPDATE_PROXY_URL:-}
|
||||||
depends_on:
|
depends_on:
|
||||||
postgres:
|
|
||||||
condition: service_healthy
|
|
||||||
redis:
|
redis:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
networks:
|
networks:
|
||||||
@@ -141,35 +141,6 @@ services:
|
|||||||
retries: 3
|
retries: 3
|
||||||
start_period: 30s
|
start_period: 30s
|
||||||
|
|
||||||
# ===========================================================================
|
|
||||||
# PostgreSQL Database
|
|
||||||
# ===========================================================================
|
|
||||||
postgres:
|
|
||||||
image: postgres:18-alpine
|
|
||||||
container_name: sub2api-postgres
|
|
||||||
restart: unless-stopped
|
|
||||||
ulimits:
|
|
||||||
nofile:
|
|
||||||
soft: 100000
|
|
||||||
hard: 100000
|
|
||||||
volumes:
|
|
||||||
- postgres_data:/var/lib/postgresql/data
|
|
||||||
environment:
|
|
||||||
- POSTGRES_USER=${POSTGRES_USER:-sub2api}
|
|
||||||
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:?POSTGRES_PASSWORD is required}
|
|
||||||
- POSTGRES_DB=${POSTGRES_DB:-sub2api}
|
|
||||||
- TZ=${TZ:-Asia/Shanghai}
|
|
||||||
networks:
|
|
||||||
- sub2api-network
|
|
||||||
healthcheck:
|
|
||||||
test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER:-sub2api} -d ${POSTGRES_DB:-sub2api}"]
|
|
||||||
interval: 10s
|
|
||||||
timeout: 5s
|
|
||||||
retries: 5
|
|
||||||
start_period: 10s
|
|
||||||
# 注意:不暴露端口到宿主机,应用通过内部网络连接
|
|
||||||
# 如需调试,可临时添加:ports: ["127.0.0.1:5433:5432"]
|
|
||||||
|
|
||||||
# ===========================================================================
|
# ===========================================================================
|
||||||
# Redis Cache
|
# Redis Cache
|
||||||
# ===========================================================================
|
# ===========================================================================
|
||||||
@@ -209,8 +180,6 @@ services:
|
|||||||
volumes:
|
volumes:
|
||||||
sub2api_data:
|
sub2api_data:
|
||||||
driver: local
|
driver: local
|
||||||
postgres_data:
|
|
||||||
driver: local
|
|
||||||
redis_data:
|
redis_data:
|
||||||
driver: local
|
driver: local
|
||||||
|
|
||||||
|
|||||||
BIN
frontend/public/wechat-qr.jpg
Normal file
BIN
frontend/public/wechat-qr.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 148 KiB |
@@ -376,7 +376,6 @@ export interface PlatformAvailability {
|
|||||||
total_accounts: number
|
total_accounts: number
|
||||||
available_count: number
|
available_count: number
|
||||||
rate_limit_count: number
|
rate_limit_count: number
|
||||||
scope_rate_limit_count?: Record<string, number>
|
|
||||||
error_count: number
|
error_count: number
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -387,7 +386,6 @@ export interface GroupAvailability {
|
|||||||
total_accounts: number
|
total_accounts: number
|
||||||
available_count: number
|
available_count: number
|
||||||
rate_limit_count: number
|
rate_limit_count: number
|
||||||
scope_rate_limit_count?: Record<string, number>
|
|
||||||
error_count: number
|
error_count: number
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -402,7 +400,6 @@ export interface AccountAvailability {
|
|||||||
is_rate_limited: boolean
|
is_rate_limited: boolean
|
||||||
rate_limit_reset_at?: string
|
rate_limit_reset_at?: string
|
||||||
rate_limit_remaining_sec?: number
|
rate_limit_remaining_sec?: number
|
||||||
scope_rate_limits?: Record<string, number>
|
|
||||||
is_overloaded: boolean
|
is_overloaded: boolean
|
||||||
overload_until?: string
|
overload_until?: string
|
||||||
overload_remaining_sec?: number
|
overload_remaining_sec?: number
|
||||||
|
|||||||
@@ -76,26 +76,6 @@
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<!-- Scope Rate Limit Indicators (Antigravity) -->
|
|
||||||
<template v-if="activeScopeRateLimits.length > 0">
|
|
||||||
<div v-for="item in activeScopeRateLimits" :key="item.scope" class="group relative">
|
|
||||||
<span
|
|
||||||
class="inline-flex items-center gap-1 rounded bg-orange-100 px-1.5 py-0.5 text-xs font-medium text-orange-700 dark:bg-orange-900/30 dark:text-orange-400"
|
|
||||||
>
|
|
||||||
<Icon name="exclamationTriangle" size="xs" :stroke-width="2" />
|
|
||||||
{{ formatScopeName(item.scope) }}
|
|
||||||
</span>
|
|
||||||
<!-- Tooltip -->
|
|
||||||
<div
|
|
||||||
class="pointer-events-none absolute bottom-full left-1/2 z-50 mb-2 -translate-x-1/2 whitespace-nowrap rounded bg-gray-900 px-2 py-1 text-xs text-white opacity-0 transition-opacity group-hover:opacity-100 dark:bg-gray-700"
|
|
||||||
>
|
|
||||||
{{ t('admin.accounts.status.scopeRateLimitedUntil', { scope: formatScopeName(item.scope), time: formatTime(item.reset_at) }) }}
|
|
||||||
<div
|
|
||||||
class="absolute left-1/2 top-full -translate-x-1/2 border-4 border-transparent border-t-gray-900 dark:border-t-gray-700" ></div>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</template>
|
|
||||||
|
|
||||||
<!-- Model Rate Limit Indicators (Antigravity OAuth Smart Retry) -->
|
<!-- Model Rate Limit Indicators (Antigravity OAuth Smart Retry) -->
|
||||||
<template v-if="activeModelRateLimits.length > 0">
|
<template v-if="activeModelRateLimits.length > 0">
|
||||||
<div v-for="item in activeModelRateLimits" :key="item.model" class="group relative">
|
<div v-for="item in activeModelRateLimits" :key="item.model" class="group relative">
|
||||||
@@ -160,16 +140,6 @@ const isRateLimited = computed(() => {
|
|||||||
return new Date(props.account.rate_limit_reset_at) > new Date()
|
return new Date(props.account.rate_limit_reset_at) > new Date()
|
||||||
})
|
})
|
||||||
|
|
||||||
// Computed: active scope rate limits (Antigravity)
|
|
||||||
const activeScopeRateLimits = computed(() => {
|
|
||||||
const scopeLimits = props.account.scope_rate_limits
|
|
||||||
if (!scopeLimits) return []
|
|
||||||
const now = new Date()
|
|
||||||
return Object.entries(scopeLimits)
|
|
||||||
.filter(([, info]) => new Date(info.reset_at) > now)
|
|
||||||
.map(([scope, info]) => ({ scope, reset_at: info.reset_at }))
|
|
||||||
})
|
|
||||||
|
|
||||||
// Computed: active model rate limits (Antigravity OAuth Smart Retry)
|
// Computed: active model rate limits (Antigravity OAuth Smart Retry)
|
||||||
const activeModelRateLimits = computed(() => {
|
const activeModelRateLimits = computed(() => {
|
||||||
const modelLimits = (props.account.extra as Record<string, unknown> | undefined)?.model_rate_limits as
|
const modelLimits = (props.account.extra as Record<string, unknown> | undefined)?.model_rate_limits as
|
||||||
|
|||||||
@@ -1038,10 +1038,7 @@
|
|||||||
</div>
|
</div>
|
||||||
|
|
||||||
<!-- Custom Error Codes Section -->
|
<!-- Custom Error Codes Section -->
|
||||||
<div
|
<div class="border-t border-gray-200 pt-4 dark:border-dark-600">
|
||||||
v-if="form.platform !== 'gemini'"
|
|
||||||
class="border-t border-gray-200 pt-4 dark:border-dark-600"
|
|
||||||
>
|
|
||||||
<div class="mb-3 flex items-center justify-between">
|
<div class="mb-3 flex items-center justify-between">
|
||||||
<div>
|
<div>
|
||||||
<label class="input-label mb-0">{{ t('admin.accounts.customErrorCodes') }}</label>
|
<label class="input-label mb-0">{{ t('admin.accounts.customErrorCodes') }}</label>
|
||||||
|
|||||||
104
frontend/src/components/common/WechatServiceButton.vue
Normal file
104
frontend/src/components/common/WechatServiceButton.vue
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
<template>
|
||||||
|
<!-- 悬浮按钮 - 使用主题色 -->
|
||||||
|
<button
|
||||||
|
@click="showModal = true"
|
||||||
|
class="fixed bottom-6 right-6 z-50 flex items-center gap-2 rounded-full bg-gradient-to-r from-primary-500 to-primary-600 px-4 py-3 text-white shadow-lg shadow-primary-500/25 transition-all hover:from-primary-600 hover:to-primary-700 hover:shadow-xl hover:shadow-primary-500/30"
|
||||||
|
>
|
||||||
|
<svg class="h-5 w-5" viewBox="0 0 24 24" fill="currentColor">
|
||||||
|
<path d="M8.691 2.188C3.891 2.188 0 5.476 0 9.53c0 2.212 1.17 4.203 3.002 5.55a.59.59 0 01.213.665l-.39 1.48c-.019.07-.048.141-.048.213 0 .163.13.295.29.295a.328.328 0 00.186-.059l2.114-1.225a.87.87 0 01.415-.106.807.807 0 01.213.026 10.07 10.07 0 002.696.37c.262 0 .52-.011.776-.028a5.91 5.91 0 01-.193-1.479c0-3.644 3.374-6.6 7.536-6.6.262 0 .52.011.776.028-.628-3.513-4.27-6.472-8.885-6.472zM5.785 5.97a1.1 1.1 0 110 2.2 1.1 1.1 0 010-2.2zm5.813 0a1.1 1.1 0 110 2.2 1.1 1.1 0 010-2.2zm5.192 2.642c-3.703 0-6.71 2.567-6.71 5.73 0 3.163 3.007 5.73 6.71 5.73a7.9 7.9 0 002.126-.288.644.644 0 01.17-.022.69.69 0 01.329.085l1.672.97a.262.262 0 00.147.046c.128 0 .23-.104.23-.233a.403.403 0 00-.038-.168l-.309-1.17a.468.468 0 01.168-.527c1.449-1.065 2.374-2.643 2.374-4.423 0-3.163-3.007-5.73-6.71-5.73h-.159zm-2.434 3.34a.88.88 0 110 1.76.88.88 0 010-1.76zm4.868 0a.88.88 0 110 1.76.88.88 0 010-1.76z"/>
|
||||||
|
</svg>
|
||||||
|
<span class="text-sm font-medium">客服</span>
|
||||||
|
</button>
|
||||||
|
|
||||||
|
<!-- 弹窗 -->
|
||||||
|
<Teleport to="body">
|
||||||
|
<Transition name="fade">
|
||||||
|
<div
|
||||||
|
v-if="showModal"
|
||||||
|
class="fixed inset-0 z-[100] flex items-center justify-center bg-black/50 p-4 backdrop-blur-sm"
|
||||||
|
@click.self="showModal = false"
|
||||||
|
>
|
||||||
|
<Transition name="scale">
|
||||||
|
<div
|
||||||
|
v-if="showModal"
|
||||||
|
class="relative w-full max-w-sm rounded-2xl bg-white p-6 shadow-2xl dark:bg-dark-700"
|
||||||
|
>
|
||||||
|
<!-- 关闭按钮 -->
|
||||||
|
<button
|
||||||
|
@click="showModal = false"
|
||||||
|
class="absolute right-4 top-4 text-gray-400 transition-colors hover:text-gray-600 dark:text-dark-400 dark:hover:text-dark-200"
|
||||||
|
>
|
||||||
|
<svg class="h-5 w-5" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||||
|
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M6 18L18 6M6 6l12 12" />
|
||||||
|
</svg>
|
||||||
|
</button>
|
||||||
|
|
||||||
|
<!-- 标题 -->
|
||||||
|
<div class="mb-4 flex items-center gap-3">
|
||||||
|
<div class="flex h-10 w-10 items-center justify-center rounded-full bg-gradient-to-br from-primary-500 to-primary-600">
|
||||||
|
<svg class="h-6 w-6 text-white" viewBox="0 0 24 24" fill="currentColor">
|
||||||
|
<path d="M8.691 2.188C3.891 2.188 0 5.476 0 9.53c0 2.212 1.17 4.203 3.002 5.55a.59.59 0 01.213.665l-.39 1.48c-.019.07-.048.141-.048.213 0 .163.13.295.29.295a.328.328 0 00.186-.059l2.114-1.225a.87.87 0 01.415-.106.807.807 0 01.213.026 10.07 10.07 0 002.696.37c.262 0 .52-.011.776-.028a5.91 5.91 0 01-.193-1.479c0-3.644 3.374-6.6 7.536-6.6.262 0 .52.011.776.028-.628-3.513-4.27-6.472-8.885-6.472zM5.785 5.97a1.1 1.1 0 110 2.2 1.1 1.1 0 010-2.2zm5.813 0a1.1 1.1 0 110 2.2 1.1 1.1 0 010-2.2zm5.192 2.642c-3.703 0-6.71 2.567-6.71 5.73 0 3.163 3.007 5.73 6.71 5.73a7.9 7.9 0 002.126-.288.644.644 0 01.17-.022.69.69 0 01.329.085l1.672.97a.262.262 0 00.147.046c.128 0 .23-.104.23-.233a.403.403 0 00-.038-.168l-.309-1.17a.468.468 0 01.168-.527c1.449-1.065 2.374-2.643 2.374-4.423 0-3.163-3.007-5.73-6.71-5.73h-.159zm-2.434 3.34a.88.88 0 110 1.76.88.88 0 010-1.76zm4.868 0a.88.88 0 110 1.76.88.88 0 010-1.76z"/>
|
||||||
|
</svg>
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<h3 class="text-lg font-semibold text-gray-900 dark:text-white">联系客服</h3>
|
||||||
|
<p class="text-sm text-gray-500 dark:text-dark-400">扫码添加好友</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- 二维码卡片 -->
|
||||||
|
<div class="mb-4 overflow-hidden rounded-xl border border-primary-100 bg-gradient-to-br from-primary-50 to-white p-3 dark:border-primary-800/30 dark:from-primary-900/10 dark:to-dark-800">
|
||||||
|
<img
|
||||||
|
src="/wechat-qr.jpg"
|
||||||
|
alt="微信二维码"
|
||||||
|
class="w-full rounded-lg"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- 提示文字 -->
|
||||||
|
<div class="text-center">
|
||||||
|
<p class="mb-2 text-sm font-medium text-primary-600 dark:text-primary-400">
|
||||||
|
微信扫码添加客服
|
||||||
|
</p>
|
||||||
|
<p class="flex items-center justify-center gap-1 text-xs text-gray-500 dark:text-dark-400">
|
||||||
|
<svg class="h-4 w-4" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||||
|
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M12 8v4l3 3m6-3a9 9 0 11-18 0 9 9 0 0118 0z" />
|
||||||
|
</svg>
|
||||||
|
工作时间:周一至周五 9:00-18:00
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</Transition>
|
||||||
|
</div>
|
||||||
|
</Transition>
|
||||||
|
</Teleport>
|
||||||
|
</template>
|
||||||
|
|
||||||
|
<script setup lang="ts">
|
||||||
|
import { ref } from 'vue'
|
||||||
|
|
||||||
|
const showModal = ref(false)
|
||||||
|
</script>
|
||||||
|
|
||||||
|
<style scoped>
|
||||||
|
.fade-enter-active,
|
||||||
|
.fade-leave-active {
|
||||||
|
transition: opacity 0.2s ease;
|
||||||
|
}
|
||||||
|
|
||||||
|
.fade-enter-from,
|
||||||
|
.fade-leave-to {
|
||||||
|
opacity: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.scale-enter-active,
|
||||||
|
.scale-leave-active {
|
||||||
|
transition: all 0.2s ease;
|
||||||
|
}
|
||||||
|
|
||||||
|
.scale-enter-from,
|
||||||
|
.scale-leave-to {
|
||||||
|
opacity: 0;
|
||||||
|
transform: scale(0.95);
|
||||||
|
}
|
||||||
|
</style>
|
||||||
@@ -121,23 +121,6 @@
|
|||||||
<Icon name="key" size="sm" />
|
<Icon name="key" size="sm" />
|
||||||
{{ t('nav.apiKeys') }}
|
{{ t('nav.apiKeys') }}
|
||||||
</router-link>
|
</router-link>
|
||||||
|
|
||||||
<a
|
|
||||||
href="https://github.com/Wei-Shaw/sub2api"
|
|
||||||
target="_blank"
|
|
||||||
rel="noopener noreferrer"
|
|
||||||
@click="closeDropdown"
|
|
||||||
class="dropdown-item"
|
|
||||||
>
|
|
||||||
<svg class="h-4 w-4" fill="currentColor" viewBox="0 0 24 24">
|
|
||||||
<path
|
|
||||||
fill-rule="evenodd"
|
|
||||||
clip-rule="evenodd"
|
|
||||||
d="M12 2C6.477 2 2 6.477 2 12c0 4.42 2.865 8.17 6.839 9.49.5.092.682-.217.682-.482 0-.237-.008-.866-.013-1.7-2.782.604-3.369-1.34-3.369-1.34-.454-1.156-1.11-1.464-1.11-1.464-.908-.62.069-.608.069-.608 1.003.07 1.531 1.03 1.531 1.03.892 1.529 2.341 1.087 2.91.831.092-.646.35-1.086.636-1.336-2.22-.253-4.555-1.11-4.555-4.943 0-1.091.39-1.984 1.029-2.683-.103-.253-.446-1.27.098-2.647 0 0 .84-.269 2.75 1.025A9.578 9.578 0 0112 6.836c.85.004 1.705.114 2.504.336 1.909-1.294 2.747-1.025 2.747-1.025.546 1.377.203 2.394.1 2.647.64.699 1.028 1.592 1.028 2.683 0 3.842-2.339 4.687-4.566 4.935.359.309.678.919.678 1.852 0 1.336-.012 2.415-.012 2.743 0 .267.18.578.688.48C19.138 20.167 22 16.418 22 12c0-5.523-4.477-10-10-10z"
|
|
||||||
/>
|
|
||||||
</svg>
|
|
||||||
{{ t('nav.github') }}
|
|
||||||
</a>
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<!-- Contact Support (only show if configured) -->
|
<!-- Contact Support (only show if configured) -->
|
||||||
|
|||||||
@@ -1356,7 +1356,6 @@ export default {
|
|||||||
overloaded: 'Overloaded',
|
overloaded: 'Overloaded',
|
||||||
tempUnschedulable: 'Temp Unschedulable',
|
tempUnschedulable: 'Temp Unschedulable',
|
||||||
rateLimitedUntil: 'Rate limited until {time}',
|
rateLimitedUntil: 'Rate limited until {time}',
|
||||||
scopeRateLimitedUntil: '{scope} rate limited until {time}',
|
|
||||||
modelRateLimitedUntil: '{model} rate limited until {time}',
|
modelRateLimitedUntil: '{model} rate limited until {time}',
|
||||||
overloadedUntil: 'Overloaded until {time}',
|
overloadedUntil: 'Overloaded until {time}',
|
||||||
viewTempUnschedDetails: 'View temp unschedulable details'
|
viewTempUnschedDetails: 'View temp unschedulable details'
|
||||||
@@ -3059,7 +3058,6 @@ export default {
|
|||||||
empty: 'No data',
|
empty: 'No data',
|
||||||
queued: 'Queue {count}',
|
queued: 'Queue {count}',
|
||||||
rateLimited: 'Rate-limited {count}',
|
rateLimited: 'Rate-limited {count}',
|
||||||
scopeRateLimitedTooltip: '{scope} rate-limited ({count} accounts)',
|
|
||||||
errorAccounts: 'Errors {count}',
|
errorAccounts: 'Errors {count}',
|
||||||
loadFailed: 'Failed to load concurrency data'
|
loadFailed: 'Failed to load concurrency data'
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -1492,7 +1492,6 @@ export default {
|
|||||||
overloaded: '过载中',
|
overloaded: '过载中',
|
||||||
tempUnschedulable: '临时不可调度',
|
tempUnschedulable: '临时不可调度',
|
||||||
rateLimitedUntil: '限流中,重置时间:{time}',
|
rateLimitedUntil: '限流中,重置时间:{time}',
|
||||||
scopeRateLimitedUntil: '{scope} 限流中,重置时间:{time}',
|
|
||||||
modelRateLimitedUntil: '{model} 限流至 {time}',
|
modelRateLimitedUntil: '{model} 限流至 {time}',
|
||||||
overloadedUntil: '负载过重,重置时间:{time}',
|
overloadedUntil: '负载过重,重置时间:{time}',
|
||||||
viewTempUnschedDetails: '查看临时不可调度详情'
|
viewTempUnschedDetails: '查看临时不可调度详情'
|
||||||
@@ -3232,7 +3231,6 @@ export default {
|
|||||||
empty: '暂无数据',
|
empty: '暂无数据',
|
||||||
queued: '队列 {count}',
|
queued: '队列 {count}',
|
||||||
rateLimited: '限流 {count}',
|
rateLimited: '限流 {count}',
|
||||||
scopeRateLimitedTooltip: '{scope} 限流中 ({count} 个账号)',
|
|
||||||
errorAccounts: '异常 {count}',
|
errorAccounts: '异常 {count}',
|
||||||
loadFailed: '加载并发数据失败'
|
loadFailed: '加载并发数据失败'
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -591,9 +591,6 @@ export interface Account {
|
|||||||
temp_unschedulable_until: string | null
|
temp_unschedulable_until: string | null
|
||||||
temp_unschedulable_reason: string | null
|
temp_unschedulable_reason: string | null
|
||||||
|
|
||||||
// Antigravity scope 级限流状态
|
|
||||||
scope_rate_limits?: Record<string, { reset_at: string; remaining_sec: number }>
|
|
||||||
|
|
||||||
// Session window fields (5-hour window)
|
// Session window fields (5-hour window)
|
||||||
session_window_start: string | null
|
session_window_start: string | null
|
||||||
session_window_end: string | null
|
session_window_end: string | null
|
||||||
|
|||||||
@@ -122,8 +122,11 @@
|
|||||||
>
|
>
|
||||||
{{ siteName }}
|
{{ siteName }}
|
||||||
</h1>
|
</h1>
|
||||||
<p class="mb-8 text-lg text-gray-600 dark:text-dark-300 md:text-xl">
|
<p class="mb-3 text-xl font-semibold text-primary-600 dark:text-primary-400 md:text-2xl">
|
||||||
{{ siteSubtitle }}
|
{{ t('home.heroSubtitle') }}
|
||||||
|
</p>
|
||||||
|
<p class="mb-8 text-base text-gray-600 dark:text-dark-300 md:text-lg">
|
||||||
|
{{ t('home.heroDescription') }}
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
<!-- CTA Button -->
|
<!-- CTA Button -->
|
||||||
@@ -177,7 +180,7 @@
|
|||||||
</div>
|
</div>
|
||||||
|
|
||||||
<!-- Feature Tags - Centered -->
|
<!-- Feature Tags - Centered -->
|
||||||
<div class="mb-12 flex flex-wrap items-center justify-center gap-4 md:gap-6">
|
<div class="mb-16 flex flex-wrap items-center justify-center gap-4 md:gap-6">
|
||||||
<div
|
<div
|
||||||
class="inline-flex items-center gap-2.5 rounded-full border border-gray-200/50 bg-white/80 px-5 py-2.5 shadow-sm backdrop-blur-sm dark:border-dark-700/50 dark:bg-dark-800/80"
|
class="inline-flex items-center gap-2.5 rounded-full border border-gray-200/50 bg-white/80 px-5 py-2.5 shadow-sm backdrop-blur-sm dark:border-dark-700/50 dark:bg-dark-800/80"
|
||||||
>
|
>
|
||||||
@@ -204,6 +207,63 @@
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<!-- Pain Points Section -->
|
||||||
|
<div class="mb-16">
|
||||||
|
<h2 class="mb-8 text-center text-2xl font-bold text-gray-900 dark:text-white md:text-3xl">
|
||||||
|
{{ t('home.painPoints.title') }}
|
||||||
|
</h2>
|
||||||
|
<div class="grid gap-4 sm:grid-cols-2 lg:grid-cols-4">
|
||||||
|
<!-- Pain Point 1: Expensive -->
|
||||||
|
<div class="rounded-xl border border-red-200/50 bg-red-50/50 p-5 dark:border-red-900/30 dark:bg-red-950/20">
|
||||||
|
<div class="mb-3 flex h-10 w-10 items-center justify-center rounded-lg bg-red-100 dark:bg-red-900/30">
|
||||||
|
<svg class="h-5 w-5 text-red-500" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
|
||||||
|
<path stroke-linecap="round" stroke-linejoin="round" d="M12 8c-1.657 0-3 .895-3 2s1.343 2 3 2 3 .895 3 2-1.343 2-3 2m0-8c1.11 0 2.08.402 2.599 1M12 8V7m0 1v8m0 0v1m0-1c-1.11 0-2.08-.402-2.599-1M21 12a9 9 0 11-18 0 9 9 0 0118 0z" />
|
||||||
|
</svg>
|
||||||
|
</div>
|
||||||
|
<h3 class="mb-1.5 font-semibold text-gray-900 dark:text-white">{{ t('home.painPoints.items.expensive.title') }}</h3>
|
||||||
|
<p class="text-sm text-gray-600 dark:text-dark-400">{{ t('home.painPoints.items.expensive.desc') }}</p>
|
||||||
|
</div>
|
||||||
|
<!-- Pain Point 2: Complex -->
|
||||||
|
<div class="rounded-xl border border-orange-200/50 bg-orange-50/50 p-5 dark:border-orange-900/30 dark:bg-orange-950/20">
|
||||||
|
<div class="mb-3 flex h-10 w-10 items-center justify-center rounded-lg bg-orange-100 dark:bg-orange-900/30">
|
||||||
|
<svg class="h-5 w-5 text-orange-500" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
|
||||||
|
<path stroke-linecap="round" stroke-linejoin="round" d="M19 11H5m14 0a2 2 0 012 2v6a2 2 0 01-2 2H5a2 2 0 01-2-2v-6a2 2 0 012-2m14 0V9a2 2 0 00-2-2M5 11V9a2 2 0 012-2m0 0V5a2 2 0 012-2h6a2 2 0 012 2v2M7 7h10" />
|
||||||
|
</svg>
|
||||||
|
</div>
|
||||||
|
<h3 class="mb-1.5 font-semibold text-gray-900 dark:text-white">{{ t('home.painPoints.items.complex.title') }}</h3>
|
||||||
|
<p class="text-sm text-gray-600 dark:text-dark-400">{{ t('home.painPoints.items.complex.desc') }}</p>
|
||||||
|
</div>
|
||||||
|
<!-- Pain Point 3: Unstable -->
|
||||||
|
<div class="rounded-xl border border-yellow-200/50 bg-yellow-50/50 p-5 dark:border-yellow-900/30 dark:bg-yellow-950/20">
|
||||||
|
<div class="mb-3 flex h-10 w-10 items-center justify-center rounded-lg bg-yellow-100 dark:bg-yellow-900/30">
|
||||||
|
<svg class="h-5 w-5 text-yellow-600" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
|
||||||
|
<path stroke-linecap="round" stroke-linejoin="round" d="M12 9v2m0 4h.01m-6.938 4h13.856c1.54 0 2.502-1.667 1.732-3L13.732 4c-.77-1.333-2.694-1.333-3.464 0L3.34 16c-.77 1.333.192 3 1.732 3z" />
|
||||||
|
</svg>
|
||||||
|
</div>
|
||||||
|
<h3 class="mb-1.5 font-semibold text-gray-900 dark:text-white">{{ t('home.painPoints.items.unstable.title') }}</h3>
|
||||||
|
<p class="text-sm text-gray-600 dark:text-dark-400">{{ t('home.painPoints.items.unstable.desc') }}</p>
|
||||||
|
</div>
|
||||||
|
<!-- Pain Point 4: No Control -->
|
||||||
|
<div class="rounded-xl border border-gray-200/50 bg-gray-50/50 p-5 dark:border-dark-700/50 dark:bg-dark-800/50">
|
||||||
|
<div class="mb-3 flex h-10 w-10 items-center justify-center rounded-lg bg-gray-100 dark:bg-dark-700">
|
||||||
|
<svg class="h-5 w-5 text-gray-500" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
|
||||||
|
<path stroke-linecap="round" stroke-linejoin="round" d="M18.364 18.364A9 9 0 005.636 5.636m12.728 12.728A9 9 0 015.636 5.636m12.728 12.728L5.636 5.636" />
|
||||||
|
</svg>
|
||||||
|
</div>
|
||||||
|
<h3 class="mb-1.5 font-semibold text-gray-900 dark:text-white">{{ t('home.painPoints.items.noControl.title') }}</h3>
|
||||||
|
<p class="text-sm text-gray-600 dark:text-dark-400">{{ t('home.painPoints.items.noControl.desc') }}</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- Solutions Section Title -->
|
||||||
|
<div class="mb-8 text-center">
|
||||||
|
<h2 class="mb-2 text-2xl font-bold text-gray-900 dark:text-white md:text-3xl">
|
||||||
|
{{ t('home.solutions.title') }}
|
||||||
|
</h2>
|
||||||
|
<p class="text-gray-600 dark:text-dark-400">{{ t('home.solutions.subtitle') }}</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
<!-- Features Grid -->
|
<!-- Features Grid -->
|
||||||
<div class="mb-12 grid gap-6 md:grid-cols-3">
|
<div class="mb-12 grid gap-6 md:grid-cols-3">
|
||||||
<!-- Feature 1: Unified Gateway -->
|
<!-- Feature 1: Unified Gateway -->
|
||||||
@@ -369,6 +429,77 @@
|
|||||||
>
|
>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<!-- Comparison Table -->
|
||||||
|
<div class="mb-16">
|
||||||
|
<h2 class="mb-8 text-center text-2xl font-bold text-gray-900 dark:text-white md:text-3xl">
|
||||||
|
{{ t('home.comparison.title') }}
|
||||||
|
</h2>
|
||||||
|
<div class="overflow-x-auto">
|
||||||
|
<table class="w-full rounded-xl border border-gray-200/50 bg-white/60 backdrop-blur-sm dark:border-dark-700/50 dark:bg-dark-800/60">
|
||||||
|
<thead>
|
||||||
|
<tr class="border-b border-gray-200/50 dark:border-dark-700/50">
|
||||||
|
<th class="px-6 py-4 text-left text-sm font-semibold text-gray-900 dark:text-white">{{ t('home.comparison.headers.feature') }}</th>
|
||||||
|
<th class="px-6 py-4 text-center text-sm font-semibold text-gray-500 dark:text-dark-400">{{ t('home.comparison.headers.official') }}</th>
|
||||||
|
<th class="px-6 py-4 text-center text-sm font-semibold text-primary-600 dark:text-primary-400">{{ t('home.comparison.headers.us') }}</th>
|
||||||
|
</tr>
|
||||||
|
</thead>
|
||||||
|
<tbody class="divide-y divide-gray-200/50 dark:divide-dark-700/50">
|
||||||
|
<tr>
|
||||||
|
<td class="px-6 py-4 text-sm font-medium text-gray-900 dark:text-white">{{ t('home.comparison.items.pricing.feature') }}</td>
|
||||||
|
<td class="px-6 py-4 text-center text-sm text-gray-500 dark:text-dark-400">{{ t('home.comparison.items.pricing.official') }}</td>
|
||||||
|
<td class="px-6 py-4 text-center text-sm font-medium text-primary-600 dark:text-primary-400">{{ t('home.comparison.items.pricing.us') }}</td>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td class="px-6 py-4 text-sm font-medium text-gray-900 dark:text-white">{{ t('home.comparison.items.models.feature') }}</td>
|
||||||
|
<td class="px-6 py-4 text-center text-sm text-gray-500 dark:text-dark-400">{{ t('home.comparison.items.models.official') }}</td>
|
||||||
|
<td class="px-6 py-4 text-center text-sm font-medium text-primary-600 dark:text-primary-400">{{ t('home.comparison.items.models.us') }}</td>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td class="px-6 py-4 text-sm font-medium text-gray-900 dark:text-white">{{ t('home.comparison.items.management.feature') }}</td>
|
||||||
|
<td class="px-6 py-4 text-center text-sm text-gray-500 dark:text-dark-400">{{ t('home.comparison.items.management.official') }}</td>
|
||||||
|
<td class="px-6 py-4 text-center text-sm font-medium text-primary-600 dark:text-primary-400">{{ t('home.comparison.items.management.us') }}</td>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td class="px-6 py-4 text-sm font-medium text-gray-900 dark:text-white">{{ t('home.comparison.items.stability.feature') }}</td>
|
||||||
|
<td class="px-6 py-4 text-center text-sm text-gray-500 dark:text-dark-400">{{ t('home.comparison.items.stability.official') }}</td>
|
||||||
|
<td class="px-6 py-4 text-center text-sm font-medium text-primary-600 dark:text-primary-400">{{ t('home.comparison.items.stability.us') }}</td>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td class="px-6 py-4 text-sm font-medium text-gray-900 dark:text-white">{{ t('home.comparison.items.control.feature') }}</td>
|
||||||
|
<td class="px-6 py-4 text-center text-sm text-gray-500 dark:text-dark-400">{{ t('home.comparison.items.control.official') }}</td>
|
||||||
|
<td class="px-6 py-4 text-center text-sm font-medium text-primary-600 dark:text-primary-400">{{ t('home.comparison.items.control.us') }}</td>
|
||||||
|
</tr>
|
||||||
|
</tbody>
|
||||||
|
</table>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- CTA Section -->
|
||||||
|
<div class="mb-8 rounded-2xl bg-gradient-to-r from-primary-500 to-primary-600 p-8 text-center shadow-xl shadow-primary-500/20 md:p-12">
|
||||||
|
<h2 class="mb-3 text-2xl font-bold text-white md:text-3xl">
|
||||||
|
{{ t('home.cta.title') }}
|
||||||
|
</h2>
|
||||||
|
<p class="mb-6 text-primary-100">
|
||||||
|
{{ t('home.cta.description') }}
|
||||||
|
</p>
|
||||||
|
<router-link
|
||||||
|
v-if="!isAuthenticated"
|
||||||
|
to="/register"
|
||||||
|
class="inline-flex items-center gap-2 rounded-full bg-white px-8 py-3 font-semibold text-primary-600 shadow-lg transition-all hover:bg-gray-50 hover:shadow-xl"
|
||||||
|
>
|
||||||
|
{{ t('home.cta.button') }}
|
||||||
|
<Icon name="arrowRight" size="md" :stroke-width="2" />
|
||||||
|
</router-link>
|
||||||
|
<router-link
|
||||||
|
v-else
|
||||||
|
:to="dashboardPath"
|
||||||
|
class="inline-flex items-center gap-2 rounded-full bg-white px-8 py-3 font-semibold text-primary-600 shadow-lg transition-all hover:bg-gray-50 hover:shadow-xl"
|
||||||
|
>
|
||||||
|
{{ t('home.goToDashboard') }}
|
||||||
|
<Icon name="arrowRight" size="md" :stroke-width="2" />
|
||||||
|
</router-link>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</main>
|
</main>
|
||||||
|
|
||||||
@@ -380,27 +511,20 @@
|
|||||||
<p class="text-sm text-gray-500 dark:text-dark-400">
|
<p class="text-sm text-gray-500 dark:text-dark-400">
|
||||||
© {{ currentYear }} {{ siteName }}. {{ t('home.footer.allRightsReserved') }}
|
© {{ currentYear }} {{ siteName }}. {{ t('home.footer.allRightsReserved') }}
|
||||||
</p>
|
</p>
|
||||||
<div class="flex items-center gap-4">
|
<a
|
||||||
<a
|
v-if="docUrl"
|
||||||
v-if="docUrl"
|
:href="docUrl"
|
||||||
:href="docUrl"
|
target="_blank"
|
||||||
target="_blank"
|
rel="noopener noreferrer"
|
||||||
rel="noopener noreferrer"
|
class="text-sm text-gray-500 transition-colors hover:text-gray-700 dark:text-dark-400 dark:hover:text-white"
|
||||||
class="text-sm text-gray-500 transition-colors hover:text-gray-700 dark:text-dark-400 dark:hover:text-white"
|
>
|
||||||
>
|
{{ t('home.docs') }}
|
||||||
{{ t('home.docs') }}
|
</a>
|
||||||
</a>
|
|
||||||
<a
|
|
||||||
:href="githubUrl"
|
|
||||||
target="_blank"
|
|
||||||
rel="noopener noreferrer"
|
|
||||||
class="text-sm text-gray-500 transition-colors hover:text-gray-700 dark:text-dark-400 dark:hover:text-white"
|
|
||||||
>
|
|
||||||
GitHub
|
|
||||||
</a>
|
|
||||||
</div>
|
|
||||||
</div>
|
</div>
|
||||||
</footer>
|
</footer>
|
||||||
|
|
||||||
|
<!-- 微信客服悬浮按钮 -->
|
||||||
|
<WechatServiceButton />
|
||||||
</div>
|
</div>
|
||||||
</template>
|
</template>
|
||||||
|
|
||||||
@@ -410,6 +534,7 @@ import { useI18n } from 'vue-i18n'
|
|||||||
import { useAuthStore, useAppStore } from '@/stores'
|
import { useAuthStore, useAppStore } from '@/stores'
|
||||||
import LocaleSwitcher from '@/components/common/LocaleSwitcher.vue'
|
import LocaleSwitcher from '@/components/common/LocaleSwitcher.vue'
|
||||||
import Icon from '@/components/icons/Icon.vue'
|
import Icon from '@/components/icons/Icon.vue'
|
||||||
|
import WechatServiceButton from '@/components/common/WechatServiceButton.vue'
|
||||||
|
|
||||||
const { t } = useI18n()
|
const { t } = useI18n()
|
||||||
|
|
||||||
@@ -419,7 +544,6 @@ const appStore = useAppStore()
|
|||||||
// Site settings - directly from appStore (already initialized from injected config)
|
// Site settings - directly from appStore (already initialized from injected config)
|
||||||
const siteName = computed(() => appStore.cachedPublicSettings?.site_name || appStore.siteName || 'Sub2API')
|
const siteName = computed(() => appStore.cachedPublicSettings?.site_name || appStore.siteName || 'Sub2API')
|
||||||
const siteLogo = computed(() => appStore.cachedPublicSettings?.site_logo || appStore.siteLogo || '')
|
const siteLogo = computed(() => appStore.cachedPublicSettings?.site_logo || appStore.siteLogo || '')
|
||||||
const siteSubtitle = computed(() => appStore.cachedPublicSettings?.site_subtitle || 'AI API Gateway Platform')
|
|
||||||
const docUrl = computed(() => appStore.cachedPublicSettings?.doc_url || appStore.docUrl || '')
|
const docUrl = computed(() => appStore.cachedPublicSettings?.doc_url || appStore.docUrl || '')
|
||||||
const homeContent = computed(() => appStore.cachedPublicSettings?.home_content || '')
|
const homeContent = computed(() => appStore.cachedPublicSettings?.home_content || '')
|
||||||
|
|
||||||
@@ -432,9 +556,6 @@ const isHomeContentUrl = computed(() => {
|
|||||||
// Theme
|
// Theme
|
||||||
const isDark = ref(document.documentElement.classList.contains('dark'))
|
const isDark = ref(document.documentElement.classList.contains('dark'))
|
||||||
|
|
||||||
// GitHub URL
|
|
||||||
const githubUrl = 'https://github.com/Wei-Shaw/sub2api'
|
|
||||||
|
|
||||||
// Auth state
|
// Auth state
|
||||||
const isAuthenticated = computed(() => authStore.isAuthenticated)
|
const isAuthenticated = computed(() => authStore.isAuthenticated)
|
||||||
const isAdmin = computed(() => authStore.isAdmin)
|
const isAdmin = computed(() => authStore.isAdmin)
|
||||||
|
|||||||
@@ -56,7 +56,6 @@ interface SummaryRow {
|
|||||||
total_accounts: number
|
total_accounts: number
|
||||||
available_accounts: number
|
available_accounts: number
|
||||||
rate_limited_accounts: number
|
rate_limited_accounts: number
|
||||||
scope_rate_limit_count?: Record<string, number>
|
|
||||||
error_accounts: number
|
error_accounts: number
|
||||||
// 并发统计
|
// 并发统计
|
||||||
total_concurrency: number
|
total_concurrency: number
|
||||||
@@ -122,7 +121,6 @@ const platformRows = computed((): SummaryRow[] => {
|
|||||||
total_accounts: totalAccounts,
|
total_accounts: totalAccounts,
|
||||||
available_accounts: availableAccounts,
|
available_accounts: availableAccounts,
|
||||||
rate_limited_accounts: safeNumber(avail.rate_limit_count),
|
rate_limited_accounts: safeNumber(avail.rate_limit_count),
|
||||||
scope_rate_limit_count: avail.scope_rate_limit_count,
|
|
||||||
error_accounts: safeNumber(avail.error_count),
|
error_accounts: safeNumber(avail.error_count),
|
||||||
total_concurrency: totalConcurrency,
|
total_concurrency: totalConcurrency,
|
||||||
used_concurrency: usedConcurrency,
|
used_concurrency: usedConcurrency,
|
||||||
@@ -162,7 +160,6 @@ const groupRows = computed((): SummaryRow[] => {
|
|||||||
total_accounts: totalAccounts,
|
total_accounts: totalAccounts,
|
||||||
available_accounts: availableAccounts,
|
available_accounts: availableAccounts,
|
||||||
rate_limited_accounts: safeNumber(avail.rate_limit_count),
|
rate_limited_accounts: safeNumber(avail.rate_limit_count),
|
||||||
scope_rate_limit_count: avail.scope_rate_limit_count,
|
|
||||||
error_accounts: safeNumber(avail.error_count),
|
error_accounts: safeNumber(avail.error_count),
|
||||||
total_concurrency: totalConcurrency,
|
total_concurrency: totalConcurrency,
|
||||||
used_concurrency: usedConcurrency,
|
used_concurrency: usedConcurrency,
|
||||||
@@ -329,15 +326,6 @@ function formatDuration(seconds: number): string {
|
|||||||
return `${hours}h`
|
return `${hours}h`
|
||||||
}
|
}
|
||||||
|
|
||||||
function formatScopeName(scope: string): string {
|
|
||||||
const names: Record<string, string> = {
|
|
||||||
claude: 'Claude',
|
|
||||||
gemini_text: 'Gemini',
|
|
||||||
gemini_image: 'Image'
|
|
||||||
}
|
|
||||||
return names[scope] || scope
|
|
||||||
}
|
|
||||||
|
|
||||||
watch(
|
watch(
|
||||||
() => realtimeEnabled.value,
|
() => realtimeEnabled.value,
|
||||||
async (enabled) => {
|
async (enabled) => {
|
||||||
@@ -505,18 +493,6 @@ watch(
|
|||||||
{{ t('admin.ops.concurrency.rateLimited', { count: row.rate_limited_accounts }) }}
|
{{ t('admin.ops.concurrency.rateLimited', { count: row.rate_limited_accounts }) }}
|
||||||
</span>
|
</span>
|
||||||
|
|
||||||
<!-- Scope 限流 (仅 Antigravity) -->
|
|
||||||
<template v-if="row.scope_rate_limit_count && Object.keys(row.scope_rate_limit_count).length > 0">
|
|
||||||
<span
|
|
||||||
v-for="(count, scope) in row.scope_rate_limit_count"
|
|
||||||
:key="scope"
|
|
||||||
class="rounded-full bg-orange-100 px-1.5 py-0.5 font-semibold text-orange-700 dark:bg-orange-900/30 dark:text-orange-400"
|
|
||||||
:title="t('admin.ops.concurrency.scopeRateLimitedTooltip', { scope, count })"
|
|
||||||
>
|
|
||||||
{{ formatScopeName(scope as string) }} {{ count }}
|
|
||||||
</span>
|
|
||||||
</template>
|
|
||||||
|
|
||||||
<!-- 异常账号 -->
|
<!-- 异常账号 -->
|
||||||
<span
|
<span
|
||||||
v-if="row.error_accounts > 0"
|
v-if="row.error_accounts > 0"
|
||||||
|
|||||||
127
stress_test_gemini_session.sh
Normal file
127
stress_test_gemini_session.sh
Normal file
@@ -0,0 +1,127 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Gemini 粘性会话压力测试脚本
|
||||||
|
# 测试目标:验证不同会话分配不同账号,同一会话保持同一账号
|
||||||
|
|
||||||
|
BASE_URL="http://host.clicodeplus.com:8080"
|
||||||
|
API_KEY="sk-32ad0a3197e528c840ea84f0dc6b2056dd3fead03526b5c605a60709bd408f7e"
|
||||||
|
MODEL="gemini-2.5-flash"
|
||||||
|
|
||||||
|
# 创建临时目录存放结果
|
||||||
|
RESULT_DIR="/tmp/gemini_stress_test_$(date +%s)"
|
||||||
|
mkdir -p "$RESULT_DIR"
|
||||||
|
|
||||||
|
echo "=========================================="
|
||||||
|
echo "Gemini 粘性会话压力测试"
|
||||||
|
echo "结果目录: $RESULT_DIR"
|
||||||
|
echo "=========================================="
|
||||||
|
|
||||||
|
# 函数:发送请求并记录
|
||||||
|
send_request() {
|
||||||
|
local session_id=$1
|
||||||
|
local round=$2
|
||||||
|
local system_prompt=$3
|
||||||
|
local contents=$4
|
||||||
|
local output_file="$RESULT_DIR/session_${session_id}_round_${round}.json"
|
||||||
|
|
||||||
|
local request_body=$(cat <<EOF
|
||||||
|
{
|
||||||
|
"systemInstruction": {
|
||||||
|
"parts": [{"text": "$system_prompt"}]
|
||||||
|
},
|
||||||
|
"contents": $contents
|
||||||
|
}
|
||||||
|
EOF
|
||||||
|
)
|
||||||
|
|
||||||
|
curl -s -X POST "${BASE_URL}/v1beta/models/${MODEL}:generateContent" \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-H "x-goog-api-key: ${API_KEY}" \
|
||||||
|
-d "$request_body" > "$output_file" 2>&1
|
||||||
|
|
||||||
|
echo "[Session $session_id Round $round] 完成"
|
||||||
|
}
|
||||||
|
|
||||||
|
# 会话1:数学计算器(累加序列)
|
||||||
|
run_session_1() {
|
||||||
|
local sys_prompt="你是一个数学计算器,只返回计算结果数字,不要任何解释"
|
||||||
|
|
||||||
|
# Round 1: 1+1=?
|
||||||
|
send_request 1 1 "$sys_prompt" '[{"role":"user","parts":[{"text":"1+1=?"}]}]'
|
||||||
|
|
||||||
|
# Round 2: 继续 2+2=?(累加历史)
|
||||||
|
send_request 1 2 "$sys_prompt" '[{"role":"user","parts":[{"text":"1+1=?"}]},{"role":"model","parts":[{"text":"2"}]},{"role":"user","parts":[{"text":"2+2=?"}]}]'
|
||||||
|
|
||||||
|
# Round 3: 继续 3+3=?
|
||||||
|
send_request 1 3 "$sys_prompt" '[{"role":"user","parts":[{"text":"1+1=?"}]},{"role":"model","parts":[{"text":"2"}]},{"role":"user","parts":[{"text":"2+2=?"}]},{"role":"model","parts":[{"text":"4"}]},{"role":"user","parts":[{"text":"3+3=?"}]}]'
|
||||||
|
|
||||||
|
# Round 4: 批量计算 10+10, 20+20, 30+30
|
||||||
|
send_request 1 4 "$sys_prompt" '[{"role":"user","parts":[{"text":"1+1=?"}]},{"role":"model","parts":[{"text":"2"}]},{"role":"user","parts":[{"text":"2+2=?"}]},{"role":"model","parts":[{"text":"4"}]},{"role":"user","parts":[{"text":"3+3=?"}]},{"role":"model","parts":[{"text":"6"}]},{"role":"user","parts":[{"text":"计算: 10+10=? 20+20=? 30+30=?"}]}]'
|
||||||
|
}
|
||||||
|
|
||||||
|
# 会话2:英文翻译器(不同系统提示词 = 不同会话)
|
||||||
|
run_session_2() {
|
||||||
|
local sys_prompt="你是一个英文翻译器,将中文翻译成英文,只返回翻译结果"
|
||||||
|
|
||||||
|
send_request 2 1 "$sys_prompt" '[{"role":"user","parts":[{"text":"你好"}]}]'
|
||||||
|
send_request 2 2 "$sys_prompt" '[{"role":"user","parts":[{"text":"你好"}]},{"role":"model","parts":[{"text":"Hello"}]},{"role":"user","parts":[{"text":"世界"}]}]'
|
||||||
|
send_request 2 3 "$sys_prompt" '[{"role":"user","parts":[{"text":"你好"}]},{"role":"model","parts":[{"text":"Hello"}]},{"role":"user","parts":[{"text":"世界"}]},{"role":"model","parts":[{"text":"World"}]},{"role":"user","parts":[{"text":"早上好"}]}]'
|
||||||
|
}
|
||||||
|
|
||||||
|
# 会话3:日文翻译器
|
||||||
|
run_session_3() {
|
||||||
|
local sys_prompt="你是一个日文翻译器,将中文翻译成日文,只返回翻译结果"
|
||||||
|
|
||||||
|
send_request 3 1 "$sys_prompt" '[{"role":"user","parts":[{"text":"你好"}]}]'
|
||||||
|
send_request 3 2 "$sys_prompt" '[{"role":"user","parts":[{"text":"你好"}]},{"role":"model","parts":[{"text":"こんにちは"}]},{"role":"user","parts":[{"text":"谢谢"}]}]'
|
||||||
|
send_request 3 3 "$sys_prompt" '[{"role":"user","parts":[{"text":"你好"}]},{"role":"model","parts":[{"text":"こんにちは"}]},{"role":"user","parts":[{"text":"谢谢"}]},{"role":"model","parts":[{"text":"ありがとう"}]},{"role":"user","parts":[{"text":"再见"}]}]'
|
||||||
|
}
|
||||||
|
|
||||||
|
# 会话4:乘法计算器(另一个数学会话,但系统提示词不同)
|
||||||
|
run_session_4() {
|
||||||
|
local sys_prompt="你是一个乘法专用计算器,只计算乘法,返回数字结果"
|
||||||
|
|
||||||
|
send_request 4 1 "$sys_prompt" '[{"role":"user","parts":[{"text":"2*3=?"}]}]'
|
||||||
|
send_request 4 2 "$sys_prompt" '[{"role":"user","parts":[{"text":"2*3=?"}]},{"role":"model","parts":[{"text":"6"}]},{"role":"user","parts":[{"text":"4*5=?"}]}]'
|
||||||
|
send_request 4 3 "$sys_prompt" '[{"role":"user","parts":[{"text":"2*3=?"}]},{"role":"model","parts":[{"text":"6"}]},{"role":"user","parts":[{"text":"4*5=?"}]},{"role":"model","parts":[{"text":"20"}]},{"role":"user","parts":[{"text":"计算: 10*10=? 20*20=?"}]}]'
|
||||||
|
}
|
||||||
|
|
||||||
|
# 会话5:诗人(完全不同的角色)
|
||||||
|
run_session_5() {
|
||||||
|
local sys_prompt="你是一位诗人,用简短的诗句回应每个话题,每次只写一句诗"
|
||||||
|
|
||||||
|
send_request 5 1 "$sys_prompt" '[{"role":"user","parts":[{"text":"春天"}]}]'
|
||||||
|
send_request 5 2 "$sys_prompt" '[{"role":"user","parts":[{"text":"春天"}]},{"role":"model","parts":[{"text":"春风拂面花满枝"}]},{"role":"user","parts":[{"text":"夏天"}]}]'
|
||||||
|
send_request 5 3 "$sys_prompt" '[{"role":"user","parts":[{"text":"春天"}]},{"role":"model","parts":[{"text":"春风拂面花满枝"}]},{"role":"user","parts":[{"text":"夏天"}]},{"role":"model","parts":[{"text":"蝉鸣蛙声伴荷香"}]},{"role":"user","parts":[{"text":"秋天"}]}]'
|
||||||
|
}
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "开始并发测试 5 个独立会话..."
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# 并发运行所有会话
|
||||||
|
run_session_1 &
|
||||||
|
run_session_2 &
|
||||||
|
run_session_3 &
|
||||||
|
run_session_4 &
|
||||||
|
run_session_5 &
|
||||||
|
|
||||||
|
# 等待所有后台任务完成
|
||||||
|
wait
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "=========================================="
|
||||||
|
echo "所有请求完成,结果保存在: $RESULT_DIR"
|
||||||
|
echo "=========================================="
|
||||||
|
|
||||||
|
# 显示结果摘要
|
||||||
|
echo ""
|
||||||
|
echo "响应摘要:"
|
||||||
|
for f in "$RESULT_DIR"/*.json; do
|
||||||
|
filename=$(basename "$f")
|
||||||
|
response=$(cat "$f" | head -c 200)
|
||||||
|
echo "[$filename]: ${response}..."
|
||||||
|
done
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "请检查服务器日志确认账号分配情况"
|
||||||
Reference in New Issue
Block a user