mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-02 22:42:14 +08:00
feat: squash merge all changes from develop-0.1.75
Squash of 124 commits from the legacy develop branch (develop-0.1.75) onto a clean v0.1.75 upstream base, to simplify future upstream merges. Key changes included: - Refactor scope-level rate limiting to model-level rate limiting - Antigravity gateway service improvements (smart retry, error policy) - Digest session store (flat cache replacing Trie-based store) - Client disconnect detection during streaming - Gemini messages compatibility service enhancements - Scheduler shuffle for thundering herd prevention - Session hash generation improvements - Frontend customizations (WeChat service, HomeView, etc.) - Ops monitoring scope cleanup
This commit is contained in:
2
.github/workflows/backend-ci.yml
vendored
2
.github/workflows/backend-ci.yml
vendored
@@ -17,6 +17,7 @@ jobs:
|
||||
go-version-file: backend/go.mod
|
||||
check-latest: false
|
||||
cache: true
|
||||
cache-dependency-path: backend/go.sum
|
||||
- name: Verify Go version
|
||||
run: |
|
||||
go version | grep -q 'go1.25.7'
|
||||
@@ -36,6 +37,7 @@ jobs:
|
||||
go-version-file: backend/go.mod
|
||||
check-latest: false
|
||||
cache: true
|
||||
cache-dependency-path: backend/go.sum
|
||||
- name: Verify Go version
|
||||
run: |
|
||||
go version | grep -q 'go1.25.7'
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -78,6 +78,7 @@ Desktop.ini
|
||||
# ===================
|
||||
tmp/
|
||||
temp/
|
||||
logs/
|
||||
*.tmp
|
||||
*.temp
|
||||
*.log
|
||||
|
||||
723
AGENTS.md
Normal file
723
AGENTS.md
Normal file
@@ -0,0 +1,723 @@
|
||||
# Sub2API 开发说明
|
||||
|
||||
## 版本管理策略
|
||||
|
||||
### 版本号规则
|
||||
|
||||
我们在官方版本号后面添加自己的小版本号:
|
||||
|
||||
- 官方版本:`v0.1.68`
|
||||
- 我们的版本:`v0.1.68.1`、`v0.1.68.2`(递增)
|
||||
|
||||
### 分支策略
|
||||
|
||||
| 分支 | 说明 |
|
||||
|------|------|
|
||||
| `main` | 我们的主分支,包含所有定制功能 |
|
||||
| `release/custom-X.Y.Z` | 基于官方 `vX.Y.Z` 的发布分支 |
|
||||
| `upstream/main` | 上游官方仓库 |
|
||||
|
||||
---
|
||||
|
||||
## 发布流程(基于新官方版本)
|
||||
|
||||
当官方发布新版本(如 `v0.1.69`)时:
|
||||
|
||||
### 1. 同步上游并创建发布分支
|
||||
|
||||
```bash
|
||||
# 获取上游最新代码
|
||||
git fetch upstream --tags
|
||||
|
||||
# 基于官方标签创建新的发布分支
|
||||
git checkout v0.1.69 -b release/custom-0.1.69
|
||||
|
||||
# 合并我们的 main 分支(包含所有定制功能)
|
||||
git merge main --no-edit
|
||||
|
||||
# 解决可能的冲突后继续
|
||||
```
|
||||
|
||||
### 2. 更新版本号并打标签
|
||||
|
||||
```bash
|
||||
# 更新版本号文件
|
||||
echo "0.1.69.1" > backend/cmd/server/VERSION
|
||||
git add backend/cmd/server/VERSION
|
||||
git commit -m "chore: bump version to 0.1.69.1"
|
||||
|
||||
# 打上我们自己的标签
|
||||
git tag v0.1.69.1
|
||||
|
||||
# 推送分支和标签
|
||||
git push origin release/custom-0.1.69
|
||||
git push origin v0.1.69.1
|
||||
```
|
||||
|
||||
### 3. 更新 main 分支
|
||||
|
||||
```bash
|
||||
# 将发布分支合并回 main,保持 main 包含最新定制功能
|
||||
git checkout main
|
||||
git merge release/custom-0.1.69
|
||||
git push origin main
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 热修复发布(在现有版本上修复)
|
||||
|
||||
当需要在当前版本上发布修复时:
|
||||
|
||||
```bash
|
||||
# 在当前发布分支上修复
|
||||
git checkout release/custom-0.1.68
|
||||
# ... 进行修复 ...
|
||||
git commit -m "fix: 修复描述"
|
||||
|
||||
# 递增小版本号
|
||||
echo "0.1.68.2" > backend/cmd/server/VERSION
|
||||
git add backend/cmd/server/VERSION
|
||||
git commit -m "chore: bump version to 0.1.68.2"
|
||||
|
||||
# 打标签并推送
|
||||
git tag v0.1.68.2
|
||||
git push origin release/custom-0.1.68
|
||||
git push origin v0.1.68.2
|
||||
|
||||
# 同步修复到 main
|
||||
git checkout main
|
||||
git cherry-pick <fix-commit-hash>
|
||||
git push origin main
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 服务器部署流程
|
||||
|
||||
### 前置条件
|
||||
|
||||
- 本地已配置 SSH 别名 `clicodeplus` 连接到服务器
|
||||
- 服务器部署目录:`/root/sub2api`(正式)、`/root/sub2api-beta`(测试)
|
||||
- 服务器使用 Docker Compose 部署
|
||||
|
||||
### 部署环境说明
|
||||
|
||||
| 环境 | 目录 | 端口 | 数据库 | 容器名 |
|
||||
|------|------|------|--------|--------|
|
||||
| 正式 | `/root/sub2api` | 8080 | `sub2api` | `sub2api` |
|
||||
| Beta | `/root/sub2api-beta` | 8084 | `beta` | `sub2api-beta` |
|
||||
|
||||
### 外部数据库
|
||||
|
||||
正式和 Beta 环境**共用外部 PostgreSQL 数据库**(非容器内数据库),配置在 `.env` 文件中:
|
||||
- `DATABASE_HOST`:外部数据库地址
|
||||
- `DATABASE_SSLMODE`:SSL 模式(通常为 `require`)
|
||||
- `POSTGRES_USER` / `POSTGRES_DB`:用户名和数据库名
|
||||
|
||||
#### 数据库操作命令
|
||||
|
||||
通过 SSH 在服务器上执行数据库操作:
|
||||
|
||||
```bash
|
||||
# 正式环境 - 查询迁移记录
|
||||
ssh clicodeplus "source /root/sub2api/deploy/.env && PGPASSWORD=\"\$POSTGRES_PASSWORD\" psql -h \$DATABASE_HOST -U \$POSTGRES_USER -d \$POSTGRES_DB -c 'SELECT * FROM schema_migrations ORDER BY applied_at DESC LIMIT 5;'"
|
||||
|
||||
# Beta 环境 - 查询迁移记录
|
||||
ssh clicodeplus "source /root/sub2api-beta/deploy/.env && PGPASSWORD=\"\$POSTGRES_PASSWORD\" psql -h \$DATABASE_HOST -U \$POSTGRES_USER -d \$POSTGRES_DB -c 'SELECT * FROM schema_migrations ORDER BY applied_at DESC LIMIT 5;'"
|
||||
|
||||
# Beta 环境 - 清除指定迁移记录(重新执行迁移)
|
||||
ssh clicodeplus "source /root/sub2api-beta/deploy/.env && PGPASSWORD=\"\$POSTGRES_PASSWORD\" psql -h \$DATABASE_HOST -U \$POSTGRES_USER -d \$POSTGRES_DB -c \"DELETE FROM schema_migrations WHERE filename LIKE '%049%';\""
|
||||
|
||||
# Beta 环境 - 更新账号数据
|
||||
ssh clicodeplus "source /root/sub2api-beta/deploy/.env && PGPASSWORD=\"\$POSTGRES_PASSWORD\" psql -h \$DATABASE_HOST -U \$POSTGRES_USER -d \$POSTGRES_DB -c \"UPDATE accounts SET credentials = credentials - 'model_mapping' WHERE platform = 'antigravity';\""
|
||||
```
|
||||
|
||||
> **注意**:使用 `source .env` 加载环境变量,避免在命令行中暴露密码。
|
||||
|
||||
### 部署步骤
|
||||
|
||||
**重要:每次部署都必须递增版本号!**
|
||||
|
||||
#### 0. 递增版本号(本地操作)
|
||||
|
||||
每次部署前,先在本地递增小版本号:
|
||||
|
||||
```bash
|
||||
# 查看当前版本号
|
||||
cat backend/cmd/server/VERSION
|
||||
# 假设当前是 0.1.69.1
|
||||
|
||||
# 递增版本号
|
||||
echo "0.1.69.2" > backend/cmd/server/VERSION
|
||||
git add backend/cmd/server/VERSION
|
||||
git commit -m "chore: bump version to 0.1.69.2"
|
||||
git push origin release/custom-0.1.69
|
||||
```
|
||||
|
||||
#### 1. 服务器拉取代码
|
||||
|
||||
```bash
|
||||
ssh clicodeplus "cd /root/sub2api && git fetch fork && git checkout -B release/custom-0.1.69 fork/release/custom-0.1.69"
|
||||
```
|
||||
|
||||
#### 2. 服务器构建镜像
|
||||
|
||||
```bash
|
||||
ssh clicodeplus "cd /root/sub2api && docker build --no-cache -t sub2api:latest -f Dockerfile ."
|
||||
```
|
||||
|
||||
#### 3. 更新镜像标签并重启服务
|
||||
|
||||
```bash
|
||||
ssh clicodeplus "docker tag sub2api:latest weishaw/sub2api:latest"
|
||||
ssh clicodeplus "cd /root/sub2api/deploy && docker compose up -d --force-recreate sub2api"
|
||||
```
|
||||
|
||||
#### 4. 验证部署
|
||||
|
||||
```bash
|
||||
# 查看启动日志
|
||||
ssh clicodeplus "docker logs sub2api --tail 20"
|
||||
|
||||
# 确认版本号(必须与步骤 0 中设置的版本号一致)
|
||||
ssh clicodeplus "cat /root/sub2api/backend/cmd/server/VERSION"
|
||||
|
||||
# 检查容器状态
|
||||
ssh clicodeplus "docker ps | grep sub2api"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Beta 并行部署(不影响现网)
|
||||
|
||||
目标:在同一台服务器上并行启动一个 beta 实例(例如端口 `8084`),**严禁改动/重启**现网实例(默认目录 `/root/sub2api`)。
|
||||
|
||||
### 设计原则
|
||||
|
||||
- **新目录**:beta 使用独立目录,例如 `/root/sub2api-beta`。
|
||||
- **敏感信息只放 `.env`**:beta 的数据库密码、JWT_SECRET 等只写入 `/root/sub2api-beta/deploy/.env`,不要提交到 git。
|
||||
- **独立 Compose Project**:通过 `docker compose -p sub2api-beta ...` 启动,确保 network/volume 隔离。
|
||||
- **独立端口**:通过 `.env` 的 `SERVER_PORT` 映射宿主机端口(例如 `8084:8080`)。
|
||||
|
||||
### 前置检查
|
||||
|
||||
```bash
|
||||
# 1) 确保 8084 未被占用
|
||||
ssh clicodeplus "ss -ltnp | grep :8084 || echo '8084 is free'"
|
||||
|
||||
# 2) 确认现网容器还在(只读检查)
|
||||
ssh clicodeplus "docker ps --format 'table {{.Names}}\t{{.Image}}\t{{.Ports}}' | sed -n '1,200p'"
|
||||
```
|
||||
|
||||
### 首次部署步骤
|
||||
|
||||
```bash
|
||||
# 0) 进入服务器
|
||||
ssh clicodeplus
|
||||
|
||||
# 1) 克隆代码到新目录(示例使用你的 fork)
|
||||
cd /root
|
||||
git clone https://github.com/touwaeriol/sub2api.git sub2api-beta
|
||||
cd /root/sub2api-beta
|
||||
git checkout release/custom-0.1.71
|
||||
|
||||
# 2) 准备 beta 的 .env(敏感信息只写这里)
|
||||
cd /root/sub2api-beta/deploy
|
||||
|
||||
# 推荐:从现网 .env 复制,保证除 DB 名/用户/端口外完全一致
|
||||
cp -f /root/sub2api/deploy/.env ./.env
|
||||
|
||||
# 仅修改以下三项(其他保持不变)
|
||||
perl -pi -e 's/^SERVER_PORT=.*/SERVER_PORT=8084/' ./.env
|
||||
perl -pi -e 's/^POSTGRES_USER=.*/POSTGRES_USER=beta/' ./.env
|
||||
perl -pi -e 's/^POSTGRES_DB=.*/POSTGRES_DB=beta/' ./.env
|
||||
|
||||
# 3) 写 compose override(避免与现网容器名冲突,镜像使用本地构建的 sub2api:beta)
|
||||
cat > docker-compose.override.yml <<'YAML'
|
||||
services:
|
||||
sub2api:
|
||||
image: sub2api:beta
|
||||
container_name: sub2api-beta
|
||||
redis:
|
||||
container_name: sub2api-beta-redis
|
||||
YAML
|
||||
|
||||
# 4) 构建 beta 镜像(基于当前代码)
|
||||
cd /root/sub2api-beta
|
||||
docker build -t sub2api:beta -f Dockerfile .
|
||||
|
||||
# 5) 启动 beta(独立 project,确保不影响现网)
|
||||
cd /root/sub2api-beta/deploy
|
||||
docker compose -p sub2api-beta --env-file .env -f docker-compose.yml -f docker-compose.override.yml up -d
|
||||
|
||||
# 6) 验证 beta
|
||||
curl -fsS http://127.0.0.1:8084/health
|
||||
docker logs sub2api-beta --tail 50
|
||||
```
|
||||
|
||||
### 数据库配置约定(beta)
|
||||
|
||||
- 数据库地址/SSL/密码:与现网一致(从现网 `.env` 复制即可)。
|
||||
- 仅修改:
|
||||
- `POSTGRES_USER=beta`
|
||||
- `POSTGRES_DB=beta`
|
||||
|
||||
注意:需要数据库侧已存在 `beta` 用户与 `beta` 数据库,并授予权限;否则容器会启动失败并不断重启。
|
||||
|
||||
### 更新 beta(拉代码 + 仅重建 beta 容器)
|
||||
|
||||
```bash
|
||||
ssh clicodeplus "set -e; cd /root/sub2api-beta && git fetch --all --tags && git checkout -f release/custom-0.1.71 && git reset --hard origin/release/custom-0.1.71"
|
||||
ssh clicodeplus "cd /root/sub2api-beta && docker build -t sub2api:beta -f Dockerfile ."
|
||||
ssh clicodeplus "cd /root/sub2api-beta/deploy && docker compose -p sub2api-beta --env-file .env -f docker-compose.yml -f docker-compose.override.yml up -d --no-deps --force-recreate sub2api"
|
||||
ssh clicodeplus "curl -fsS http://127.0.0.1:8084/health"
|
||||
```
|
||||
|
||||
### 停止/回滚 beta(只影响 beta)
|
||||
|
||||
```bash
|
||||
ssh clicodeplus "cd /root/sub2api-beta/deploy && docker compose -p sub2api-beta -f docker-compose.yml -f docker-compose.override.yml down"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 服务器首次部署
|
||||
|
||||
### 1. 克隆代码并配置远程仓库
|
||||
|
||||
```bash
|
||||
ssh clicodeplus
|
||||
cd /root
|
||||
git clone https://github.com/Wei-Shaw/sub2api.git
|
||||
cd sub2api
|
||||
|
||||
# 添加 fork 仓库
|
||||
git remote add fork https://github.com/touwaeriol/sub2api.git
|
||||
```
|
||||
|
||||
### 2. 切换到定制分支并配置环境
|
||||
|
||||
```bash
|
||||
git fetch fork
|
||||
git checkout -B release/custom-0.1.69 fork/release/custom-0.1.69
|
||||
|
||||
cd deploy
|
||||
cp .env.example .env
|
||||
vim .env # 配置 DATABASE_URL, REDIS_URL, JWT_SECRET 等
|
||||
```
|
||||
|
||||
### 3. 构建并启动
|
||||
|
||||
```bash
|
||||
cd /root/sub2api
|
||||
docker build -t sub2api:latest -f Dockerfile .
|
||||
docker tag sub2api:latest weishaw/sub2api:latest
|
||||
cd deploy && docker compose up -d
|
||||
```
|
||||
|
||||
### 6. 启动服务
|
||||
|
||||
```bash
|
||||
# 进入 deploy 目录
|
||||
cd deploy
|
||||
|
||||
# 启动所有服务(PostgreSQL、Redis、sub2api)
|
||||
docker compose up -d
|
||||
|
||||
# 查看服务状态
|
||||
docker compose ps
|
||||
```
|
||||
|
||||
### 7. 验证部署
|
||||
|
||||
```bash
|
||||
# 查看应用日志
|
||||
docker logs sub2api --tail 50
|
||||
|
||||
# 检查健康状态
|
||||
curl http://localhost:8080/health
|
||||
|
||||
# 确认版本号
|
||||
cat /root/sub2api/backend/cmd/server/VERSION
|
||||
```
|
||||
|
||||
### 8. 常用运维命令
|
||||
|
||||
```bash
|
||||
# 查看实时日志
|
||||
docker logs -f sub2api
|
||||
|
||||
# 重启服务
|
||||
docker compose restart sub2api
|
||||
|
||||
# 停止所有服务
|
||||
docker compose down
|
||||
|
||||
# 停止并删除数据卷(慎用!会删除数据库数据)
|
||||
docker compose down -v
|
||||
|
||||
# 查看资源使用情况
|
||||
docker stats sub2api
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 定制功能说明
|
||||
|
||||
当前定制分支包含以下功能(相对于官方版本):
|
||||
|
||||
### UI/UX 定制
|
||||
|
||||
| 功能 | 说明 |
|
||||
|------|------|
|
||||
| 首页优化 | 面向用户的价值主张设计 |
|
||||
| 移除 GitHub 链接 | 用户菜单中不显示 GitHub 导航 |
|
||||
| 微信客服按钮 | 首页悬浮微信客服入口 |
|
||||
| 限流时间精确显示 | 账号限流时间显示精确到秒 |
|
||||
|
||||
### Antigravity 平台增强
|
||||
|
||||
| 功能 | 说明 |
|
||||
|------|------|
|
||||
| Scope 级别限流 | 按配额域(claude/gemini_text/gemini_image)独立限流,避免整个账号被锁定 |
|
||||
| 模型级别限流 | 按具体模型(如 claude-opus-4-5)独立限流,更精细的限流控制 |
|
||||
| 限流预检查 | 调度时预检查账号/模型限流状态,避免选中已限流账号 |
|
||||
| 秒级冷却时间 | 支持 429 响应的秒级精确冷却时间 |
|
||||
| 身份注入优化 | 模型身份信息注入 + 静默边界防止身份泄露 |
|
||||
| thoughtSignature 修复 | Gemini 3 函数调用 400 错误修复 |
|
||||
| max_tokens 自动修正 | 自动修正 max_tokens <= budget_tokens 导致的 400 错误 |
|
||||
|
||||
### 调度算法优化
|
||||
|
||||
| 功能 | 说明 |
|
||||
|------|------|
|
||||
| 分层过滤选择 | 调度算法从全排序改为分层过滤,提升性能 |
|
||||
| LRU 随机选择 | 相同 LRU 时间时随机选择,避免账号集中 |
|
||||
| 限流等待阈值配置化 | 可配置的限流等待阈值 |
|
||||
|
||||
### 运维增强
|
||||
|
||||
| 功能 | 说明 |
|
||||
|------|------|
|
||||
| Scope 限流统计 | 运维界面展示 Antigravity 账号 scope 级别限流统计 |
|
||||
| 账号限流状态显示 | 账号列表显示 scope 和模型级别限流状态 |
|
||||
| 清除限流按钮增强 | 有 scope/模型限流时也显示清除限流按钮 |
|
||||
|
||||
### 其他修复
|
||||
|
||||
| 功能 | 说明 |
|
||||
|------|------|
|
||||
| .gitattributes | 确保迁移文件使用 LF 换行符(解决 Windows 下 SQL 摘要不一致) |
|
||||
| 部署配置优化 | DATABASE_HOST 和 DATABASE_SSLMODE 可通过 .env 配置 |
|
||||
|
||||
---
|
||||
|
||||
## 注意事项
|
||||
|
||||
1. **前端必须打包进镜像**:使用 `docker build` 在服务器上构建,Dockerfile 会自动编译前端并 embed 到后端二进制中
|
||||
|
||||
2. **镜像标签**:docker-compose.yml 使用 `weishaw/sub2api:latest`,本地构建后需要 `docker tag` 覆盖
|
||||
|
||||
3. **Windows 换行符问题**:已通过 `.gitattributes` 解决,确保 `*.sql` 文件始终使用 LF
|
||||
|
||||
4. **版本号管理**:每次发布必须更新 `backend/cmd/server/VERSION` 并打标签
|
||||
|
||||
5. **合并冲突**:合并上游新版本时,重点关注以下文件可能的冲突:
|
||||
- `backend/internal/service/antigravity_gateway_service.go`
|
||||
- `backend/internal/service/gateway_service.go`
|
||||
- `backend/internal/pkg/antigravity/request_transformer.go`
|
||||
|
||||
---
|
||||
|
||||
## Go 代码规范
|
||||
|
||||
### 1. 函数设计
|
||||
|
||||
#### 单一职责原则
|
||||
- **函数行数**:单个函数常规不应超过 **30 行**,超过时应拆分为子函数。若某段逻辑确实不可拆分(如复杂的状态机、协议解析等),可以例外,但需添加注释说明原因
|
||||
- **嵌套层级**:避免超过 3 层嵌套,使用 early return 减少嵌套
|
||||
|
||||
```go
|
||||
// ❌ 不推荐:深层嵌套
|
||||
func process(data []Item) {
|
||||
for _, item := range data {
|
||||
if item.Valid {
|
||||
if item.Type == "A" {
|
||||
if item.Status == "active" {
|
||||
// 业务逻辑...
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ✅ 推荐:early return
|
||||
func process(data []Item) {
|
||||
for _, item := range data {
|
||||
if !item.Valid {
|
||||
continue
|
||||
}
|
||||
if item.Type != "A" {
|
||||
continue
|
||||
}
|
||||
if item.Status != "active" {
|
||||
continue
|
||||
}
|
||||
// 业务逻辑...
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
#### 复杂逻辑提取
|
||||
将复杂的条件判断或处理逻辑提取为独立函数:
|
||||
|
||||
```go
|
||||
// ❌ 不推荐:内联复杂逻辑
|
||||
if resp.StatusCode == 429 || resp.StatusCode == 503 {
|
||||
// 80+ 行处理逻辑...
|
||||
}
|
||||
|
||||
// ✅ 推荐:提取为独立函数
|
||||
result := handleRateLimitResponse(resp, params)
|
||||
switch result.action {
|
||||
case actionRetry:
|
||||
continue
|
||||
case actionBreak:
|
||||
return result.resp, nil
|
||||
}
|
||||
```
|
||||
|
||||
### 2. 重复代码消除
|
||||
|
||||
#### 配置获取模式
|
||||
将重复的配置获取逻辑提取为方法:
|
||||
|
||||
```go
|
||||
// ❌ 不推荐:重复代码
|
||||
logBody := s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBody
|
||||
maxBytes := 2048
|
||||
if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes > 0 {
|
||||
maxBytes = s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||
}
|
||||
|
||||
// ✅ 推荐:提取为方法
|
||||
func (s *Service) getLogConfig() (logBody bool, maxBytes int) {
|
||||
maxBytes = 2048
|
||||
if s.settingService == nil || s.settingService.cfg == nil {
|
||||
return false, maxBytes
|
||||
}
|
||||
cfg := s.settingService.cfg.Gateway
|
||||
if cfg.LogUpstreamErrorBodyMaxBytes > 0 {
|
||||
maxBytes = cfg.LogUpstreamErrorBodyMaxBytes
|
||||
}
|
||||
return cfg.LogUpstreamErrorBody, maxBytes
|
||||
}
|
||||
```
|
||||
|
||||
### 3. 常量管理
|
||||
|
||||
#### 避免魔法数字
|
||||
所有硬编码的数值都应定义为常量:
|
||||
|
||||
```go
|
||||
// ❌ 不推荐
|
||||
if retryDelay >= 10*time.Second {
|
||||
resetAt := time.Now().Add(30 * time.Second)
|
||||
}
|
||||
|
||||
// ✅ 推荐
|
||||
const (
|
||||
rateLimitThreshold = 10 * time.Second
|
||||
defaultRateLimitDuration = 30 * time.Second
|
||||
)
|
||||
|
||||
if retryDelay >= rateLimitThreshold {
|
||||
resetAt := time.Now().Add(defaultRateLimitDuration)
|
||||
}
|
||||
```
|
||||
|
||||
#### 注释引用常量名
|
||||
在注释中引用常量名而非硬编码值:
|
||||
|
||||
```go
|
||||
// ❌ 不推荐
|
||||
// < 10s: 等待后重试
|
||||
|
||||
// ✅ 推荐
|
||||
// < rateLimitThreshold: 等待后重试
|
||||
```
|
||||
|
||||
### 4. 错误处理
|
||||
|
||||
#### 使用结构化日志
|
||||
优先使用 `slog` 进行结构化日志记录:
|
||||
|
||||
```go
|
||||
// ❌ 不推荐
|
||||
log.Printf("%s status=%d model_rate_limit_failed model=%s error=%v", prefix, statusCode, modelName, err)
|
||||
|
||||
// ✅ 推荐
|
||||
slog.Error("failed to set model rate limit",
|
||||
"prefix", prefix,
|
||||
"status_code", statusCode,
|
||||
"model", modelName,
|
||||
"error", err,
|
||||
)
|
||||
```
|
||||
|
||||
### 5. 测试规范
|
||||
|
||||
#### Mock 函数签名同步
|
||||
修改函数签名时,必须同步更新所有测试中的 mock 函数:
|
||||
|
||||
```go
|
||||
// 如果修改了 handleError 签名
|
||||
handleError func(..., groupID int64, sessionHash string) *Result
|
||||
|
||||
// 必须同步更新测试中的 mock
|
||||
handleError: func(..., groupID int64, sessionHash string) *Result {
|
||||
return nil
|
||||
},
|
||||
```
|
||||
|
||||
#### 测试构建标签
|
||||
统一使用测试构建标签:
|
||||
|
||||
```go
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
```
|
||||
|
||||
### 6. 时间格式解析
|
||||
|
||||
#### 使用标准库
|
||||
优先使用 `time.ParseDuration`,支持所有 Go duration 格式:
|
||||
|
||||
```go
|
||||
// ❌ 不推荐:手动限制格式
|
||||
if !strings.HasSuffix(delay, "s") || strings.Contains(delay, "m") {
|
||||
continue
|
||||
}
|
||||
|
||||
// ✅ 推荐:使用标准库
|
||||
dur, err := time.ParseDuration(delay) // 支持 "0.5s", "4m50s", "1h30m" 等
|
||||
```
|
||||
|
||||
### 7. 接口设计
|
||||
|
||||
#### 接口隔离原则
|
||||
定义最小化接口,只包含必需的方法:
|
||||
|
||||
```go
|
||||
// ❌ 不推荐:使用过于宽泛的接口
|
||||
type AccountRepository interface {
|
||||
// 20+ 个方法...
|
||||
}
|
||||
|
||||
// ✅ 推荐:定义最小化接口
|
||||
type ModelRateLimiter interface {
|
||||
SetModelRateLimit(ctx context.Context, id int64, modelKey string, resetAt time.Time) error
|
||||
}
|
||||
```
|
||||
|
||||
### 8. 并发安全
|
||||
|
||||
#### 共享数据保护
|
||||
访问可能被并发修改的数据时,确保线程安全:
|
||||
|
||||
```go
|
||||
// 如果 Account.Extra 可能被并发修改
|
||||
// 需要使用互斥锁或原子操作保护读取
|
||||
func (a *Account) GetRateLimitRemainingTime(model string) time.Duration {
|
||||
a.mu.RLock()
|
||||
defer a.mu.RUnlock()
|
||||
// 读取 Extra 字段...
|
||||
}
|
||||
```
|
||||
|
||||
### 9. 命名规范
|
||||
|
||||
#### 一致的命名风格
|
||||
- 常量使用 camelCase:`rateLimitThreshold`
|
||||
- 类型使用 PascalCase:`AntigravityQuotaScope`
|
||||
- 同一概念使用统一命名:`Threshold` 或 `Limit`,不要混用
|
||||
|
||||
```go
|
||||
// ❌ 不推荐:命名不一致
|
||||
antigravitySmartRetryMinWait // 使用 Min
|
||||
antigravityRateLimitThreshold // 使用 Threshold
|
||||
|
||||
// ✅ 推荐:统一风格
|
||||
antigravityMinRetryWait
|
||||
antigravityRateLimitThreshold
|
||||
```
|
||||
|
||||
### 10. 代码审查清单
|
||||
|
||||
在提交代码前,检查以下项目:
|
||||
|
||||
- [ ] 函数是否超过 30 行?(不可拆分的逻辑除外,需注释说明)
|
||||
- [ ] 嵌套是否超过 3 层?
|
||||
- [ ] 是否有重复代码可以提取?
|
||||
- [ ] 是否使用了魔法数字?
|
||||
- [ ] Mock 函数签名是否与实际函数一致?
|
||||
- [ ] 测试是否覆盖了新增逻辑?
|
||||
- [ ] 日志是否包含足够的上下文信息?
|
||||
- [ ] 是否考虑了并发安全?
|
||||
|
||||
---
|
||||
|
||||
## CI 检查与发布门禁
|
||||
|
||||
### GitHub Actions 检查项
|
||||
|
||||
本项目有 4 个 CI 任务,**任何代码推送或发布前都必须全部通过**:
|
||||
|
||||
| Workflow | Job | 说明 | 本地验证命令 |
|
||||
|----------|-----|------|-------------|
|
||||
| CI | `test` | 单元测试 + 集成测试 | `cd backend && make test-unit && make test-integration` |
|
||||
| CI | `golangci-lint` | Go 代码静态检查(golangci-lint v2.7) | `cd backend && golangci-lint run --timeout=5m` |
|
||||
| Security Scan | `backend-security` | govulncheck + gosec 安全扫描 | `cd backend && govulncheck ./... && gosec -severity high -confidence high ./...` |
|
||||
| Security Scan | `frontend-security` | pnpm audit 前端依赖安全检查 | `cd frontend && pnpm audit --prod --audit-level=high` |
|
||||
|
||||
### 向上游提交 PR
|
||||
|
||||
PR 目标是上游官方仓库,**只包含通用功能改动**(bug fix、新功能、性能优化等)。
|
||||
|
||||
**以下文件禁止出现在 PR 中**(属于我们 fork 的定制化内容):
|
||||
- `CLAUDE.md`、`AGENTS.md` — 我们的开发文档
|
||||
- `backend/cmd/server/VERSION` — 我们的版本号文件
|
||||
- UI 定制改动(GitHub 链接移除、微信客服按钮、首页定制等)
|
||||
- 部署配置(`deploy/` 目录下的定制修改)
|
||||
|
||||
**PR 流程**:
|
||||
1. 从 `develop` 创建功能分支,只包含要提交给上游的改动
|
||||
2. 推送分支后,**等待 4 个 CI job 全部通过**
|
||||
3. 确认通过后再创建 PR
|
||||
4. 使用 `gh run list --repo touwaeriol/sub2api --branch <branch>` 检查状态
|
||||
|
||||
### 自有分支推送(develop / main)
|
||||
|
||||
推送到我们自己的 `develop` 或 `main` 分支时,包含所有改动(定制化 + 通用功能)。
|
||||
|
||||
**推送流程**:
|
||||
1. 本地运行 `cd backend && make test-unit` 确保单元测试通过
|
||||
2. 本地运行 `cd backend && gofmt -l ./...` 确保格式正确
|
||||
3. 推送后确认 CI 和 Security Scan 两个 workflow 的 4 个 job 全部绿色 ✅
|
||||
4. 任何 job 失败必须立即修复,**禁止在 CI 未通过的状态下继续后续操作**
|
||||
|
||||
### 发布版本
|
||||
|
||||
1. 确保 `main` 分支最新提交的 4 个 CI job 全部通过
|
||||
2. 递增 `backend/cmd/server/VERSION`,提交并推送
|
||||
3. 打 tag 推送后,确认 tag 触发的 3 个 workflow(CI、Security Scan、Release)全部通过
|
||||
4. **Release workflow 失败时禁止部署** — 必须先修复问题,删除旧 tag,重新打 tag
|
||||
5. 使用 `gh run list --repo touwaeriol/sub2api --limit 10` 确认状态
|
||||
|
||||
### 常见 CI 失败原因及修复
|
||||
- **gofmt**:struct 字段对齐不一致 → 运行 `gofmt -w <file>` 修复
|
||||
- **golangci-lint**:未使用的变量/导入 → 删除或使用 `_` 忽略
|
||||
- **test 失败**:mock 函数签名不一致 → 同步更新 mock
|
||||
- **gosec**:安全漏洞 → 根据提示修复或添加例外
|
||||
723
CLAUDE.md
Normal file
723
CLAUDE.md
Normal file
@@ -0,0 +1,723 @@
|
||||
# Sub2API 开发说明
|
||||
|
||||
## 版本管理策略
|
||||
|
||||
### 版本号规则
|
||||
|
||||
我们在官方版本号后面添加自己的小版本号:
|
||||
|
||||
- 官方版本:`v0.1.68`
|
||||
- 我们的版本:`v0.1.68.1`、`v0.1.68.2`(递增)
|
||||
|
||||
### 分支策略
|
||||
|
||||
| 分支 | 说明 |
|
||||
|------|------|
|
||||
| `main` | 我们的主分支,包含所有定制功能 |
|
||||
| `release/custom-X.Y.Z` | 基于官方 `vX.Y.Z` 的发布分支 |
|
||||
| `upstream/main` | 上游官方仓库 |
|
||||
|
||||
---
|
||||
|
||||
## 发布流程(基于新官方版本)
|
||||
|
||||
当官方发布新版本(如 `v0.1.69`)时:
|
||||
|
||||
### 1. 同步上游并创建发布分支
|
||||
|
||||
```bash
|
||||
# 获取上游最新代码
|
||||
git fetch upstream --tags
|
||||
|
||||
# 基于官方标签创建新的发布分支
|
||||
git checkout v0.1.69 -b release/custom-0.1.69
|
||||
|
||||
# 合并我们的 main 分支(包含所有定制功能)
|
||||
git merge main --no-edit
|
||||
|
||||
# 解决可能的冲突后继续
|
||||
```
|
||||
|
||||
### 2. 更新版本号并打标签
|
||||
|
||||
```bash
|
||||
# 更新版本号文件
|
||||
echo "0.1.69.1" > backend/cmd/server/VERSION
|
||||
git add backend/cmd/server/VERSION
|
||||
git commit -m "chore: bump version to 0.1.69.1"
|
||||
|
||||
# 打上我们自己的标签
|
||||
git tag v0.1.69.1
|
||||
|
||||
# 推送分支和标签
|
||||
git push origin release/custom-0.1.69
|
||||
git push origin v0.1.69.1
|
||||
```
|
||||
|
||||
### 3. 更新 main 分支
|
||||
|
||||
```bash
|
||||
# 将发布分支合并回 main,保持 main 包含最新定制功能
|
||||
git checkout main
|
||||
git merge release/custom-0.1.69
|
||||
git push origin main
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 热修复发布(在现有版本上修复)
|
||||
|
||||
当需要在当前版本上发布修复时:
|
||||
|
||||
```bash
|
||||
# 在当前发布分支上修复
|
||||
git checkout release/custom-0.1.68
|
||||
# ... 进行修复 ...
|
||||
git commit -m "fix: 修复描述"
|
||||
|
||||
# 递增小版本号
|
||||
echo "0.1.68.2" > backend/cmd/server/VERSION
|
||||
git add backend/cmd/server/VERSION
|
||||
git commit -m "chore: bump version to 0.1.68.2"
|
||||
|
||||
# 打标签并推送
|
||||
git tag v0.1.68.2
|
||||
git push origin release/custom-0.1.68
|
||||
git push origin v0.1.68.2
|
||||
|
||||
# 同步修复到 main
|
||||
git checkout main
|
||||
git cherry-pick <fix-commit-hash>
|
||||
git push origin main
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 服务器部署流程
|
||||
|
||||
### 前置条件
|
||||
|
||||
- 本地已配置 SSH 别名 `clicodeplus` 连接到服务器
|
||||
- 服务器部署目录:`/root/sub2api`(正式)、`/root/sub2api-beta`(测试)
|
||||
- 服务器使用 Docker Compose 部署
|
||||
|
||||
### 部署环境说明
|
||||
|
||||
| 环境 | 目录 | 端口 | 数据库 | 容器名 |
|
||||
|------|------|------|--------|--------|
|
||||
| 正式 | `/root/sub2api` | 8080 | `sub2api` | `sub2api` |
|
||||
| Beta | `/root/sub2api-beta` | 8084 | `beta` | `sub2api-beta` |
|
||||
|
||||
### 外部数据库
|
||||
|
||||
正式和 Beta 环境**共用外部 PostgreSQL 数据库**(非容器内数据库),配置在 `.env` 文件中:
|
||||
- `DATABASE_HOST`:外部数据库地址
|
||||
- `DATABASE_SSLMODE`:SSL 模式(通常为 `require`)
|
||||
- `POSTGRES_USER` / `POSTGRES_DB`:用户名和数据库名
|
||||
|
||||
#### 数据库操作命令
|
||||
|
||||
通过 SSH 在服务器上执行数据库操作:
|
||||
|
||||
```bash
|
||||
# 正式环境 - 查询迁移记录
|
||||
ssh clicodeplus "source /root/sub2api/deploy/.env && PGPASSWORD=\"\$POSTGRES_PASSWORD\" psql -h \$DATABASE_HOST -U \$POSTGRES_USER -d \$POSTGRES_DB -c 'SELECT * FROM schema_migrations ORDER BY applied_at DESC LIMIT 5;'"
|
||||
|
||||
# Beta 环境 - 查询迁移记录
|
||||
ssh clicodeplus "source /root/sub2api-beta/deploy/.env && PGPASSWORD=\"\$POSTGRES_PASSWORD\" psql -h \$DATABASE_HOST -U \$POSTGRES_USER -d \$POSTGRES_DB -c 'SELECT * FROM schema_migrations ORDER BY applied_at DESC LIMIT 5;'"
|
||||
|
||||
# Beta 环境 - 清除指定迁移记录(重新执行迁移)
|
||||
ssh clicodeplus "source /root/sub2api-beta/deploy/.env && PGPASSWORD=\"\$POSTGRES_PASSWORD\" psql -h \$DATABASE_HOST -U \$POSTGRES_USER -d \$POSTGRES_DB -c \"DELETE FROM schema_migrations WHERE filename LIKE '%049%';\""
|
||||
|
||||
# Beta 环境 - 更新账号数据
|
||||
ssh clicodeplus "source /root/sub2api-beta/deploy/.env && PGPASSWORD=\"\$POSTGRES_PASSWORD\" psql -h \$DATABASE_HOST -U \$POSTGRES_USER -d \$POSTGRES_DB -c \"UPDATE accounts SET credentials = credentials - 'model_mapping' WHERE platform = 'antigravity';\""
|
||||
```
|
||||
|
||||
> **注意**:使用 `source .env` 加载环境变量,避免在命令行中暴露密码。
|
||||
|
||||
### 部署步骤
|
||||
|
||||
**重要:每次部署都必须递增版本号!**
|
||||
|
||||
#### 0. 递增版本号(本地操作)
|
||||
|
||||
每次部署前,先在本地递增小版本号:
|
||||
|
||||
```bash
|
||||
# 查看当前版本号
|
||||
cat backend/cmd/server/VERSION
|
||||
# 假设当前是 0.1.69.1
|
||||
|
||||
# 递增版本号
|
||||
echo "0.1.69.2" > backend/cmd/server/VERSION
|
||||
git add backend/cmd/server/VERSION
|
||||
git commit -m "chore: bump version to 0.1.69.2"
|
||||
git push origin release/custom-0.1.69
|
||||
```
|
||||
|
||||
#### 1. 服务器拉取代码
|
||||
|
||||
```bash
|
||||
ssh clicodeplus "cd /root/sub2api && git fetch fork && git checkout -B release/custom-0.1.69 fork/release/custom-0.1.69"
|
||||
```
|
||||
|
||||
#### 2. 服务器构建镜像
|
||||
|
||||
```bash
|
||||
ssh clicodeplus "cd /root/sub2api && docker build --no-cache -t sub2api:latest -f Dockerfile ."
|
||||
```
|
||||
|
||||
#### 3. 更新镜像标签并重启服务
|
||||
|
||||
```bash
|
||||
ssh clicodeplus "docker tag sub2api:latest weishaw/sub2api:latest"
|
||||
ssh clicodeplus "cd /root/sub2api/deploy && docker compose up -d --force-recreate sub2api"
|
||||
```
|
||||
|
||||
#### 4. 验证部署
|
||||
|
||||
```bash
|
||||
# 查看启动日志
|
||||
ssh clicodeplus "docker logs sub2api --tail 20"
|
||||
|
||||
# 确认版本号(必须与步骤 0 中设置的版本号一致)
|
||||
ssh clicodeplus "cat /root/sub2api/backend/cmd/server/VERSION"
|
||||
|
||||
# 检查容器状态
|
||||
ssh clicodeplus "docker ps | grep sub2api"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Beta 并行部署(不影响现网)
|
||||
|
||||
目标:在同一台服务器上并行启动一个 beta 实例(例如端口 `8084`),**严禁改动/重启**现网实例(默认目录 `/root/sub2api`)。
|
||||
|
||||
### 设计原则
|
||||
|
||||
- **新目录**:beta 使用独立目录,例如 `/root/sub2api-beta`。
|
||||
- **敏感信息只放 `.env`**:beta 的数据库密码、JWT_SECRET 等只写入 `/root/sub2api-beta/deploy/.env`,不要提交到 git。
|
||||
- **独立 Compose Project**:通过 `docker compose -p sub2api-beta ...` 启动,确保 network/volume 隔离。
|
||||
- **独立端口**:通过 `.env` 的 `SERVER_PORT` 映射宿主机端口(例如 `8084:8080`)。
|
||||
|
||||
### 前置检查
|
||||
|
||||
```bash
|
||||
# 1) 确保 8084 未被占用
|
||||
ssh clicodeplus "ss -ltnp | grep :8084 || echo '8084 is free'"
|
||||
|
||||
# 2) 确认现网容器还在(只读检查)
|
||||
ssh clicodeplus "docker ps --format 'table {{.Names}}\t{{.Image}}\t{{.Ports}}' | sed -n '1,200p'"
|
||||
```
|
||||
|
||||
### 首次部署步骤
|
||||
|
||||
```bash
|
||||
# 0) 进入服务器
|
||||
ssh clicodeplus
|
||||
|
||||
# 1) 克隆代码到新目录(示例使用你的 fork)
|
||||
cd /root
|
||||
git clone https://github.com/touwaeriol/sub2api.git sub2api-beta
|
||||
cd /root/sub2api-beta
|
||||
git checkout release/custom-0.1.71
|
||||
|
||||
# 2) 准备 beta 的 .env(敏感信息只写这里)
|
||||
cd /root/sub2api-beta/deploy
|
||||
|
||||
# 推荐:从现网 .env 复制,保证除 DB 名/用户/端口外完全一致
|
||||
cp -f /root/sub2api/deploy/.env ./.env
|
||||
|
||||
# 仅修改以下三项(其他保持不变)
|
||||
perl -pi -e 's/^SERVER_PORT=.*/SERVER_PORT=8084/' ./.env
|
||||
perl -pi -e 's/^POSTGRES_USER=.*/POSTGRES_USER=beta/' ./.env
|
||||
perl -pi -e 's/^POSTGRES_DB=.*/POSTGRES_DB=beta/' ./.env
|
||||
|
||||
# 3) 写 compose override(避免与现网容器名冲突,镜像使用本地构建的 sub2api:beta)
|
||||
cat > docker-compose.override.yml <<'YAML'
|
||||
services:
|
||||
sub2api:
|
||||
image: sub2api:beta
|
||||
container_name: sub2api-beta
|
||||
redis:
|
||||
container_name: sub2api-beta-redis
|
||||
YAML
|
||||
|
||||
# 4) 构建 beta 镜像(基于当前代码)
|
||||
cd /root/sub2api-beta
|
||||
docker build -t sub2api:beta -f Dockerfile .
|
||||
|
||||
# 5) 启动 beta(独立 project,确保不影响现网)
|
||||
cd /root/sub2api-beta/deploy
|
||||
docker compose -p sub2api-beta --env-file .env -f docker-compose.yml -f docker-compose.override.yml up -d
|
||||
|
||||
# 6) 验证 beta
|
||||
curl -fsS http://127.0.0.1:8084/health
|
||||
docker logs sub2api-beta --tail 50
|
||||
```
|
||||
|
||||
### 数据库配置约定(beta)
|
||||
|
||||
- 数据库地址/SSL/密码:与现网一致(从现网 `.env` 复制即可)。
|
||||
- 仅修改:
|
||||
- `POSTGRES_USER=beta`
|
||||
- `POSTGRES_DB=beta`
|
||||
|
||||
注意:需要数据库侧已存在 `beta` 用户与 `beta` 数据库,并授予权限;否则容器会启动失败并不断重启。
|
||||
|
||||
### 更新 beta(拉代码 + 仅重建 beta 容器)
|
||||
|
||||
```bash
|
||||
ssh clicodeplus "set -e; cd /root/sub2api-beta && git fetch --all --tags && git checkout -f release/custom-0.1.71 && git reset --hard origin/release/custom-0.1.71"
|
||||
ssh clicodeplus "cd /root/sub2api-beta && docker build -t sub2api:beta -f Dockerfile ."
|
||||
ssh clicodeplus "cd /root/sub2api-beta/deploy && docker compose -p sub2api-beta --env-file .env -f docker-compose.yml -f docker-compose.override.yml up -d --no-deps --force-recreate sub2api"
|
||||
ssh clicodeplus "curl -fsS http://127.0.0.1:8084/health"
|
||||
```
|
||||
|
||||
### 停止/回滚 beta(只影响 beta)
|
||||
|
||||
```bash
|
||||
ssh clicodeplus "cd /root/sub2api-beta/deploy && docker compose -p sub2api-beta -f docker-compose.yml -f docker-compose.override.yml down"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 服务器首次部署
|
||||
|
||||
### 1. 克隆代码并配置远程仓库
|
||||
|
||||
```bash
|
||||
ssh clicodeplus
|
||||
cd /root
|
||||
git clone https://github.com/Wei-Shaw/sub2api.git
|
||||
cd sub2api
|
||||
|
||||
# 添加 fork 仓库
|
||||
git remote add fork https://github.com/touwaeriol/sub2api.git
|
||||
```
|
||||
|
||||
### 2. 切换到定制分支并配置环境
|
||||
|
||||
```bash
|
||||
git fetch fork
|
||||
git checkout -B release/custom-0.1.69 fork/release/custom-0.1.69
|
||||
|
||||
cd deploy
|
||||
cp .env.example .env
|
||||
vim .env # 配置 DATABASE_URL, REDIS_URL, JWT_SECRET 等
|
||||
```
|
||||
|
||||
### 3. 构建并启动
|
||||
|
||||
```bash
|
||||
cd /root/sub2api
|
||||
docker build -t sub2api:latest -f Dockerfile .
|
||||
docker tag sub2api:latest weishaw/sub2api:latest
|
||||
cd deploy && docker compose up -d
|
||||
```
|
||||
|
||||
### 6. 启动服务
|
||||
|
||||
```bash
|
||||
# 进入 deploy 目录
|
||||
cd deploy
|
||||
|
||||
# 启动所有服务(PostgreSQL、Redis、sub2api)
|
||||
docker compose up -d
|
||||
|
||||
# 查看服务状态
|
||||
docker compose ps
|
||||
```
|
||||
|
||||
### 7. 验证部署
|
||||
|
||||
```bash
|
||||
# 查看应用日志
|
||||
docker logs sub2api --tail 50
|
||||
|
||||
# 检查健康状态
|
||||
curl http://localhost:8080/health
|
||||
|
||||
# 确认版本号
|
||||
cat /root/sub2api/backend/cmd/server/VERSION
|
||||
```
|
||||
|
||||
### 8. 常用运维命令
|
||||
|
||||
```bash
|
||||
# 查看实时日志
|
||||
docker logs -f sub2api
|
||||
|
||||
# 重启服务
|
||||
docker compose restart sub2api
|
||||
|
||||
# 停止所有服务
|
||||
docker compose down
|
||||
|
||||
# 停止并删除数据卷(慎用!会删除数据库数据)
|
||||
docker compose down -v
|
||||
|
||||
# 查看资源使用情况
|
||||
docker stats sub2api
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 定制功能说明
|
||||
|
||||
当前定制分支包含以下功能(相对于官方版本):
|
||||
|
||||
### UI/UX 定制
|
||||
|
||||
| 功能 | 说明 |
|
||||
|------|------|
|
||||
| 首页优化 | 面向用户的价值主张设计 |
|
||||
| 移除 GitHub 链接 | 用户菜单中不显示 GitHub 导航 |
|
||||
| 微信客服按钮 | 首页悬浮微信客服入口 |
|
||||
| 限流时间精确显示 | 账号限流时间显示精确到秒 |
|
||||
|
||||
### Antigravity 平台增强
|
||||
|
||||
| 功能 | 说明 |
|
||||
|------|------|
|
||||
| Scope 级别限流 | 按配额域(claude/gemini_text/gemini_image)独立限流,避免整个账号被锁定 |
|
||||
| 模型级别限流 | 按具体模型(如 claude-opus-4-5)独立限流,更精细的限流控制 |
|
||||
| 限流预检查 | 调度时预检查账号/模型限流状态,避免选中已限流账号 |
|
||||
| 秒级冷却时间 | 支持 429 响应的秒级精确冷却时间 |
|
||||
| 身份注入优化 | 模型身份信息注入 + 静默边界防止身份泄露 |
|
||||
| thoughtSignature 修复 | Gemini 3 函数调用 400 错误修复 |
|
||||
| max_tokens 自动修正 | 自动修正 max_tokens <= budget_tokens 导致的 400 错误 |
|
||||
|
||||
### 调度算法优化
|
||||
|
||||
| 功能 | 说明 |
|
||||
|------|------|
|
||||
| 分层过滤选择 | 调度算法从全排序改为分层过滤,提升性能 |
|
||||
| LRU 随机选择 | 相同 LRU 时间时随机选择,避免账号集中 |
|
||||
| 限流等待阈值配置化 | 可配置的限流等待阈值 |
|
||||
|
||||
### 运维增强
|
||||
|
||||
| 功能 | 说明 |
|
||||
|------|------|
|
||||
| Scope 限流统计 | 运维界面展示 Antigravity 账号 scope 级别限流统计 |
|
||||
| 账号限流状态显示 | 账号列表显示 scope 和模型级别限流状态 |
|
||||
| 清除限流按钮增强 | 有 scope/模型限流时也显示清除限流按钮 |
|
||||
|
||||
### 其他修复
|
||||
|
||||
| 功能 | 说明 |
|
||||
|------|------|
|
||||
| .gitattributes | 确保迁移文件使用 LF 换行符(解决 Windows 下 SQL 摘要不一致) |
|
||||
| 部署配置优化 | DATABASE_HOST 和 DATABASE_SSLMODE 可通过 .env 配置 |
|
||||
|
||||
---
|
||||
|
||||
## 注意事项
|
||||
|
||||
1. **前端必须打包进镜像**:使用 `docker build` 在服务器上构建,Dockerfile 会自动编译前端并 embed 到后端二进制中
|
||||
|
||||
2. **镜像标签**:docker-compose.yml 使用 `weishaw/sub2api:latest`,本地构建后需要 `docker tag` 覆盖
|
||||
|
||||
3. **Windows 换行符问题**:已通过 `.gitattributes` 解决,确保 `*.sql` 文件始终使用 LF
|
||||
|
||||
4. **版本号管理**:每次发布必须更新 `backend/cmd/server/VERSION` 并打标签
|
||||
|
||||
5. **合并冲突**:合并上游新版本时,重点关注以下文件可能的冲突:
|
||||
- `backend/internal/service/antigravity_gateway_service.go`
|
||||
- `backend/internal/service/gateway_service.go`
|
||||
- `backend/internal/pkg/antigravity/request_transformer.go`
|
||||
|
||||
---
|
||||
|
||||
## Go 代码规范
|
||||
|
||||
### 1. 函数设计
|
||||
|
||||
#### 单一职责原则
|
||||
- **函数行数**:单个函数常规不应超过 **30 行**,超过时应拆分为子函数。若某段逻辑确实不可拆分(如复杂的状态机、协议解析等),可以例外,但需添加注释说明原因
|
||||
- **嵌套层级**:避免超过 3 层嵌套,使用 early return 减少嵌套
|
||||
|
||||
```go
|
||||
// ❌ 不推荐:深层嵌套
|
||||
func process(data []Item) {
|
||||
for _, item := range data {
|
||||
if item.Valid {
|
||||
if item.Type == "A" {
|
||||
if item.Status == "active" {
|
||||
// 业务逻辑...
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ✅ 推荐:early return
|
||||
func process(data []Item) {
|
||||
for _, item := range data {
|
||||
if !item.Valid {
|
||||
continue
|
||||
}
|
||||
if item.Type != "A" {
|
||||
continue
|
||||
}
|
||||
if item.Status != "active" {
|
||||
continue
|
||||
}
|
||||
// 业务逻辑...
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
#### 复杂逻辑提取
|
||||
将复杂的条件判断或处理逻辑提取为独立函数:
|
||||
|
||||
```go
|
||||
// ❌ 不推荐:内联复杂逻辑
|
||||
if resp.StatusCode == 429 || resp.StatusCode == 503 {
|
||||
// 80+ 行处理逻辑...
|
||||
}
|
||||
|
||||
// ✅ 推荐:提取为独立函数
|
||||
result := handleRateLimitResponse(resp, params)
|
||||
switch result.action {
|
||||
case actionRetry:
|
||||
continue
|
||||
case actionBreak:
|
||||
return result.resp, nil
|
||||
}
|
||||
```
|
||||
|
||||
### 2. 重复代码消除
|
||||
|
||||
#### 配置获取模式
|
||||
将重复的配置获取逻辑提取为方法:
|
||||
|
||||
```go
|
||||
// ❌ 不推荐:重复代码
|
||||
logBody := s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBody
|
||||
maxBytes := 2048
|
||||
if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes > 0 {
|
||||
maxBytes = s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||
}
|
||||
|
||||
// ✅ 推荐:提取为方法
|
||||
func (s *Service) getLogConfig() (logBody bool, maxBytes int) {
|
||||
maxBytes = 2048
|
||||
if s.settingService == nil || s.settingService.cfg == nil {
|
||||
return false, maxBytes
|
||||
}
|
||||
cfg := s.settingService.cfg.Gateway
|
||||
if cfg.LogUpstreamErrorBodyMaxBytes > 0 {
|
||||
maxBytes = cfg.LogUpstreamErrorBodyMaxBytes
|
||||
}
|
||||
return cfg.LogUpstreamErrorBody, maxBytes
|
||||
}
|
||||
```
|
||||
|
||||
### 3. 常量管理
|
||||
|
||||
#### 避免魔法数字
|
||||
所有硬编码的数值都应定义为常量:
|
||||
|
||||
```go
|
||||
// ❌ 不推荐
|
||||
if retryDelay >= 10*time.Second {
|
||||
resetAt := time.Now().Add(30 * time.Second)
|
||||
}
|
||||
|
||||
// ✅ 推荐
|
||||
const (
|
||||
rateLimitThreshold = 10 * time.Second
|
||||
defaultRateLimitDuration = 30 * time.Second
|
||||
)
|
||||
|
||||
if retryDelay >= rateLimitThreshold {
|
||||
resetAt := time.Now().Add(defaultRateLimitDuration)
|
||||
}
|
||||
```
|
||||
|
||||
#### 注释引用常量名
|
||||
在注释中引用常量名而非硬编码值:
|
||||
|
||||
```go
|
||||
// ❌ 不推荐
|
||||
// < 10s: 等待后重试
|
||||
|
||||
// ✅ 推荐
|
||||
// < rateLimitThreshold: 等待后重试
|
||||
```
|
||||
|
||||
### 4. 错误处理
|
||||
|
||||
#### 使用结构化日志
|
||||
优先使用 `slog` 进行结构化日志记录:
|
||||
|
||||
```go
|
||||
// ❌ 不推荐
|
||||
log.Printf("%s status=%d model_rate_limit_failed model=%s error=%v", prefix, statusCode, modelName, err)
|
||||
|
||||
// ✅ 推荐
|
||||
slog.Error("failed to set model rate limit",
|
||||
"prefix", prefix,
|
||||
"status_code", statusCode,
|
||||
"model", modelName,
|
||||
"error", err,
|
||||
)
|
||||
```
|
||||
|
||||
### 5. 测试规范
|
||||
|
||||
#### Mock 函数签名同步
|
||||
修改函数签名时,必须同步更新所有测试中的 mock 函数:
|
||||
|
||||
```go
|
||||
// 如果修改了 handleError 签名
|
||||
handleError func(..., groupID int64, sessionHash string) *Result
|
||||
|
||||
// 必须同步更新测试中的 mock
|
||||
handleError: func(..., groupID int64, sessionHash string) *Result {
|
||||
return nil
|
||||
},
|
||||
```
|
||||
|
||||
#### 测试构建标签
|
||||
统一使用测试构建标签:
|
||||
|
||||
```go
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
```
|
||||
|
||||
### 6. 时间格式解析
|
||||
|
||||
#### 使用标准库
|
||||
优先使用 `time.ParseDuration`,支持所有 Go duration 格式:
|
||||
|
||||
```go
|
||||
// ❌ 不推荐:手动限制格式
|
||||
if !strings.HasSuffix(delay, "s") || strings.Contains(delay, "m") {
|
||||
continue
|
||||
}
|
||||
|
||||
// ✅ 推荐:使用标准库
|
||||
dur, err := time.ParseDuration(delay) // 支持 "0.5s", "4m50s", "1h30m" 等
|
||||
```
|
||||
|
||||
### 7. 接口设计
|
||||
|
||||
#### 接口隔离原则
|
||||
定义最小化接口,只包含必需的方法:
|
||||
|
||||
```go
|
||||
// ❌ 不推荐:使用过于宽泛的接口
|
||||
type AccountRepository interface {
|
||||
// 20+ 个方法...
|
||||
}
|
||||
|
||||
// ✅ 推荐:定义最小化接口
|
||||
type ModelRateLimiter interface {
|
||||
SetModelRateLimit(ctx context.Context, id int64, modelKey string, resetAt time.Time) error
|
||||
}
|
||||
```
|
||||
|
||||
### 8. 并发安全
|
||||
|
||||
#### 共享数据保护
|
||||
访问可能被并发修改的数据时,确保线程安全:
|
||||
|
||||
```go
|
||||
// 如果 Account.Extra 可能被并发修改
|
||||
// 需要使用互斥锁或原子操作保护读取
|
||||
func (a *Account) GetRateLimitRemainingTime(model string) time.Duration {
|
||||
a.mu.RLock()
|
||||
defer a.mu.RUnlock()
|
||||
// 读取 Extra 字段...
|
||||
}
|
||||
```
|
||||
|
||||
### 9. 命名规范
|
||||
|
||||
#### 一致的命名风格
|
||||
- 常量使用 camelCase:`rateLimitThreshold`
|
||||
- 类型使用 PascalCase:`AntigravityQuotaScope`
|
||||
- 同一概念使用统一命名:`Threshold` 或 `Limit`,不要混用
|
||||
|
||||
```go
|
||||
// ❌ 不推荐:命名不一致
|
||||
antigravitySmartRetryMinWait // 使用 Min
|
||||
antigravityRateLimitThreshold // 使用 Threshold
|
||||
|
||||
// ✅ 推荐:统一风格
|
||||
antigravityMinRetryWait
|
||||
antigravityRateLimitThreshold
|
||||
```
|
||||
|
||||
### 10. 代码审查清单
|
||||
|
||||
在提交代码前,检查以下项目:
|
||||
|
||||
- [ ] 函数是否超过 30 行?(不可拆分的逻辑除外,需注释说明)
|
||||
- [ ] 嵌套是否超过 3 层?
|
||||
- [ ] 是否有重复代码可以提取?
|
||||
- [ ] 是否使用了魔法数字?
|
||||
- [ ] Mock 函数签名是否与实际函数一致?
|
||||
- [ ] 测试是否覆盖了新增逻辑?
|
||||
- [ ] 日志是否包含足够的上下文信息?
|
||||
- [ ] 是否考虑了并发安全?
|
||||
|
||||
---
|
||||
|
||||
## CI 检查与发布门禁
|
||||
|
||||
### GitHub Actions 检查项
|
||||
|
||||
本项目有 4 个 CI 任务,**任何代码推送或发布前都必须全部通过**:
|
||||
|
||||
| Workflow | Job | 说明 | 本地验证命令 |
|
||||
|----------|-----|------|-------------|
|
||||
| CI | `test` | 单元测试 + 集成测试 | `cd backend && make test-unit && make test-integration` |
|
||||
| CI | `golangci-lint` | Go 代码静态检查(golangci-lint v2.7) | `cd backend && golangci-lint run --timeout=5m` |
|
||||
| Security Scan | `backend-security` | govulncheck + gosec 安全扫描 | `cd backend && govulncheck ./... && gosec -severity high -confidence high ./...` |
|
||||
| Security Scan | `frontend-security` | pnpm audit 前端依赖安全检查 | `cd frontend && pnpm audit --prod --audit-level=high` |
|
||||
|
||||
### 向上游提交 PR
|
||||
|
||||
PR 目标是上游官方仓库,**只包含通用功能改动**(bug fix、新功能、性能优化等)。
|
||||
|
||||
**以下文件禁止出现在 PR 中**(属于我们 fork 的定制化内容):
|
||||
- `CLAUDE.md`、`AGENTS.md` — 我们的开发文档
|
||||
- `backend/cmd/server/VERSION` — 我们的版本号文件
|
||||
- UI 定制改动(GitHub 链接移除、微信客服按钮、首页定制等)
|
||||
- 部署配置(`deploy/` 目录下的定制修改)
|
||||
|
||||
**PR 流程**:
|
||||
1. 从 `develop` 创建功能分支,只包含要提交给上游的改动
|
||||
2. 推送分支后,**等待 4 个 CI job 全部通过**
|
||||
3. 确认通过后再创建 PR
|
||||
4. 使用 `gh run list --repo touwaeriol/sub2api --branch <branch>` 检查状态
|
||||
|
||||
### 自有分支推送(develop / main)
|
||||
|
||||
推送到我们自己的 `develop` 或 `main` 分支时,包含所有改动(定制化 + 通用功能)。
|
||||
|
||||
**推送流程**:
|
||||
1. 本地运行 `cd backend && make test-unit` 确保单元测试通过
|
||||
2. 本地运行 `cd backend && gofmt -l ./...` 确保格式正确
|
||||
3. 推送后确认 CI 和 Security Scan 两个 workflow 的 4 个 job 全部绿色 ✅
|
||||
4. 任何 job 失败必须立即修复,**禁止在 CI 未通过的状态下继续后续操作**
|
||||
|
||||
### 发布版本
|
||||
|
||||
1. 确保 `main` 分支最新提交的 4 个 CI job 全部通过
|
||||
2. 递增 `backend/cmd/server/VERSION`,提交并推送
|
||||
3. 打 tag 推送后,确认 tag 触发的 3 个 workflow(CI、Security Scan、Release)全部通过
|
||||
4. **Release workflow 失败时禁止部署** — 必须先修复问题,删除旧 tag,重新打 tag
|
||||
5. 使用 `gh run list --repo touwaeriol/sub2api --limit 10` 确认状态
|
||||
|
||||
### 常见 CI 失败原因及修复
|
||||
- **gofmt**:struct 字段对齐不一致 → 运行 `gofmt -w <file>` 修复
|
||||
- **golangci-lint**:未使用的变量/导入 → 删除或使用 `_` 忽略
|
||||
- **test 失败**:mock 函数签名不一致 → 同步更新 mock
|
||||
- **gosec**:安全漏洞 → 根据提示修复或添加例外
|
||||
@@ -1 +1 @@
|
||||
0.1.74.7
|
||||
0.1.75.7
|
||||
@@ -154,7 +154,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
identityService := service.NewIdentityService(identityCache)
|
||||
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
|
||||
claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService)
|
||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache)
|
||||
digestSessionStore := service.NewDigestSessionStore()
|
||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, digestSessionStore)
|
||||
openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService)
|
||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
|
||||
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
||||
|
||||
@@ -103,6 +103,7 @@ require (
|
||||
github.com/ncruces/go-strftime v1.0.0 // indirect
|
||||
github.com/opencontainers/go-digest v1.0.0 // indirect
|
||||
github.com/opencontainers/image-spec v1.1.1 // indirect
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
||||
|
||||
@@ -207,6 +207,8 @@ github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8
|
||||
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
|
||||
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
|
||||
github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M=
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc=
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ=
|
||||
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
|
||||
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
|
||||
@@ -2,11 +2,6 @@ package dto
|
||||
|
||||
import "time"
|
||||
|
||||
type ScopeRateLimitInfo struct {
|
||||
ResetAt time.Time `json:"reset_at"`
|
||||
RemainingSec int64 `json:"remaining_sec"`
|
||||
}
|
||||
|
||||
type User struct {
|
||||
ID int64 `json:"id"`
|
||||
Email string `json:"email"`
|
||||
@@ -126,9 +121,6 @@ type Account struct {
|
||||
RateLimitResetAt *time.Time `json:"rate_limit_reset_at"`
|
||||
OverloadUntil *time.Time `json:"overload_until"`
|
||||
|
||||
// Antigravity scope 级限流状态(从 extra 提取)
|
||||
ScopeRateLimits map[string]ScopeRateLimitInfo `json:"scope_rate_limits,omitempty"`
|
||||
|
||||
TempUnschedulableUntil *time.Time `json:"temp_unschedulable_until"`
|
||||
TempUnschedulableReason string `json:"temp_unschedulable_reason"`
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/domain"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
@@ -114,7 +115,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
|
||||
setOpsRequestContext(c, "", false, body)
|
||||
|
||||
parsedReq, err := service.ParseGatewayRequest(body)
|
||||
parsedReq, err := service.ParseGatewayRequest(body, domain.PlatformAnthropic)
|
||||
if err != nil {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
||||
return
|
||||
@@ -203,6 +204,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 计算粘性会话hash
|
||||
parsedReq.SessionContext = &service.SessionContext{
|
||||
ClientIP: ip.GetClientIP(c),
|
||||
UserAgent: c.GetHeader("User-Agent"),
|
||||
APIKeyID: apiKey.ID,
|
||||
}
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
|
||||
|
||||
// 获取平台:优先使用强制平台(/antigravity 路由,中间件已设置 request.Context),否则使用分组平台
|
||||
@@ -335,7 +341,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
if errors.As(err, &failoverErr) {
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
lastFailoverErr = failoverErr
|
||||
if failoverErr.ForceCacheBilling {
|
||||
if needForceCacheBilling(hasBoundSession, failoverErr) {
|
||||
forceCacheBilling = true
|
||||
}
|
||||
if switchCount >= maxAccountSwitches {
|
||||
@@ -344,6 +350,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
switchCount++
|
||||
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
|
||||
if account.Platform == service.PlatformAntigravity {
|
||||
if !sleepFailoverDelay(c.Request.Context(), switchCount) {
|
||||
return
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
// 错误响应已在Forward中处理,这里只记录日志
|
||||
@@ -530,7 +541,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
if errors.As(err, &failoverErr) {
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
lastFailoverErr = failoverErr
|
||||
if failoverErr.ForceCacheBilling {
|
||||
if needForceCacheBilling(hasBoundSession, failoverErr) {
|
||||
forceCacheBilling = true
|
||||
}
|
||||
if switchCount >= maxAccountSwitches {
|
||||
@@ -539,6 +550,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
switchCount++
|
||||
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
|
||||
if account.Platform == service.PlatformAntigravity {
|
||||
if !sleepFailoverDelay(c.Request.Context(), switchCount) {
|
||||
return
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
// 错误响应已在Forward中处理,这里只记录日志
|
||||
@@ -801,6 +817,27 @@ func (h *GatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotT
|
||||
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
|
||||
}
|
||||
|
||||
// needForceCacheBilling 判断 failover 时是否需要强制缓存计费
|
||||
// 粘性会话切换账号、或上游明确标记时,将 input_tokens 转为 cache_read 计费
|
||||
func needForceCacheBilling(hasBoundSession bool, failoverErr *service.UpstreamFailoverError) bool {
|
||||
return hasBoundSession || (failoverErr != nil && failoverErr.ForceCacheBilling)
|
||||
}
|
||||
|
||||
// sleepFailoverDelay 账号切换线性递增延时:第1次0s、第2次1s、第3次2s…
|
||||
// 返回 false 表示 context 已取消。
|
||||
func sleepFailoverDelay(ctx context.Context, switchCount int) bool {
|
||||
delay := time.Duration(switchCount-1) * time.Second
|
||||
if delay <= 0 {
|
||||
return true
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
case <-time.After(delay):
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, platform string, streamStarted bool) {
|
||||
statusCode := failoverErr.StatusCode
|
||||
responseBody := failoverErr.ResponseBody
|
||||
@@ -934,7 +971,7 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
||||
|
||||
setOpsRequestContext(c, "", false, body)
|
||||
|
||||
parsedReq, err := service.ParseGatewayRequest(body)
|
||||
parsedReq, err := service.ParseGatewayRequest(body, domain.PlatformAnthropic)
|
||||
if err != nil {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
||||
return
|
||||
@@ -962,6 +999,11 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 计算粘性会话 hash
|
||||
parsedReq.SessionContext = &service.SessionContext{
|
||||
ClientIP: ip.GetClientIP(c),
|
||||
UserAgent: c.GetHeader("User-Agent"),
|
||||
APIKeyID: apiKey.ID,
|
||||
}
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
|
||||
|
||||
// 选择支持该模型的账号
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/domain"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/gemini"
|
||||
@@ -30,13 +31,6 @@ import (
|
||||
// 匹配格式: /Users/xxx/.gemini/tmp/[64位十六进制哈希]
|
||||
var geminiCLITmpDirRegex = regexp.MustCompile(`/\.gemini/tmp/([A-Fa-f0-9]{64})`)
|
||||
|
||||
func isGeminiCLIRequest(c *gin.Context, body []byte) bool {
|
||||
if strings.TrimSpace(c.GetHeader("x-gemini-api-privileged-user-id")) != "" {
|
||||
return true
|
||||
}
|
||||
return geminiCLITmpDirRegex.Match(body)
|
||||
}
|
||||
|
||||
// GeminiV1BetaListModels proxies:
|
||||
// GET /v1beta/models
|
||||
func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) {
|
||||
@@ -239,7 +233,14 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
sessionHash := extractGeminiCLISessionHash(c, body)
|
||||
if sessionHash == "" {
|
||||
// Fallback: 使用通用的会话哈希生成逻辑(适用于其他客户端)
|
||||
parsedReq, _ := service.ParseGatewayRequest(body)
|
||||
parsedReq, _ := service.ParseGatewayRequest(body, domain.PlatformGemini)
|
||||
if parsedReq != nil {
|
||||
parsedReq.SessionContext = &service.SessionContext{
|
||||
ClientIP: ip.GetClientIP(c),
|
||||
UserAgent: c.GetHeader("User-Agent"),
|
||||
APIKeyID: apiKey.ID,
|
||||
}
|
||||
}
|
||||
sessionHash = h.gatewayService.GenerateSessionHash(parsedReq)
|
||||
}
|
||||
sessionKey := sessionHash
|
||||
@@ -258,6 +259,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
var geminiDigestChain string
|
||||
var geminiPrefixHash string
|
||||
var geminiSessionUUID string
|
||||
var matchedDigestChain string
|
||||
useDigestFallback := sessionBoundAccountID == 0
|
||||
|
||||
if useDigestFallback {
|
||||
@@ -284,13 +286,14 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
)
|
||||
|
||||
// 查找会话
|
||||
foundUUID, foundAccountID, found := h.gatewayService.FindGeminiSession(
|
||||
foundUUID, foundAccountID, foundMatchedChain, found := h.gatewayService.FindGeminiSession(
|
||||
c.Request.Context(),
|
||||
derefGroupID(apiKey.GroupID),
|
||||
geminiPrefixHash,
|
||||
geminiDigestChain,
|
||||
)
|
||||
if found {
|
||||
matchedDigestChain = foundMatchedChain
|
||||
sessionBoundAccountID = foundAccountID
|
||||
geminiSessionUUID = foundUUID
|
||||
log.Printf("[Gemini] Digest fallback matched: uuid=%s, accountID=%d, chain=%s",
|
||||
@@ -316,7 +319,6 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
|
||||
// 判断是否真的绑定了粘性会话:有 sessionKey 且已经绑定到某个账号
|
||||
hasBoundSession := sessionKey != "" && sessionBoundAccountID > 0
|
||||
isCLI := isGeminiCLIRequest(c, body)
|
||||
cleanedForUnknownBinding := false
|
||||
|
||||
maxAccountSwitches := h.maxAccountSwitchesGemini
|
||||
@@ -344,10 +346,10 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
log.Printf("[Gemini] Sticky session account switched: %d -> %d, cleaning thoughtSignature", sessionBoundAccountID, account.ID)
|
||||
body = service.CleanGeminiNativeThoughtSignatures(body)
|
||||
sessionBoundAccountID = account.ID
|
||||
} else if sessionKey != "" && sessionBoundAccountID == 0 && isCLI && !cleanedForUnknownBinding && bytes.Contains(body, []byte(`"thoughtSignature"`)) {
|
||||
// 无缓存绑定但请求里已有 thoughtSignature:常见于缓存丢失/TTL 过期后,CLI 继续携带旧签名。
|
||||
} else if sessionKey != "" && sessionBoundAccountID == 0 && !cleanedForUnknownBinding && bytes.Contains(body, []byte(`"thoughtSignature"`)) {
|
||||
// 无缓存绑定但请求里已有 thoughtSignature:常见于缓存丢失/TTL 过期后,客户端继续携带旧签名。
|
||||
// 为避免第一次转发就 400,这里做一次确定性清理,让新账号重新生成签名链路。
|
||||
log.Printf("[Gemini] Sticky session binding missing for CLI request, cleaning thoughtSignature proactively")
|
||||
log.Printf("[Gemini] Sticky session binding missing, cleaning thoughtSignature proactively")
|
||||
body = service.CleanGeminiNativeThoughtSignatures(body)
|
||||
cleanedForUnknownBinding = true
|
||||
sessionBoundAccountID = account.ID
|
||||
@@ -422,7 +424,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
if failoverErr.ForceCacheBilling {
|
||||
if needForceCacheBilling(hasBoundSession, failoverErr) {
|
||||
forceCacheBilling = true
|
||||
}
|
||||
if switchCount >= maxAccountSwitches {
|
||||
@@ -433,6 +435,11 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
lastFailoverErr = failoverErr
|
||||
switchCount++
|
||||
log.Printf("Gemini account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
|
||||
if account.Platform == service.PlatformAntigravity {
|
||||
if !sleepFailoverDelay(c.Request.Context(), switchCount) {
|
||||
return
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
// ForwardNative already wrote the response
|
||||
@@ -453,6 +460,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
geminiDigestChain,
|
||||
geminiSessionUUID,
|
||||
account.ID,
|
||||
matchedDigestChain,
|
||||
); err != nil {
|
||||
log.Printf("[Gemini] Failed to save digest session: %v", err)
|
||||
}
|
||||
|
||||
@@ -798,53 +798,6 @@ func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetA
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *accountRepository) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope service.AntigravityQuotaScope, resetAt time.Time) error {
|
||||
now := time.Now().UTC()
|
||||
payload := map[string]string{
|
||||
"rate_limited_at": now.Format(time.RFC3339),
|
||||
"rate_limit_reset_at": resetAt.UTC().Format(time.RFC3339),
|
||||
}
|
||||
raw, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
scopeKey := string(scope)
|
||||
client := clientFromContext(ctx, r.client)
|
||||
result, err := client.ExecContext(
|
||||
ctx,
|
||||
`UPDATE accounts SET
|
||||
extra = jsonb_set(
|
||||
jsonb_set(COALESCE(extra, '{}'::jsonb), '{antigravity_quota_scopes}'::text[], COALESCE(extra->'antigravity_quota_scopes', '{}'::jsonb), true),
|
||||
ARRAY['antigravity_quota_scopes', $1]::text[],
|
||||
$2::jsonb,
|
||||
true
|
||||
),
|
||||
updated_at = NOW(),
|
||||
last_used_at = NOW()
|
||||
WHERE id = $3 AND deleted_at IS NULL`,
|
||||
scopeKey,
|
||||
raw,
|
||||
id,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
affected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if affected == 0 {
|
||||
return service.ErrAccountNotFound
|
||||
}
|
||||
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue quota scope failed: account=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *accountRepository) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
|
||||
if scope == "" {
|
||||
return nil
|
||||
|
||||
@@ -11,63 +11,6 @@ import (
|
||||
|
||||
const stickySessionPrefix = "sticky_session:"
|
||||
|
||||
// Gemini Trie Lua 脚本
|
||||
const (
|
||||
// geminiTrieFindScript 查找最长前缀匹配的 Lua 脚本
|
||||
// KEYS[1] = trie key
|
||||
// ARGV[1] = digestChain (如 "u:a-m:b-u:c-m:d")
|
||||
// ARGV[2] = TTL seconds (用于刷新)
|
||||
// 返回: 最长匹配的 value (uuid:accountID) 或 nil
|
||||
// 查找成功时自动刷新 TTL,防止活跃会话意外过期
|
||||
geminiTrieFindScript = `
|
||||
local chain = ARGV[1]
|
||||
local ttl = tonumber(ARGV[2])
|
||||
local lastMatch = nil
|
||||
local path = ""
|
||||
|
||||
for part in string.gmatch(chain, "[^-]+") do
|
||||
path = path == "" and part or path .. "-" .. part
|
||||
local val = redis.call('HGET', KEYS[1], path)
|
||||
if val and val ~= "" then
|
||||
lastMatch = val
|
||||
end
|
||||
end
|
||||
|
||||
if lastMatch then
|
||||
redis.call('EXPIRE', KEYS[1], ttl)
|
||||
end
|
||||
|
||||
return lastMatch
|
||||
`
|
||||
|
||||
// geminiTrieSaveScript 保存会话到 Trie 的 Lua 脚本
|
||||
// KEYS[1] = trie key
|
||||
// ARGV[1] = digestChain
|
||||
// ARGV[2] = value (uuid:accountID)
|
||||
// ARGV[3] = TTL seconds
|
||||
geminiTrieSaveScript = `
|
||||
local chain = ARGV[1]
|
||||
local value = ARGV[2]
|
||||
local ttl = tonumber(ARGV[3])
|
||||
local path = ""
|
||||
|
||||
for part in string.gmatch(chain, "[^-]+") do
|
||||
path = path == "" and part or path .. "-" .. part
|
||||
end
|
||||
redis.call('HSET', KEYS[1], path, value)
|
||||
redis.call('EXPIRE', KEYS[1], ttl)
|
||||
return "OK"
|
||||
`
|
||||
)
|
||||
|
||||
// 模型负载统计相关常量
|
||||
const (
|
||||
modelLoadKeyPrefix = "ag:model_load:" // 模型调用次数 key 前缀
|
||||
modelLastUsedKeyPrefix = "ag:model_last_used:" // 模型最后调度时间 key 前缀
|
||||
modelLoadTTL = 24 * time.Hour // 调用次数 TTL(24 小时无调用后清零)
|
||||
modelLastUsedTTL = 24 * time.Hour // 最后调度时间 TTL
|
||||
)
|
||||
|
||||
type gatewayCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
@@ -108,171 +51,3 @@ func (c *gatewayCache) DeleteSessionAccountID(ctx context.Context, groupID int64
|
||||
key := buildSessionKey(groupID, sessionHash)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
// ============ Antigravity 模型负载统计方法 ============
|
||||
|
||||
// modelLoadKey 构建模型调用次数 key
|
||||
// 格式: ag:model_load:{accountID}:{model}
|
||||
func modelLoadKey(accountID int64, model string) string {
|
||||
return fmt.Sprintf("%s%d:%s", modelLoadKeyPrefix, accountID, model)
|
||||
}
|
||||
|
||||
// modelLastUsedKey 构建模型最后调度时间 key
|
||||
// 格式: ag:model_last_used:{accountID}:{model}
|
||||
func modelLastUsedKey(accountID int64, model string) string {
|
||||
return fmt.Sprintf("%s%d:%s", modelLastUsedKeyPrefix, accountID, model)
|
||||
}
|
||||
|
||||
// IncrModelCallCount 增加模型调用次数并更新最后调度时间
|
||||
// 返回更新后的调用次数
|
||||
func (c *gatewayCache) IncrModelCallCount(ctx context.Context, accountID int64, model string) (int64, error) {
|
||||
loadKey := modelLoadKey(accountID, model)
|
||||
lastUsedKey := modelLastUsedKey(accountID, model)
|
||||
|
||||
pipe := c.rdb.Pipeline()
|
||||
incrCmd := pipe.Incr(ctx, loadKey)
|
||||
pipe.Expire(ctx, loadKey, modelLoadTTL) // 每次调用刷新 TTL
|
||||
pipe.Set(ctx, lastUsedKey, time.Now().Unix(), modelLastUsedTTL)
|
||||
if _, err := pipe.Exec(ctx); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return incrCmd.Val(), nil
|
||||
}
|
||||
|
||||
// GetModelLoadBatch 批量获取账号的模型负载信息
|
||||
func (c *gatewayCache) GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*service.ModelLoadInfo, error) {
|
||||
if len(accountIDs) == 0 {
|
||||
return make(map[int64]*service.ModelLoadInfo), nil
|
||||
}
|
||||
|
||||
loadCmds, lastUsedCmds := c.pipelineModelLoadGet(ctx, accountIDs, model)
|
||||
return c.parseModelLoadResults(accountIDs, loadCmds, lastUsedCmds), nil
|
||||
}
|
||||
|
||||
// pipelineModelLoadGet 批量获取模型负载的 Pipeline 操作
|
||||
func (c *gatewayCache) pipelineModelLoadGet(
|
||||
ctx context.Context,
|
||||
accountIDs []int64,
|
||||
model string,
|
||||
) (map[int64]*redis.StringCmd, map[int64]*redis.StringCmd) {
|
||||
pipe := c.rdb.Pipeline()
|
||||
loadCmds := make(map[int64]*redis.StringCmd, len(accountIDs))
|
||||
lastUsedCmds := make(map[int64]*redis.StringCmd, len(accountIDs))
|
||||
|
||||
for _, id := range accountIDs {
|
||||
loadCmds[id] = pipe.Get(ctx, modelLoadKey(id, model))
|
||||
lastUsedCmds[id] = pipe.Get(ctx, modelLastUsedKey(id, model))
|
||||
}
|
||||
_, _ = pipe.Exec(ctx) // 忽略错误,key 不存在是正常的
|
||||
return loadCmds, lastUsedCmds
|
||||
}
|
||||
|
||||
// parseModelLoadResults 解析 Pipeline 结果
|
||||
func (c *gatewayCache) parseModelLoadResults(
|
||||
accountIDs []int64,
|
||||
loadCmds map[int64]*redis.StringCmd,
|
||||
lastUsedCmds map[int64]*redis.StringCmd,
|
||||
) map[int64]*service.ModelLoadInfo {
|
||||
result := make(map[int64]*service.ModelLoadInfo, len(accountIDs))
|
||||
for _, id := range accountIDs {
|
||||
result[id] = &service.ModelLoadInfo{
|
||||
CallCount: getInt64OrZero(loadCmds[id]),
|
||||
LastUsedAt: getTimeOrZero(lastUsedCmds[id]),
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// getInt64OrZero 从 StringCmd 获取 int64 值,失败返回 0
|
||||
func getInt64OrZero(cmd *redis.StringCmd) int64 {
|
||||
val, _ := cmd.Int64()
|
||||
return val
|
||||
}
|
||||
|
||||
// getTimeOrZero 从 StringCmd 获取 time.Time,失败返回零值
|
||||
func getTimeOrZero(cmd *redis.StringCmd) time.Time {
|
||||
val, err := cmd.Int64()
|
||||
if err != nil {
|
||||
return time.Time{}
|
||||
}
|
||||
return time.Unix(val, 0)
|
||||
}
|
||||
|
||||
// ============ Gemini 会话 Fallback 方法 (Trie 实现) ============
|
||||
|
||||
// FindGeminiSession 查找 Gemini 会话(使用 Trie + Lua 脚本实现 O(L) 查询)
|
||||
// 返回最长匹配的会话信息,匹配成功时自动刷新 TTL
|
||||
func (c *gatewayCache) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
|
||||
if digestChain == "" {
|
||||
return "", 0, false
|
||||
}
|
||||
|
||||
trieKey := service.BuildGeminiTrieKey(groupID, prefixHash)
|
||||
ttlSeconds := int(service.GeminiSessionTTL().Seconds())
|
||||
|
||||
// 使用 Lua 脚本在 Redis 端执行 Trie 查找,O(L) 次 HGET,1 次网络往返
|
||||
// 查找成功时自动刷新 TTL,防止活跃会话意外过期
|
||||
result, err := c.rdb.Eval(ctx, geminiTrieFindScript, []string{trieKey}, digestChain, ttlSeconds).Result()
|
||||
if err != nil || result == nil {
|
||||
return "", 0, false
|
||||
}
|
||||
|
||||
value, ok := result.(string)
|
||||
if !ok || value == "" {
|
||||
return "", 0, false
|
||||
}
|
||||
|
||||
uuid, accountID, ok = service.ParseGeminiSessionValue(value)
|
||||
return uuid, accountID, ok
|
||||
}
|
||||
|
||||
// SaveGeminiSession 保存 Gemini 会话(使用 Trie + Lua 脚本)
|
||||
func (c *gatewayCache) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
|
||||
if digestChain == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
trieKey := service.BuildGeminiTrieKey(groupID, prefixHash)
|
||||
value := service.FormatGeminiSessionValue(uuid, accountID)
|
||||
ttlSeconds := int(service.GeminiSessionTTL().Seconds())
|
||||
|
||||
return c.rdb.Eval(ctx, geminiTrieSaveScript, []string{trieKey}, digestChain, value, ttlSeconds).Err()
|
||||
}
|
||||
|
||||
// ============ Anthropic 会话 Fallback 方法 (复用 Trie 实现) ============
|
||||
|
||||
// FindAnthropicSession 查找 Anthropic 会话(复用 Gemini Trie Lua 脚本)
|
||||
func (c *gatewayCache) FindAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
|
||||
if digestChain == "" {
|
||||
return "", 0, false
|
||||
}
|
||||
|
||||
trieKey := service.BuildAnthropicTrieKey(groupID, prefixHash)
|
||||
ttlSeconds := int(service.AnthropicSessionTTL().Seconds())
|
||||
|
||||
result, err := c.rdb.Eval(ctx, geminiTrieFindScript, []string{trieKey}, digestChain, ttlSeconds).Result()
|
||||
if err != nil || result == nil {
|
||||
return "", 0, false
|
||||
}
|
||||
|
||||
value, ok := result.(string)
|
||||
if !ok || value == "" {
|
||||
return "", 0, false
|
||||
}
|
||||
|
||||
uuid, accountID, ok = service.ParseGeminiSessionValue(value)
|
||||
return uuid, accountID, ok
|
||||
}
|
||||
|
||||
// SaveAnthropicSession 保存 Anthropic 会话(复用 Gemini Trie Lua 脚本)
|
||||
func (c *gatewayCache) SaveAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
|
||||
if digestChain == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
trieKey := service.BuildAnthropicTrieKey(groupID, prefixHash)
|
||||
value := service.FormatGeminiSessionValue(uuid, accountID)
|
||||
ttlSeconds := int(service.AnthropicSessionTTL().Seconds())
|
||||
|
||||
return c.rdb.Eval(ctx, geminiTrieSaveScript, []string{trieKey}, digestChain, value, ttlSeconds).Err()
|
||||
}
|
||||
|
||||
@@ -104,157 +104,6 @@ func (s *GatewayCacheSuite) TestGetSessionAccountID_CorruptedValue() {
|
||||
require.False(s.T(), errors.Is(err, redis.Nil), "expected parsing error, not redis.Nil")
|
||||
}
|
||||
|
||||
// ============ Gemini Trie 会话测试 ============
|
||||
|
||||
func (s *GatewayCacheSuite) TestGeminiSessionTrie_SaveAndFind() {
|
||||
groupID := int64(1)
|
||||
prefixHash := "testprefix"
|
||||
digestChain := "u:hash1-m:hash2-u:hash3"
|
||||
uuid := "test-uuid-123"
|
||||
accountID := int64(42)
|
||||
|
||||
// 保存会话
|
||||
err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, digestChain, uuid, accountID)
|
||||
require.NoError(s.T(), err, "SaveGeminiSession")
|
||||
|
||||
// 精确匹配查找
|
||||
foundUUID, foundAccountID, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, digestChain)
|
||||
require.True(s.T(), found, "should find exact match")
|
||||
require.Equal(s.T(), uuid, foundUUID)
|
||||
require.Equal(s.T(), accountID, foundAccountID)
|
||||
}
|
||||
|
||||
func (s *GatewayCacheSuite) TestGeminiSessionTrie_PrefixMatch() {
|
||||
groupID := int64(1)
|
||||
prefixHash := "prefixmatch"
|
||||
shortChain := "u:a-m:b"
|
||||
longChain := "u:a-m:b-u:c-m:d"
|
||||
uuid := "uuid-prefix"
|
||||
accountID := int64(100)
|
||||
|
||||
// 保存短链
|
||||
err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, shortChain, uuid, accountID)
|
||||
require.NoError(s.T(), err)
|
||||
|
||||
// 用长链查找,应该匹配到短链(前缀匹配)
|
||||
foundUUID, foundAccountID, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, longChain)
|
||||
require.True(s.T(), found, "should find prefix match")
|
||||
require.Equal(s.T(), uuid, foundUUID)
|
||||
require.Equal(s.T(), accountID, foundAccountID)
|
||||
}
|
||||
|
||||
func (s *GatewayCacheSuite) TestGeminiSessionTrie_LongestPrefixMatch() {
|
||||
groupID := int64(1)
|
||||
prefixHash := "longestmatch"
|
||||
|
||||
// 保存多个不同长度的链
|
||||
err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, "u:a", "uuid-short", 1)
|
||||
require.NoError(s.T(), err)
|
||||
err = s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, "u:a-m:b", "uuid-medium", 2)
|
||||
require.NoError(s.T(), err)
|
||||
err = s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, "u:a-m:b-u:c", "uuid-long", 3)
|
||||
require.NoError(s.T(), err)
|
||||
|
||||
// 查找更长的链,应该匹配到最长的前缀
|
||||
foundUUID, foundAccountID, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, "u:a-m:b-u:c-m:d-u:e")
|
||||
require.True(s.T(), found, "should find longest prefix match")
|
||||
require.Equal(s.T(), "uuid-long", foundUUID)
|
||||
require.Equal(s.T(), int64(3), foundAccountID)
|
||||
|
||||
// 查找中等长度的链
|
||||
foundUUID, foundAccountID, found = s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, "u:a-m:b-u:x")
|
||||
require.True(s.T(), found)
|
||||
require.Equal(s.T(), "uuid-medium", foundUUID)
|
||||
require.Equal(s.T(), int64(2), foundAccountID)
|
||||
}
|
||||
|
||||
func (s *GatewayCacheSuite) TestGeminiSessionTrie_NoMatch() {
|
||||
groupID := int64(1)
|
||||
prefixHash := "nomatch"
|
||||
digestChain := "u:a-m:b"
|
||||
|
||||
// 保存一个会话
|
||||
err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, digestChain, "uuid", 1)
|
||||
require.NoError(s.T(), err)
|
||||
|
||||
// 用不同的链查找,应该找不到
|
||||
_, _, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, "u:x-m:y")
|
||||
require.False(s.T(), found, "should not find non-matching chain")
|
||||
}
|
||||
|
||||
func (s *GatewayCacheSuite) TestGeminiSessionTrie_DifferentPrefixHash() {
|
||||
groupID := int64(1)
|
||||
digestChain := "u:a-m:b"
|
||||
|
||||
// 保存到 prefixHash1
|
||||
err := s.cache.SaveGeminiSession(s.ctx, groupID, "prefix1", digestChain, "uuid1", 1)
|
||||
require.NoError(s.T(), err)
|
||||
|
||||
// 用 prefixHash2 查找,应该找不到(不同用户/客户端隔离)
|
||||
_, _, found := s.cache.FindGeminiSession(s.ctx, groupID, "prefix2", digestChain)
|
||||
require.False(s.T(), found, "different prefixHash should be isolated")
|
||||
}
|
||||
|
||||
func (s *GatewayCacheSuite) TestGeminiSessionTrie_DifferentGroupID() {
|
||||
prefixHash := "sameprefix"
|
||||
digestChain := "u:a-m:b"
|
||||
|
||||
// 保存到 groupID 1
|
||||
err := s.cache.SaveGeminiSession(s.ctx, 1, prefixHash, digestChain, "uuid1", 1)
|
||||
require.NoError(s.T(), err)
|
||||
|
||||
// 用 groupID 2 查找,应该找不到(分组隔离)
|
||||
_, _, found := s.cache.FindGeminiSession(s.ctx, 2, prefixHash, digestChain)
|
||||
require.False(s.T(), found, "different groupID should be isolated")
|
||||
}
|
||||
|
||||
func (s *GatewayCacheSuite) TestGeminiSessionTrie_EmptyDigestChain() {
|
||||
groupID := int64(1)
|
||||
prefixHash := "emptytest"
|
||||
|
||||
// 空链不应该保存
|
||||
err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, "", "uuid", 1)
|
||||
require.NoError(s.T(), err, "empty chain should not error")
|
||||
|
||||
// 空链查找应该返回 false
|
||||
_, _, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, "")
|
||||
require.False(s.T(), found, "empty chain should not match")
|
||||
}
|
||||
|
||||
func (s *GatewayCacheSuite) TestGeminiSessionTrie_MultipleSessions() {
|
||||
groupID := int64(1)
|
||||
prefixHash := "multisession"
|
||||
|
||||
// 保存多个不同会话(模拟 1000 个并发会话的场景)
|
||||
sessions := []struct {
|
||||
chain string
|
||||
uuid string
|
||||
accountID int64
|
||||
}{
|
||||
{"u:session1", "uuid-1", 1},
|
||||
{"u:session2-m:reply2", "uuid-2", 2},
|
||||
{"u:session3-m:reply3-u:msg3", "uuid-3", 3},
|
||||
}
|
||||
|
||||
for _, sess := range sessions {
|
||||
err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, sess.chain, sess.uuid, sess.accountID)
|
||||
require.NoError(s.T(), err)
|
||||
}
|
||||
|
||||
// 验证每个会话都能正确查找
|
||||
for _, sess := range sessions {
|
||||
foundUUID, foundAccountID, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, sess.chain)
|
||||
require.True(s.T(), found, "should find session: %s", sess.chain)
|
||||
require.Equal(s.T(), sess.uuid, foundUUID)
|
||||
require.Equal(s.T(), sess.accountID, foundAccountID)
|
||||
}
|
||||
|
||||
// 验证继续对话的场景
|
||||
foundUUID, foundAccountID, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, "u:session2-m:reply2-u:newmsg")
|
||||
require.True(s.T(), found)
|
||||
require.Equal(s.T(), "uuid-2", foundUUID)
|
||||
require.Equal(s.T(), int64(2), foundAccountID)
|
||||
}
|
||||
|
||||
func TestGatewayCacheSuite(t *testing.T) {
|
||||
suite.Run(t, new(GatewayCacheSuite))
|
||||
|
||||
@@ -1,234 +0,0 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
// ============ Gateway Cache 模型负载统计集成测试 ============
|
||||
|
||||
type GatewayCacheModelLoadSuite struct {
|
||||
suite.Suite
|
||||
}
|
||||
|
||||
func TestGatewayCacheModelLoadSuite(t *testing.T) {
|
||||
suite.Run(t, new(GatewayCacheModelLoadSuite))
|
||||
}
|
||||
|
||||
func (s *GatewayCacheModelLoadSuite) TestIncrModelCallCount_Basic() {
|
||||
t := s.T()
|
||||
rdb := testRedis(t)
|
||||
cache := &gatewayCache{rdb: rdb}
|
||||
ctx := context.Background()
|
||||
|
||||
accountID := int64(123)
|
||||
model := "claude-sonnet-4-20250514"
|
||||
|
||||
// 首次调用应返回 1
|
||||
count1, err := cache.IncrModelCallCount(ctx, accountID, model)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1), count1)
|
||||
|
||||
// 第二次调用应返回 2
|
||||
count2, err := cache.IncrModelCallCount(ctx, accountID, model)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(2), count2)
|
||||
|
||||
// 第三次调用应返回 3
|
||||
count3, err := cache.IncrModelCallCount(ctx, accountID, model)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(3), count3)
|
||||
}
|
||||
|
||||
func (s *GatewayCacheModelLoadSuite) TestIncrModelCallCount_DifferentModels() {
|
||||
t := s.T()
|
||||
rdb := testRedis(t)
|
||||
cache := &gatewayCache{rdb: rdb}
|
||||
ctx := context.Background()
|
||||
|
||||
accountID := int64(456)
|
||||
model1 := "claude-sonnet-4-20250514"
|
||||
model2 := "claude-opus-4-5-20251101"
|
||||
|
||||
// 不同模型应该独立计数
|
||||
count1, err := cache.IncrModelCallCount(ctx, accountID, model1)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1), count1)
|
||||
|
||||
count2, err := cache.IncrModelCallCount(ctx, accountID, model2)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1), count2)
|
||||
|
||||
count1Again, err := cache.IncrModelCallCount(ctx, accountID, model1)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(2), count1Again)
|
||||
}
|
||||
|
||||
func (s *GatewayCacheModelLoadSuite) TestIncrModelCallCount_DifferentAccounts() {
|
||||
t := s.T()
|
||||
rdb := testRedis(t)
|
||||
cache := &gatewayCache{rdb: rdb}
|
||||
ctx := context.Background()
|
||||
|
||||
account1 := int64(111)
|
||||
account2 := int64(222)
|
||||
model := "gemini-2.5-pro"
|
||||
|
||||
// 不同账号应该独立计数
|
||||
count1, err := cache.IncrModelCallCount(ctx, account1, model)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1), count1)
|
||||
|
||||
count2, err := cache.IncrModelCallCount(ctx, account2, model)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1), count2)
|
||||
}
|
||||
|
||||
func (s *GatewayCacheModelLoadSuite) TestGetModelLoadBatch_Empty() {
|
||||
t := s.T()
|
||||
rdb := testRedis(t)
|
||||
cache := &gatewayCache{rdb: rdb}
|
||||
ctx := context.Background()
|
||||
|
||||
result, err := cache.GetModelLoadBatch(ctx, []int64{}, "any-model")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Empty(t, result)
|
||||
}
|
||||
|
||||
func (s *GatewayCacheModelLoadSuite) TestGetModelLoadBatch_NonExistent() {
|
||||
t := s.T()
|
||||
rdb := testRedis(t)
|
||||
cache := &gatewayCache{rdb: rdb}
|
||||
ctx := context.Background()
|
||||
|
||||
// 查询不存在的账号应返回零值
|
||||
result, err := cache.GetModelLoadBatch(ctx, []int64{9999, 9998}, "claude-sonnet-4-20250514")
|
||||
require.NoError(t, err)
|
||||
require.Len(t, result, 2)
|
||||
|
||||
require.Equal(t, int64(0), result[9999].CallCount)
|
||||
require.True(t, result[9999].LastUsedAt.IsZero())
|
||||
require.Equal(t, int64(0), result[9998].CallCount)
|
||||
require.True(t, result[9998].LastUsedAt.IsZero())
|
||||
}
|
||||
|
||||
func (s *GatewayCacheModelLoadSuite) TestGetModelLoadBatch_AfterIncrement() {
|
||||
t := s.T()
|
||||
rdb := testRedis(t)
|
||||
cache := &gatewayCache{rdb: rdb}
|
||||
ctx := context.Background()
|
||||
|
||||
accountID := int64(789)
|
||||
model := "claude-sonnet-4-20250514"
|
||||
|
||||
// 先增加调用次数
|
||||
beforeIncr := time.Now()
|
||||
_, err := cache.IncrModelCallCount(ctx, accountID, model)
|
||||
require.NoError(t, err)
|
||||
_, err = cache.IncrModelCallCount(ctx, accountID, model)
|
||||
require.NoError(t, err)
|
||||
_, err = cache.IncrModelCallCount(ctx, accountID, model)
|
||||
require.NoError(t, err)
|
||||
afterIncr := time.Now()
|
||||
|
||||
// 获取负载信息
|
||||
result, err := cache.GetModelLoadBatch(ctx, []int64{accountID}, model)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, result, 1)
|
||||
|
||||
loadInfo := result[accountID]
|
||||
require.NotNil(t, loadInfo)
|
||||
require.Equal(t, int64(3), loadInfo.CallCount)
|
||||
require.False(t, loadInfo.LastUsedAt.IsZero())
|
||||
// LastUsedAt 应该在 beforeIncr 和 afterIncr 之间
|
||||
require.True(t, loadInfo.LastUsedAt.After(beforeIncr.Add(-time.Second)) || loadInfo.LastUsedAt.Equal(beforeIncr))
|
||||
require.True(t, loadInfo.LastUsedAt.Before(afterIncr.Add(time.Second)) || loadInfo.LastUsedAt.Equal(afterIncr))
|
||||
}
|
||||
|
||||
func (s *GatewayCacheModelLoadSuite) TestGetModelLoadBatch_MultipleAccounts() {
|
||||
t := s.T()
|
||||
rdb := testRedis(t)
|
||||
cache := &gatewayCache{rdb: rdb}
|
||||
ctx := context.Background()
|
||||
|
||||
model := "claude-opus-4-5-20251101"
|
||||
account1 := int64(1001)
|
||||
account2 := int64(1002)
|
||||
account3 := int64(1003) // 不调用
|
||||
|
||||
// account1 调用 2 次
|
||||
_, err := cache.IncrModelCallCount(ctx, account1, model)
|
||||
require.NoError(t, err)
|
||||
_, err = cache.IncrModelCallCount(ctx, account1, model)
|
||||
require.NoError(t, err)
|
||||
|
||||
// account2 调用 5 次
|
||||
for i := 0; i < 5; i++ {
|
||||
_, err = cache.IncrModelCallCount(ctx, account2, model)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// 批量获取
|
||||
result, err := cache.GetModelLoadBatch(ctx, []int64{account1, account2, account3}, model)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, result, 3)
|
||||
|
||||
require.Equal(t, int64(2), result[account1].CallCount)
|
||||
require.False(t, result[account1].LastUsedAt.IsZero())
|
||||
|
||||
require.Equal(t, int64(5), result[account2].CallCount)
|
||||
require.False(t, result[account2].LastUsedAt.IsZero())
|
||||
|
||||
require.Equal(t, int64(0), result[account3].CallCount)
|
||||
require.True(t, result[account3].LastUsedAt.IsZero())
|
||||
}
|
||||
|
||||
func (s *GatewayCacheModelLoadSuite) TestGetModelLoadBatch_ModelIsolation() {
|
||||
t := s.T()
|
||||
rdb := testRedis(t)
|
||||
cache := &gatewayCache{rdb: rdb}
|
||||
ctx := context.Background()
|
||||
|
||||
accountID := int64(2001)
|
||||
model1 := "claude-sonnet-4-20250514"
|
||||
model2 := "gemini-2.5-pro"
|
||||
|
||||
// 对 model1 调用 3 次
|
||||
for i := 0; i < 3; i++ {
|
||||
_, err := cache.IncrModelCallCount(ctx, accountID, model1)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// 获取 model1 的负载
|
||||
result1, err := cache.GetModelLoadBatch(ctx, []int64{accountID}, model1)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(3), result1[accountID].CallCount)
|
||||
|
||||
// 获取 model2 的负载(应该为 0)
|
||||
result2, err := cache.GetModelLoadBatch(ctx, []int64{accountID}, model2)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(0), result2[accountID].CallCount)
|
||||
}
|
||||
|
||||
// ============ 辅助函数测试 ============
|
||||
|
||||
func (s *GatewayCacheModelLoadSuite) TestModelLoadKey_Format() {
|
||||
t := s.T()
|
||||
|
||||
key := modelLoadKey(123, "claude-sonnet-4")
|
||||
require.Equal(t, "ag:model_load:123:claude-sonnet-4", key)
|
||||
}
|
||||
|
||||
func (s *GatewayCacheModelLoadSuite) TestModelLastUsedKey_Format() {
|
||||
t := s.T()
|
||||
|
||||
key := modelLastUsedKey(456, "gemini-2.5-pro")
|
||||
require.Equal(t, "ag:model_last_used:456:gemini-2.5-pro", key)
|
||||
}
|
||||
@@ -1004,10 +1004,6 @@ func (s *stubAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (s *stubAccountRepo) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope service.AntigravityQuotaScope, resetAt time.Time) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (s *stubAccountRepo) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
@@ -50,7 +50,6 @@ type AccountRepository interface {
|
||||
ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error)
|
||||
|
||||
SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error
|
||||
SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error
|
||||
SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error
|
||||
SetOverloaded(ctx context.Context, id int64, until time.Time) error
|
||||
SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error
|
||||
|
||||
@@ -143,10 +143,6 @@ func (s *accountRepoStub) SetRateLimited(ctx context.Context, id int64, resetAt
|
||||
panic("unexpected SetRateLimited call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error {
|
||||
panic("unexpected SetAntigravityQuotaScopeLimit call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
|
||||
panic("unexpected SetModelRateLimit call")
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
@@ -12,9 +11,6 @@ const (
|
||||
// anthropicSessionTTLSeconds Anthropic 会话缓存 TTL(5 分钟)
|
||||
anthropicSessionTTLSeconds = 300
|
||||
|
||||
// anthropicTrieKeyPrefix Anthropic Trie 会话 key 前缀
|
||||
anthropicTrieKeyPrefix = "anthropic:trie:"
|
||||
|
||||
// anthropicDigestSessionKeyPrefix Anthropic 摘要 fallback 会话 key 前缀
|
||||
anthropicDigestSessionKeyPrefix = "anthropic:digest:"
|
||||
)
|
||||
@@ -68,12 +64,6 @@ func rolePrefix(role string) string {
|
||||
}
|
||||
}
|
||||
|
||||
// BuildAnthropicTrieKey 构建 Anthropic Trie Redis key
|
||||
// 格式: anthropic:trie:{groupID}:{prefixHash}
|
||||
func BuildAnthropicTrieKey(groupID int64, prefixHash string) string {
|
||||
return anthropicTrieKeyPrefix + strconv.FormatInt(groupID, 10) + ":" + prefixHash
|
||||
}
|
||||
|
||||
// GenerateAnthropicDigestSessionKey 生成 Anthropic 摘要 fallback 的 sessionKey
|
||||
// 组合 prefixHash 前 8 位 + uuid 前 8 位,确保不同会话产生不同的 sessionKey
|
||||
func GenerateAnthropicDigestSessionKey(prefixHash, uuid string) string {
|
||||
|
||||
@@ -236,43 +236,6 @@ func TestBuildAnthropicDigestChain_Deterministic(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAnthropicTrieKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
groupID int64
|
||||
prefixHash string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "normal",
|
||||
groupID: 123,
|
||||
prefixHash: "abcdef12",
|
||||
want: "anthropic:trie:123:abcdef12",
|
||||
},
|
||||
{
|
||||
name: "zero group",
|
||||
groupID: 0,
|
||||
prefixHash: "xyz",
|
||||
want: "anthropic:trie:0:xyz",
|
||||
},
|
||||
{
|
||||
name: "empty prefix",
|
||||
groupID: 1,
|
||||
prefixHash: "",
|
||||
want: "anthropic:trie:1:",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := BuildAnthropicTrieKey(tt.groupID, tt.prefixHash)
|
||||
if got != tt.want {
|
||||
t.Errorf("BuildAnthropicTrieKey(%d, %q) = %q, want %q", tt.groupID, tt.prefixHash, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateAnthropicDigestSessionKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -4,17 +4,42 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// antigravityFailingWriter 模拟客户端断开连接的 gin.ResponseWriter
|
||||
type antigravityFailingWriter struct {
|
||||
gin.ResponseWriter
|
||||
failAfter int // 允许成功写入的次数,之后所有写入返回错误
|
||||
writes int
|
||||
}
|
||||
|
||||
func (w *antigravityFailingWriter) Write(p []byte) (int, error) {
|
||||
if w.writes >= w.failAfter {
|
||||
return 0, errors.New("write failed: client disconnected")
|
||||
}
|
||||
w.writes++
|
||||
return w.ResponseWriter.Write(p)
|
||||
}
|
||||
|
||||
// newAntigravityTestService 创建用于流式测试的 AntigravityGatewayService
|
||||
func newAntigravityTestService(cfg *config.Config) *AntigravityGatewayService {
|
||||
return &AntigravityGatewayService{
|
||||
settingService: &SettingService{cfg: cfg},
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripSignatureSensitiveBlocksFromClaudeRequest(t *testing.T) {
|
||||
req := &antigravity.ClaudeRequest{
|
||||
Model: "claude-sonnet-4-5",
|
||||
@@ -337,8 +362,8 @@ func TestAntigravityGatewayService_Forward_StickySessionForceCacheBilling(t *tes
|
||||
require.True(t, failoverErr.ForceCacheBilling, "ForceCacheBilling should be true for sticky session switch")
|
||||
}
|
||||
|
||||
// TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling
|
||||
// 验证:ForwardGemini 粘性会话切换时,UpstreamFailoverError.ForceCacheBilling 应为 true
|
||||
// TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling verifies
|
||||
// that ForwardGemini sets ForceCacheBilling=true for sticky session switch.
|
||||
func TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
writer := httptest.NewRecorder()
|
||||
@@ -391,3 +416,438 @@ func TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling(
|
||||
require.Equal(t, http.StatusServiceUnavailable, failoverErr.StatusCode)
|
||||
require.True(t, failoverErr.ForceCacheBilling, "ForceCacheBilling should be true for sticky session switch")
|
||||
}
|
||||
|
||||
// --- 流式 happy path 测试 ---
|
||||
|
||||
// TestStreamUpstreamResponse_NormalComplete
|
||||
// 验证:正常流式转发完成时,数据正确透传、usage 正确收集、clientDisconnect=false
|
||||
func TestStreamUpstreamResponse_NormalComplete(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
svc := newAntigravityTestService(&config.Config{
|
||||
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
|
||||
})
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
|
||||
|
||||
go func() {
|
||||
defer func() { _ = pw.Close() }()
|
||||
fmt.Fprintln(pw, `event: message_start`)
|
||||
fmt.Fprintln(pw, `data: {"type":"message_start","message":{"usage":{"input_tokens":10}}}`)
|
||||
fmt.Fprintln(pw, "")
|
||||
fmt.Fprintln(pw, `event: content_block_delta`)
|
||||
fmt.Fprintln(pw, `data: {"type":"content_block_delta","delta":{"text":"hello"}}`)
|
||||
fmt.Fprintln(pw, "")
|
||||
fmt.Fprintln(pw, `event: message_delta`)
|
||||
fmt.Fprintln(pw, `data: {"type":"message_delta","usage":{"output_tokens":5}}`)
|
||||
fmt.Fprintln(pw, "")
|
||||
}()
|
||||
|
||||
result := svc.streamUpstreamResponse(c, resp, time.Now())
|
||||
_ = pr.Close()
|
||||
|
||||
require.NotNil(t, result)
|
||||
require.False(t, result.clientDisconnect, "normal completion should not set clientDisconnect")
|
||||
require.NotNil(t, result.usage)
|
||||
require.Equal(t, 5, result.usage.OutputTokens, "should collect output_tokens from message_delta")
|
||||
require.NotNil(t, result.firstTokenMs, "should record first token time")
|
||||
|
||||
// 验证数据被透传到客户端
|
||||
body := rec.Body.String()
|
||||
require.Contains(t, body, "event: message_start")
|
||||
require.Contains(t, body, "content_block_delta")
|
||||
require.Contains(t, body, "message_delta")
|
||||
}
|
||||
|
||||
// TestHandleGeminiStreamingResponse_NormalComplete
|
||||
// 验证:正常 Gemini 流式转发,数据正确透传、usage 正确收集
|
||||
func TestHandleGeminiStreamingResponse_NormalComplete(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
svc := newAntigravityTestService(&config.Config{
|
||||
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
|
||||
})
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
|
||||
|
||||
go func() {
|
||||
defer func() { _ = pw.Close() }()
|
||||
// 第一个 chunk(部分内容)
|
||||
fmt.Fprintln(pw, `data: {"candidates":[{"content":{"parts":[{"text":"Hello"}]}}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":3}}`)
|
||||
fmt.Fprintln(pw, "")
|
||||
// 第二个 chunk(最终内容+完整 usage)
|
||||
fmt.Fprintln(pw, `data: {"candidates":[{"content":{"parts":[{"text":" world"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":8,"cachedContentTokenCount":2}}`)
|
||||
fmt.Fprintln(pw, "")
|
||||
}()
|
||||
|
||||
result, err := svc.handleGeminiStreamingResponse(c, resp, time.Now())
|
||||
_ = pr.Close()
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.False(t, result.clientDisconnect, "normal completion should not set clientDisconnect")
|
||||
require.NotNil(t, result.usage)
|
||||
// Gemini usage: promptTokenCount=10, candidatesTokenCount=8, cachedContentTokenCount=2
|
||||
// → InputTokens=10-2=8, OutputTokens=8, CacheReadInputTokens=2
|
||||
require.Equal(t, 8, result.usage.InputTokens)
|
||||
require.Equal(t, 8, result.usage.OutputTokens)
|
||||
require.Equal(t, 2, result.usage.CacheReadInputTokens)
|
||||
require.NotNil(t, result.firstTokenMs, "should record first token time")
|
||||
|
||||
// 验证数据被透传到客户端
|
||||
body := rec.Body.String()
|
||||
require.Contains(t, body, "Hello")
|
||||
require.Contains(t, body, "world")
|
||||
// 不应包含错误事件
|
||||
require.NotContains(t, body, "event: error")
|
||||
}
|
||||
|
||||
// TestHandleClaudeStreamingResponse_NormalComplete
|
||||
// 验证:正常 Claude 流式转发(Gemini→Claude 转换),数据正确转换并输出
|
||||
func TestHandleClaudeStreamingResponse_NormalComplete(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
svc := newAntigravityTestService(&config.Config{
|
||||
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
|
||||
})
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
|
||||
|
||||
go func() {
|
||||
defer func() { _ = pw.Close() }()
|
||||
// v1internal 包装格式:Gemini 数据嵌套在 "response" 字段下
|
||||
// ProcessLine 先尝试反序列化为 V1InternalResponse,裸格式会导致 Response.UsageMetadata 为空
|
||||
fmt.Fprintln(pw, `data: {"response":{"candidates":[{"content":{"parts":[{"text":"Hi there"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":5,"candidatesTokenCount":3}}}`)
|
||||
fmt.Fprintln(pw, "")
|
||||
}()
|
||||
|
||||
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5")
|
||||
_ = pr.Close()
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.False(t, result.clientDisconnect, "normal completion should not set clientDisconnect")
|
||||
require.NotNil(t, result.usage)
|
||||
// Gemini→Claude 转换的 usage:promptTokenCount=5→InputTokens=5, candidatesTokenCount=3→OutputTokens=3
|
||||
require.Equal(t, 5, result.usage.InputTokens)
|
||||
require.Equal(t, 3, result.usage.OutputTokens)
|
||||
require.NotNil(t, result.firstTokenMs, "should record first token time")
|
||||
|
||||
// 验证输出是 Claude SSE 格式(processor 会转换)
|
||||
body := rec.Body.String()
|
||||
require.Contains(t, body, "event: message_start", "should contain Claude message_start event")
|
||||
require.Contains(t, body, "event: message_stop", "should contain Claude message_stop event")
|
||||
// 不应包含错误事件
|
||||
require.NotContains(t, body, "event: error")
|
||||
}
|
||||
|
||||
// --- 流式客户端断开检测测试 ---
|
||||
|
||||
// TestStreamUpstreamResponse_ClientDisconnectDrainsUsage
|
||||
// 验证:客户端写入失败后,streamUpstreamResponse 继续读取上游以收集 usage
|
||||
func TestStreamUpstreamResponse_ClientDisconnectDrainsUsage(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
svc := newAntigravityTestService(&config.Config{
|
||||
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
|
||||
})
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
c.Writer = &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
|
||||
|
||||
go func() {
|
||||
defer func() { _ = pw.Close() }()
|
||||
fmt.Fprintln(pw, `event: message_start`)
|
||||
fmt.Fprintln(pw, `data: {"type":"message_start","message":{"usage":{"input_tokens":10}}}`)
|
||||
fmt.Fprintln(pw, "")
|
||||
fmt.Fprintln(pw, `event: message_delta`)
|
||||
fmt.Fprintln(pw, `data: {"type":"message_delta","usage":{"output_tokens":20}}`)
|
||||
fmt.Fprintln(pw, "")
|
||||
}()
|
||||
|
||||
result := svc.streamUpstreamResponse(c, resp, time.Now())
|
||||
_ = pr.Close()
|
||||
|
||||
require.NotNil(t, result)
|
||||
require.True(t, result.clientDisconnect)
|
||||
require.NotNil(t, result.usage)
|
||||
require.Equal(t, 20, result.usage.OutputTokens)
|
||||
}
|
||||
|
||||
// TestStreamUpstreamResponse_ContextCanceled
|
||||
// 验证:context 取消时返回 usage 且标记 clientDisconnect
|
||||
func TestStreamUpstreamResponse_ContextCanceled(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
svc := newAntigravityTestService(&config.Config{
|
||||
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
|
||||
})
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil).WithContext(ctx)
|
||||
|
||||
resp := &http.Response{StatusCode: http.StatusOK, Body: cancelReadCloser{}, Header: http.Header{}}
|
||||
|
||||
result := svc.streamUpstreamResponse(c, resp, time.Now())
|
||||
|
||||
require.NotNil(t, result)
|
||||
require.True(t, result.clientDisconnect)
|
||||
require.NotContains(t, rec.Body.String(), "event: error")
|
||||
}
|
||||
|
||||
// TestStreamUpstreamResponse_Timeout
|
||||
// 验证:上游超时时返回已收集的 usage
|
||||
func TestStreamUpstreamResponse_Timeout(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
svc := newAntigravityTestService(&config.Config{
|
||||
Gateway: config.GatewayConfig{StreamDataIntervalTimeout: 1, MaxLineSize: defaultMaxLineSize},
|
||||
})
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
|
||||
|
||||
result := svc.streamUpstreamResponse(c, resp, time.Now())
|
||||
_ = pw.Close()
|
||||
_ = pr.Close()
|
||||
|
||||
require.NotNil(t, result)
|
||||
require.False(t, result.clientDisconnect)
|
||||
}
|
||||
|
||||
// TestStreamUpstreamResponse_TimeoutAfterClientDisconnect
|
||||
// 验证:客户端断开后上游超时,返回 usage 并标记 clientDisconnect
|
||||
func TestStreamUpstreamResponse_TimeoutAfterClientDisconnect(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
svc := newAntigravityTestService(&config.Config{
|
||||
Gateway: config.GatewayConfig{StreamDataIntervalTimeout: 1, MaxLineSize: defaultMaxLineSize},
|
||||
})
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
c.Writer = &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
|
||||
|
||||
go func() {
|
||||
fmt.Fprintln(pw, `data: {"type":"message_start","message":{"usage":{"input_tokens":5}}}`)
|
||||
fmt.Fprintln(pw, "")
|
||||
// 不关闭 pw → 等待超时
|
||||
}()
|
||||
|
||||
result := svc.streamUpstreamResponse(c, resp, time.Now())
|
||||
_ = pw.Close()
|
||||
_ = pr.Close()
|
||||
|
||||
require.NotNil(t, result)
|
||||
require.True(t, result.clientDisconnect)
|
||||
}
|
||||
|
||||
// TestHandleGeminiStreamingResponse_ClientDisconnect
|
||||
// 验证:Gemini 流式转发中客户端断开后继续 drain 上游
|
||||
func TestHandleGeminiStreamingResponse_ClientDisconnect(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
svc := newAntigravityTestService(&config.Config{
|
||||
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
|
||||
})
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
c.Writer = &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
|
||||
|
||||
go func() {
|
||||
defer func() { _ = pw.Close() }()
|
||||
fmt.Fprintln(pw, `data: {"candidates":[{"content":{"parts":[{"text":"hi"}]}}],"usageMetadata":{"promptTokenCount":5,"candidatesTokenCount":10}}`)
|
||||
fmt.Fprintln(pw, "")
|
||||
}()
|
||||
|
||||
result, err := svc.handleGeminiStreamingResponse(c, resp, time.Now())
|
||||
_ = pr.Close()
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.True(t, result.clientDisconnect)
|
||||
require.NotContains(t, rec.Body.String(), "write_failed")
|
||||
}
|
||||
|
||||
// TestHandleGeminiStreamingResponse_ContextCanceled
|
||||
// 验证:context 取消时不注入错误事件
|
||||
func TestHandleGeminiStreamingResponse_ContextCanceled(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
svc := newAntigravityTestService(&config.Config{
|
||||
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
|
||||
})
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil).WithContext(ctx)
|
||||
|
||||
resp := &http.Response{StatusCode: http.StatusOK, Body: cancelReadCloser{}, Header: http.Header{}}
|
||||
|
||||
result, err := svc.handleGeminiStreamingResponse(c, resp, time.Now())
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.True(t, result.clientDisconnect)
|
||||
require.NotContains(t, rec.Body.String(), "event: error")
|
||||
}
|
||||
|
||||
// TestHandleClaudeStreamingResponse_ClientDisconnect
|
||||
// 验证:Claude 流式转发中客户端断开后继续 drain 上游
|
||||
func TestHandleClaudeStreamingResponse_ClientDisconnect(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
svc := newAntigravityTestService(&config.Config{
|
||||
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
|
||||
})
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
c.Writer = &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
|
||||
|
||||
go func() {
|
||||
defer func() { _ = pw.Close() }()
|
||||
// v1internal 包装格式
|
||||
fmt.Fprintln(pw, `data: {"response":{"candidates":[{"content":{"parts":[{"text":"hello"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":8,"candidatesTokenCount":15}}}`)
|
||||
fmt.Fprintln(pw, "")
|
||||
}()
|
||||
|
||||
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5")
|
||||
_ = pr.Close()
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.True(t, result.clientDisconnect)
|
||||
}
|
||||
|
||||
// TestHandleClaudeStreamingResponse_ContextCanceled
|
||||
// 验证:context 取消时不注入错误事件
|
||||
func TestHandleClaudeStreamingResponse_ContextCanceled(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
svc := newAntigravityTestService(&config.Config{
|
||||
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
|
||||
})
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil).WithContext(ctx)
|
||||
|
||||
resp := &http.Response{StatusCode: http.StatusOK, Body: cancelReadCloser{}, Header: http.Header{}}
|
||||
|
||||
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5")
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.True(t, result.clientDisconnect)
|
||||
require.NotContains(t, rec.Body.String(), "event: error")
|
||||
}
|
||||
|
||||
// TestExtractSSEUsage 验证 extractSSEUsage 从 SSE data 行正确提取 usage
|
||||
func TestExtractSSEUsage(t *testing.T) {
|
||||
svc := &AntigravityGatewayService{}
|
||||
tests := []struct {
|
||||
name string
|
||||
line string
|
||||
expected ClaudeUsage
|
||||
}{
|
||||
{
|
||||
name: "message_delta with output_tokens",
|
||||
line: `data: {"type":"message_delta","usage":{"output_tokens":42}}`,
|
||||
expected: ClaudeUsage{OutputTokens: 42},
|
||||
},
|
||||
{
|
||||
name: "non-data line ignored",
|
||||
line: `event: message_start`,
|
||||
expected: ClaudeUsage{},
|
||||
},
|
||||
{
|
||||
name: "top-level usage with all fields",
|
||||
line: `data: {"usage":{"input_tokens":10,"output_tokens":20,"cache_read_input_tokens":5,"cache_creation_input_tokens":3}}`,
|
||||
expected: ClaudeUsage{InputTokens: 10, OutputTokens: 20, CacheReadInputTokens: 5, CacheCreationInputTokens: 3},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
usage := &ClaudeUsage{}
|
||||
svc.extractSSEUsage(tt.line, usage)
|
||||
require.Equal(t, tt.expected, *usage)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAntigravityClientWriter 验证 antigravityClientWriter 的断开检测
|
||||
func TestAntigravityClientWriter(t *testing.T) {
|
||||
t.Run("normal write succeeds", func(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
flusher, _ := c.Writer.(http.Flusher)
|
||||
cw := newAntigravityClientWriter(c.Writer, flusher, "test")
|
||||
|
||||
ok := cw.Write([]byte("hello"))
|
||||
require.True(t, ok)
|
||||
require.False(t, cw.Disconnected())
|
||||
require.Contains(t, rec.Body.String(), "hello")
|
||||
})
|
||||
|
||||
t.Run("write failure marks disconnected", func(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
fw := &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
|
||||
flusher, _ := c.Writer.(http.Flusher)
|
||||
cw := newAntigravityClientWriter(fw, flusher, "test")
|
||||
|
||||
ok := cw.Write([]byte("hello"))
|
||||
require.False(t, ok)
|
||||
require.True(t, cw.Disconnected())
|
||||
})
|
||||
|
||||
t.Run("subsequent writes are no-op", func(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
fw := &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
|
||||
flusher, _ := c.Writer.(http.Flusher)
|
||||
cw := newAntigravityClientWriter(fw, flusher, "test")
|
||||
|
||||
cw.Write([]byte("first"))
|
||||
ok := cw.Fprintf("second %d", 2)
|
||||
require.False(t, ok)
|
||||
require.True(t, cw.Disconnected())
|
||||
})
|
||||
}
|
||||
|
||||
@@ -2,63 +2,23 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const antigravityQuotaScopesKey = "antigravity_quota_scopes"
|
||||
|
||||
// AntigravityQuotaScope 表示 Antigravity 的配额域
|
||||
type AntigravityQuotaScope string
|
||||
|
||||
const (
|
||||
AntigravityQuotaScopeClaude AntigravityQuotaScope = "claude"
|
||||
AntigravityQuotaScopeGeminiText AntigravityQuotaScope = "gemini_text"
|
||||
AntigravityQuotaScopeGeminiImage AntigravityQuotaScope = "gemini_image"
|
||||
)
|
||||
|
||||
// IsScopeSupported 检查给定的 scope 是否在分组支持的 scope 列表中
|
||||
func IsScopeSupported(supportedScopes []string, scope AntigravityQuotaScope) bool {
|
||||
if len(supportedScopes) == 0 {
|
||||
// 未配置时默认全部支持
|
||||
return true
|
||||
}
|
||||
supported := slices.Contains(supportedScopes, string(scope))
|
||||
return supported
|
||||
}
|
||||
|
||||
// ResolveAntigravityQuotaScope 根据模型名称解析配额域(导出版本)
|
||||
func ResolveAntigravityQuotaScope(requestedModel string) (AntigravityQuotaScope, bool) {
|
||||
return resolveAntigravityQuotaScope(requestedModel)
|
||||
}
|
||||
|
||||
// resolveAntigravityQuotaScope 根据模型名称解析配额域
|
||||
func resolveAntigravityQuotaScope(requestedModel string) (AntigravityQuotaScope, bool) {
|
||||
model := normalizeAntigravityModelName(requestedModel)
|
||||
if model == "" {
|
||||
return "", false
|
||||
}
|
||||
switch {
|
||||
case strings.HasPrefix(model, "claude-"):
|
||||
return AntigravityQuotaScopeClaude, true
|
||||
case strings.HasPrefix(model, "gemini-"):
|
||||
if isImageGenerationModel(model) {
|
||||
return AntigravityQuotaScopeGeminiImage, true
|
||||
}
|
||||
return AntigravityQuotaScopeGeminiText, true
|
||||
default:
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeAntigravityModelName(model string) string {
|
||||
normalized := strings.ToLower(strings.TrimSpace(model))
|
||||
normalized = strings.TrimPrefix(normalized, "models/")
|
||||
return normalized
|
||||
}
|
||||
|
||||
// IsSchedulableForModel 结合 Antigravity 配额域限流判断是否可调度。
|
||||
// resolveAntigravityModelKey 根据请求的模型名解析限流 key
|
||||
// 返回空字符串表示无法解析
|
||||
func resolveAntigravityModelKey(requestedModel string) string {
|
||||
return normalizeAntigravityModelName(requestedModel)
|
||||
}
|
||||
|
||||
// IsSchedulableForModel 结合模型级限流判断是否可调度。
|
||||
// 保持旧签名以兼容既有调用方;默认使用 context.Background()。
|
||||
func (a *Account) IsSchedulableForModel(requestedModel string) bool {
|
||||
return a.IsSchedulableForModelWithContext(context.Background(), requestedModel)
|
||||
@@ -74,107 +34,20 @@ func (a *Account) IsSchedulableForModelWithContext(ctx context.Context, requeste
|
||||
if a.isModelRateLimitedWithContext(ctx, requestedModel) {
|
||||
return false
|
||||
}
|
||||
if a.Platform != PlatformAntigravity {
|
||||
return true
|
||||
}
|
||||
scope, ok := resolveAntigravityQuotaScope(requestedModel)
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
resetAt := a.antigravityQuotaScopeResetAt(scope)
|
||||
if resetAt == nil {
|
||||
return true
|
||||
}
|
||||
now := time.Now()
|
||||
return !now.Before(*resetAt)
|
||||
return true
|
||||
}
|
||||
|
||||
func (a *Account) antigravityQuotaScopeResetAt(scope AntigravityQuotaScope) *time.Time {
|
||||
if a == nil || a.Extra == nil || scope == "" {
|
||||
return nil
|
||||
}
|
||||
rawScopes, ok := a.Extra[antigravityQuotaScopesKey].(map[string]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
rawScope, ok := rawScopes[string(scope)].(map[string]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
resetAtRaw, ok := rawScope["rate_limit_reset_at"].(string)
|
||||
if !ok || strings.TrimSpace(resetAtRaw) == "" {
|
||||
return nil
|
||||
}
|
||||
resetAt, err := time.Parse(time.RFC3339, resetAtRaw)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return &resetAt
|
||||
}
|
||||
|
||||
var antigravityAllScopes = []AntigravityQuotaScope{
|
||||
AntigravityQuotaScopeClaude,
|
||||
AntigravityQuotaScopeGeminiText,
|
||||
AntigravityQuotaScopeGeminiImage,
|
||||
}
|
||||
|
||||
func (a *Account) GetAntigravityScopeRateLimits() map[string]int64 {
|
||||
if a == nil || a.Platform != PlatformAntigravity {
|
||||
return nil
|
||||
}
|
||||
now := time.Now()
|
||||
result := make(map[string]int64)
|
||||
for _, scope := range antigravityAllScopes {
|
||||
resetAt := a.antigravityQuotaScopeResetAt(scope)
|
||||
if resetAt != nil && now.Before(*resetAt) {
|
||||
remainingSec := int64(time.Until(*resetAt).Seconds())
|
||||
if remainingSec > 0 {
|
||||
result[string(scope)] = remainingSec
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(result) == 0 {
|
||||
return nil
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// GetQuotaScopeRateLimitRemainingTime 获取模型域限流剩余时间
|
||||
// 返回 0 表示未限流或已过期
|
||||
func (a *Account) GetQuotaScopeRateLimitRemainingTime(requestedModel string) time.Duration {
|
||||
if a == nil || a.Platform != PlatformAntigravity {
|
||||
return 0
|
||||
}
|
||||
scope, ok := resolveAntigravityQuotaScope(requestedModel)
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
resetAt := a.antigravityQuotaScopeResetAt(scope)
|
||||
if resetAt == nil {
|
||||
return 0
|
||||
}
|
||||
if remaining := time.Until(*resetAt); remaining > 0 {
|
||||
return remaining
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// GetRateLimitRemainingTime 获取限流剩余时间(模型限流和模型域限流取最大值)
|
||||
// GetRateLimitRemainingTime 获取限流剩余时间(模型级限流)
|
||||
// 返回 0 表示未限流或已过期
|
||||
func (a *Account) GetRateLimitRemainingTime(requestedModel string) time.Duration {
|
||||
return a.GetRateLimitRemainingTimeWithContext(context.Background(), requestedModel)
|
||||
}
|
||||
|
||||
// GetRateLimitRemainingTimeWithContext 获取限流剩余时间(模型限流和模型域限流取最大值)
|
||||
// GetRateLimitRemainingTimeWithContext 获取限流剩余时间(模型级限流)
|
||||
// 返回 0 表示未限流或已过期
|
||||
func (a *Account) GetRateLimitRemainingTimeWithContext(ctx context.Context, requestedModel string) time.Duration {
|
||||
if a == nil {
|
||||
return 0
|
||||
}
|
||||
modelRemaining := a.GetModelRateLimitRemainingTimeWithContext(ctx, requestedModel)
|
||||
scopeRemaining := a.GetQuotaScopeRateLimitRemainingTime(requestedModel)
|
||||
if modelRemaining > scopeRemaining {
|
||||
return modelRemaining
|
||||
}
|
||||
return scopeRemaining
|
||||
return a.GetModelRateLimitRemainingTimeWithContext(ctx, requestedModel)
|
||||
}
|
||||
|
||||
@@ -59,12 +59,6 @@ func (s *stubAntigravityUpstream) DoWithTLS(req *http.Request, proxyURL string,
|
||||
return s.Do(req, proxyURL, accountID, accountConcurrency)
|
||||
}
|
||||
|
||||
type scopeLimitCall struct {
|
||||
accountID int64
|
||||
scope AntigravityQuotaScope
|
||||
resetAt time.Time
|
||||
}
|
||||
|
||||
type rateLimitCall struct {
|
||||
accountID int64
|
||||
resetAt time.Time
|
||||
@@ -78,16 +72,10 @@ type modelRateLimitCall struct {
|
||||
|
||||
type stubAntigravityAccountRepo struct {
|
||||
AccountRepository
|
||||
scopeCalls []scopeLimitCall
|
||||
rateCalls []rateLimitCall
|
||||
modelRateLimitCalls []modelRateLimitCall
|
||||
}
|
||||
|
||||
func (s *stubAntigravityAccountRepo) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error {
|
||||
s.scopeCalls = append(s.scopeCalls, scopeLimitCall{accountID: id, scope: scope, resetAt: resetAt})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stubAntigravityAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
||||
s.rateCalls = append(s.rateCalls, rateLimitCall{accountID: id, resetAt: resetAt})
|
||||
return nil
|
||||
@@ -131,10 +119,9 @@ func TestAntigravityRetryLoop_URLFallback_UsesLatestSuccess(t *testing.T) {
|
||||
accessToken: "token",
|
||||
action: "generateContent",
|
||||
body: []byte(`{"input":"test"}`),
|
||||
quotaScope: AntigravityQuotaScopeClaude,
|
||||
httpUpstream: upstream,
|
||||
requestedModel: "claude-sonnet-4-5",
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
handleErrorCalled = true
|
||||
return nil
|
||||
},
|
||||
@@ -155,23 +142,6 @@ func TestAntigravityRetryLoop_URLFallback_UsesLatestSuccess(t *testing.T) {
|
||||
require.Equal(t, base2, available[0])
|
||||
}
|
||||
|
||||
func TestAntigravityHandleUpstreamError_UsesScopeLimit(t *testing.T) {
|
||||
// 分区限流始终开启,不再支持通过环境变量关闭
|
||||
repo := &stubAntigravityAccountRepo{}
|
||||
svc := &AntigravityGatewayService{accountRepo: repo}
|
||||
account := &Account{ID: 9, Name: "acc-9", Platform: PlatformAntigravity}
|
||||
|
||||
body := buildGeminiRateLimitBody("3s")
|
||||
svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, AntigravityQuotaScopeClaude, 0, "", false)
|
||||
|
||||
require.Len(t, repo.scopeCalls, 1)
|
||||
require.Empty(t, repo.rateCalls)
|
||||
call := repo.scopeCalls[0]
|
||||
require.Equal(t, account.ID, call.accountID)
|
||||
require.Equal(t, AntigravityQuotaScopeClaude, call.scope)
|
||||
require.WithinDuration(t, time.Now().Add(3*time.Second), call.resetAt, 2*time.Second)
|
||||
}
|
||||
|
||||
// TestHandleUpstreamError_429_ModelRateLimit 测试 429 模型限流场景
|
||||
func TestHandleUpstreamError_429_ModelRateLimit(t *testing.T) {
|
||||
repo := &stubAntigravityAccountRepo{}
|
||||
@@ -189,7 +159,7 @@ func TestHandleUpstreamError_429_ModelRateLimit(t *testing.T) {
|
||||
}
|
||||
}`)
|
||||
|
||||
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, AntigravityQuotaScopeClaude, 0, "", false)
|
||||
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, "claude-sonnet-4-5", 0, "", false)
|
||||
|
||||
// 应该触发模型限流
|
||||
require.NotNil(t, result)
|
||||
@@ -200,22 +170,22 @@ func TestHandleUpstreamError_429_ModelRateLimit(t *testing.T) {
|
||||
require.Equal(t, "claude-sonnet-4-5", repo.modelRateLimitCalls[0].modelKey)
|
||||
}
|
||||
|
||||
// TestHandleUpstreamError_429_NonModelRateLimit 测试 429 非模型限流场景(走 scope 限流)
|
||||
// TestHandleUpstreamError_429_NonModelRateLimit 测试 429 非模型限流场景(走模型级限流兜底)
|
||||
func TestHandleUpstreamError_429_NonModelRateLimit(t *testing.T) {
|
||||
repo := &stubAntigravityAccountRepo{}
|
||||
svc := &AntigravityGatewayService{accountRepo: repo}
|
||||
account := &Account{ID: 2, Name: "acc-2", Platform: PlatformAntigravity}
|
||||
|
||||
// 429 + 普通限流响应(无 RATE_LIMIT_EXCEEDED reason)→ scope 限流
|
||||
// 429 + 普通限流响应(无 RATE_LIMIT_EXCEEDED reason)→ 走模型级限流兜底
|
||||
body := buildGeminiRateLimitBody("5s")
|
||||
|
||||
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, AntigravityQuotaScopeClaude, 0, "", false)
|
||||
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, "claude-sonnet-4-5", 0, "", false)
|
||||
|
||||
// 不应该触发模型限流,应该走 scope 限流
|
||||
// handleModelRateLimit 不会处理(因为没有 RATE_LIMIT_EXCEEDED),
|
||||
// 但 429 兜底逻辑会使用 requestedModel 设置模型级限流
|
||||
require.Nil(t, result)
|
||||
require.Empty(t, repo.modelRateLimitCalls)
|
||||
require.Len(t, repo.scopeCalls, 1)
|
||||
require.Equal(t, AntigravityQuotaScopeClaude, repo.scopeCalls[0].scope)
|
||||
require.Len(t, repo.modelRateLimitCalls, 1)
|
||||
require.Equal(t, "claude-sonnet-4-5", repo.modelRateLimitCalls[0].modelKey)
|
||||
}
|
||||
|
||||
// TestHandleUpstreamError_503_ModelRateLimit 测试 503 模型限流场景
|
||||
@@ -235,7 +205,7 @@ func TestHandleUpstreamError_503_ModelRateLimit(t *testing.T) {
|
||||
}
|
||||
}`)
|
||||
|
||||
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, AntigravityQuotaScopeGeminiText, 0, "", false)
|
||||
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, "gemini-3-pro-high", 0, "", false)
|
||||
|
||||
// 应该触发模型限流
|
||||
require.NotNil(t, result)
|
||||
@@ -263,12 +233,11 @@ func TestHandleUpstreamError_503_NonModelRateLimit(t *testing.T) {
|
||||
}
|
||||
}`)
|
||||
|
||||
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, AntigravityQuotaScopeGeminiText, 0, "", false)
|
||||
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, "gemini-3-pro-high", 0, "", false)
|
||||
|
||||
// 503 非模型限流不应该做任何处理
|
||||
require.Nil(t, result)
|
||||
require.Empty(t, repo.modelRateLimitCalls, "503 non-model rate limit should not trigger model rate limit")
|
||||
require.Empty(t, repo.scopeCalls, "503 non-model rate limit should not trigger scope rate limit")
|
||||
require.Empty(t, repo.rateCalls, "503 non-model rate limit should not trigger account rate limit")
|
||||
}
|
||||
|
||||
@@ -281,12 +250,11 @@ func TestHandleUpstreamError_503_EmptyBody(t *testing.T) {
|
||||
// 503 + 空响应体 → 不做任何处理
|
||||
body := []byte(`{}`)
|
||||
|
||||
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, AntigravityQuotaScopeGeminiText, 0, "", false)
|
||||
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, "gemini-3-pro-high", 0, "", false)
|
||||
|
||||
// 503 空响应不应该做任何处理
|
||||
require.Nil(t, result)
|
||||
require.Empty(t, repo.modelRateLimitCalls)
|
||||
require.Empty(t, repo.scopeCalls)
|
||||
require.Empty(t, repo.rateCalls)
|
||||
}
|
||||
|
||||
@@ -307,15 +275,7 @@ func TestAccountIsSchedulableForModel_AntigravityRateLimits(t *testing.T) {
|
||||
require.False(t, account.IsSchedulableForModel("gemini-3-flash"))
|
||||
|
||||
account.RateLimitResetAt = nil
|
||||
account.Extra = map[string]any{
|
||||
antigravityQuotaScopesKey: map[string]any{
|
||||
"claude": map[string]any{
|
||||
"rate_limit_reset_at": future.Format(time.RFC3339),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
require.False(t, account.IsSchedulableForModel("claude-sonnet-4-5"))
|
||||
require.True(t, account.IsSchedulableForModel("claude-sonnet-4-5"))
|
||||
require.True(t, account.IsSchedulableForModel("gemini-3-flash"))
|
||||
}
|
||||
|
||||
@@ -635,6 +595,7 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) {
|
||||
}`,
|
||||
expectedShouldRetry: false,
|
||||
expectedShouldRateLimit: true,
|
||||
minWait: 7 * time.Second,
|
||||
modelName: "gemini-pro",
|
||||
},
|
||||
{
|
||||
@@ -652,6 +613,7 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) {
|
||||
}`,
|
||||
expectedShouldRetry: false,
|
||||
expectedShouldRateLimit: true,
|
||||
minWait: 39 * time.Second,
|
||||
modelName: "gemini-3-pro-high",
|
||||
},
|
||||
{
|
||||
@@ -669,6 +631,7 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) {
|
||||
}`,
|
||||
expectedShouldRetry: false,
|
||||
expectedShouldRateLimit: true,
|
||||
minWait: 30 * time.Second,
|
||||
modelName: "gemini-2.5-flash",
|
||||
},
|
||||
{
|
||||
@@ -686,6 +649,7 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) {
|
||||
}`,
|
||||
expectedShouldRetry: false,
|
||||
expectedShouldRateLimit: true,
|
||||
minWait: 30 * time.Second,
|
||||
modelName: "claude-sonnet-4-5",
|
||||
},
|
||||
}
|
||||
@@ -704,6 +668,11 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) {
|
||||
t.Errorf("wait = %v, want >= %v", wait, tt.minWait)
|
||||
}
|
||||
}
|
||||
if shouldRateLimit && tt.minWait > 0 {
|
||||
if wait < tt.minWait {
|
||||
t.Errorf("rate limit wait = %v, want >= %v", wait, tt.minWait)
|
||||
}
|
||||
}
|
||||
if (shouldRetry || shouldRateLimit) && model != tt.modelName {
|
||||
t.Errorf("modelName = %q, want %q", model, tt.modelName)
|
||||
}
|
||||
@@ -832,7 +801,7 @@ func TestAntigravityRetryLoop_PreCheck_SwitchesWhenRateLimited(t *testing.T) {
|
||||
requestedModel: "claude-sonnet-4-5",
|
||||
httpUpstream: upstream,
|
||||
isStickySession: true,
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
})
|
||||
@@ -875,7 +844,7 @@ func TestAntigravityRetryLoop_PreCheck_SwitchesWhenRemainingLong(t *testing.T) {
|
||||
requestedModel: "claude-sonnet-4-5",
|
||||
httpUpstream: upstream,
|
||||
isStickySession: true,
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
})
|
||||
|
||||
@@ -75,7 +75,7 @@ func TestHandleSmartRetry_URLLevelRateLimit(t *testing.T) {
|
||||
accessToken: "token",
|
||||
action: "generateContent",
|
||||
body: []byte(`{"input":"test"}`),
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
@@ -127,7 +127,7 @@ func TestHandleSmartRetry_LongDelay_ReturnsSwitchError(t *testing.T) {
|
||||
body: []byte(`{"input":"test"}`),
|
||||
accountRepo: repo,
|
||||
isStickySession: true,
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
@@ -194,7 +194,7 @@ func TestHandleSmartRetry_ShortDelay_SmartRetrySuccess(t *testing.T) {
|
||||
action: "generateContent",
|
||||
body: []byte(`{"input":"test"}`),
|
||||
httpUpstream: upstream,
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
@@ -269,7 +269,7 @@ func TestHandleSmartRetry_ShortDelay_SmartRetryFailed_ReturnsSwitchError(t *test
|
||||
httpUpstream: upstream,
|
||||
accountRepo: repo,
|
||||
isStickySession: false,
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
@@ -331,7 +331,7 @@ func TestHandleSmartRetry_503_ModelCapacityExhausted_ReturnsSwitchError(t *testi
|
||||
body: []byte(`{"input":"test"}`),
|
||||
accountRepo: repo,
|
||||
isStickySession: true,
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
@@ -387,7 +387,7 @@ func TestHandleSmartRetry_NonAntigravityAccount_ContinuesDefaultLogic(t *testing
|
||||
accessToken: "token",
|
||||
action: "generateContent",
|
||||
body: []byte(`{"input":"test"}`),
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
@@ -436,7 +436,7 @@ func TestHandleSmartRetry_NonModelRateLimit_ContinuesDefaultLogic(t *testing.T)
|
||||
accessToken: "token",
|
||||
action: "generateContent",
|
||||
body: []byte(`{"input":"test"}`),
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
@@ -487,7 +487,7 @@ func TestHandleSmartRetry_ExactlyAtThreshold_ReturnsSwitchError(t *testing.T) {
|
||||
action: "generateContent",
|
||||
body: []byte(`{"input":"test"}`),
|
||||
accountRepo: repo,
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
@@ -548,7 +548,7 @@ func TestAntigravityRetryLoop_HandleSmartRetry_SwitchError_Propagates(t *testing
|
||||
httpUpstream: upstream,
|
||||
accountRepo: repo,
|
||||
isStickySession: true,
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
})
|
||||
@@ -604,7 +604,7 @@ func TestHandleSmartRetry_NetworkError_ExhaustsRetry(t *testing.T) {
|
||||
body: []byte(`{"input":"test"}`),
|
||||
httpUpstream: upstream,
|
||||
accountRepo: repo,
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
@@ -662,7 +662,7 @@ func TestHandleSmartRetry_NoRetryDelay_UsesDefaultRateLimit(t *testing.T) {
|
||||
body: []byte(`{"input":"test"}`),
|
||||
accountRepo: repo,
|
||||
isStickySession: true,
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
@@ -754,7 +754,7 @@ func TestHandleSmartRetry_ShortDelay_StickySession_FailedRetry_ClearsSession(t *
|
||||
isStickySession: true,
|
||||
groupID: 42,
|
||||
sessionHash: "sticky-hash-abc",
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
@@ -842,7 +842,7 @@ func TestHandleSmartRetry_ShortDelay_NonStickySession_FailedRetry_NoDeleteSessio
|
||||
isStickySession: false,
|
||||
groupID: 42,
|
||||
sessionHash: "", // 非粘性会话,sessionHash 为空
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
@@ -918,7 +918,7 @@ func TestHandleSmartRetry_ShortDelay_StickySession_FailedRetry_NilCache_NoPanic(
|
||||
isStickySession: true,
|
||||
groupID: 42,
|
||||
sessionHash: "sticky-hash-nil-cache",
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
@@ -983,7 +983,7 @@ func TestHandleSmartRetry_ShortDelay_StickySession_SuccessRetry_NoDeleteSession(
|
||||
isStickySession: true,
|
||||
groupID: 42,
|
||||
sessionHash: "sticky-hash-success",
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
@@ -1043,7 +1043,7 @@ func TestHandleSmartRetry_LongDelay_StickySession_NoDeleteInHandleSmartRetry(t *
|
||||
isStickySession: true,
|
||||
groupID: 42,
|
||||
sessionHash: "sticky-hash-long-delay",
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
@@ -1108,7 +1108,7 @@ func TestHandleSmartRetry_ShortDelay_NetworkError_StickySession_ClearsSession(t
|
||||
isStickySession: true,
|
||||
groupID: 99,
|
||||
sessionHash: "sticky-net-error",
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
@@ -1188,7 +1188,7 @@ func TestHandleSmartRetry_ShortDelay_503_StickySession_FailedRetry_ClearsSession
|
||||
isStickySession: true,
|
||||
groupID: 77,
|
||||
sessionHash: "sticky-503-short",
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
@@ -1278,7 +1278,7 @@ func TestAntigravityRetryLoop_SmartRetryFailed_StickySession_SwitchErrorPropagat
|
||||
isStickySession: true,
|
||||
groupID: 55,
|
||||
sessionHash: "sticky-loop-test",
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
})
|
||||
@@ -1296,4 +1296,4 @@ func TestAntigravityRetryLoop_SmartRetryFailed_StickySession_SwitchErrorPropagat
|
||||
require.Len(t, cache.deleteCalls, 1, "should clear sticky session in handleSmartRetry")
|
||||
require.Equal(t, int64(55), cache.deleteCalls[0].groupID)
|
||||
require.Equal(t, "sticky-loop-test", cache.deleteCalls[0].sessionHash)
|
||||
}
|
||||
}
|
||||
|
||||
69
backend/internal/service/digest_session_store.go
Normal file
69
backend/internal/service/digest_session_store.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
gocache "github.com/patrickmn/go-cache"
|
||||
)
|
||||
|
||||
// digestSessionTTL 摘要会话默认 TTL
|
||||
const digestSessionTTL = 5 * time.Minute
|
||||
|
||||
// sessionEntry flat cache 条目
|
||||
type sessionEntry struct {
|
||||
uuid string
|
||||
accountID int64
|
||||
}
|
||||
|
||||
// DigestSessionStore 内存摘要会话存储(flat cache 实现)
|
||||
// key: "{groupID}:{prefixHash}|{digestChain}" → *sessionEntry
|
||||
type DigestSessionStore struct {
|
||||
cache *gocache.Cache
|
||||
}
|
||||
|
||||
// NewDigestSessionStore 创建内存摘要会话存储
|
||||
func NewDigestSessionStore() *DigestSessionStore {
|
||||
return &DigestSessionStore{
|
||||
cache: gocache.New(digestSessionTTL, time.Minute),
|
||||
}
|
||||
}
|
||||
|
||||
// Save 保存摘要会话。oldDigestChain 为 Find 返回的 matchedChain,用于删旧 key。
|
||||
func (s *DigestSessionStore) Save(groupID int64, prefixHash, digestChain, uuid string, accountID int64, oldDigestChain string) {
|
||||
if digestChain == "" {
|
||||
return
|
||||
}
|
||||
ns := buildNS(groupID, prefixHash)
|
||||
s.cache.Set(ns+digestChain, &sessionEntry{uuid: uuid, accountID: accountID}, gocache.DefaultExpiration)
|
||||
if oldDigestChain != "" && oldDigestChain != digestChain {
|
||||
s.cache.Delete(ns + oldDigestChain)
|
||||
}
|
||||
}
|
||||
|
||||
// Find 查找摘要会话,从完整 chain 逐段截断,返回最长匹配及对应 matchedChain。
|
||||
func (s *DigestSessionStore) Find(groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, matchedChain string, found bool) {
|
||||
if digestChain == "" {
|
||||
return "", 0, "", false
|
||||
}
|
||||
ns := buildNS(groupID, prefixHash)
|
||||
chain := digestChain
|
||||
for {
|
||||
if val, ok := s.cache.Get(ns + chain); ok {
|
||||
if e, ok := val.(*sessionEntry); ok {
|
||||
return e.uuid, e.accountID, chain, true
|
||||
}
|
||||
}
|
||||
i := strings.LastIndex(chain, "-")
|
||||
if i < 0 {
|
||||
return "", 0, "", false
|
||||
}
|
||||
chain = chain[:i]
|
||||
}
|
||||
}
|
||||
|
||||
// buildNS 构建 namespace 前缀
|
||||
func buildNS(groupID int64, prefixHash string) string {
|
||||
return strconv.FormatInt(groupID, 10) + ":" + prefixHash + "|"
|
||||
}
|
||||
312
backend/internal/service/digest_session_store_test.go
Normal file
312
backend/internal/service/digest_session_store_test.go
Normal file
@@ -0,0 +1,312 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
gocache "github.com/patrickmn/go-cache"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDigestSessionStore_SaveAndFind(t *testing.T) {
|
||||
store := NewDigestSessionStore()
|
||||
|
||||
store.Save(1, "prefix", "s:a1-u:b2-m:c3", "uuid-1", 100, "")
|
||||
|
||||
uuid, accountID, _, found := store.Find(1, "prefix", "s:a1-u:b2-m:c3")
|
||||
require.True(t, found)
|
||||
assert.Equal(t, "uuid-1", uuid)
|
||||
assert.Equal(t, int64(100), accountID)
|
||||
}
|
||||
|
||||
func TestDigestSessionStore_PrefixMatch(t *testing.T) {
|
||||
store := NewDigestSessionStore()
|
||||
|
||||
// 保存短链
|
||||
store.Save(1, "prefix", "u:a-m:b", "uuid-short", 10, "")
|
||||
|
||||
// 用长链查找,应前缀匹配到短链
|
||||
uuid, accountID, matchedChain, found := store.Find(1, "prefix", "u:a-m:b-u:c-m:d")
|
||||
require.True(t, found)
|
||||
assert.Equal(t, "uuid-short", uuid)
|
||||
assert.Equal(t, int64(10), accountID)
|
||||
assert.Equal(t, "u:a-m:b", matchedChain)
|
||||
}
|
||||
|
||||
func TestDigestSessionStore_LongestPrefixMatch(t *testing.T) {
|
||||
store := NewDigestSessionStore()
|
||||
|
||||
store.Save(1, "prefix", "u:a", "uuid-1", 1, "")
|
||||
store.Save(1, "prefix", "u:a-m:b", "uuid-2", 2, "")
|
||||
store.Save(1, "prefix", "u:a-m:b-u:c", "uuid-3", 3, "")
|
||||
|
||||
// 应匹配最深的 "u:a-m:b-u:c"(从完整 chain 逐段截断,先命中最长的)
|
||||
uuid, accountID, _, found := store.Find(1, "prefix", "u:a-m:b-u:c-m:d-u:e")
|
||||
require.True(t, found)
|
||||
assert.Equal(t, "uuid-3", uuid)
|
||||
assert.Equal(t, int64(3), accountID)
|
||||
|
||||
// 查找中等长度,应匹配到 "u:a-m:b"
|
||||
uuid, accountID, _, found = store.Find(1, "prefix", "u:a-m:b-u:x")
|
||||
require.True(t, found)
|
||||
assert.Equal(t, "uuid-2", uuid)
|
||||
assert.Equal(t, int64(2), accountID)
|
||||
}
|
||||
|
||||
func TestDigestSessionStore_SaveDeletesOldChain(t *testing.T) {
|
||||
store := NewDigestSessionStore()
|
||||
|
||||
// 第一轮:保存 "u:a-m:b"
|
||||
store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "")
|
||||
|
||||
// 第二轮:同一 uuid 保存更长的链,传入旧 chain
|
||||
store.Save(1, "prefix", "u:a-m:b-u:c-m:d", "uuid-1", 100, "u:a-m:b")
|
||||
|
||||
// 旧链 "u:a-m:b" 应已被删除
|
||||
_, _, _, found := store.Find(1, "prefix", "u:a-m:b")
|
||||
assert.False(t, found, "old chain should be deleted")
|
||||
|
||||
// 新链应能找到
|
||||
uuid, accountID, _, found := store.Find(1, "prefix", "u:a-m:b-u:c-m:d")
|
||||
require.True(t, found)
|
||||
assert.Equal(t, "uuid-1", uuid)
|
||||
assert.Equal(t, int64(100), accountID)
|
||||
}
|
||||
|
||||
func TestDigestSessionStore_DifferentSessionsNoInterference(t *testing.T) {
|
||||
store := NewDigestSessionStore()
|
||||
|
||||
// 相同系统提示词,不同用户提示词
|
||||
store.Save(1, "prefix", "s:sys-u:user1", "uuid-1", 100, "")
|
||||
store.Save(1, "prefix", "s:sys-u:user2", "uuid-2", 200, "")
|
||||
|
||||
uuid, accountID, _, found := store.Find(1, "prefix", "s:sys-u:user1-m:reply1")
|
||||
require.True(t, found)
|
||||
assert.Equal(t, "uuid-1", uuid)
|
||||
assert.Equal(t, int64(100), accountID)
|
||||
|
||||
uuid, accountID, _, found = store.Find(1, "prefix", "s:sys-u:user2-m:reply2")
|
||||
require.True(t, found)
|
||||
assert.Equal(t, "uuid-2", uuid)
|
||||
assert.Equal(t, int64(200), accountID)
|
||||
}
|
||||
|
||||
func TestDigestSessionStore_NoMatch(t *testing.T) {
|
||||
store := NewDigestSessionStore()
|
||||
|
||||
store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "")
|
||||
|
||||
// 完全不同的 chain
|
||||
_, _, _, found := store.Find(1, "prefix", "u:x-m:y")
|
||||
assert.False(t, found)
|
||||
}
|
||||
|
||||
func TestDigestSessionStore_DifferentPrefixHash(t *testing.T) {
|
||||
store := NewDigestSessionStore()
|
||||
|
||||
store.Save(1, "prefix1", "u:a-m:b", "uuid-1", 100, "")
|
||||
|
||||
// 不同 prefixHash 应隔离
|
||||
_, _, _, found := store.Find(1, "prefix2", "u:a-m:b")
|
||||
assert.False(t, found)
|
||||
}
|
||||
|
||||
func TestDigestSessionStore_DifferentGroupID(t *testing.T) {
|
||||
store := NewDigestSessionStore()
|
||||
|
||||
store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "")
|
||||
|
||||
// 不同 groupID 应隔离
|
||||
_, _, _, found := store.Find(2, "prefix", "u:a-m:b")
|
||||
assert.False(t, found)
|
||||
}
|
||||
|
||||
func TestDigestSessionStore_EmptyDigestChain(t *testing.T) {
|
||||
store := NewDigestSessionStore()
|
||||
|
||||
// 空链不应保存
|
||||
store.Save(1, "prefix", "", "uuid-1", 100, "")
|
||||
_, _, _, found := store.Find(1, "prefix", "")
|
||||
assert.False(t, found)
|
||||
}
|
||||
|
||||
func TestDigestSessionStore_TTLExpiration(t *testing.T) {
|
||||
store := &DigestSessionStore{
|
||||
cache: gocache.New(100*time.Millisecond, 50*time.Millisecond),
|
||||
}
|
||||
|
||||
store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "")
|
||||
|
||||
// 立即应该能找到
|
||||
_, _, _, found := store.Find(1, "prefix", "u:a-m:b")
|
||||
require.True(t, found)
|
||||
|
||||
// 等待过期 + 清理周期
|
||||
time.Sleep(300 * time.Millisecond)
|
||||
|
||||
// 过期后应找不到
|
||||
_, _, _, found = store.Find(1, "prefix", "u:a-m:b")
|
||||
assert.False(t, found)
|
||||
}
|
||||
|
||||
func TestDigestSessionStore_ConcurrentSafety(t *testing.T) {
|
||||
store := NewDigestSessionStore()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
const goroutines = 50
|
||||
const operations = 100
|
||||
|
||||
wg.Add(goroutines)
|
||||
for g := 0; g < goroutines; g++ {
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
prefix := fmt.Sprintf("prefix-%d", id%5)
|
||||
for i := 0; i < operations; i++ {
|
||||
chain := fmt.Sprintf("u:%d-m:%d", id, i)
|
||||
uuid := fmt.Sprintf("uuid-%d-%d", id, i)
|
||||
store.Save(1, prefix, chain, uuid, int64(id), "")
|
||||
store.Find(1, prefix, chain)
|
||||
}
|
||||
}(g)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestDigestSessionStore_MultipleSessions(t *testing.T) {
|
||||
store := NewDigestSessionStore()
|
||||
|
||||
sessions := []struct {
|
||||
chain string
|
||||
uuid string
|
||||
accountID int64
|
||||
}{
|
||||
{"u:session1", "uuid-1", 1},
|
||||
{"u:session2-m:reply2", "uuid-2", 2},
|
||||
{"u:session3-m:reply3-u:msg3", "uuid-3", 3},
|
||||
}
|
||||
|
||||
for _, sess := range sessions {
|
||||
store.Save(1, "prefix", sess.chain, sess.uuid, sess.accountID, "")
|
||||
}
|
||||
|
||||
// 验证每个会话都能正确查找
|
||||
for _, sess := range sessions {
|
||||
uuid, accountID, _, found := store.Find(1, "prefix", sess.chain)
|
||||
require.True(t, found, "should find session: %s", sess.chain)
|
||||
assert.Equal(t, sess.uuid, uuid)
|
||||
assert.Equal(t, sess.accountID, accountID)
|
||||
}
|
||||
|
||||
// 验证继续对话的场景
|
||||
uuid, accountID, _, found := store.Find(1, "prefix", "u:session2-m:reply2-u:newmsg")
|
||||
require.True(t, found)
|
||||
assert.Equal(t, "uuid-2", uuid)
|
||||
assert.Equal(t, int64(2), accountID)
|
||||
}
|
||||
|
||||
func TestDigestSessionStore_Performance1000Sessions(t *testing.T) {
|
||||
store := NewDigestSessionStore()
|
||||
|
||||
// 插入 1000 个会话
|
||||
for i := 0; i < 1000; i++ {
|
||||
chain := fmt.Sprintf("s:sys-u:user%d-m:reply%d", i, i)
|
||||
store.Save(1, "prefix", chain, fmt.Sprintf("uuid-%d", i), int64(i), "")
|
||||
}
|
||||
|
||||
// 查找性能测试
|
||||
start := time.Now()
|
||||
const lookups = 10000
|
||||
for i := 0; i < lookups; i++ {
|
||||
idx := i % 1000
|
||||
chain := fmt.Sprintf("s:sys-u:user%d-m:reply%d-u:newmsg", idx, idx)
|
||||
_, _, _, found := store.Find(1, "prefix", chain)
|
||||
assert.True(t, found)
|
||||
}
|
||||
elapsed := time.Since(start)
|
||||
t.Logf("%d lookups in %v (%.0f ns/op)", lookups, elapsed, float64(elapsed.Nanoseconds())/lookups)
|
||||
}
|
||||
|
||||
func TestDigestSessionStore_FindReturnsMatchedChain(t *testing.T) {
|
||||
store := NewDigestSessionStore()
|
||||
|
||||
store.Save(1, "prefix", "u:a-m:b-u:c", "uuid-1", 100, "")
|
||||
|
||||
// 精确匹配
|
||||
_, _, matchedChain, found := store.Find(1, "prefix", "u:a-m:b-u:c")
|
||||
require.True(t, found)
|
||||
assert.Equal(t, "u:a-m:b-u:c", matchedChain)
|
||||
|
||||
// 前缀匹配(截断后命中)
|
||||
_, _, matchedChain, found = store.Find(1, "prefix", "u:a-m:b-u:c-m:d-u:e")
|
||||
require.True(t, found)
|
||||
assert.Equal(t, "u:a-m:b-u:c", matchedChain)
|
||||
}
|
||||
|
||||
func TestDigestSessionStore_CacheItemCountStable(t *testing.T) {
|
||||
store := NewDigestSessionStore()
|
||||
|
||||
// 模拟 100 个独立会话,每个进行 10 轮对话
|
||||
// 正确传递 oldDigestChain 时,每个会话始终只保留 1 个 key
|
||||
for conv := 0; conv < 100; conv++ {
|
||||
var prevMatchedChain string
|
||||
for round := 0; round < 10; round++ {
|
||||
chain := fmt.Sprintf("s:sys-u:user%d", conv)
|
||||
for r := 0; r < round; r++ {
|
||||
chain += fmt.Sprintf("-m:a%d-u:q%d", r, r+1)
|
||||
}
|
||||
uuid := fmt.Sprintf("uuid-conv%d", conv)
|
||||
|
||||
_, _, matched, _ := store.Find(1, "prefix", chain)
|
||||
store.Save(1, "prefix", chain, uuid, int64(conv), matched)
|
||||
prevMatchedChain = matched
|
||||
_ = prevMatchedChain
|
||||
}
|
||||
}
|
||||
|
||||
// 100 个会话 × 1 key/会话 = 应该 ≤ 100 个 key
|
||||
// 允许少量并发残留,但绝不能接近 100×10=1000
|
||||
itemCount := store.cache.ItemCount()
|
||||
assert.LessOrEqual(t, itemCount, 100, "cache should have at most 100 items (1 per conversation), got %d", itemCount)
|
||||
t.Logf("Cache item count after 100 conversations × 10 rounds: %d", itemCount)
|
||||
}
|
||||
|
||||
func TestDigestSessionStore_TTLPreventsUnboundedGrowth(t *testing.T) {
|
||||
// 使用极短 TTL 验证大量写入后 cache 能被清理
|
||||
store := &DigestSessionStore{
|
||||
cache: gocache.New(100*time.Millisecond, 50*time.Millisecond),
|
||||
}
|
||||
|
||||
// 插入 500 个不同的 key(无 oldDigestChain,模拟最坏场景:全是新会话首轮)
|
||||
for i := 0; i < 500; i++ {
|
||||
chain := fmt.Sprintf("u:user%d", i)
|
||||
store.Save(1, "prefix", chain, fmt.Sprintf("uuid-%d", i), int64(i), "")
|
||||
}
|
||||
|
||||
assert.Equal(t, 500, store.cache.ItemCount())
|
||||
|
||||
// 等待 TTL + 清理周期
|
||||
time.Sleep(300 * time.Millisecond)
|
||||
|
||||
assert.Equal(t, 0, store.cache.ItemCount(), "all items should be expired and cleaned up")
|
||||
}
|
||||
|
||||
func TestDigestSessionStore_SaveSameChainNoDelete(t *testing.T) {
|
||||
store := NewDigestSessionStore()
|
||||
|
||||
// 保存 chain
|
||||
store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "")
|
||||
|
||||
// 用户重发相同消息:oldDigestChain == digestChain,不应删掉刚设置的 key
|
||||
store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "u:a-m:b")
|
||||
|
||||
// 仍然能找到
|
||||
uuid, accountID, _, found := store.Find(1, "prefix", "u:a-m:b")
|
||||
require.True(t, found)
|
||||
assert.Equal(t, "uuid-1", uuid)
|
||||
assert.Equal(t, int64(100), accountID)
|
||||
}
|
||||
366
backend/internal/service/error_policy_integration_test.go
Normal file
366
backend/internal/service/error_policy_integration_test.go
Normal file
@@ -0,0 +1,366 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Mocks (scoped to this file by naming convention)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// epFixedUpstream returns a fixed response for every request.
|
||||
type epFixedUpstream struct {
|
||||
statusCode int
|
||||
body string
|
||||
calls int
|
||||
}
|
||||
|
||||
func (u *epFixedUpstream) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
|
||||
u.calls++
|
||||
return &http.Response{
|
||||
StatusCode: u.statusCode,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(strings.NewReader(u.body)),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (u *epFixedUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) {
|
||||
return u.Do(req, proxyURL, accountID, accountConcurrency)
|
||||
}
|
||||
|
||||
// epAccountRepo records SetTempUnschedulable / SetError calls.
|
||||
type epAccountRepo struct {
|
||||
mockAccountRepoForGemini
|
||||
tempCalls int
|
||||
setErrCalls int
|
||||
}
|
||||
|
||||
func (r *epAccountRepo) SetTempUnschedulable(_ context.Context, _ int64, _ time.Time, _ string) error {
|
||||
r.tempCalls++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *epAccountRepo) SetError(_ context.Context, _ int64, _ string) error {
|
||||
r.setErrCalls++
|
||||
return nil
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func saveAndSetBaseURLs(t *testing.T) {
|
||||
t.Helper()
|
||||
oldBaseURLs := append([]string(nil), antigravity.BaseURLs...)
|
||||
oldAvail := antigravity.DefaultURLAvailability
|
||||
antigravity.BaseURLs = []string{"https://ep-test.example"}
|
||||
antigravity.DefaultURLAvailability = antigravity.NewURLAvailability(time.Minute)
|
||||
t.Cleanup(func() {
|
||||
antigravity.BaseURLs = oldBaseURLs
|
||||
antigravity.DefaultURLAvailability = oldAvail
|
||||
})
|
||||
}
|
||||
|
||||
func newRetryParams(account *Account, upstream HTTPUpstream, handleError func(context.Context, string, *Account, int, http.Header, []byte, string, int64, string, bool) *handleModelRateLimitResult) antigravityRetryLoopParams {
|
||||
return antigravityRetryLoopParams{
|
||||
ctx: context.Background(),
|
||||
prefix: "[ep-test]",
|
||||
account: account,
|
||||
accessToken: "token",
|
||||
action: "generateContent",
|
||||
body: []byte(`{"input":"test"}`),
|
||||
httpUpstream: upstream,
|
||||
requestedModel: "claude-sonnet-4-5",
|
||||
handleError: handleError,
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// TestRetryLoop_ErrorPolicy_CustomErrorCodes
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestRetryLoop_ErrorPolicy_CustomErrorCodes(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
upstreamStatus int
|
||||
upstreamBody string
|
||||
customCodes []any
|
||||
expectHandleError int
|
||||
expectUpstream int
|
||||
expectStatusCode int
|
||||
}{
|
||||
{
|
||||
name: "429_in_custom_codes_matched",
|
||||
upstreamStatus: 429,
|
||||
upstreamBody: `{"error":"rate limited"}`,
|
||||
customCodes: []any{float64(429)},
|
||||
expectHandleError: 1,
|
||||
expectUpstream: 1,
|
||||
expectStatusCode: 429,
|
||||
},
|
||||
{
|
||||
name: "429_not_in_custom_codes_skipped",
|
||||
upstreamStatus: 429,
|
||||
upstreamBody: `{"error":"rate limited"}`,
|
||||
customCodes: []any{float64(500)},
|
||||
expectHandleError: 0,
|
||||
expectUpstream: 1,
|
||||
expectStatusCode: 429,
|
||||
},
|
||||
{
|
||||
name: "500_in_custom_codes_matched",
|
||||
upstreamStatus: 500,
|
||||
upstreamBody: `{"error":"internal"}`,
|
||||
customCodes: []any{float64(500)},
|
||||
expectHandleError: 1,
|
||||
expectUpstream: 1,
|
||||
expectStatusCode: 500,
|
||||
},
|
||||
{
|
||||
name: "500_not_in_custom_codes_skipped",
|
||||
upstreamStatus: 500,
|
||||
upstreamBody: `{"error":"internal"}`,
|
||||
customCodes: []any{float64(429)},
|
||||
expectHandleError: 0,
|
||||
expectUpstream: 1,
|
||||
expectStatusCode: 500,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
saveAndSetBaseURLs(t)
|
||||
|
||||
upstream := &epFixedUpstream{statusCode: tt.upstreamStatus, body: tt.upstreamBody}
|
||||
repo := &epAccountRepo{}
|
||||
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
|
||||
account := &Account{
|
||||
ID: 100,
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformAntigravity,
|
||||
Schedulable: true,
|
||||
Status: StatusActive,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"custom_error_codes_enabled": true,
|
||||
"custom_error_codes": tt.customCodes,
|
||||
},
|
||||
}
|
||||
|
||||
svc := &AntigravityGatewayService{rateLimitService: rlSvc}
|
||||
|
||||
var handleErrorCount int
|
||||
p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult {
|
||||
handleErrorCount++
|
||||
return nil
|
||||
})
|
||||
|
||||
result, err := svc.antigravityRetryLoop(p)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.resp)
|
||||
defer func() { _ = result.resp.Body.Close() }()
|
||||
|
||||
require.Equal(t, tt.expectStatusCode, result.resp.StatusCode)
|
||||
require.Equal(t, tt.expectHandleError, handleErrorCount, "handleError call count")
|
||||
require.Equal(t, tt.expectUpstream, upstream.calls, "upstream call count")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// TestRetryLoop_ErrorPolicy_TempUnschedulable
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestRetryLoop_ErrorPolicy_TempUnschedulable(t *testing.T) {
|
||||
tempRulesAccount := func(rules []any) *Account {
|
||||
return &Account{
|
||||
ID: 200,
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
Schedulable: true,
|
||||
Status: StatusActive,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"temp_unschedulable_enabled": true,
|
||||
"temp_unschedulable_rules": rules,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
overloadedRule := map[string]any{
|
||||
"error_code": float64(503),
|
||||
"keywords": []any{"overloaded"},
|
||||
"duration_minutes": float64(10),
|
||||
}
|
||||
|
||||
rateLimitRule := map[string]any{
|
||||
"error_code": float64(429),
|
||||
"keywords": []any{"rate limited keyword"},
|
||||
"duration_minutes": float64(5),
|
||||
}
|
||||
|
||||
t.Run("503_overloaded_matches_rule", func(t *testing.T) {
|
||||
saveAndSetBaseURLs(t)
|
||||
|
||||
upstream := &epFixedUpstream{statusCode: 503, body: `overloaded`}
|
||||
repo := &epAccountRepo{}
|
||||
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
svc := &AntigravityGatewayService{rateLimitService: rlSvc}
|
||||
|
||||
account := tempRulesAccount([]any{overloadedRule})
|
||||
p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult {
|
||||
t.Error("handleError should not be called for temp unschedulable")
|
||||
return nil
|
||||
})
|
||||
|
||||
result, err := svc.antigravityRetryLoop(p)
|
||||
|
||||
require.Nil(t, result)
|
||||
var switchErr *AntigravityAccountSwitchError
|
||||
require.ErrorAs(t, err, &switchErr)
|
||||
require.Equal(t, account.ID, switchErr.OriginalAccountID)
|
||||
require.Equal(t, 1, upstream.calls, "should not retry")
|
||||
})
|
||||
|
||||
t.Run("429_rate_limited_keyword_matches_rule", func(t *testing.T) {
|
||||
saveAndSetBaseURLs(t)
|
||||
|
||||
upstream := &epFixedUpstream{statusCode: 429, body: `rate limited keyword`}
|
||||
repo := &epAccountRepo{}
|
||||
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
svc := &AntigravityGatewayService{rateLimitService: rlSvc}
|
||||
|
||||
account := tempRulesAccount([]any{rateLimitRule})
|
||||
p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult {
|
||||
t.Error("handleError should not be called for temp unschedulable")
|
||||
return nil
|
||||
})
|
||||
|
||||
result, err := svc.antigravityRetryLoop(p)
|
||||
|
||||
require.Nil(t, result)
|
||||
var switchErr *AntigravityAccountSwitchError
|
||||
require.ErrorAs(t, err, &switchErr)
|
||||
require.Equal(t, account.ID, switchErr.OriginalAccountID)
|
||||
require.Equal(t, 1, upstream.calls, "should not retry")
|
||||
})
|
||||
|
||||
t.Run("503_body_no_match_continues_default_retry", func(t *testing.T) {
|
||||
saveAndSetBaseURLs(t)
|
||||
|
||||
upstream := &epFixedUpstream{statusCode: 503, body: `random`}
|
||||
repo := &epAccountRepo{}
|
||||
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
svc := &AntigravityGatewayService{rateLimitService: rlSvc}
|
||||
|
||||
account := tempRulesAccount([]any{overloadedRule})
|
||||
|
||||
// Use a short-lived context: the backoff sleep (~1s) will be
|
||||
// interrupted, proving the code entered the default retry path
|
||||
// instead of breaking early via error policy.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
})
|
||||
p.ctx = ctx
|
||||
|
||||
result, err := svc.antigravityRetryLoop(p)
|
||||
|
||||
// Context cancellation during backoff proves default retry was entered
|
||||
require.Nil(t, result)
|
||||
require.ErrorIs(t, err, context.DeadlineExceeded)
|
||||
require.GreaterOrEqual(t, upstream.calls, 1, "should have called upstream at least once")
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// TestRetryLoop_ErrorPolicy_NilRateLimitService
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestRetryLoop_ErrorPolicy_NilRateLimitService(t *testing.T) {
|
||||
saveAndSetBaseURLs(t)
|
||||
|
||||
upstream := &epFixedUpstream{statusCode: 429, body: `{"error":"rate limited"}`}
|
||||
// rateLimitService is nil — must not panic
|
||||
svc := &AntigravityGatewayService{rateLimitService: nil}
|
||||
|
||||
account := &Account{
|
||||
ID: 300,
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
Schedulable: true,
|
||||
Status: StatusActive,
|
||||
Concurrency: 1,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
})
|
||||
p.ctx = ctx
|
||||
|
||||
// Should not panic; enters the default retry path (eventually times out)
|
||||
result, err := svc.antigravityRetryLoop(p)
|
||||
|
||||
require.Nil(t, result)
|
||||
require.ErrorIs(t, err, context.DeadlineExceeded)
|
||||
require.GreaterOrEqual(t, upstream.calls, 1)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// TestRetryLoop_ErrorPolicy_NoPolicy_OriginalBehavior
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestRetryLoop_ErrorPolicy_NoPolicy_OriginalBehavior(t *testing.T) {
|
||||
saveAndSetBaseURLs(t)
|
||||
|
||||
upstream := &epFixedUpstream{statusCode: 429, body: `{"error":"rate limited"}`}
|
||||
repo := &epAccountRepo{}
|
||||
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
svc := &AntigravityGatewayService{rateLimitService: rlSvc}
|
||||
|
||||
// Plain OAuth account with no error policy configured
|
||||
account := &Account{
|
||||
ID: 400,
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
Schedulable: true,
|
||||
Status: StatusActive,
|
||||
Concurrency: 1,
|
||||
}
|
||||
|
||||
var handleErrorCount int
|
||||
p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult {
|
||||
handleErrorCount++
|
||||
return nil
|
||||
})
|
||||
|
||||
result, err := svc.antigravityRetryLoop(p)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.resp)
|
||||
defer func() { _ = result.resp.Body.Close() }()
|
||||
|
||||
require.Equal(t, http.StatusTooManyRequests, result.resp.StatusCode)
|
||||
require.Equal(t, antigravityMaxRetries, upstream.calls, "should exhaust all retries")
|
||||
require.Equal(t, 1, handleErrorCount, "handleError should be called once after retries exhausted")
|
||||
}
|
||||
289
backend/internal/service/error_policy_test.go
Normal file
289
backend/internal/service/error_policy_test.go
Normal file
@@ -0,0 +1,289 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// TestCheckErrorPolicy — 6 table-driven cases for the pure logic function
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestCheckErrorPolicy(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
account *Account
|
||||
statusCode int
|
||||
body []byte
|
||||
expected ErrorPolicyResult
|
||||
}{
|
||||
{
|
||||
name: "no_policy_oauth_returns_none",
|
||||
account: &Account{
|
||||
ID: 1,
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
// no custom error codes, no temp rules
|
||||
},
|
||||
statusCode: 500,
|
||||
body: []byte(`"error"`),
|
||||
expected: ErrorPolicyNone,
|
||||
},
|
||||
{
|
||||
name: "custom_error_codes_hit_returns_matched",
|
||||
account: &Account{
|
||||
ID: 2,
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformAntigravity,
|
||||
Credentials: map[string]any{
|
||||
"custom_error_codes_enabled": true,
|
||||
"custom_error_codes": []any{float64(429), float64(500)},
|
||||
},
|
||||
},
|
||||
statusCode: 500,
|
||||
body: []byte(`"error"`),
|
||||
expected: ErrorPolicyMatched,
|
||||
},
|
||||
{
|
||||
name: "custom_error_codes_miss_returns_skipped",
|
||||
account: &Account{
|
||||
ID: 3,
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformAntigravity,
|
||||
Credentials: map[string]any{
|
||||
"custom_error_codes_enabled": true,
|
||||
"custom_error_codes": []any{float64(429), float64(500)},
|
||||
},
|
||||
},
|
||||
statusCode: 503,
|
||||
body: []byte(`"error"`),
|
||||
expected: ErrorPolicySkipped,
|
||||
},
|
||||
{
|
||||
name: "temp_unschedulable_hit_returns_temp_unscheduled",
|
||||
account: &Account{
|
||||
ID: 4,
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
Credentials: map[string]any{
|
||||
"temp_unschedulable_enabled": true,
|
||||
"temp_unschedulable_rules": []any{
|
||||
map[string]any{
|
||||
"error_code": float64(503),
|
||||
"keywords": []any{"overloaded"},
|
||||
"duration_minutes": float64(10),
|
||||
"description": "overloaded rule",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
statusCode: 503,
|
||||
body: []byte(`overloaded service`),
|
||||
expected: ErrorPolicyTempUnscheduled,
|
||||
},
|
||||
{
|
||||
name: "temp_unschedulable_body_miss_returns_none",
|
||||
account: &Account{
|
||||
ID: 5,
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
Credentials: map[string]any{
|
||||
"temp_unschedulable_enabled": true,
|
||||
"temp_unschedulable_rules": []any{
|
||||
map[string]any{
|
||||
"error_code": float64(503),
|
||||
"keywords": []any{"overloaded"},
|
||||
"duration_minutes": float64(10),
|
||||
"description": "overloaded rule",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
statusCode: 503,
|
||||
body: []byte(`random msg`),
|
||||
expected: ErrorPolicyNone,
|
||||
},
|
||||
{
|
||||
name: "custom_error_codes_override_temp_unschedulable",
|
||||
account: &Account{
|
||||
ID: 6,
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformAntigravity,
|
||||
Credentials: map[string]any{
|
||||
"custom_error_codes_enabled": true,
|
||||
"custom_error_codes": []any{float64(503)},
|
||||
"temp_unschedulable_enabled": true,
|
||||
"temp_unschedulable_rules": []any{
|
||||
map[string]any{
|
||||
"error_code": float64(503),
|
||||
"keywords": []any{"overloaded"},
|
||||
"duration_minutes": float64(10),
|
||||
"description": "overloaded rule",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
statusCode: 503,
|
||||
body: []byte(`overloaded`),
|
||||
expected: ErrorPolicyMatched, // custom codes take precedence
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
repo := &errorPolicyRepoStub{}
|
||||
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
|
||||
result := svc.CheckErrorPolicy(context.Background(), tt.account, tt.statusCode, tt.body)
|
||||
require.Equal(t, tt.expected, result, "unexpected ErrorPolicyResult")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// TestApplyErrorPolicy — 4 table-driven cases for the wrapper method
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestApplyErrorPolicy(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
account *Account
|
||||
statusCode int
|
||||
body []byte
|
||||
expectedHandled bool
|
||||
expectedSwitchErr bool // expect *AntigravityAccountSwitchError
|
||||
handleErrorCalls int
|
||||
}{
|
||||
{
|
||||
name: "none_not_handled",
|
||||
account: &Account{
|
||||
ID: 10,
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
},
|
||||
statusCode: 500,
|
||||
body: []byte(`"error"`),
|
||||
expectedHandled: false,
|
||||
handleErrorCalls: 0,
|
||||
},
|
||||
{
|
||||
name: "skipped_handled_no_handleError",
|
||||
account: &Account{
|
||||
ID: 11,
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformAntigravity,
|
||||
Credentials: map[string]any{
|
||||
"custom_error_codes_enabled": true,
|
||||
"custom_error_codes": []any{float64(429)},
|
||||
},
|
||||
},
|
||||
statusCode: 500, // not in custom codes
|
||||
body: []byte(`"error"`),
|
||||
expectedHandled: true,
|
||||
handleErrorCalls: 0,
|
||||
},
|
||||
{
|
||||
name: "matched_handled_calls_handleError",
|
||||
account: &Account{
|
||||
ID: 12,
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformAntigravity,
|
||||
Credentials: map[string]any{
|
||||
"custom_error_codes_enabled": true,
|
||||
"custom_error_codes": []any{float64(500)},
|
||||
},
|
||||
},
|
||||
statusCode: 500,
|
||||
body: []byte(`"error"`),
|
||||
expectedHandled: true,
|
||||
handleErrorCalls: 1,
|
||||
},
|
||||
{
|
||||
name: "temp_unscheduled_returns_switch_error",
|
||||
account: &Account{
|
||||
ID: 13,
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
Credentials: map[string]any{
|
||||
"temp_unschedulable_enabled": true,
|
||||
"temp_unschedulable_rules": []any{
|
||||
map[string]any{
|
||||
"error_code": float64(503),
|
||||
"keywords": []any{"overloaded"},
|
||||
"duration_minutes": float64(10),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
statusCode: 503,
|
||||
body: []byte(`overloaded`),
|
||||
expectedHandled: true,
|
||||
expectedSwitchErr: true,
|
||||
handleErrorCalls: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
repo := &errorPolicyRepoStub{}
|
||||
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
svc := &AntigravityGatewayService{
|
||||
rateLimitService: rlSvc,
|
||||
}
|
||||
|
||||
var handleErrorCount int
|
||||
p := antigravityRetryLoopParams{
|
||||
ctx: context.Background(),
|
||||
prefix: "[test]",
|
||||
account: tt.account,
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
handleErrorCount++
|
||||
return nil
|
||||
},
|
||||
isStickySession: true,
|
||||
}
|
||||
|
||||
handled, retErr := svc.applyErrorPolicy(p, tt.statusCode, http.Header{}, tt.body)
|
||||
|
||||
require.Equal(t, tt.expectedHandled, handled, "handled mismatch")
|
||||
require.Equal(t, tt.handleErrorCalls, handleErrorCount, "handleError call count mismatch")
|
||||
|
||||
if tt.expectedSwitchErr {
|
||||
var switchErr *AntigravityAccountSwitchError
|
||||
require.ErrorAs(t, retErr, &switchErr)
|
||||
require.Equal(t, tt.account.ID, switchErr.OriginalAccountID)
|
||||
} else {
|
||||
require.NoError(t, retErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// errorPolicyRepoStub — minimal AccountRepository stub for error policy tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type errorPolicyRepoStub struct {
|
||||
mockAccountRepoForGemini
|
||||
tempCalls int
|
||||
setErrCalls int
|
||||
lastErrorMsg string
|
||||
}
|
||||
|
||||
func (r *errorPolicyRepoStub) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
|
||||
r.tempCalls++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *errorPolicyRepoStub) SetError(ctx context.Context, id int64, errorMsg string) error {
|
||||
r.setErrCalls++
|
||||
r.lastErrorMsg = errorMsg
|
||||
return nil
|
||||
}
|
||||
@@ -142,9 +142,6 @@ func (m *mockAccountRepoForPlatform) ListSchedulableByGroupIDAndPlatforms(ctx co
|
||||
func (m *mockAccountRepoForPlatform) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForPlatform) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForPlatform) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
|
||||
return nil
|
||||
}
|
||||
@@ -216,29 +213,6 @@ func (m *mockGatewayCacheForPlatform) DeleteSessionAccountID(ctx context.Context
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockGatewayCacheForPlatform) IncrModelCallCount(ctx context.Context, accountID int64, model string) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *mockGatewayCacheForPlatform) GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*ModelLoadInfo, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockGatewayCacheForPlatform) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
|
||||
return "", 0, false
|
||||
}
|
||||
|
||||
func (m *mockGatewayCacheForPlatform) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockGatewayCacheForPlatform) FindAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
|
||||
return "", 0, false
|
||||
}
|
||||
|
||||
func (m *mockGatewayCacheForPlatform) SaveAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type mockGroupRepoForGateway struct {
|
||||
groups map[int64]*Group
|
||||
|
||||
@@ -6,9 +6,19 @@ import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/domain"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
)
|
||||
|
||||
// SessionContext 粘性会话上下文,用于区分不同来源的请求。
|
||||
// 仅在 GenerateSessionHash 第 3 级 fallback(消息内容 hash)时混入,
|
||||
// 避免不同用户发送相同消息产生相同 hash 导致账号集中。
|
||||
type SessionContext struct {
|
||||
ClientIP string
|
||||
UserAgent string
|
||||
APIKeyID int64
|
||||
}
|
||||
|
||||
// ParsedRequest 保存网关请求的预解析结果
|
||||
//
|
||||
// 性能优化说明:
|
||||
@@ -22,20 +32,22 @@ import (
|
||||
// 2. 将解析结果 ParsedRequest 传递给 Service 层
|
||||
// 3. 避免重复 json.Unmarshal,减少 CPU 和内存开销
|
||||
type ParsedRequest struct {
|
||||
Body []byte // 原始请求体(保留用于转发)
|
||||
Model string // 请求的模型名称
|
||||
Stream bool // 是否为流式请求
|
||||
MetadataUserID string // metadata.user_id(用于会话亲和)
|
||||
System any // system 字段内容
|
||||
Messages []any // messages 数组
|
||||
HasSystem bool // 是否包含 system 字段(包含 null 也视为显式传入)
|
||||
ThinkingEnabled bool // 是否开启 thinking(部分平台会影响最终模型名)
|
||||
MaxTokens int // max_tokens 值(用于探测请求拦截)
|
||||
Body []byte // 原始请求体(保留用于转发)
|
||||
Model string // 请求的模型名称
|
||||
Stream bool // 是否为流式请求
|
||||
MetadataUserID string // metadata.user_id(用于会话亲和)
|
||||
System any // system 字段内容
|
||||
Messages []any // messages 数组
|
||||
HasSystem bool // 是否包含 system 字段(包含 null 也视为显式传入)
|
||||
ThinkingEnabled bool // 是否开启 thinking(部分平台会影响最终模型名)
|
||||
MaxTokens int // max_tokens 值(用于探测请求拦截)
|
||||
SessionContext *SessionContext // 可选:请求上下文区分因子(nil 时行为不变)
|
||||
}
|
||||
|
||||
// ParseGatewayRequest 解析网关请求体并返回结构化结果
|
||||
// 性能优化:一次解析提取所有需要的字段,避免重复 Unmarshal
|
||||
func ParseGatewayRequest(body []byte) (*ParsedRequest, error) {
|
||||
// ParseGatewayRequest 解析网关请求体并返回结构化结果。
|
||||
// protocol 指定请求协议格式(domain.PlatformAnthropic / domain.PlatformGemini),
|
||||
// 不同协议使用不同的 system/messages 字段名。
|
||||
func ParseGatewayRequest(body []byte, protocol string) (*ParsedRequest, error) {
|
||||
var req map[string]any
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
return nil, err
|
||||
@@ -64,14 +76,29 @@ func ParseGatewayRequest(body []byte) (*ParsedRequest, error) {
|
||||
parsed.MetadataUserID = userID
|
||||
}
|
||||
}
|
||||
// system 字段只要存在就视为显式提供(即使为 null),
|
||||
// 以避免客户端传 null 时被默认 system 误注入。
|
||||
if system, ok := req["system"]; ok {
|
||||
parsed.HasSystem = true
|
||||
parsed.System = system
|
||||
}
|
||||
if messages, ok := req["messages"].([]any); ok {
|
||||
parsed.Messages = messages
|
||||
|
||||
switch protocol {
|
||||
case domain.PlatformGemini:
|
||||
// Gemini 原生格式: systemInstruction.parts / contents
|
||||
if sysInst, ok := req["systemInstruction"].(map[string]any); ok {
|
||||
if parts, ok := sysInst["parts"].([]any); ok {
|
||||
parsed.System = parts
|
||||
}
|
||||
}
|
||||
if contents, ok := req["contents"].([]any); ok {
|
||||
parsed.Messages = contents
|
||||
}
|
||||
default:
|
||||
// Anthropic / OpenAI 格式: system / messages
|
||||
// system 字段只要存在就视为显式提供(即使为 null),
|
||||
// 以避免客户端传 null 时被默认 system 误注入。
|
||||
if system, ok := req["system"]; ok {
|
||||
parsed.HasSystem = true
|
||||
parsed.System = system
|
||||
}
|
||||
if messages, ok := req["messages"].([]any); ok {
|
||||
parsed.Messages = messages
|
||||
}
|
||||
}
|
||||
|
||||
// thinking: {type: "enabled"}
|
||||
|
||||
@@ -4,12 +4,13 @@ import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/domain"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestParseGatewayRequest(t *testing.T) {
|
||||
body := []byte(`{"model":"claude-3-7-sonnet","stream":true,"metadata":{"user_id":"session_123e4567-e89b-12d3-a456-426614174000"},"system":[{"type":"text","text":"hello","cache_control":{"type":"ephemeral"}}],"messages":[{"content":"hi"}]}`)
|
||||
parsed, err := ParseGatewayRequest(body)
|
||||
parsed, err := ParseGatewayRequest(body, "")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "claude-3-7-sonnet", parsed.Model)
|
||||
require.True(t, parsed.Stream)
|
||||
@@ -22,7 +23,7 @@ func TestParseGatewayRequest(t *testing.T) {
|
||||
|
||||
func TestParseGatewayRequest_ThinkingEnabled(t *testing.T) {
|
||||
body := []byte(`{"model":"claude-sonnet-4-5","thinking":{"type":"enabled"},"messages":[{"content":"hi"}]}`)
|
||||
parsed, err := ParseGatewayRequest(body)
|
||||
parsed, err := ParseGatewayRequest(body, "")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "claude-sonnet-4-5", parsed.Model)
|
||||
require.True(t, parsed.ThinkingEnabled)
|
||||
@@ -30,21 +31,21 @@ func TestParseGatewayRequest_ThinkingEnabled(t *testing.T) {
|
||||
|
||||
func TestParseGatewayRequest_MaxTokens(t *testing.T) {
|
||||
body := []byte(`{"model":"claude-haiku-4-5","max_tokens":1}`)
|
||||
parsed, err := ParseGatewayRequest(body)
|
||||
parsed, err := ParseGatewayRequest(body, "")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, parsed.MaxTokens)
|
||||
}
|
||||
|
||||
func TestParseGatewayRequest_MaxTokensNonIntegralIgnored(t *testing.T) {
|
||||
body := []byte(`{"model":"claude-haiku-4-5","max_tokens":1.5}`)
|
||||
parsed, err := ParseGatewayRequest(body)
|
||||
parsed, err := ParseGatewayRequest(body, "")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 0, parsed.MaxTokens)
|
||||
}
|
||||
|
||||
func TestParseGatewayRequest_SystemNull(t *testing.T) {
|
||||
body := []byte(`{"model":"claude-3","system":null}`)
|
||||
parsed, err := ParseGatewayRequest(body)
|
||||
parsed, err := ParseGatewayRequest(body, "")
|
||||
require.NoError(t, err)
|
||||
// 显式传入 system:null 也应视为“字段已存在”,避免默认 system 被注入。
|
||||
require.True(t, parsed.HasSystem)
|
||||
@@ -53,16 +54,112 @@ func TestParseGatewayRequest_SystemNull(t *testing.T) {
|
||||
|
||||
func TestParseGatewayRequest_InvalidModelType(t *testing.T) {
|
||||
body := []byte(`{"model":123}`)
|
||||
_, err := ParseGatewayRequest(body)
|
||||
_, err := ParseGatewayRequest(body, "")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestParseGatewayRequest_InvalidStreamType(t *testing.T) {
|
||||
body := []byte(`{"stream":"true"}`)
|
||||
_, err := ParseGatewayRequest(body)
|
||||
_, err := ParseGatewayRequest(body, "")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
// ============ Gemini 原生格式解析测试 ============
|
||||
|
||||
func TestParseGatewayRequest_GeminiContents(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"contents": [
|
||||
{"role": "user", "parts": [{"text": "Hello"}]},
|
||||
{"role": "model", "parts": [{"text": "Hi there"}]},
|
||||
{"role": "user", "parts": [{"text": "How are you?"}]}
|
||||
]
|
||||
}`)
|
||||
parsed, err := ParseGatewayRequest(body, domain.PlatformGemini)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, parsed.Messages, 3, "should parse contents as Messages")
|
||||
require.False(t, parsed.HasSystem, "Gemini format should not set HasSystem")
|
||||
require.Nil(t, parsed.System, "no systemInstruction means nil System")
|
||||
}
|
||||
|
||||
func TestParseGatewayRequest_GeminiSystemInstruction(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"systemInstruction": {
|
||||
"parts": [{"text": "You are a helpful assistant."}]
|
||||
},
|
||||
"contents": [
|
||||
{"role": "user", "parts": [{"text": "Hello"}]}
|
||||
]
|
||||
}`)
|
||||
parsed, err := ParseGatewayRequest(body, domain.PlatformGemini)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, parsed.System, "should parse systemInstruction.parts as System")
|
||||
parts, ok := parsed.System.([]any)
|
||||
require.True(t, ok)
|
||||
require.Len(t, parts, 1)
|
||||
partMap, ok := parts[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "You are a helpful assistant.", partMap["text"])
|
||||
require.Len(t, parsed.Messages, 1)
|
||||
}
|
||||
|
||||
func TestParseGatewayRequest_GeminiWithModel(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"model": "gemini-2.5-pro",
|
||||
"contents": [{"role": "user", "parts": [{"text": "test"}]}]
|
||||
}`)
|
||||
parsed, err := ParseGatewayRequest(body, domain.PlatformGemini)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "gemini-2.5-pro", parsed.Model)
|
||||
require.Len(t, parsed.Messages, 1)
|
||||
}
|
||||
|
||||
func TestParseGatewayRequest_GeminiIgnoresAnthropicFields(t *testing.T) {
|
||||
// Gemini 格式下 system/messages 字段应被忽略
|
||||
body := []byte(`{
|
||||
"system": "should be ignored",
|
||||
"messages": [{"role": "user", "content": "ignored"}],
|
||||
"contents": [{"role": "user", "parts": [{"text": "real content"}]}]
|
||||
}`)
|
||||
parsed, err := ParseGatewayRequest(body, domain.PlatformGemini)
|
||||
require.NoError(t, err)
|
||||
require.False(t, parsed.HasSystem, "Gemini protocol should not parse Anthropic system field")
|
||||
require.Nil(t, parsed.System, "no systemInstruction = nil System")
|
||||
require.Len(t, parsed.Messages, 1, "should use contents, not messages")
|
||||
}
|
||||
|
||||
func TestParseGatewayRequest_GeminiEmptyContents(t *testing.T) {
|
||||
body := []byte(`{"contents": []}`)
|
||||
parsed, err := ParseGatewayRequest(body, domain.PlatformGemini)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, parsed.Messages)
|
||||
}
|
||||
|
||||
func TestParseGatewayRequest_GeminiNoContents(t *testing.T) {
|
||||
body := []byte(`{"model": "gemini-2.5-flash"}`)
|
||||
parsed, err := ParseGatewayRequest(body, domain.PlatformGemini)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, parsed.Messages)
|
||||
require.Equal(t, "gemini-2.5-flash", parsed.Model)
|
||||
}
|
||||
|
||||
func TestParseGatewayRequest_AnthropicIgnoresGeminiFields(t *testing.T) {
|
||||
// Anthropic 格式下 contents/systemInstruction 字段应被忽略
|
||||
body := []byte(`{
|
||||
"system": "real system",
|
||||
"messages": [{"role": "user", "content": "real content"}],
|
||||
"contents": [{"role": "user", "parts": [{"text": "ignored"}]}],
|
||||
"systemInstruction": {"parts": [{"text": "ignored"}]}
|
||||
}`)
|
||||
parsed, err := ParseGatewayRequest(body, domain.PlatformAnthropic)
|
||||
require.NoError(t, err)
|
||||
require.True(t, parsed.HasSystem)
|
||||
require.Equal(t, "real system", parsed.System)
|
||||
require.Len(t, parsed.Messages, 1)
|
||||
msg, ok := parsed.Messages[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "real content", msg["content"])
|
||||
}
|
||||
|
||||
func TestFilterThinkingBlocks(t *testing.T) {
|
||||
containsThinkingBlock := func(body []byte) bool {
|
||||
var req map[string]any
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -17,6 +16,7 @@ import (
|
||||
"os"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@@ -26,6 +26,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
||||
"github.com/cespare/xxhash/v2"
|
||||
"github.com/google/uuid"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
@@ -245,9 +246,6 @@ var (
|
||||
// ErrClaudeCodeOnly 表示分组仅允许 Claude Code 客户端访问
|
||||
var ErrClaudeCodeOnly = errors.New("this group only allows Claude Code clients")
|
||||
|
||||
// ErrModelScopeNotSupported 表示请求的模型系列不在分组支持的范围内
|
||||
var ErrModelScopeNotSupported = errors.New("model scope not supported by this group")
|
||||
|
||||
// allowedHeaders 白名单headers(参考CRS项目)
|
||||
var allowedHeaders = map[string]bool{
|
||||
"accept": true,
|
||||
@@ -273,13 +271,6 @@ var allowedHeaders = map[string]bool{
|
||||
// GatewayCache 定义网关服务的缓存操作接口。
|
||||
// 提供粘性会话(Sticky Session)的存储、查询、刷新和删除功能。
|
||||
//
|
||||
// ModelLoadInfo 模型负载信息(用于 Antigravity 调度)
|
||||
// Model load info for Antigravity scheduling
|
||||
type ModelLoadInfo struct {
|
||||
CallCount int64 // 当前分钟调用次数 / Call count in current minute
|
||||
LastUsedAt time.Time // 最后调度时间(零值表示未调度过)/ Last scheduling time (zero means never scheduled)
|
||||
}
|
||||
|
||||
// GatewayCache defines cache operations for gateway service.
|
||||
// Provides sticky session storage, retrieval, refresh and deletion capabilities.
|
||||
type GatewayCache interface {
|
||||
@@ -295,32 +286,6 @@ type GatewayCache interface {
|
||||
// DeleteSessionAccountID 删除粘性会话绑定,用于账号不可用时主动清理
|
||||
// Delete sticky session binding, used to proactively clean up when account becomes unavailable
|
||||
DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error
|
||||
|
||||
// IncrModelCallCount 增加模型调用次数并更新最后调度时间(Antigravity 专用)
|
||||
// Increment model call count and update last scheduling time (Antigravity only)
|
||||
// 返回更新后的调用次数
|
||||
IncrModelCallCount(ctx context.Context, accountID int64, model string) (int64, error)
|
||||
|
||||
// GetModelLoadBatch 批量获取账号的模型负载信息(Antigravity 专用)
|
||||
// Batch get model load info for accounts (Antigravity only)
|
||||
GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*ModelLoadInfo, error)
|
||||
|
||||
// FindGeminiSession 查找 Gemini 会话(MGET 倒序匹配)
|
||||
// Find Gemini session using MGET reverse order matching
|
||||
// 返回最长匹配的会话信息(uuid, accountID)
|
||||
FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool)
|
||||
|
||||
// SaveGeminiSession 保存 Gemini 会话
|
||||
// Save Gemini session binding
|
||||
SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error
|
||||
|
||||
// FindAnthropicSession 查找 Anthropic 会话(Trie 匹配)
|
||||
// Find Anthropic session using Trie matching
|
||||
FindAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool)
|
||||
|
||||
// SaveAnthropicSession 保存 Anthropic 会话
|
||||
// Save Anthropic session binding
|
||||
SaveAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error
|
||||
}
|
||||
|
||||
// derefGroupID safely dereferences *int64 to int64, returning 0 if nil
|
||||
@@ -415,6 +380,7 @@ type GatewayService struct {
|
||||
userSubRepo UserSubscriptionRepository
|
||||
userGroupRateRepo UserGroupRateRepository
|
||||
cache GatewayCache
|
||||
digestStore *DigestSessionStore
|
||||
cfg *config.Config
|
||||
schedulerSnapshot *SchedulerSnapshotService
|
||||
billingService *BillingService
|
||||
@@ -448,6 +414,7 @@ func NewGatewayService(
|
||||
deferredService *DeferredService,
|
||||
claudeTokenProvider *ClaudeTokenProvider,
|
||||
sessionLimitCache SessionLimitCache,
|
||||
digestStore *DigestSessionStore,
|
||||
) *GatewayService {
|
||||
return &GatewayService{
|
||||
accountRepo: accountRepo,
|
||||
@@ -457,6 +424,7 @@ func NewGatewayService(
|
||||
userSubRepo: userSubRepo,
|
||||
userGroupRateRepo: userGroupRateRepo,
|
||||
cache: cache,
|
||||
digestStore: digestStore,
|
||||
cfg: cfg,
|
||||
schedulerSnapshot: schedulerSnapshot,
|
||||
concurrencyService: concurrencyService,
|
||||
@@ -490,8 +458,17 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string {
|
||||
return s.hashContent(cacheableContent)
|
||||
}
|
||||
|
||||
// 3. 最后 fallback: 使用 system + 所有消息的完整摘要串
|
||||
// 3. 最后 fallback: 使用 session上下文 + system + 所有消息的完整摘要串
|
||||
var combined strings.Builder
|
||||
// 混入请求上下文区分因子,避免不同用户相同消息产生相同 hash
|
||||
if parsed.SessionContext != nil {
|
||||
_, _ = combined.WriteString(parsed.SessionContext.ClientIP)
|
||||
_, _ = combined.WriteString(":")
|
||||
_, _ = combined.WriteString(parsed.SessionContext.UserAgent)
|
||||
_, _ = combined.WriteString(":")
|
||||
_, _ = combined.WriteString(strconv.FormatInt(parsed.SessionContext.APIKeyID, 10))
|
||||
_, _ = combined.WriteString("|")
|
||||
}
|
||||
if parsed.System != nil {
|
||||
systemText := s.extractTextFromSystem(parsed.System)
|
||||
if systemText != "" {
|
||||
@@ -500,9 +477,20 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string {
|
||||
}
|
||||
for _, msg := range parsed.Messages {
|
||||
if m, ok := msg.(map[string]any); ok {
|
||||
msgText := s.extractTextFromContent(m["content"])
|
||||
if msgText != "" {
|
||||
_, _ = combined.WriteString(msgText)
|
||||
if content, exists := m["content"]; exists {
|
||||
// Anthropic: messages[].content
|
||||
if msgText := s.extractTextFromContent(content); msgText != "" {
|
||||
_, _ = combined.WriteString(msgText)
|
||||
}
|
||||
} else if parts, ok := m["parts"].([]any); ok {
|
||||
// Gemini: contents[].parts[].text
|
||||
for _, part := range parts {
|
||||
if partMap, ok := part.(map[string]any); ok {
|
||||
if text, ok := partMap["text"].(string); ok {
|
||||
_, _ = combined.WriteString(text)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -536,35 +524,37 @@ func (s *GatewayService) GetCachedSessionAccountID(ctx context.Context, groupID
|
||||
|
||||
// FindGeminiSession 查找 Gemini 会话(基于内容摘要链的 Fallback 匹配)
|
||||
// 返回最长匹配的会话信息(uuid, accountID)
|
||||
func (s *GatewayService) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
|
||||
if digestChain == "" || s.cache == nil {
|
||||
return "", 0, false
|
||||
func (s *GatewayService) FindGeminiSession(_ context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, matchedChain string, found bool) {
|
||||
if digestChain == "" || s.digestStore == nil {
|
||||
return "", 0, "", false
|
||||
}
|
||||
return s.cache.FindGeminiSession(ctx, groupID, prefixHash, digestChain)
|
||||
return s.digestStore.Find(groupID, prefixHash, digestChain)
|
||||
}
|
||||
|
||||
// SaveGeminiSession 保存 Gemini 会话
|
||||
func (s *GatewayService) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
|
||||
if digestChain == "" || s.cache == nil {
|
||||
// SaveGeminiSession 保存 Gemini 会话。oldDigestChain 为 Find 返回的 matchedChain,用于删旧 key。
|
||||
func (s *GatewayService) SaveGeminiSession(_ context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64, oldDigestChain string) error {
|
||||
if digestChain == "" || s.digestStore == nil {
|
||||
return nil
|
||||
}
|
||||
return s.cache.SaveGeminiSession(ctx, groupID, prefixHash, digestChain, uuid, accountID)
|
||||
s.digestStore.Save(groupID, prefixHash, digestChain, uuid, accountID, oldDigestChain)
|
||||
return nil
|
||||
}
|
||||
|
||||
// FindAnthropicSession 查找 Anthropic 会话(基于内容摘要链的 Fallback 匹配)
|
||||
func (s *GatewayService) FindAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
|
||||
if digestChain == "" || s.cache == nil {
|
||||
return "", 0, false
|
||||
func (s *GatewayService) FindAnthropicSession(_ context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, matchedChain string, found bool) {
|
||||
if digestChain == "" || s.digestStore == nil {
|
||||
return "", 0, "", false
|
||||
}
|
||||
return s.cache.FindAnthropicSession(ctx, groupID, prefixHash, digestChain)
|
||||
return s.digestStore.Find(groupID, prefixHash, digestChain)
|
||||
}
|
||||
|
||||
// SaveAnthropicSession 保存 Anthropic 会话
|
||||
func (s *GatewayService) SaveAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
|
||||
if digestChain == "" || s.cache == nil {
|
||||
func (s *GatewayService) SaveAnthropicSession(_ context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64, oldDigestChain string) error {
|
||||
if digestChain == "" || s.digestStore == nil {
|
||||
return nil
|
||||
}
|
||||
return s.cache.SaveAnthropicSession(ctx, groupID, prefixHash, digestChain, uuid, accountID)
|
||||
s.digestStore.Save(groupID, prefixHash, digestChain, uuid, accountID, oldDigestChain)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string {
|
||||
@@ -649,8 +639,8 @@ func (s *GatewayService) extractTextFromContent(content any) string {
|
||||
}
|
||||
|
||||
func (s *GatewayService) hashContent(content string) string {
|
||||
hash := sha256.Sum256([]byte(content))
|
||||
return hex.EncodeToString(hash[:16]) // 32字符
|
||||
h := xxhash.Sum64String(content)
|
||||
return strconv.FormatUint(h, 36)
|
||||
}
|
||||
|
||||
// replaceModelInBody 替换请求体中的model字段
|
||||
@@ -1009,13 +999,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
log.Printf("[ModelRoutingDebug] load-aware enabled: group_id=%v model=%s session=%s platform=%s", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), platform)
|
||||
}
|
||||
|
||||
// Antigravity 模型系列检查(在账号选择前检查,确保所有代码路径都经过此检查)
|
||||
if platform == PlatformAntigravity && groupID != nil && requestedModel != "" {
|
||||
if err := s.checkAntigravityModelScope(ctx, *groupID, requestedModel); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
accounts, useMixed, err := s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -1209,6 +1192,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
return a.account.LastUsedAt.Before(*b.account.LastUsedAt)
|
||||
}
|
||||
})
|
||||
shuffleWithinSortGroups(routingAvailable)
|
||||
|
||||
// 4. 尝试获取槽位
|
||||
for _, item := range routingAvailable {
|
||||
@@ -1362,10 +1346,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
return result, nil
|
||||
}
|
||||
} else {
|
||||
// Antigravity 平台:获取模型负载信息
|
||||
var modelLoadMap map[int64]*ModelLoadInfo
|
||||
isAntigravity := platform == PlatformAntigravity
|
||||
|
||||
var available []accountWithLoad
|
||||
for _, acc := range candidates {
|
||||
loadInfo := loadMap[acc.ID]
|
||||
@@ -1380,109 +1360,44 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
}
|
||||
}
|
||||
|
||||
// Antigravity 平台:按账号实际映射后的模型名获取模型负载(与 Forward 的统计保持一致)
|
||||
if isAntigravity && requestedModel != "" && s.cache != nil && len(available) > 0 {
|
||||
modelLoadMap = make(map[int64]*ModelLoadInfo, len(available))
|
||||
modelToAccountIDs := make(map[string][]int64)
|
||||
for _, item := range available {
|
||||
mappedModel := mapAntigravityModel(item.account, requestedModel)
|
||||
if mappedModel == "" {
|
||||
continue
|
||||
}
|
||||
modelToAccountIDs[mappedModel] = append(modelToAccountIDs[mappedModel], item.account.ID)
|
||||
// 分层过滤选择:优先级 → 负载率 → LRU
|
||||
for len(available) > 0 {
|
||||
// 1. 取优先级最小的集合
|
||||
candidates := filterByMinPriority(available)
|
||||
// 2. 取负载率最低的集合
|
||||
candidates = filterByMinLoadRate(candidates)
|
||||
// 3. LRU 选择最久未用的账号
|
||||
selected := selectByLRU(candidates, preferOAuth)
|
||||
if selected == nil {
|
||||
break
|
||||
}
|
||||
for model, ids := range modelToAccountIDs {
|
||||
batch, err := s.cache.GetModelLoadBatch(ctx, ids, model)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
for id, info := range batch {
|
||||
modelLoadMap[id] = info
|
||||
}
|
||||
}
|
||||
if len(modelLoadMap) == 0 {
|
||||
modelLoadMap = nil
|
||||
}
|
||||
}
|
||||
|
||||
// Antigravity 平台:优先级硬过滤 →(同优先级内)按调用次数选择(最少优先,新账号用平均值)
|
||||
// 其他平台:分层过滤选择:优先级 → 负载率 → LRU
|
||||
if isAntigravity {
|
||||
for len(available) > 0 {
|
||||
// 1. 取优先级最小的集合(硬过滤)
|
||||
candidates := filterByMinPriority(available)
|
||||
// 2. 同优先级内按调用次数选择(调用次数最少优先,新账号使用平均值)
|
||||
selected := selectByCallCount(candidates, modelLoadMap, preferOAuth)
|
||||
if selected == nil {
|
||||
break
|
||||
}
|
||||
|
||||
result, err := s.tryAcquireAccountSlot(ctx, selected.account.ID, selected.account.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
// 会话数量限制检查
|
||||
if !s.checkAndRegisterSession(ctx, selected.account, sessionHash) {
|
||||
result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
|
||||
} else {
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.account.ID, stickySessionTTL)
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: selected.account,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, nil
|
||||
result, err := s.tryAcquireAccountSlot(ctx, selected.account.ID, selected.account.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
// 会话数量限制检查
|
||||
if !s.checkAndRegisterSession(ctx, selected.account, sessionHash) {
|
||||
result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
|
||||
} else {
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.account.ID, stickySessionTTL)
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: selected.account,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 移除已尝试的账号,重新选择
|
||||
selectedID := selected.account.ID
|
||||
newAvailable := make([]accountWithLoad, 0, len(available)-1)
|
||||
for _, acc := range available {
|
||||
if acc.account.ID != selectedID {
|
||||
newAvailable = append(newAvailable, acc)
|
||||
}
|
||||
}
|
||||
available = newAvailable
|
||||
}
|
||||
} else {
|
||||
for len(available) > 0 {
|
||||
// 1. 取优先级最小的集合
|
||||
candidates := filterByMinPriority(available)
|
||||
// 2. 取负载率最低的集合
|
||||
candidates = filterByMinLoadRate(candidates)
|
||||
// 3. LRU 选择最久未用的账号
|
||||
selected := selectByLRU(candidates, preferOAuth)
|
||||
if selected == nil {
|
||||
break
|
||||
}
|
||||
|
||||
result, err := s.tryAcquireAccountSlot(ctx, selected.account.ID, selected.account.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
// 会话数量限制检查
|
||||
if !s.checkAndRegisterSession(ctx, selected.account, sessionHash) {
|
||||
result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
|
||||
} else {
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.account.ID, stickySessionTTL)
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: selected.account,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, nil
|
||||
}
|
||||
// 移除已尝试的账号,重新进行分层过滤
|
||||
selectedID := selected.account.ID
|
||||
newAvailable := make([]accountWithLoad, 0, len(available)-1)
|
||||
for _, acc := range available {
|
||||
if acc.account.ID != selectedID {
|
||||
newAvailable = append(newAvailable, acc)
|
||||
}
|
||||
|
||||
// 移除已尝试的账号,重新进行分层过滤
|
||||
selectedID := selected.account.ID
|
||||
newAvailable := make([]accountWithLoad, 0, len(available)-1)
|
||||
for _, acc := range available {
|
||||
if acc.account.ID != selectedID {
|
||||
newAvailable = append(newAvailable, acc)
|
||||
}
|
||||
}
|
||||
available = newAvailable
|
||||
}
|
||||
available = newAvailable
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2018,87 +1933,79 @@ func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) {
|
||||
return a.LastUsedAt.Before(*b.LastUsedAt)
|
||||
}
|
||||
})
|
||||
shuffleWithinPriorityAndLastUsed(accounts)
|
||||
}
|
||||
|
||||
// selectByCallCount 从候选账号中选择调用次数最少的账号(Antigravity 专用)
|
||||
// 新账号(CallCount=0)使用平均调用次数作为虚拟值,避免冷启动被猛调
|
||||
// 如果有多个账号具有相同的最小调用次数,则随机选择一个
|
||||
func selectByCallCount(accounts []accountWithLoad, modelLoadMap map[int64]*ModelLoadInfo, preferOAuth bool) *accountWithLoad {
|
||||
if len(accounts) == 0 {
|
||||
return nil
|
||||
// shuffleWithinSortGroups 对排序后的 accountWithLoad 切片,按 (Priority, LoadRate, LastUsedAt) 分组后组内随机打乱。
|
||||
// 防止并发请求读取同一快照时,确定性排序导致所有请求命中相同账号。
|
||||
func shuffleWithinSortGroups(accounts []accountWithLoad) {
|
||||
if len(accounts) <= 1 {
|
||||
return
|
||||
}
|
||||
if len(accounts) == 1 {
|
||||
return &accounts[0]
|
||||
}
|
||||
|
||||
// 如果没有负载信息,回退到 LRU
|
||||
if modelLoadMap == nil {
|
||||
return selectByLRU(accounts, preferOAuth)
|
||||
}
|
||||
|
||||
// 1. 计算平均调用次数(用于新账号冷启动)
|
||||
var totalCallCount int64
|
||||
var countWithCalls int
|
||||
for _, acc := range accounts {
|
||||
if info := modelLoadMap[acc.account.ID]; info != nil && info.CallCount > 0 {
|
||||
totalCallCount += info.CallCount
|
||||
countWithCalls++
|
||||
i := 0
|
||||
for i < len(accounts) {
|
||||
j := i + 1
|
||||
for j < len(accounts) && sameAccountWithLoadGroup(accounts[i], accounts[j]) {
|
||||
j++
|
||||
}
|
||||
}
|
||||
|
||||
var avgCallCount int64
|
||||
if countWithCalls > 0 {
|
||||
avgCallCount = totalCallCount / int64(countWithCalls)
|
||||
}
|
||||
|
||||
// 2. 获取每个账号的有效调用次数
|
||||
getEffectiveCallCount := func(acc accountWithLoad) int64 {
|
||||
if acc.account == nil {
|
||||
return 0
|
||||
if j-i > 1 {
|
||||
mathrand.Shuffle(j-i, func(a, b int) {
|
||||
accounts[i+a], accounts[i+b] = accounts[i+b], accounts[i+a]
|
||||
})
|
||||
}
|
||||
info := modelLoadMap[acc.account.ID]
|
||||
if info == nil || info.CallCount == 0 {
|
||||
return avgCallCount // 新账号使用平均值
|
||||
}
|
||||
return info.CallCount
|
||||
i = j
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 找到最小调用次数
|
||||
minCount := getEffectiveCallCount(accounts[0])
|
||||
for _, acc := range accounts[1:] {
|
||||
if c := getEffectiveCallCount(acc); c < minCount {
|
||||
minCount = c
|
||||
}
|
||||
// sameAccountWithLoadGroup 判断两个 accountWithLoad 是否属于同一排序组
|
||||
func sameAccountWithLoadGroup(a, b accountWithLoad) bool {
|
||||
if a.account.Priority != b.account.Priority {
|
||||
return false
|
||||
}
|
||||
|
||||
// 4. 收集所有具有最小调用次数的账号
|
||||
var candidateIdxs []int
|
||||
for i, acc := range accounts {
|
||||
if getEffectiveCallCount(acc) == minCount {
|
||||
candidateIdxs = append(candidateIdxs, i)
|
||||
}
|
||||
if a.loadInfo.LoadRate != b.loadInfo.LoadRate {
|
||||
return false
|
||||
}
|
||||
return sameLastUsedAt(a.account.LastUsedAt, b.account.LastUsedAt)
|
||||
}
|
||||
|
||||
// 5. 如果只有一个候选,直接返回
|
||||
if len(candidateIdxs) == 1 {
|
||||
return &accounts[candidateIdxs[0]]
|
||||
// shuffleWithinPriorityAndLastUsed 对排序后的 []*Account 切片,按 (Priority, LastUsedAt) 分组后组内随机打乱。
|
||||
func shuffleWithinPriorityAndLastUsed(accounts []*Account) {
|
||||
if len(accounts) <= 1 {
|
||||
return
|
||||
}
|
||||
|
||||
// 6. preferOAuth 处理
|
||||
if preferOAuth {
|
||||
var oauthIdxs []int
|
||||
for _, idx := range candidateIdxs {
|
||||
if accounts[idx].account.Type == AccountTypeOAuth {
|
||||
oauthIdxs = append(oauthIdxs, idx)
|
||||
}
|
||||
i := 0
|
||||
for i < len(accounts) {
|
||||
j := i + 1
|
||||
for j < len(accounts) && sameAccountGroup(accounts[i], accounts[j]) {
|
||||
j++
|
||||
}
|
||||
if len(oauthIdxs) > 0 {
|
||||
candidateIdxs = oauthIdxs
|
||||
if j-i > 1 {
|
||||
mathrand.Shuffle(j-i, func(a, b int) {
|
||||
accounts[i+a], accounts[i+b] = accounts[i+b], accounts[i+a]
|
||||
})
|
||||
}
|
||||
i = j
|
||||
}
|
||||
}
|
||||
|
||||
// 7. 随机选择
|
||||
return &accounts[candidateIdxs[mathrand.Intn(len(candidateIdxs))]]
|
||||
// sameAccountGroup 判断两个 Account 是否属于同一排序组(Priority + LastUsedAt)
|
||||
func sameAccountGroup(a, b *Account) bool {
|
||||
if a.Priority != b.Priority {
|
||||
return false
|
||||
}
|
||||
return sameLastUsedAt(a.LastUsedAt, b.LastUsedAt)
|
||||
}
|
||||
|
||||
// sameLastUsedAt 判断两个 LastUsedAt 是否相同(精度到秒)
|
||||
func sameLastUsedAt(a, b *time.Time) bool {
|
||||
switch {
|
||||
case a == nil && b == nil:
|
||||
return true
|
||||
case a == nil || b == nil:
|
||||
return false
|
||||
default:
|
||||
return a.Unix() == b.Unix()
|
||||
}
|
||||
}
|
||||
|
||||
// sortCandidatesForFallback 根据配置选择排序策略
|
||||
@@ -2153,13 +2060,6 @@ func shuffleWithinPriority(accounts []*Account) {
|
||||
|
||||
// selectAccountForModelWithPlatform 选择单平台账户(完全隔离)
|
||||
func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) {
|
||||
// 对 Antigravity 平台,检查请求的模型系列是否在分组支持范围内
|
||||
if platform == PlatformAntigravity && groupID != nil && requestedModel != "" {
|
||||
if err := s.checkAntigravityModelScope(ctx, *groupID, requestedModel); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
preferOAuth := platform == PlatformGemini
|
||||
routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, platform)
|
||||
|
||||
@@ -5171,27 +5071,6 @@ func (s *GatewayService) validateUpstreamBaseURL(raw string) (string, error) {
|
||||
return normalized, nil
|
||||
}
|
||||
|
||||
// checkAntigravityModelScope 检查 Antigravity 平台的模型系列是否在分组支持范围内
|
||||
func (s *GatewayService) checkAntigravityModelScope(ctx context.Context, groupID int64, requestedModel string) error {
|
||||
scope, ok := ResolveAntigravityQuotaScope(requestedModel)
|
||||
if !ok {
|
||||
return nil // 无法解析 scope,跳过检查
|
||||
}
|
||||
|
||||
group, err := s.resolveGroupByID(ctx, groupID)
|
||||
if err != nil {
|
||||
return nil // 查询失败时放行
|
||||
}
|
||||
if group == nil {
|
||||
return nil // 分组不存在时放行
|
||||
}
|
||||
|
||||
if !IsScopeSupported(group.SupportedModelScopes, scope) {
|
||||
return ErrModelScopeNotSupported
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAvailableModels returns the list of models available for a group
|
||||
// It aggregates model_mapping keys from all schedulable accounts in the group
|
||||
func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64, platform string) []string {
|
||||
|
||||
@@ -14,7 +14,7 @@ func BenchmarkGenerateSessionHash_Metadata(b *testing.B) {
|
||||
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
parsed, err := ParseGatewayRequest(body)
|
||||
parsed, err := ParseGatewayRequest(body, "")
|
||||
if err != nil {
|
||||
b.Fatalf("解析请求失败: %v", err)
|
||||
}
|
||||
|
||||
384
backend/internal/service/gemini_error_policy_test.go
Normal file
384
backend/internal/service/gemini_error_policy_test.go
Normal file
@@ -0,0 +1,384 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// TestShouldFailoverGeminiUpstreamError — verifies the failover decision
|
||||
// for the ErrorPolicyNone path (original logic preserved).
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestShouldFailoverGeminiUpstreamError(t *testing.T) {
|
||||
svc := &GeminiMessagesCompatService{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
statusCode int
|
||||
expected bool
|
||||
}{
|
||||
{"401_failover", 401, true},
|
||||
{"403_failover", 403, true},
|
||||
{"429_failover", 429, true},
|
||||
{"529_failover", 529, true},
|
||||
{"500_failover", 500, true},
|
||||
{"502_failover", 502, true},
|
||||
{"503_failover", 503, true},
|
||||
{"400_no_failover", 400, false},
|
||||
{"404_no_failover", 404, false},
|
||||
{"422_no_failover", 422, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := svc.shouldFailoverGeminiUpstreamError(tt.statusCode)
|
||||
require.Equal(t, tt.expected, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// TestCheckErrorPolicy_GeminiAccounts — verifies CheckErrorPolicy works
|
||||
// correctly for Gemini platform accounts (API Key type).
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestCheckErrorPolicy_GeminiAccounts(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
account *Account
|
||||
statusCode int
|
||||
body []byte
|
||||
expected ErrorPolicyResult
|
||||
}{
|
||||
{
|
||||
name: "gemini_apikey_custom_codes_hit",
|
||||
account: &Account{
|
||||
ID: 100,
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformGemini,
|
||||
Credentials: map[string]any{
|
||||
"custom_error_codes_enabled": true,
|
||||
"custom_error_codes": []any{float64(429), float64(500)},
|
||||
},
|
||||
},
|
||||
statusCode: 429,
|
||||
body: []byte(`{"error":"rate limited"}`),
|
||||
expected: ErrorPolicyMatched,
|
||||
},
|
||||
{
|
||||
name: "gemini_apikey_custom_codes_miss",
|
||||
account: &Account{
|
||||
ID: 101,
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformGemini,
|
||||
Credentials: map[string]any{
|
||||
"custom_error_codes_enabled": true,
|
||||
"custom_error_codes": []any{float64(429)},
|
||||
},
|
||||
},
|
||||
statusCode: 500,
|
||||
body: []byte(`{"error":"internal"}`),
|
||||
expected: ErrorPolicySkipped,
|
||||
},
|
||||
{
|
||||
name: "gemini_apikey_no_custom_codes_returns_none",
|
||||
account: &Account{
|
||||
ID: 102,
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformGemini,
|
||||
},
|
||||
statusCode: 500,
|
||||
body: []byte(`{"error":"internal"}`),
|
||||
expected: ErrorPolicyNone,
|
||||
},
|
||||
{
|
||||
name: "gemini_apikey_temp_unschedulable_hit",
|
||||
account: &Account{
|
||||
ID: 103,
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformGemini,
|
||||
Credentials: map[string]any{
|
||||
"temp_unschedulable_enabled": true,
|
||||
"temp_unschedulable_rules": []any{
|
||||
map[string]any{
|
||||
"error_code": float64(503),
|
||||
"keywords": []any{"overloaded"},
|
||||
"duration_minutes": float64(10),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
statusCode: 503,
|
||||
body: []byte(`overloaded service`),
|
||||
expected: ErrorPolicyTempUnscheduled,
|
||||
},
|
||||
{
|
||||
name: "gemini_custom_codes_override_temp_unschedulable",
|
||||
account: &Account{
|
||||
ID: 104,
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformGemini,
|
||||
Credentials: map[string]any{
|
||||
"custom_error_codes_enabled": true,
|
||||
"custom_error_codes": []any{float64(503)},
|
||||
"temp_unschedulable_enabled": true,
|
||||
"temp_unschedulable_rules": []any{
|
||||
map[string]any{
|
||||
"error_code": float64(503),
|
||||
"keywords": []any{"overloaded"},
|
||||
"duration_minutes": float64(10),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
statusCode: 503,
|
||||
body: []byte(`overloaded`),
|
||||
expected: ErrorPolicyMatched, // custom codes take precedence
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
repo := &errorPolicyRepoStub{}
|
||||
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
|
||||
result := svc.CheckErrorPolicy(context.Background(), tt.account, tt.statusCode, tt.body)
|
||||
require.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// TestGeminiErrorPolicyIntegration — verifies the Gemini error handling
|
||||
// paths produce the correct behavior for each ErrorPolicyResult.
|
||||
//
|
||||
// These tests simulate the inline error policy switch in handleClaudeCompat
|
||||
// and forwardNativeGemini by calling the same methods in the same order.
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestGeminiErrorPolicyIntegration(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
account *Account
|
||||
statusCode int
|
||||
respBody []byte
|
||||
expectFailover bool // expect UpstreamFailoverError
|
||||
expectHandleError bool // expect handleGeminiUpstreamError to be called
|
||||
expectShouldFailover bool // for None path, whether shouldFailover triggers
|
||||
}{
|
||||
{
|
||||
name: "custom_codes_matched_429_failover",
|
||||
account: &Account{
|
||||
ID: 200,
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformGemini,
|
||||
Credentials: map[string]any{
|
||||
"custom_error_codes_enabled": true,
|
||||
"custom_error_codes": []any{float64(429)},
|
||||
},
|
||||
},
|
||||
statusCode: 429,
|
||||
respBody: []byte(`{"error":"rate limited"}`),
|
||||
expectFailover: true,
|
||||
expectHandleError: true,
|
||||
},
|
||||
{
|
||||
name: "custom_codes_skipped_500_no_failover",
|
||||
account: &Account{
|
||||
ID: 201,
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformGemini,
|
||||
Credentials: map[string]any{
|
||||
"custom_error_codes_enabled": true,
|
||||
"custom_error_codes": []any{float64(429)},
|
||||
},
|
||||
},
|
||||
statusCode: 500,
|
||||
respBody: []byte(`{"error":"internal"}`),
|
||||
expectFailover: false,
|
||||
expectHandleError: false,
|
||||
},
|
||||
{
|
||||
name: "temp_unschedulable_matched_failover",
|
||||
account: &Account{
|
||||
ID: 202,
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformGemini,
|
||||
Credentials: map[string]any{
|
||||
"temp_unschedulable_enabled": true,
|
||||
"temp_unschedulable_rules": []any{
|
||||
map[string]any{
|
||||
"error_code": float64(503),
|
||||
"keywords": []any{"overloaded"},
|
||||
"duration_minutes": float64(10),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
statusCode: 503,
|
||||
respBody: []byte(`overloaded`),
|
||||
expectFailover: true,
|
||||
expectHandleError: true,
|
||||
},
|
||||
{
|
||||
name: "no_policy_429_failover_via_shouldFailover",
|
||||
account: &Account{
|
||||
ID: 203,
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformGemini,
|
||||
},
|
||||
statusCode: 429,
|
||||
respBody: []byte(`{"error":"rate limited"}`),
|
||||
expectFailover: true,
|
||||
expectHandleError: true,
|
||||
expectShouldFailover: true,
|
||||
},
|
||||
{
|
||||
name: "no_policy_400_no_failover",
|
||||
account: &Account{
|
||||
ID: 204,
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformGemini,
|
||||
},
|
||||
statusCode: 400,
|
||||
respBody: []byte(`{"error":"bad request"}`),
|
||||
expectFailover: false,
|
||||
expectHandleError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
repo := &geminiErrorPolicyRepo{}
|
||||
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
svc := &GeminiMessagesCompatService{
|
||||
accountRepo: repo,
|
||||
rateLimitService: rlSvc,
|
||||
}
|
||||
|
||||
writer := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(writer)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
|
||||
// Simulate the Claude compat error handling path (same logic as native).
|
||||
// This mirrors the inline switch in handleClaudeCompat.
|
||||
var handleErrorCalled bool
|
||||
var gotFailover bool
|
||||
|
||||
ctx := context.Background()
|
||||
statusCode := tt.statusCode
|
||||
respBody := tt.respBody
|
||||
account := tt.account
|
||||
headers := http.Header{}
|
||||
|
||||
if svc.rateLimitService != nil {
|
||||
switch svc.rateLimitService.CheckErrorPolicy(ctx, account, statusCode, respBody) {
|
||||
case ErrorPolicySkipped:
|
||||
// Skipped → return error directly (no handleGeminiUpstreamError, no failover)
|
||||
gotFailover = false
|
||||
handleErrorCalled = false
|
||||
goto verify
|
||||
case ErrorPolicyMatched, ErrorPolicyTempUnscheduled:
|
||||
svc.handleGeminiUpstreamError(ctx, account, statusCode, headers, respBody)
|
||||
handleErrorCalled = true
|
||||
gotFailover = true
|
||||
goto verify
|
||||
}
|
||||
}
|
||||
|
||||
// ErrorPolicyNone → original logic
|
||||
svc.handleGeminiUpstreamError(ctx, account, statusCode, headers, respBody)
|
||||
handleErrorCalled = true
|
||||
if svc.shouldFailoverGeminiUpstreamError(statusCode) {
|
||||
gotFailover = true
|
||||
}
|
||||
|
||||
verify:
|
||||
require.Equal(t, tt.expectFailover, gotFailover, "failover mismatch")
|
||||
require.Equal(t, tt.expectHandleError, handleErrorCalled, "handleGeminiUpstreamError call mismatch")
|
||||
|
||||
if tt.expectShouldFailover {
|
||||
require.True(t, svc.shouldFailoverGeminiUpstreamError(statusCode),
|
||||
"shouldFailoverGeminiUpstreamError should return true for status %d", statusCode)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// TestGeminiErrorPolicy_NilRateLimitService — verifies nil safety
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestGeminiErrorPolicy_NilRateLimitService(t *testing.T) {
|
||||
svc := &GeminiMessagesCompatService{
|
||||
rateLimitService: nil,
|
||||
}
|
||||
|
||||
// When rateLimitService is nil, error policy is skipped → falls through to
|
||||
// shouldFailoverGeminiUpstreamError (original logic).
|
||||
// Verify this doesn't panic and follows expected behavior.
|
||||
|
||||
ctx := context.Background()
|
||||
account := &Account{
|
||||
ID: 300,
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformGemini,
|
||||
Credentials: map[string]any{
|
||||
"custom_error_codes_enabled": true,
|
||||
"custom_error_codes": []any{float64(429)},
|
||||
},
|
||||
}
|
||||
|
||||
// The nil check should prevent CheckErrorPolicy from being called
|
||||
if svc.rateLimitService != nil {
|
||||
t.Fatal("rateLimitService should be nil for this test")
|
||||
}
|
||||
|
||||
// shouldFailoverGeminiUpstreamError still works
|
||||
require.True(t, svc.shouldFailoverGeminiUpstreamError(429))
|
||||
require.False(t, svc.shouldFailoverGeminiUpstreamError(400))
|
||||
|
||||
// handleGeminiUpstreamError should not panic with nil rateLimitService
|
||||
require.NotPanics(t, func() {
|
||||
svc.handleGeminiUpstreamError(ctx, account, 500, http.Header{}, []byte(`error`))
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// geminiErrorPolicyRepo — minimal AccountRepository stub for Gemini error
|
||||
// policy tests. Embeds mockAccountRepoForGemini and adds tracking.
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type geminiErrorPolicyRepo struct {
|
||||
mockAccountRepoForGemini
|
||||
setErrorCalls int
|
||||
setRateLimitedCalls int
|
||||
setTempCalls int
|
||||
}
|
||||
|
||||
func (r *geminiErrorPolicyRepo) SetError(_ context.Context, _ int64, _ string) error {
|
||||
r.setErrorCalls++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *geminiErrorPolicyRepo) SetRateLimited(_ context.Context, _ int64, _ time.Time) error {
|
||||
r.setRateLimitedCalls++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *geminiErrorPolicyRepo) SetTempUnschedulable(_ context.Context, _ int64, _ time.Time, _ string) error {
|
||||
r.setTempCalls++
|
||||
return nil
|
||||
}
|
||||
@@ -831,38 +831,47 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
tempMatched := false
|
||||
// 统一错误策略:自定义错误码 + 临时不可调度
|
||||
if s.rateLimitService != nil {
|
||||
tempMatched = s.rateLimitService.HandleTempUnschedulable(ctx, account, resp.StatusCode, respBody)
|
||||
}
|
||||
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
if tempMatched {
|
||||
upstreamReqID := resp.Header.Get(requestIDHeader)
|
||||
if upstreamReqID == "" {
|
||||
upstreamReqID = resp.Header.Get("x-goog-request-id")
|
||||
}
|
||||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
upstreamDetail := ""
|
||||
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||
if maxBytes <= 0 {
|
||||
maxBytes = 2048
|
||||
switch s.rateLimitService.CheckErrorPolicy(ctx, account, resp.StatusCode, respBody) {
|
||||
case ErrorPolicySkipped:
|
||||
upstreamReqID := resp.Header.Get(requestIDHeader)
|
||||
if upstreamReqID == "" {
|
||||
upstreamReqID = resp.Header.Get("x-goog-request-id")
|
||||
}
|
||||
upstreamDetail = truncateString(string(respBody), maxBytes)
|
||||
return nil, s.writeGeminiMappedError(c, account, resp.StatusCode, upstreamReqID, respBody)
|
||||
case ErrorPolicyMatched, ErrorPolicyTempUnscheduled:
|
||||
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
upstreamReqID := resp.Header.Get(requestIDHeader)
|
||||
if upstreamReqID == "" {
|
||||
upstreamReqID = resp.Header.Get("x-goog-request-id")
|
||||
}
|
||||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
upstreamDetail := ""
|
||||
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||
if maxBytes <= 0 {
|
||||
maxBytes = 2048
|
||||
}
|
||||
upstreamDetail = truncateString(string(respBody), maxBytes)
|
||||
}
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: upstreamReqID,
|
||||
Kind: "failover",
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
|
||||
}
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: upstreamReqID,
|
||||
Kind: "failover",
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
|
||||
}
|
||||
|
||||
// ErrorPolicyNone → 原有逻辑
|
||||
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
|
||||
upstreamReqID := resp.Header.Get(requestIDHeader)
|
||||
if upstreamReqID == "" {
|
||||
@@ -1249,14 +1258,9 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
tempMatched := false
|
||||
if s.rateLimitService != nil {
|
||||
tempMatched = s.rateLimitService.HandleTempUnschedulable(ctx, account, resp.StatusCode, respBody)
|
||||
}
|
||||
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
|
||||
// Best-effort fallback for OAuth tokens missing AI Studio scopes when calling countTokens.
|
||||
// This avoids Gemini SDKs failing hard during preflight token counting.
|
||||
// Checked before error policy so it always works regardless of custom error codes.
|
||||
if action == "countTokens" && isOAuth && isGeminiInsufficientScope(resp.Header, respBody) {
|
||||
estimated := estimateGeminiCountTokens(body)
|
||||
c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated})
|
||||
@@ -1270,30 +1274,46 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
}, nil
|
||||
}
|
||||
|
||||
if tempMatched {
|
||||
evBody := unwrapIfNeeded(isOAuth, respBody)
|
||||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(evBody))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
upstreamDetail := ""
|
||||
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||
if maxBytes <= 0 {
|
||||
maxBytes = 2048
|
||||
// 统一错误策略:自定义错误码 + 临时不可调度
|
||||
if s.rateLimitService != nil {
|
||||
switch s.rateLimitService.CheckErrorPolicy(ctx, account, resp.StatusCode, respBody) {
|
||||
case ErrorPolicySkipped:
|
||||
respBody = unwrapIfNeeded(isOAuth, respBody)
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
if contentType == "" {
|
||||
contentType = "application/json"
|
||||
}
|
||||
upstreamDetail = truncateString(string(evBody), maxBytes)
|
||||
c.Data(resp.StatusCode, contentType, respBody)
|
||||
return nil, fmt.Errorf("gemini upstream error: %d (skipped by error policy)", resp.StatusCode)
|
||||
case ErrorPolicyMatched, ErrorPolicyTempUnscheduled:
|
||||
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
evBody := unwrapIfNeeded(isOAuth, respBody)
|
||||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(evBody))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
upstreamDetail := ""
|
||||
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||
if maxBytes <= 0 {
|
||||
maxBytes = 2048
|
||||
}
|
||||
upstreamDetail = truncateString(string(evBody), maxBytes)
|
||||
}
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: requestID,
|
||||
Kind: "failover",
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
|
||||
}
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: requestID,
|
||||
Kind: "failover",
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
|
||||
}
|
||||
|
||||
// ErrorPolicyNone → 原有逻辑
|
||||
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
|
||||
evBody := unwrapIfNeeded(isOAuth, respBody)
|
||||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(evBody))
|
||||
|
||||
@@ -133,9 +133,6 @@ func (m *mockAccountRepoForGemini) ListSchedulableByGroupIDAndPlatforms(ctx cont
|
||||
func (m *mockAccountRepoForGemini) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
|
||||
return nil
|
||||
}
|
||||
@@ -265,29 +262,6 @@ func (m *mockGatewayCacheForGemini) DeleteSessionAccountID(ctx context.Context,
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockGatewayCacheForGemini) IncrModelCallCount(ctx context.Context, accountID int64, model string) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *mockGatewayCacheForGemini) GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*ModelLoadInfo, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockGatewayCacheForGemini) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
|
||||
return "", 0, false
|
||||
}
|
||||
|
||||
func (m *mockGatewayCacheForGemini) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockGatewayCacheForGemini) FindAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
|
||||
return "", 0, false
|
||||
}
|
||||
|
||||
func (m *mockGatewayCacheForGemini) SaveAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform 测试 Gemini 单平台选择
|
||||
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform(t *testing.T) {
|
||||
|
||||
@@ -6,26 +6,11 @@ import (
|
||||
"encoding/json"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
"github.com/cespare/xxhash/v2"
|
||||
)
|
||||
|
||||
// Gemini 会话 ID Fallback 相关常量
|
||||
const (
|
||||
// geminiSessionTTLSeconds Gemini 会话缓存 TTL(5 分钟)
|
||||
geminiSessionTTLSeconds = 300
|
||||
|
||||
// geminiSessionKeyPrefix Gemini 会话 Redis key 前缀
|
||||
geminiSessionKeyPrefix = "gemini:sess:"
|
||||
)
|
||||
|
||||
// GeminiSessionTTL 返回 Gemini 会话缓存 TTL
|
||||
func GeminiSessionTTL() time.Duration {
|
||||
return geminiSessionTTLSeconds * time.Second
|
||||
}
|
||||
|
||||
// shortHash 使用 XXHash64 + Base36 生成短 hash(16 字符)
|
||||
// XXHash64 比 SHA256 快约 10 倍,Base36 比 Hex 短约 20%
|
||||
func shortHash(data []byte) string {
|
||||
@@ -79,35 +64,6 @@ func GenerateGeminiPrefixHash(userID, apiKeyID int64, ip, userAgent, platform, m
|
||||
return base64.RawURLEncoding.EncodeToString(hash[:12])
|
||||
}
|
||||
|
||||
// BuildGeminiSessionKey 构建 Gemini 会话 Redis key
|
||||
// 格式: gemini:sess:{groupID}:{prefixHash}:{digestChain}
|
||||
func BuildGeminiSessionKey(groupID int64, prefixHash, digestChain string) string {
|
||||
return geminiSessionKeyPrefix + strconv.FormatInt(groupID, 10) + ":" + prefixHash + ":" + digestChain
|
||||
}
|
||||
|
||||
// GenerateDigestChainPrefixes 生成摘要链的所有前缀(从长到短)
|
||||
// 用于 MGET 批量查询最长匹配
|
||||
func GenerateDigestChainPrefixes(chain string) []string {
|
||||
if chain == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var prefixes []string
|
||||
c := chain
|
||||
|
||||
for c != "" {
|
||||
prefixes = append(prefixes, c)
|
||||
// 找到最后一个 "-" 的位置
|
||||
if i := strings.LastIndex(c, "-"); i > 0 {
|
||||
c = c[:i]
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return prefixes
|
||||
}
|
||||
|
||||
// ParseGeminiSessionValue 解析 Gemini 会话缓存值
|
||||
// 格式: {uuid}:{accountID}
|
||||
func ParseGeminiSessionValue(value string) (uuid string, accountID int64, ok bool) {
|
||||
@@ -139,15 +95,6 @@ func FormatGeminiSessionValue(uuid string, accountID int64) string {
|
||||
// geminiDigestSessionKeyPrefix Gemini 摘要 fallback 会话 key 前缀
|
||||
const geminiDigestSessionKeyPrefix = "gemini:digest:"
|
||||
|
||||
// geminiTrieKeyPrefix Gemini Trie 会话 key 前缀
|
||||
const geminiTrieKeyPrefix = "gemini:trie:"
|
||||
|
||||
// BuildGeminiTrieKey 构建 Gemini Trie Redis key
|
||||
// 格式: gemini:trie:{groupID}:{prefixHash}
|
||||
func BuildGeminiTrieKey(groupID int64, prefixHash string) string {
|
||||
return geminiTrieKeyPrefix + strconv.FormatInt(groupID, 10) + ":" + prefixHash
|
||||
}
|
||||
|
||||
// GenerateGeminiDigestSessionKey 生成 Gemini 摘要 fallback 的 sessionKey
|
||||
// 组合 prefixHash 前 8 位 + uuid 前 8 位,确保不同会话产生不同的 sessionKey
|
||||
// 用于在 SelectAccountWithLoadAwareness 中保持粘性会话
|
||||
|
||||
@@ -1,41 +1,14 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
)
|
||||
|
||||
// mockGeminiSessionCache 模拟 Redis 缓存
|
||||
type mockGeminiSessionCache struct {
|
||||
sessions map[string]string // key -> value
|
||||
}
|
||||
|
||||
func newMockGeminiSessionCache() *mockGeminiSessionCache {
|
||||
return &mockGeminiSessionCache{sessions: make(map[string]string)}
|
||||
}
|
||||
|
||||
func (m *mockGeminiSessionCache) Save(groupID int64, prefixHash, digestChain, uuid string, accountID int64) {
|
||||
key := BuildGeminiSessionKey(groupID, prefixHash, digestChain)
|
||||
value := FormatGeminiSessionValue(uuid, accountID)
|
||||
m.sessions[key] = value
|
||||
}
|
||||
|
||||
func (m *mockGeminiSessionCache) Find(groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
|
||||
prefixes := GenerateDigestChainPrefixes(digestChain)
|
||||
for _, p := range prefixes {
|
||||
key := BuildGeminiSessionKey(groupID, prefixHash, p)
|
||||
if val, ok := m.sessions[key]; ok {
|
||||
return ParseGeminiSessionValue(val)
|
||||
}
|
||||
}
|
||||
return "", 0, false
|
||||
}
|
||||
|
||||
// TestGeminiSessionContinuousConversation 测试连续会话的摘要链匹配
|
||||
func TestGeminiSessionContinuousConversation(t *testing.T) {
|
||||
cache := newMockGeminiSessionCache()
|
||||
store := NewDigestSessionStore()
|
||||
groupID := int64(1)
|
||||
prefixHash := "test_prefix_hash"
|
||||
sessionUUID := "session-uuid-12345"
|
||||
@@ -54,13 +27,13 @@ func TestGeminiSessionContinuousConversation(t *testing.T) {
|
||||
t.Logf("Round 1 chain: %s", chain1)
|
||||
|
||||
// 第一轮:没有找到会话,创建新会话
|
||||
_, _, found := cache.Find(groupID, prefixHash, chain1)
|
||||
_, _, _, found := store.Find(groupID, prefixHash, chain1)
|
||||
if found {
|
||||
t.Error("Round 1: should not find existing session")
|
||||
}
|
||||
|
||||
// 保存第一轮会话
|
||||
cache.Save(groupID, prefixHash, chain1, sessionUUID, accountID)
|
||||
// 保存第一轮会话(首轮无旧 chain)
|
||||
store.Save(groupID, prefixHash, chain1, sessionUUID, accountID, "")
|
||||
|
||||
// 模拟第二轮对话(用户继续对话)
|
||||
req2 := &antigravity.GeminiRequest{
|
||||
@@ -77,7 +50,7 @@ func TestGeminiSessionContinuousConversation(t *testing.T) {
|
||||
t.Logf("Round 2 chain: %s", chain2)
|
||||
|
||||
// 第二轮:应该能找到会话(通过前缀匹配)
|
||||
foundUUID, foundAccID, found := cache.Find(groupID, prefixHash, chain2)
|
||||
foundUUID, foundAccID, matchedChain, found := store.Find(groupID, prefixHash, chain2)
|
||||
if !found {
|
||||
t.Error("Round 2: should find session via prefix matching")
|
||||
}
|
||||
@@ -88,8 +61,8 @@ func TestGeminiSessionContinuousConversation(t *testing.T) {
|
||||
t.Errorf("Round 2: expected accountID %d, got %d", accountID, foundAccID)
|
||||
}
|
||||
|
||||
// 保存第二轮会话
|
||||
cache.Save(groupID, prefixHash, chain2, sessionUUID, accountID)
|
||||
// 保存第二轮会话,传入 Find 返回的 matchedChain 以删旧 key
|
||||
store.Save(groupID, prefixHash, chain2, sessionUUID, accountID, matchedChain)
|
||||
|
||||
// 模拟第三轮对话
|
||||
req3 := &antigravity.GeminiRequest{
|
||||
@@ -108,7 +81,7 @@ func TestGeminiSessionContinuousConversation(t *testing.T) {
|
||||
t.Logf("Round 3 chain: %s", chain3)
|
||||
|
||||
// 第三轮:应该能找到会话(通过第二轮的前缀匹配)
|
||||
foundUUID, foundAccID, found = cache.Find(groupID, prefixHash, chain3)
|
||||
foundUUID, foundAccID, _, found = store.Find(groupID, prefixHash, chain3)
|
||||
if !found {
|
||||
t.Error("Round 3: should find session via prefix matching")
|
||||
}
|
||||
@@ -118,13 +91,11 @@ func TestGeminiSessionContinuousConversation(t *testing.T) {
|
||||
if foundAccID != accountID {
|
||||
t.Errorf("Round 3: expected accountID %d, got %d", accountID, foundAccID)
|
||||
}
|
||||
|
||||
t.Log("✓ Continuous conversation session matching works correctly!")
|
||||
}
|
||||
|
||||
// TestGeminiSessionDifferentConversations 测试不同会话不会错误匹配
|
||||
func TestGeminiSessionDifferentConversations(t *testing.T) {
|
||||
cache := newMockGeminiSessionCache()
|
||||
store := NewDigestSessionStore()
|
||||
groupID := int64(1)
|
||||
prefixHash := "test_prefix_hash"
|
||||
|
||||
@@ -135,7 +106,7 @@ func TestGeminiSessionDifferentConversations(t *testing.T) {
|
||||
},
|
||||
}
|
||||
chain1 := BuildGeminiDigestChain(req1)
|
||||
cache.Save(groupID, prefixHash, chain1, "session-1", 100)
|
||||
store.Save(groupID, prefixHash, chain1, "session-1", 100, "")
|
||||
|
||||
// 第二个完全不同的会话
|
||||
req2 := &antigravity.GeminiRequest{
|
||||
@@ -146,61 +117,29 @@ func TestGeminiSessionDifferentConversations(t *testing.T) {
|
||||
chain2 := BuildGeminiDigestChain(req2)
|
||||
|
||||
// 不同会话不应该匹配
|
||||
_, _, found := cache.Find(groupID, prefixHash, chain2)
|
||||
_, _, _, found := store.Find(groupID, prefixHash, chain2)
|
||||
if found {
|
||||
t.Error("Different conversations should not match")
|
||||
}
|
||||
|
||||
t.Log("✓ Different conversations are correctly isolated!")
|
||||
}
|
||||
|
||||
// TestGeminiSessionPrefixMatchingOrder 测试前缀匹配的优先级(最长匹配优先)
|
||||
func TestGeminiSessionPrefixMatchingOrder(t *testing.T) {
|
||||
cache := newMockGeminiSessionCache()
|
||||
store := NewDigestSessionStore()
|
||||
groupID := int64(1)
|
||||
prefixHash := "test_prefix_hash"
|
||||
|
||||
// 创建一个三轮对话
|
||||
req := &antigravity.GeminiRequest{
|
||||
SystemInstruction: &antigravity.GeminiContent{
|
||||
Parts: []antigravity.GeminiPart{{Text: "System prompt"}},
|
||||
},
|
||||
Contents: []antigravity.GeminiContent{
|
||||
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "Q1"}}},
|
||||
{Role: "model", Parts: []antigravity.GeminiPart{{Text: "A1"}}},
|
||||
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "Q2"}}},
|
||||
},
|
||||
}
|
||||
fullChain := BuildGeminiDigestChain(req)
|
||||
prefixes := GenerateDigestChainPrefixes(fullChain)
|
||||
|
||||
t.Logf("Full chain: %s", fullChain)
|
||||
t.Logf("Prefixes (longest first): %v", prefixes)
|
||||
|
||||
// 验证前缀生成顺序(从长到短)
|
||||
if len(prefixes) != 4 {
|
||||
t.Errorf("Expected 4 prefixes, got %d", len(prefixes))
|
||||
}
|
||||
|
||||
// 保存不同轮次的会话到不同账号
|
||||
// 第一轮(最短前缀)-> 账号 1
|
||||
cache.Save(groupID, prefixHash, prefixes[3], "session-round1", 1)
|
||||
// 第二轮 -> 账号 2
|
||||
cache.Save(groupID, prefixHash, prefixes[2], "session-round2", 2)
|
||||
// 第三轮(最长前缀,完整链)-> 账号 3
|
||||
cache.Save(groupID, prefixHash, prefixes[0], "session-round3", 3)
|
||||
store.Save(groupID, prefixHash, "s:sys-u:q1", "session-round1", 1, "")
|
||||
store.Save(groupID, prefixHash, "s:sys-u:q1-m:a1", "session-round2", 2, "")
|
||||
store.Save(groupID, prefixHash, "s:sys-u:q1-m:a1-u:q2", "session-round3", 3, "")
|
||||
|
||||
// 查找应该返回最长匹配(账号 3)
|
||||
_, accID, found := cache.Find(groupID, prefixHash, fullChain)
|
||||
// 查找更长的链,应该返回最长匹配(账号 3)
|
||||
_, accID, _, found := store.Find(groupID, prefixHash, "s:sys-u:q1-m:a1-u:q2-m:a2")
|
||||
if !found {
|
||||
t.Error("Should find session")
|
||||
}
|
||||
if accID != 3 {
|
||||
t.Errorf("Should match longest prefix (account 3), got account %d", accID)
|
||||
}
|
||||
|
||||
t.Log("✓ Longest prefix matching works correctly!")
|
||||
}
|
||||
|
||||
// 确保 context 包被使用(避免未使用的导入警告)
|
||||
var _ = context.Background
|
||||
|
||||
@@ -152,61 +152,6 @@ func TestGenerateGeminiPrefixHash(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateDigestChainPrefixes(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
chain string
|
||||
want []string
|
||||
wantLen int
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
chain: "",
|
||||
wantLen: 0,
|
||||
},
|
||||
{
|
||||
name: "single part",
|
||||
chain: "u:abc123",
|
||||
want: []string{"u:abc123"},
|
||||
wantLen: 1,
|
||||
},
|
||||
{
|
||||
name: "two parts",
|
||||
chain: "s:xyz-u:abc",
|
||||
want: []string{"s:xyz-u:abc", "s:xyz"},
|
||||
wantLen: 2,
|
||||
},
|
||||
{
|
||||
name: "four parts",
|
||||
chain: "s:a-u:b-m:c-u:d",
|
||||
want: []string{"s:a-u:b-m:c-u:d", "s:a-u:b-m:c", "s:a-u:b", "s:a"},
|
||||
wantLen: 4,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := GenerateDigestChainPrefixes(tt.chain)
|
||||
|
||||
if len(result) != tt.wantLen {
|
||||
t.Errorf("expected %d prefixes, got %d: %v", tt.wantLen, len(result), result)
|
||||
}
|
||||
|
||||
if tt.want != nil {
|
||||
for i, want := range tt.want {
|
||||
if i >= len(result) {
|
||||
t.Errorf("missing prefix at index %d", i)
|
||||
continue
|
||||
}
|
||||
if result[i] != want {
|
||||
t.Errorf("prefix[%d]: expected %s, got %s", i, want, result[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseGeminiSessionValue(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -442,40 +387,3 @@ func TestGenerateGeminiDigestSessionKey(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestBuildGeminiTrieKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
groupID int64
|
||||
prefixHash string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "normal",
|
||||
groupID: 123,
|
||||
prefixHash: "abcdef12",
|
||||
want: "gemini:trie:123:abcdef12",
|
||||
},
|
||||
{
|
||||
name: "zero group",
|
||||
groupID: 0,
|
||||
prefixHash: "xyz",
|
||||
want: "gemini:trie:0:xyz",
|
||||
},
|
||||
{
|
||||
name: "empty prefix",
|
||||
groupID: 1,
|
||||
prefixHash: "",
|
||||
want: "gemini:trie:1:",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := BuildGeminiTrieKey(tt.groupID, tt.prefixHash)
|
||||
if got != tt.want {
|
||||
t.Errorf("BuildGeminiTrieKey(%d, %q) = %q, want %q", tt.groupID, tt.prefixHash, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
1213
backend/internal/service/generate_session_hash_test.go
Normal file
1213
backend/internal/service/generate_session_hash_test.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -318,110 +318,6 @@ func TestGetModelRateLimitRemainingTime(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetQuotaScopeRateLimitRemainingTime(t *testing.T) {
|
||||
now := time.Now()
|
||||
future10m := now.Add(10 * time.Minute).Format(time.RFC3339)
|
||||
past := now.Add(-10 * time.Minute).Format(time.RFC3339)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
account *Account
|
||||
requestedModel string
|
||||
minExpected time.Duration
|
||||
maxExpected time.Duration
|
||||
}{
|
||||
{
|
||||
name: "nil account",
|
||||
account: nil,
|
||||
requestedModel: "claude-sonnet-4-5",
|
||||
minExpected: 0,
|
||||
maxExpected: 0,
|
||||
},
|
||||
{
|
||||
name: "non-antigravity platform",
|
||||
account: &Account{
|
||||
Platform: PlatformAnthropic,
|
||||
Extra: map[string]any{
|
||||
antigravityQuotaScopesKey: map[string]any{
|
||||
"claude": map[string]any{
|
||||
"rate_limit_reset_at": future10m,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
requestedModel: "claude-sonnet-4-5",
|
||||
minExpected: 0,
|
||||
maxExpected: 0,
|
||||
},
|
||||
{
|
||||
name: "claude scope rate limited",
|
||||
account: &Account{
|
||||
Platform: PlatformAntigravity,
|
||||
Extra: map[string]any{
|
||||
antigravityQuotaScopesKey: map[string]any{
|
||||
"claude": map[string]any{
|
||||
"rate_limit_reset_at": future10m,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
requestedModel: "claude-sonnet-4-5",
|
||||
minExpected: 9 * time.Minute,
|
||||
maxExpected: 11 * time.Minute,
|
||||
},
|
||||
{
|
||||
name: "gemini_text scope rate limited",
|
||||
account: &Account{
|
||||
Platform: PlatformAntigravity,
|
||||
Extra: map[string]any{
|
||||
antigravityQuotaScopesKey: map[string]any{
|
||||
"gemini_text": map[string]any{
|
||||
"rate_limit_reset_at": future10m,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
requestedModel: "gemini-3-flash",
|
||||
minExpected: 9 * time.Minute,
|
||||
maxExpected: 11 * time.Minute,
|
||||
},
|
||||
{
|
||||
name: "expired scope rate limit",
|
||||
account: &Account{
|
||||
Platform: PlatformAntigravity,
|
||||
Extra: map[string]any{
|
||||
antigravityQuotaScopesKey: map[string]any{
|
||||
"claude": map[string]any{
|
||||
"rate_limit_reset_at": past,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
requestedModel: "claude-sonnet-4-5",
|
||||
minExpected: 0,
|
||||
maxExpected: 0,
|
||||
},
|
||||
{
|
||||
name: "unsupported model",
|
||||
account: &Account{
|
||||
Platform: PlatformAntigravity,
|
||||
},
|
||||
requestedModel: "gpt-4",
|
||||
minExpected: 0,
|
||||
maxExpected: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := tt.account.GetQuotaScopeRateLimitRemainingTime(tt.requestedModel)
|
||||
if result < tt.minExpected || result > tt.maxExpected {
|
||||
t.Errorf("GetQuotaScopeRateLimitRemainingTime() = %v, want between %v and %v", result, tt.minExpected, tt.maxExpected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetRateLimitRemainingTime(t *testing.T) {
|
||||
now := time.Now()
|
||||
future15m := now.Add(15 * time.Minute).Format(time.RFC3339)
|
||||
@@ -442,45 +338,19 @@ func TestGetRateLimitRemainingTime(t *testing.T) {
|
||||
maxExpected: 0,
|
||||
},
|
||||
{
|
||||
name: "model remaining > scope remaining - returns model",
|
||||
name: "model rate limited - 15 minutes",
|
||||
account: &Account{
|
||||
Platform: PlatformAntigravity,
|
||||
Extra: map[string]any{
|
||||
modelRateLimitsKey: map[string]any{
|
||||
"claude-sonnet-4-5": map[string]any{
|
||||
"rate_limit_reset_at": future15m, // 15 分钟
|
||||
},
|
||||
},
|
||||
antigravityQuotaScopesKey: map[string]any{
|
||||
"claude": map[string]any{
|
||||
"rate_limit_reset_at": future5m, // 5 分钟
|
||||
"rate_limit_reset_at": future15m,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
requestedModel: "claude-sonnet-4-5",
|
||||
minExpected: 14 * time.Minute, // 应返回较大的 15 分钟
|
||||
maxExpected: 16 * time.Minute,
|
||||
},
|
||||
{
|
||||
name: "scope remaining > model remaining - returns scope",
|
||||
account: &Account{
|
||||
Platform: PlatformAntigravity,
|
||||
Extra: map[string]any{
|
||||
modelRateLimitsKey: map[string]any{
|
||||
"claude-sonnet-4-5": map[string]any{
|
||||
"rate_limit_reset_at": future5m, // 5 分钟
|
||||
},
|
||||
},
|
||||
antigravityQuotaScopesKey: map[string]any{
|
||||
"claude": map[string]any{
|
||||
"rate_limit_reset_at": future15m, // 15 分钟
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
requestedModel: "claude-sonnet-4-5",
|
||||
minExpected: 14 * time.Minute, // 应返回较大的 15 分钟
|
||||
minExpected: 14 * time.Minute,
|
||||
maxExpected: 16 * time.Minute,
|
||||
},
|
||||
{
|
||||
@@ -499,22 +369,6 @@ func TestGetRateLimitRemainingTime(t *testing.T) {
|
||||
minExpected: 4 * time.Minute,
|
||||
maxExpected: 6 * time.Minute,
|
||||
},
|
||||
{
|
||||
name: "only scope rate limited",
|
||||
account: &Account{
|
||||
Platform: PlatformAntigravity,
|
||||
Extra: map[string]any{
|
||||
antigravityQuotaScopesKey: map[string]any{
|
||||
"claude": map[string]any{
|
||||
"rate_limit_reset_at": future5m,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
requestedModel: "claude-sonnet-4-5",
|
||||
minExpected: 4 * time.Minute,
|
||||
maxExpected: 6 * time.Minute,
|
||||
},
|
||||
{
|
||||
name: "neither rate limited",
|
||||
account: &Account{
|
||||
|
||||
@@ -580,10 +580,6 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
|
||||
}
|
||||
}
|
||||
} else {
|
||||
type accountWithLoad struct {
|
||||
account *Account
|
||||
loadInfo *AccountLoadInfo
|
||||
}
|
||||
var available []accountWithLoad
|
||||
for _, acc := range candidates {
|
||||
loadInfo := loadMap[acc.ID]
|
||||
@@ -618,6 +614,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
|
||||
return a.account.LastUsedAt.Before(*b.account.LastUsedAt)
|
||||
}
|
||||
})
|
||||
shuffleWithinSortGroups(available)
|
||||
|
||||
for _, item := range available {
|
||||
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
|
||||
|
||||
@@ -204,30 +204,6 @@ func (c *stubGatewayCache) DeleteSessionAccountID(ctx context.Context, groupID i
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *stubGatewayCache) IncrModelCallCount(ctx context.Context, accountID int64, model string) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (c *stubGatewayCache) GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*ModelLoadInfo, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (c *stubGatewayCache) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
|
||||
return "", 0, false
|
||||
}
|
||||
|
||||
func (c *stubGatewayCache) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *stubGatewayCache) FindAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
|
||||
return "", 0, false
|
||||
}
|
||||
|
||||
func (c *stubGatewayCache) SaveAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulable(t *testing.T) {
|
||||
now := time.Now()
|
||||
resetAt := now.Add(10 * time.Minute)
|
||||
|
||||
@@ -66,7 +66,6 @@ func (s *OpsService) GetAccountAvailabilityStats(ctx context.Context, platformFi
|
||||
}
|
||||
|
||||
isAvailable := acc.Status == StatusActive && acc.Schedulable && !isRateLimited && !isOverloaded && !isTempUnsched
|
||||
scopeRateLimits := acc.GetAntigravityScopeRateLimits()
|
||||
|
||||
if acc.Platform != "" {
|
||||
if _, ok := platform[acc.Platform]; !ok {
|
||||
@@ -85,14 +84,6 @@ func (s *OpsService) GetAccountAvailabilityStats(ctx context.Context, platformFi
|
||||
if hasError {
|
||||
p.ErrorCount++
|
||||
}
|
||||
if len(scopeRateLimits) > 0 {
|
||||
if p.ScopeRateLimitCount == nil {
|
||||
p.ScopeRateLimitCount = make(map[string]int64)
|
||||
}
|
||||
for scope := range scopeRateLimits {
|
||||
p.ScopeRateLimitCount[scope]++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, grp := range acc.Groups {
|
||||
@@ -117,14 +108,6 @@ func (s *OpsService) GetAccountAvailabilityStats(ctx context.Context, platformFi
|
||||
if hasError {
|
||||
g.ErrorCount++
|
||||
}
|
||||
if len(scopeRateLimits) > 0 {
|
||||
if g.ScopeRateLimitCount == nil {
|
||||
g.ScopeRateLimitCount = make(map[string]int64)
|
||||
}
|
||||
for scope := range scopeRateLimits {
|
||||
g.ScopeRateLimitCount[scope]++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
displayGroupID := int64(0)
|
||||
@@ -157,9 +140,6 @@ func (s *OpsService) GetAccountAvailabilityStats(ctx context.Context, platformFi
|
||||
item.RateLimitRemainingSec = &remainingSec
|
||||
}
|
||||
}
|
||||
if len(scopeRateLimits) > 0 {
|
||||
item.ScopeRateLimits = scopeRateLimits
|
||||
}
|
||||
if isOverloaded && acc.OverloadUntil != nil {
|
||||
item.OverloadUntil = acc.OverloadUntil
|
||||
remainingSec := int64(time.Until(*acc.OverloadUntil).Seconds())
|
||||
|
||||
@@ -50,24 +50,22 @@ type UserConcurrencyInfo struct {
|
||||
|
||||
// PlatformAvailability aggregates account availability by platform.
|
||||
type PlatformAvailability struct {
|
||||
Platform string `json:"platform"`
|
||||
TotalAccounts int64 `json:"total_accounts"`
|
||||
AvailableCount int64 `json:"available_count"`
|
||||
RateLimitCount int64 `json:"rate_limit_count"`
|
||||
ScopeRateLimitCount map[string]int64 `json:"scope_rate_limit_count,omitempty"`
|
||||
ErrorCount int64 `json:"error_count"`
|
||||
Platform string `json:"platform"`
|
||||
TotalAccounts int64 `json:"total_accounts"`
|
||||
AvailableCount int64 `json:"available_count"`
|
||||
RateLimitCount int64 `json:"rate_limit_count"`
|
||||
ErrorCount int64 `json:"error_count"`
|
||||
}
|
||||
|
||||
// GroupAvailability aggregates account availability by group.
|
||||
type GroupAvailability struct {
|
||||
GroupID int64 `json:"group_id"`
|
||||
GroupName string `json:"group_name"`
|
||||
Platform string `json:"platform"`
|
||||
TotalAccounts int64 `json:"total_accounts"`
|
||||
AvailableCount int64 `json:"available_count"`
|
||||
RateLimitCount int64 `json:"rate_limit_count"`
|
||||
ScopeRateLimitCount map[string]int64 `json:"scope_rate_limit_count,omitempty"`
|
||||
ErrorCount int64 `json:"error_count"`
|
||||
GroupID int64 `json:"group_id"`
|
||||
GroupName string `json:"group_name"`
|
||||
Platform string `json:"platform"`
|
||||
TotalAccounts int64 `json:"total_accounts"`
|
||||
AvailableCount int64 `json:"available_count"`
|
||||
RateLimitCount int64 `json:"rate_limit_count"`
|
||||
ErrorCount int64 `json:"error_count"`
|
||||
}
|
||||
|
||||
// AccountAvailability represents current availability for a single account.
|
||||
@@ -85,11 +83,10 @@ type AccountAvailability struct {
|
||||
IsOverloaded bool `json:"is_overloaded"`
|
||||
HasError bool `json:"has_error"`
|
||||
|
||||
RateLimitResetAt *time.Time `json:"rate_limit_reset_at"`
|
||||
RateLimitRemainingSec *int64 `json:"rate_limit_remaining_sec"`
|
||||
ScopeRateLimits map[string]int64 `json:"scope_rate_limits,omitempty"`
|
||||
OverloadUntil *time.Time `json:"overload_until"`
|
||||
OverloadRemainingSec *int64 `json:"overload_remaining_sec"`
|
||||
ErrorMessage string `json:"error_message"`
|
||||
TempUnschedulableUntil *time.Time `json:"temp_unschedulable_until,omitempty"`
|
||||
RateLimitResetAt *time.Time `json:"rate_limit_reset_at"`
|
||||
RateLimitRemainingSec *int64 `json:"rate_limit_remaining_sec"`
|
||||
OverloadUntil *time.Time `json:"overload_until"`
|
||||
OverloadRemainingSec *int64 `json:"overload_remaining_sec"`
|
||||
ErrorMessage string `json:"error_message"`
|
||||
TempUnschedulableUntil *time.Time `json:"temp_unschedulable_until,omitempty"`
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/domain"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -528,7 +529,7 @@ func (s *OpsService) selectAccountForRetry(ctx context.Context, reqType opsRetry
|
||||
func extractRetryModelAndStream(reqType opsRetryRequestType, errorLog *OpsErrorLogDetail, body []byte) (model string, stream bool, err error) {
|
||||
switch reqType {
|
||||
case opsRetryTypeMessages:
|
||||
parsed, parseErr := ParseGatewayRequest(body)
|
||||
parsed, parseErr := ParseGatewayRequest(body, domain.PlatformAnthropic)
|
||||
if parseErr != nil {
|
||||
return "", false, fmt.Errorf("failed to parse messages request body: %w", parseErr)
|
||||
}
|
||||
@@ -596,7 +597,7 @@ func (s *OpsService) executeWithAccount(ctx context.Context, reqType opsRetryReq
|
||||
if s.gatewayService == nil {
|
||||
return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "gateway service not available"}
|
||||
}
|
||||
parsedReq, parseErr := ParseGatewayRequest(body)
|
||||
parsedReq, parseErr := ParseGatewayRequest(body, domain.PlatformAnthropic)
|
||||
if parseErr != nil {
|
||||
return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "failed to parse request body"}
|
||||
}
|
||||
|
||||
@@ -62,6 +62,32 @@ func (s *RateLimitService) SetTokenCacheInvalidator(invalidator TokenCacheInvali
|
||||
s.tokenCacheInvalidator = invalidator
|
||||
}
|
||||
|
||||
// ErrorPolicyResult 表示错误策略检查的结果
|
||||
type ErrorPolicyResult int
|
||||
|
||||
const (
|
||||
ErrorPolicyNone ErrorPolicyResult = iota // 未命中任何策略,继续默认逻辑
|
||||
ErrorPolicySkipped // 自定义错误码开启但未命中,跳过处理
|
||||
ErrorPolicyMatched // 自定义错误码命中,应停止调度
|
||||
ErrorPolicyTempUnscheduled // 临时不可调度规则命中
|
||||
)
|
||||
|
||||
// CheckErrorPolicy 检查自定义错误码和临时不可调度规则。
|
||||
// 自定义错误码开启时覆盖后续所有逻辑(包括临时不可调度)。
|
||||
func (s *RateLimitService) CheckErrorPolicy(ctx context.Context, account *Account, statusCode int, responseBody []byte) ErrorPolicyResult {
|
||||
if account.IsCustomErrorCodesEnabled() {
|
||||
if account.ShouldHandleErrorCode(statusCode) {
|
||||
return ErrorPolicyMatched
|
||||
}
|
||||
slog.Info("account_error_code_skipped", "account_id", account.ID, "status_code", statusCode)
|
||||
return ErrorPolicySkipped
|
||||
}
|
||||
if s.tryTempUnschedulable(ctx, account, statusCode, responseBody) {
|
||||
return ErrorPolicyTempUnscheduled
|
||||
}
|
||||
return ErrorPolicyNone
|
||||
}
|
||||
|
||||
// HandleUpstreamError 处理上游错误响应,标记账号状态
|
||||
// 返回是否应该停止该账号的调度
|
||||
func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, responseBody []byte) (shouldDisable bool) {
|
||||
|
||||
318
backend/internal/service/scheduler_shuffle_test.go
Normal file
318
backend/internal/service/scheduler_shuffle_test.go
Normal file
@@ -0,0 +1,318 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// ============ shuffleWithinSortGroups 测试 ============
|
||||
|
||||
func TestShuffleWithinSortGroups_Empty(t *testing.T) {
|
||||
shuffleWithinSortGroups(nil)
|
||||
shuffleWithinSortGroups([]accountWithLoad{})
|
||||
}
|
||||
|
||||
func TestShuffleWithinSortGroups_SingleElement(t *testing.T) {
|
||||
accounts := []accountWithLoad{
|
||||
{account: &Account{ID: 1, Priority: 1}, loadInfo: &AccountLoadInfo{LoadRate: 10}},
|
||||
}
|
||||
shuffleWithinSortGroups(accounts)
|
||||
require.Equal(t, int64(1), accounts[0].account.ID)
|
||||
}
|
||||
|
||||
func TestShuffleWithinSortGroups_DifferentGroups_OrderPreserved(t *testing.T) {
|
||||
now := time.Now()
|
||||
earlier := now.Add(-1 * time.Hour)
|
||||
|
||||
accounts := []accountWithLoad{
|
||||
{account: &Account{ID: 1, Priority: 1, LastUsedAt: &earlier}, loadInfo: &AccountLoadInfo{LoadRate: 10}},
|
||||
{account: &Account{ID: 2, Priority: 1, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{LoadRate: 20}},
|
||||
{account: &Account{ID: 3, Priority: 2, LastUsedAt: &earlier}, loadInfo: &AccountLoadInfo{LoadRate: 10}},
|
||||
}
|
||||
|
||||
// 每个元素都属于不同组(Priority 或 LoadRate 或 LastUsedAt 不同),顺序不变
|
||||
for i := 0; i < 20; i++ {
|
||||
cpy := make([]accountWithLoad, len(accounts))
|
||||
copy(cpy, accounts)
|
||||
shuffleWithinSortGroups(cpy)
|
||||
require.Equal(t, int64(1), cpy[0].account.ID)
|
||||
require.Equal(t, int64(2), cpy[1].account.ID)
|
||||
require.Equal(t, int64(3), cpy[2].account.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestShuffleWithinSortGroups_SameGroup_Shuffled(t *testing.T) {
|
||||
now := time.Now()
|
||||
// 同一秒的时间戳视为同一组
|
||||
sameSecond := time.Unix(now.Unix(), 0)
|
||||
sameSecond2 := time.Unix(now.Unix(), 500_000_000) // 同一秒但不同纳秒
|
||||
|
||||
accounts := []accountWithLoad{
|
||||
{account: &Account{ID: 1, Priority: 1, LastUsedAt: &sameSecond}, loadInfo: &AccountLoadInfo{LoadRate: 10}},
|
||||
{account: &Account{ID: 2, Priority: 1, LastUsedAt: &sameSecond2}, loadInfo: &AccountLoadInfo{LoadRate: 10}},
|
||||
{account: &Account{ID: 3, Priority: 1, LastUsedAt: &sameSecond}, loadInfo: &AccountLoadInfo{LoadRate: 10}},
|
||||
}
|
||||
|
||||
// 多次执行,验证所有 ID 都出现在第一个位置(说明确实被打乱了)
|
||||
seen := map[int64]bool{}
|
||||
for i := 0; i < 100; i++ {
|
||||
cpy := make([]accountWithLoad, len(accounts))
|
||||
copy(cpy, accounts)
|
||||
shuffleWithinSortGroups(cpy)
|
||||
seen[cpy[0].account.ID] = true
|
||||
// 无论怎么打乱,所有 ID 都应在候选中
|
||||
ids := map[int64]bool{}
|
||||
for _, a := range cpy {
|
||||
ids[a.account.ID] = true
|
||||
}
|
||||
require.True(t, ids[1] && ids[2] && ids[3])
|
||||
}
|
||||
// 至少 2 个不同的 ID 出现在首位(随机性验证)
|
||||
require.GreaterOrEqual(t, len(seen), 2, "shuffle should produce different orderings")
|
||||
}
|
||||
|
||||
func TestShuffleWithinSortGroups_NilLastUsedAt_SameGroup(t *testing.T) {
|
||||
accounts := []accountWithLoad{
|
||||
{account: &Account{ID: 1, Priority: 1, LastUsedAt: nil}, loadInfo: &AccountLoadInfo{LoadRate: 0}},
|
||||
{account: &Account{ID: 2, Priority: 1, LastUsedAt: nil}, loadInfo: &AccountLoadInfo{LoadRate: 0}},
|
||||
{account: &Account{ID: 3, Priority: 1, LastUsedAt: nil}, loadInfo: &AccountLoadInfo{LoadRate: 0}},
|
||||
}
|
||||
|
||||
seen := map[int64]bool{}
|
||||
for i := 0; i < 100; i++ {
|
||||
cpy := make([]accountWithLoad, len(accounts))
|
||||
copy(cpy, accounts)
|
||||
shuffleWithinSortGroups(cpy)
|
||||
seen[cpy[0].account.ID] = true
|
||||
}
|
||||
require.GreaterOrEqual(t, len(seen), 2, "nil LastUsedAt accounts should be shuffled")
|
||||
}
|
||||
|
||||
func TestShuffleWithinSortGroups_MixedGroups(t *testing.T) {
|
||||
now := time.Now()
|
||||
earlier := now.Add(-1 * time.Hour)
|
||||
sameAsNow := time.Unix(now.Unix(), 0)
|
||||
|
||||
// 组1: Priority=1, LoadRate=10, LastUsedAt=earlier (ID 1) — 单元素组
|
||||
// 组2: Priority=1, LoadRate=20, LastUsedAt=now (ID 2, 3) — 双元素组
|
||||
// 组3: Priority=2, LoadRate=10, LastUsedAt=earlier (ID 4) — 单元素组
|
||||
accounts := []accountWithLoad{
|
||||
{account: &Account{ID: 1, Priority: 1, LastUsedAt: &earlier}, loadInfo: &AccountLoadInfo{LoadRate: 10}},
|
||||
{account: &Account{ID: 2, Priority: 1, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{LoadRate: 20}},
|
||||
{account: &Account{ID: 3, Priority: 1, LastUsedAt: &sameAsNow}, loadInfo: &AccountLoadInfo{LoadRate: 20}},
|
||||
{account: &Account{ID: 4, Priority: 2, LastUsedAt: &earlier}, loadInfo: &AccountLoadInfo{LoadRate: 10}},
|
||||
}
|
||||
|
||||
for i := 0; i < 20; i++ {
|
||||
cpy := make([]accountWithLoad, len(accounts))
|
||||
copy(cpy, accounts)
|
||||
shuffleWithinSortGroups(cpy)
|
||||
|
||||
// 组间顺序不变
|
||||
require.Equal(t, int64(1), cpy[0].account.ID, "group 1 position fixed")
|
||||
require.Equal(t, int64(4), cpy[3].account.ID, "group 3 position fixed")
|
||||
|
||||
// 组2 内部可以打乱,但仍在位置 1 和 2
|
||||
mid := map[int64]bool{cpy[1].account.ID: true, cpy[2].account.ID: true}
|
||||
require.True(t, mid[2] && mid[3], "group 2 elements should stay in positions 1-2")
|
||||
}
|
||||
}
|
||||
|
||||
// ============ shuffleWithinPriorityAndLastUsed 测试 ============
|
||||
|
||||
func TestShuffleWithinPriorityAndLastUsed_Empty(t *testing.T) {
|
||||
shuffleWithinPriorityAndLastUsed(nil)
|
||||
shuffleWithinPriorityAndLastUsed([]*Account{})
|
||||
}
|
||||
|
||||
func TestShuffleWithinPriorityAndLastUsed_SingleElement(t *testing.T) {
|
||||
accounts := []*Account{{ID: 1, Priority: 1}}
|
||||
shuffleWithinPriorityAndLastUsed(accounts)
|
||||
require.Equal(t, int64(1), accounts[0].ID)
|
||||
}
|
||||
|
||||
func TestShuffleWithinPriorityAndLastUsed_SameGroup_Shuffled(t *testing.T) {
|
||||
accounts := []*Account{
|
||||
{ID: 1, Priority: 1, LastUsedAt: nil},
|
||||
{ID: 2, Priority: 1, LastUsedAt: nil},
|
||||
{ID: 3, Priority: 1, LastUsedAt: nil},
|
||||
}
|
||||
|
||||
seen := map[int64]bool{}
|
||||
for i := 0; i < 100; i++ {
|
||||
cpy := make([]*Account, len(accounts))
|
||||
copy(cpy, accounts)
|
||||
shuffleWithinPriorityAndLastUsed(cpy)
|
||||
seen[cpy[0].ID] = true
|
||||
}
|
||||
require.GreaterOrEqual(t, len(seen), 2, "same group should be shuffled")
|
||||
}
|
||||
|
||||
func TestShuffleWithinPriorityAndLastUsed_DifferentPriority_OrderPreserved(t *testing.T) {
|
||||
accounts := []*Account{
|
||||
{ID: 1, Priority: 1, LastUsedAt: nil},
|
||||
{ID: 2, Priority: 2, LastUsedAt: nil},
|
||||
{ID: 3, Priority: 3, LastUsedAt: nil},
|
||||
}
|
||||
|
||||
for i := 0; i < 20; i++ {
|
||||
cpy := make([]*Account, len(accounts))
|
||||
copy(cpy, accounts)
|
||||
shuffleWithinPriorityAndLastUsed(cpy)
|
||||
require.Equal(t, int64(1), cpy[0].ID)
|
||||
require.Equal(t, int64(2), cpy[1].ID)
|
||||
require.Equal(t, int64(3), cpy[2].ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestShuffleWithinPriorityAndLastUsed_DifferentLastUsedAt_OrderPreserved(t *testing.T) {
|
||||
now := time.Now()
|
||||
earlier := now.Add(-1 * time.Hour)
|
||||
|
||||
accounts := []*Account{
|
||||
{ID: 1, Priority: 1, LastUsedAt: nil},
|
||||
{ID: 2, Priority: 1, LastUsedAt: &earlier},
|
||||
{ID: 3, Priority: 1, LastUsedAt: &now},
|
||||
}
|
||||
|
||||
for i := 0; i < 20; i++ {
|
||||
cpy := make([]*Account, len(accounts))
|
||||
copy(cpy, accounts)
|
||||
shuffleWithinPriorityAndLastUsed(cpy)
|
||||
require.Equal(t, int64(1), cpy[0].ID)
|
||||
require.Equal(t, int64(2), cpy[1].ID)
|
||||
require.Equal(t, int64(3), cpy[2].ID)
|
||||
}
|
||||
}
|
||||
|
||||
// ============ sameLastUsedAt 测试 ============
|
||||
|
||||
func TestSameLastUsedAt(t *testing.T) {
|
||||
now := time.Now()
|
||||
sameSecond := time.Unix(now.Unix(), 0)
|
||||
sameSecondDiffNano := time.Unix(now.Unix(), 999_999_999)
|
||||
differentSecond := now.Add(1 * time.Second)
|
||||
|
||||
t.Run("both nil", func(t *testing.T) {
|
||||
require.True(t, sameLastUsedAt(nil, nil))
|
||||
})
|
||||
|
||||
t.Run("one nil one not", func(t *testing.T) {
|
||||
require.False(t, sameLastUsedAt(nil, &now))
|
||||
require.False(t, sameLastUsedAt(&now, nil))
|
||||
})
|
||||
|
||||
t.Run("same second different nanoseconds", func(t *testing.T) {
|
||||
require.True(t, sameLastUsedAt(&sameSecond, &sameSecondDiffNano))
|
||||
})
|
||||
|
||||
t.Run("different seconds", func(t *testing.T) {
|
||||
require.False(t, sameLastUsedAt(&now, &differentSecond))
|
||||
})
|
||||
|
||||
t.Run("exact same time", func(t *testing.T) {
|
||||
require.True(t, sameLastUsedAt(&now, &now))
|
||||
})
|
||||
}
|
||||
|
||||
// ============ sameAccountWithLoadGroup 测试 ============
|
||||
|
||||
func TestSameAccountWithLoadGroup(t *testing.T) {
|
||||
now := time.Now()
|
||||
sameSecond := time.Unix(now.Unix(), 0)
|
||||
|
||||
t.Run("same group", func(t *testing.T) {
|
||||
a := accountWithLoad{account: &Account{Priority: 1, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{LoadRate: 10}}
|
||||
b := accountWithLoad{account: &Account{Priority: 1, LastUsedAt: &sameSecond}, loadInfo: &AccountLoadInfo{LoadRate: 10}}
|
||||
require.True(t, sameAccountWithLoadGroup(a, b))
|
||||
})
|
||||
|
||||
t.Run("different priority", func(t *testing.T) {
|
||||
a := accountWithLoad{account: &Account{Priority: 1, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{LoadRate: 10}}
|
||||
b := accountWithLoad{account: &Account{Priority: 2, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{LoadRate: 10}}
|
||||
require.False(t, sameAccountWithLoadGroup(a, b))
|
||||
})
|
||||
|
||||
t.Run("different load rate", func(t *testing.T) {
|
||||
a := accountWithLoad{account: &Account{Priority: 1, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{LoadRate: 10}}
|
||||
b := accountWithLoad{account: &Account{Priority: 1, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{LoadRate: 20}}
|
||||
require.False(t, sameAccountWithLoadGroup(a, b))
|
||||
})
|
||||
|
||||
t.Run("different last used at", func(t *testing.T) {
|
||||
later := now.Add(1 * time.Second)
|
||||
a := accountWithLoad{account: &Account{Priority: 1, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{LoadRate: 10}}
|
||||
b := accountWithLoad{account: &Account{Priority: 1, LastUsedAt: &later}, loadInfo: &AccountLoadInfo{LoadRate: 10}}
|
||||
require.False(t, sameAccountWithLoadGroup(a, b))
|
||||
})
|
||||
|
||||
t.Run("both nil LastUsedAt", func(t *testing.T) {
|
||||
a := accountWithLoad{account: &Account{Priority: 1, LastUsedAt: nil}, loadInfo: &AccountLoadInfo{LoadRate: 0}}
|
||||
b := accountWithLoad{account: &Account{Priority: 1, LastUsedAt: nil}, loadInfo: &AccountLoadInfo{LoadRate: 0}}
|
||||
require.True(t, sameAccountWithLoadGroup(a, b))
|
||||
})
|
||||
}
|
||||
|
||||
// ============ sameAccountGroup 测试 ============
|
||||
|
||||
func TestSameAccountGroup(t *testing.T) {
|
||||
now := time.Now()
|
||||
|
||||
t.Run("same group", func(t *testing.T) {
|
||||
a := &Account{Priority: 1, LastUsedAt: nil}
|
||||
b := &Account{Priority: 1, LastUsedAt: nil}
|
||||
require.True(t, sameAccountGroup(a, b))
|
||||
})
|
||||
|
||||
t.Run("different priority", func(t *testing.T) {
|
||||
a := &Account{Priority: 1, LastUsedAt: nil}
|
||||
b := &Account{Priority: 2, LastUsedAt: nil}
|
||||
require.False(t, sameAccountGroup(a, b))
|
||||
})
|
||||
|
||||
t.Run("different LastUsedAt", func(t *testing.T) {
|
||||
later := now.Add(1 * time.Second)
|
||||
a := &Account{Priority: 1, LastUsedAt: &now}
|
||||
b := &Account{Priority: 1, LastUsedAt: &later}
|
||||
require.False(t, sameAccountGroup(a, b))
|
||||
})
|
||||
}
|
||||
|
||||
// ============ sortAccountsByPriorityAndLastUsed 集成随机化测试 ============
|
||||
|
||||
func TestSortAccountsByPriorityAndLastUsed_WithShuffle(t *testing.T) {
|
||||
t.Run("same priority and nil LastUsedAt are shuffled", func(t *testing.T) {
|
||||
accounts := []*Account{
|
||||
{ID: 1, Priority: 1, LastUsedAt: nil},
|
||||
{ID: 2, Priority: 1, LastUsedAt: nil},
|
||||
{ID: 3, Priority: 1, LastUsedAt: nil},
|
||||
}
|
||||
|
||||
seen := map[int64]bool{}
|
||||
for i := 0; i < 100; i++ {
|
||||
cpy := make([]*Account, len(accounts))
|
||||
copy(cpy, accounts)
|
||||
sortAccountsByPriorityAndLastUsed(cpy, false)
|
||||
seen[cpy[0].ID] = true
|
||||
}
|
||||
require.GreaterOrEqual(t, len(seen), 2, "identical sort keys should produce different orderings after shuffle")
|
||||
})
|
||||
|
||||
t.Run("different priorities still sorted correctly", func(t *testing.T) {
|
||||
now := time.Now()
|
||||
accounts := []*Account{
|
||||
{ID: 3, Priority: 3, LastUsedAt: &now},
|
||||
{ID: 1, Priority: 1, LastUsedAt: &now},
|
||||
{ID: 2, Priority: 2, LastUsedAt: &now},
|
||||
}
|
||||
|
||||
sortAccountsByPriorityAndLastUsed(accounts, false)
|
||||
require.Equal(t, int64(1), accounts[0].ID)
|
||||
require.Equal(t, int64(2), accounts[1].ID)
|
||||
require.Equal(t, int64(3), accounts[2].ID)
|
||||
})
|
||||
}
|
||||
@@ -275,4 +275,5 @@ var ProviderSet = wire.NewSet(
|
||||
NewUsageCache,
|
||||
NewTotpService,
|
||||
NewErrorPassthroughService,
|
||||
NewDigestSessionStore,
|
||||
)
|
||||
|
||||
@@ -47,13 +47,15 @@ services:
|
||||
|
||||
# =======================================================================
|
||||
# Database Configuration (PostgreSQL)
|
||||
# Default: uses local postgres container
|
||||
# External DB: set DATABASE_HOST and DATABASE_SSLMODE in .env
|
||||
# =======================================================================
|
||||
- DATABASE_HOST=postgres
|
||||
- DATABASE_PORT=5432
|
||||
- DATABASE_HOST=${DATABASE_HOST:-postgres}
|
||||
- DATABASE_PORT=${DATABASE_PORT:-5432}
|
||||
- DATABASE_USER=${POSTGRES_USER:-sub2api}
|
||||
- DATABASE_PASSWORD=${POSTGRES_PASSWORD:?POSTGRES_PASSWORD is required}
|
||||
- DATABASE_DBNAME=${POSTGRES_DB:-sub2api}
|
||||
- DATABASE_SSLMODE=disable
|
||||
- DATABASE_SSLMODE=${DATABASE_SSLMODE:-disable}
|
||||
|
||||
# =======================================================================
|
||||
# Redis Configuration
|
||||
@@ -128,8 +130,6 @@ services:
|
||||
# Examples: http://host:port, socks5://host:port
|
||||
- UPDATE_PROXY_URL=${UPDATE_PROXY_URL:-}
|
||||
depends_on:
|
||||
postgres:
|
||||
condition: service_healthy
|
||||
redis:
|
||||
condition: service_healthy
|
||||
networks:
|
||||
@@ -141,35 +141,6 @@ services:
|
||||
retries: 3
|
||||
start_period: 30s
|
||||
|
||||
# ===========================================================================
|
||||
# PostgreSQL Database
|
||||
# ===========================================================================
|
||||
postgres:
|
||||
image: postgres:18-alpine
|
||||
container_name: sub2api-postgres
|
||||
restart: unless-stopped
|
||||
ulimits:
|
||||
nofile:
|
||||
soft: 100000
|
||||
hard: 100000
|
||||
volumes:
|
||||
- postgres_data:/var/lib/postgresql/data
|
||||
environment:
|
||||
- POSTGRES_USER=${POSTGRES_USER:-sub2api}
|
||||
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:?POSTGRES_PASSWORD is required}
|
||||
- POSTGRES_DB=${POSTGRES_DB:-sub2api}
|
||||
- TZ=${TZ:-Asia/Shanghai}
|
||||
networks:
|
||||
- sub2api-network
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER:-sub2api} -d ${POSTGRES_DB:-sub2api}"]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
start_period: 10s
|
||||
# 注意:不暴露端口到宿主机,应用通过内部网络连接
|
||||
# 如需调试,可临时添加:ports: ["127.0.0.1:5433:5432"]
|
||||
|
||||
# ===========================================================================
|
||||
# Redis Cache
|
||||
# ===========================================================================
|
||||
@@ -209,8 +180,6 @@ services:
|
||||
volumes:
|
||||
sub2api_data:
|
||||
driver: local
|
||||
postgres_data:
|
||||
driver: local
|
||||
redis_data:
|
||||
driver: local
|
||||
|
||||
|
||||
BIN
frontend/public/wechat-qr.jpg
Normal file
BIN
frontend/public/wechat-qr.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 148 KiB |
@@ -376,7 +376,6 @@ export interface PlatformAvailability {
|
||||
total_accounts: number
|
||||
available_count: number
|
||||
rate_limit_count: number
|
||||
scope_rate_limit_count?: Record<string, number>
|
||||
error_count: number
|
||||
}
|
||||
|
||||
@@ -387,7 +386,6 @@ export interface GroupAvailability {
|
||||
total_accounts: number
|
||||
available_count: number
|
||||
rate_limit_count: number
|
||||
scope_rate_limit_count?: Record<string, number>
|
||||
error_count: number
|
||||
}
|
||||
|
||||
@@ -402,7 +400,6 @@ export interface AccountAvailability {
|
||||
is_rate_limited: boolean
|
||||
rate_limit_reset_at?: string
|
||||
rate_limit_remaining_sec?: number
|
||||
scope_rate_limits?: Record<string, number>
|
||||
is_overloaded: boolean
|
||||
overload_until?: string
|
||||
overload_remaining_sec?: number
|
||||
|
||||
@@ -76,26 +76,6 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Scope Rate Limit Indicators (Antigravity) -->
|
||||
<template v-if="activeScopeRateLimits.length > 0">
|
||||
<div v-for="item in activeScopeRateLimits" :key="item.scope" class="group relative">
|
||||
<span
|
||||
class="inline-flex items-center gap-1 rounded bg-orange-100 px-1.5 py-0.5 text-xs font-medium text-orange-700 dark:bg-orange-900/30 dark:text-orange-400"
|
||||
>
|
||||
<Icon name="exclamationTriangle" size="xs" :stroke-width="2" />
|
||||
{{ formatScopeName(item.scope) }}
|
||||
</span>
|
||||
<!-- Tooltip -->
|
||||
<div
|
||||
class="pointer-events-none absolute bottom-full left-1/2 z-50 mb-2 -translate-x-1/2 whitespace-nowrap rounded bg-gray-900 px-2 py-1 text-xs text-white opacity-0 transition-opacity group-hover:opacity-100 dark:bg-gray-700"
|
||||
>
|
||||
{{ t('admin.accounts.status.scopeRateLimitedUntil', { scope: formatScopeName(item.scope), time: formatTime(item.reset_at) }) }}
|
||||
<div
|
||||
class="absolute left-1/2 top-full -translate-x-1/2 border-4 border-transparent border-t-gray-900 dark:border-t-gray-700" ></div>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<!-- Model Rate Limit Indicators (Antigravity OAuth Smart Retry) -->
|
||||
<template v-if="activeModelRateLimits.length > 0">
|
||||
<div v-for="item in activeModelRateLimits" :key="item.model" class="group relative">
|
||||
@@ -160,16 +140,6 @@ const isRateLimited = computed(() => {
|
||||
return new Date(props.account.rate_limit_reset_at) > new Date()
|
||||
})
|
||||
|
||||
// Computed: active scope rate limits (Antigravity)
|
||||
const activeScopeRateLimits = computed(() => {
|
||||
const scopeLimits = props.account.scope_rate_limits
|
||||
if (!scopeLimits) return []
|
||||
const now = new Date()
|
||||
return Object.entries(scopeLimits)
|
||||
.filter(([, info]) => new Date(info.reset_at) > now)
|
||||
.map(([scope, info]) => ({ scope, reset_at: info.reset_at }))
|
||||
})
|
||||
|
||||
// Computed: active model rate limits (Antigravity OAuth Smart Retry)
|
||||
const activeModelRateLimits = computed(() => {
|
||||
const modelLimits = (props.account.extra as Record<string, unknown> | undefined)?.model_rate_limits as
|
||||
|
||||
@@ -1038,10 +1038,7 @@
|
||||
</div>
|
||||
|
||||
<!-- Custom Error Codes Section -->
|
||||
<div
|
||||
v-if="form.platform !== 'gemini'"
|
||||
class="border-t border-gray-200 pt-4 dark:border-dark-600"
|
||||
>
|
||||
<div class="border-t border-gray-200 pt-4 dark:border-dark-600">
|
||||
<div class="mb-3 flex items-center justify-between">
|
||||
<div>
|
||||
<label class="input-label mb-0">{{ t('admin.accounts.customErrorCodes') }}</label>
|
||||
|
||||
104
frontend/src/components/common/WechatServiceButton.vue
Normal file
104
frontend/src/components/common/WechatServiceButton.vue
Normal file
@@ -0,0 +1,104 @@
|
||||
<template>
|
||||
<!-- 悬浮按钮 - 使用主题色 -->
|
||||
<button
|
||||
@click="showModal = true"
|
||||
class="fixed bottom-6 right-6 z-50 flex items-center gap-2 rounded-full bg-gradient-to-r from-primary-500 to-primary-600 px-4 py-3 text-white shadow-lg shadow-primary-500/25 transition-all hover:from-primary-600 hover:to-primary-700 hover:shadow-xl hover:shadow-primary-500/30"
|
||||
>
|
||||
<svg class="h-5 w-5" viewBox="0 0 24 24" fill="currentColor">
|
||||
<path d="M8.691 2.188C3.891 2.188 0 5.476 0 9.53c0 2.212 1.17 4.203 3.002 5.55a.59.59 0 01.213.665l-.39 1.48c-.019.07-.048.141-.048.213 0 .163.13.295.29.295a.328.328 0 00.186-.059l2.114-1.225a.87.87 0 01.415-.106.807.807 0 01.213.026 10.07 10.07 0 002.696.37c.262 0 .52-.011.776-.028a5.91 5.91 0 01-.193-1.479c0-3.644 3.374-6.6 7.536-6.6.262 0 .52.011.776.028-.628-3.513-4.27-6.472-8.885-6.472zM5.785 5.97a1.1 1.1 0 110 2.2 1.1 1.1 0 010-2.2zm5.813 0a1.1 1.1 0 110 2.2 1.1 1.1 0 010-2.2zm5.192 2.642c-3.703 0-6.71 2.567-6.71 5.73 0 3.163 3.007 5.73 6.71 5.73a7.9 7.9 0 002.126-.288.644.644 0 01.17-.022.69.69 0 01.329.085l1.672.97a.262.262 0 00.147.046c.128 0 .23-.104.23-.233a.403.403 0 00-.038-.168l-.309-1.17a.468.468 0 01.168-.527c1.449-1.065 2.374-2.643 2.374-4.423 0-3.163-3.007-5.73-6.71-5.73h-.159zm-2.434 3.34a.88.88 0 110 1.76.88.88 0 010-1.76zm4.868 0a.88.88 0 110 1.76.88.88 0 010-1.76z"/>
|
||||
</svg>
|
||||
<span class="text-sm font-medium">客服</span>
|
||||
</button>
|
||||
|
||||
<!-- 弹窗 -->
|
||||
<Teleport to="body">
|
||||
<Transition name="fade">
|
||||
<div
|
||||
v-if="showModal"
|
||||
class="fixed inset-0 z-[100] flex items-center justify-center bg-black/50 p-4 backdrop-blur-sm"
|
||||
@click.self="showModal = false"
|
||||
>
|
||||
<Transition name="scale">
|
||||
<div
|
||||
v-if="showModal"
|
||||
class="relative w-full max-w-sm rounded-2xl bg-white p-6 shadow-2xl dark:bg-dark-700"
|
||||
>
|
||||
<!-- 关闭按钮 -->
|
||||
<button
|
||||
@click="showModal = false"
|
||||
class="absolute right-4 top-4 text-gray-400 transition-colors hover:text-gray-600 dark:text-dark-400 dark:hover:text-dark-200"
|
||||
>
|
||||
<svg class="h-5 w-5" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M6 18L18 6M6 6l12 12" />
|
||||
</svg>
|
||||
</button>
|
||||
|
||||
<!-- 标题 -->
|
||||
<div class="mb-4 flex items-center gap-3">
|
||||
<div class="flex h-10 w-10 items-center justify-center rounded-full bg-gradient-to-br from-primary-500 to-primary-600">
|
||||
<svg class="h-6 w-6 text-white" viewBox="0 0 24 24" fill="currentColor">
|
||||
<path d="M8.691 2.188C3.891 2.188 0 5.476 0 9.53c0 2.212 1.17 4.203 3.002 5.55a.59.59 0 01.213.665l-.39 1.48c-.019.07-.048.141-.048.213 0 .163.13.295.29.295a.328.328 0 00.186-.059l2.114-1.225a.87.87 0 01.415-.106.807.807 0 01.213.026 10.07 10.07 0 002.696.37c.262 0 .52-.011.776-.028a5.91 5.91 0 01-.193-1.479c0-3.644 3.374-6.6 7.536-6.6.262 0 .52.011.776.028-.628-3.513-4.27-6.472-8.885-6.472zM5.785 5.97a1.1 1.1 0 110 2.2 1.1 1.1 0 010-2.2zm5.813 0a1.1 1.1 0 110 2.2 1.1 1.1 0 010-2.2zm5.192 2.642c-3.703 0-6.71 2.567-6.71 5.73 0 3.163 3.007 5.73 6.71 5.73a7.9 7.9 0 002.126-.288.644.644 0 01.17-.022.69.69 0 01.329.085l1.672.97a.262.262 0 00.147.046c.128 0 .23-.104.23-.233a.403.403 0 00-.038-.168l-.309-1.17a.468.468 0 01.168-.527c1.449-1.065 2.374-2.643 2.374-4.423 0-3.163-3.007-5.73-6.71-5.73h-.159zm-2.434 3.34a.88.88 0 110 1.76.88.88 0 010-1.76zm4.868 0a.88.88 0 110 1.76.88.88 0 010-1.76z"/>
|
||||
</svg>
|
||||
</div>
|
||||
<div>
|
||||
<h3 class="text-lg font-semibold text-gray-900 dark:text-white">联系客服</h3>
|
||||
<p class="text-sm text-gray-500 dark:text-dark-400">扫码添加好友</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 二维码卡片 -->
|
||||
<div class="mb-4 overflow-hidden rounded-xl border border-primary-100 bg-gradient-to-br from-primary-50 to-white p-3 dark:border-primary-800/30 dark:from-primary-900/10 dark:to-dark-800">
|
||||
<img
|
||||
src="/wechat-qr.jpg"
|
||||
alt="微信二维码"
|
||||
class="w-full rounded-lg"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<!-- 提示文字 -->
|
||||
<div class="text-center">
|
||||
<p class="mb-2 text-sm font-medium text-primary-600 dark:text-primary-400">
|
||||
微信扫码添加客服
|
||||
</p>
|
||||
<p class="flex items-center justify-center gap-1 text-xs text-gray-500 dark:text-dark-400">
|
||||
<svg class="h-4 w-4" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M12 8v4l3 3m6-3a9 9 0 11-18 0 9 9 0 0118 0z" />
|
||||
</svg>
|
||||
工作时间:周一至周五 9:00-18:00
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</Transition>
|
||||
</div>
|
||||
</Transition>
|
||||
</Teleport>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref } from 'vue'
|
||||
|
||||
const showModal = ref(false)
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.fade-enter-active,
|
||||
.fade-leave-active {
|
||||
transition: opacity 0.2s ease;
|
||||
}
|
||||
|
||||
.fade-enter-from,
|
||||
.fade-leave-to {
|
||||
opacity: 0;
|
||||
}
|
||||
|
||||
.scale-enter-active,
|
||||
.scale-leave-active {
|
||||
transition: all 0.2s ease;
|
||||
}
|
||||
|
||||
.scale-enter-from,
|
||||
.scale-leave-to {
|
||||
opacity: 0;
|
||||
transform: scale(0.95);
|
||||
}
|
||||
</style>
|
||||
@@ -121,23 +121,6 @@
|
||||
<Icon name="key" size="sm" />
|
||||
{{ t('nav.apiKeys') }}
|
||||
</router-link>
|
||||
|
||||
<a
|
||||
href="https://github.com/Wei-Shaw/sub2api"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
@click="closeDropdown"
|
||||
class="dropdown-item"
|
||||
>
|
||||
<svg class="h-4 w-4" fill="currentColor" viewBox="0 0 24 24">
|
||||
<path
|
||||
fill-rule="evenodd"
|
||||
clip-rule="evenodd"
|
||||
d="M12 2C6.477 2 2 6.477 2 12c0 4.42 2.865 8.17 6.839 9.49.5.092.682-.217.682-.482 0-.237-.008-.866-.013-1.7-2.782.604-3.369-1.34-3.369-1.34-.454-1.156-1.11-1.464-1.11-1.464-.908-.62.069-.608.069-.608 1.003.07 1.531 1.03 1.531 1.03.892 1.529 2.341 1.087 2.91.831.092-.646.35-1.086.636-1.336-2.22-.253-4.555-1.11-4.555-4.943 0-1.091.39-1.984 1.029-2.683-.103-.253-.446-1.27.098-2.647 0 0 .84-.269 2.75 1.025A9.578 9.578 0 0112 6.836c.85.004 1.705.114 2.504.336 1.909-1.294 2.747-1.025 2.747-1.025.546 1.377.203 2.394.1 2.647.64.699 1.028 1.592 1.028 2.683 0 3.842-2.339 4.687-4.566 4.935.359.309.678.919.678 1.852 0 1.336-.012 2.415-.012 2.743 0 .267.18.578.688.48C19.138 20.167 22 16.418 22 12c0-5.523-4.477-10-10-10z"
|
||||
/>
|
||||
</svg>
|
||||
{{ t('nav.github') }}
|
||||
</a>
|
||||
</div>
|
||||
|
||||
<!-- Contact Support (only show if configured) -->
|
||||
|
||||
@@ -1356,7 +1356,6 @@ export default {
|
||||
overloaded: 'Overloaded',
|
||||
tempUnschedulable: 'Temp Unschedulable',
|
||||
rateLimitedUntil: 'Rate limited until {time}',
|
||||
scopeRateLimitedUntil: '{scope} rate limited until {time}',
|
||||
modelRateLimitedUntil: '{model} rate limited until {time}',
|
||||
overloadedUntil: 'Overloaded until {time}',
|
||||
viewTempUnschedDetails: 'View temp unschedulable details'
|
||||
@@ -3059,7 +3058,6 @@ export default {
|
||||
empty: 'No data',
|
||||
queued: 'Queue {count}',
|
||||
rateLimited: 'Rate-limited {count}',
|
||||
scopeRateLimitedTooltip: '{scope} rate-limited ({count} accounts)',
|
||||
errorAccounts: 'Errors {count}',
|
||||
loadFailed: 'Failed to load concurrency data'
|
||||
},
|
||||
|
||||
@@ -1492,7 +1492,6 @@ export default {
|
||||
overloaded: '过载中',
|
||||
tempUnschedulable: '临时不可调度',
|
||||
rateLimitedUntil: '限流中,重置时间:{time}',
|
||||
scopeRateLimitedUntil: '{scope} 限流中,重置时间:{time}',
|
||||
modelRateLimitedUntil: '{model} 限流至 {time}',
|
||||
overloadedUntil: '负载过重,重置时间:{time}',
|
||||
viewTempUnschedDetails: '查看临时不可调度详情'
|
||||
@@ -3232,7 +3231,6 @@ export default {
|
||||
empty: '暂无数据',
|
||||
queued: '队列 {count}',
|
||||
rateLimited: '限流 {count}',
|
||||
scopeRateLimitedTooltip: '{scope} 限流中 ({count} 个账号)',
|
||||
errorAccounts: '异常 {count}',
|
||||
loadFailed: '加载并发数据失败'
|
||||
},
|
||||
|
||||
@@ -591,9 +591,6 @@ export interface Account {
|
||||
temp_unschedulable_until: string | null
|
||||
temp_unschedulable_reason: string | null
|
||||
|
||||
// Antigravity scope 级限流状态
|
||||
scope_rate_limits?: Record<string, { reset_at: string; remaining_sec: number }>
|
||||
|
||||
// Session window fields (5-hour window)
|
||||
session_window_start: string | null
|
||||
session_window_end: string | null
|
||||
|
||||
@@ -122,8 +122,11 @@
|
||||
>
|
||||
{{ siteName }}
|
||||
</h1>
|
||||
<p class="mb-8 text-lg text-gray-600 dark:text-dark-300 md:text-xl">
|
||||
{{ siteSubtitle }}
|
||||
<p class="mb-3 text-xl font-semibold text-primary-600 dark:text-primary-400 md:text-2xl">
|
||||
{{ t('home.heroSubtitle') }}
|
||||
</p>
|
||||
<p class="mb-8 text-base text-gray-600 dark:text-dark-300 md:text-lg">
|
||||
{{ t('home.heroDescription') }}
|
||||
</p>
|
||||
|
||||
<!-- CTA Button -->
|
||||
@@ -177,7 +180,7 @@
|
||||
</div>
|
||||
|
||||
<!-- Feature Tags - Centered -->
|
||||
<div class="mb-12 flex flex-wrap items-center justify-center gap-4 md:gap-6">
|
||||
<div class="mb-16 flex flex-wrap items-center justify-center gap-4 md:gap-6">
|
||||
<div
|
||||
class="inline-flex items-center gap-2.5 rounded-full border border-gray-200/50 bg-white/80 px-5 py-2.5 shadow-sm backdrop-blur-sm dark:border-dark-700/50 dark:bg-dark-800/80"
|
||||
>
|
||||
@@ -204,6 +207,63 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Pain Points Section -->
|
||||
<div class="mb-16">
|
||||
<h2 class="mb-8 text-center text-2xl font-bold text-gray-900 dark:text-white md:text-3xl">
|
||||
{{ t('home.painPoints.title') }}
|
||||
</h2>
|
||||
<div class="grid gap-4 sm:grid-cols-2 lg:grid-cols-4">
|
||||
<!-- Pain Point 1: Expensive -->
|
||||
<div class="rounded-xl border border-red-200/50 bg-red-50/50 p-5 dark:border-red-900/30 dark:bg-red-950/20">
|
||||
<div class="mb-3 flex h-10 w-10 items-center justify-center rounded-lg bg-red-100 dark:bg-red-900/30">
|
||||
<svg class="h-5 w-5 text-red-500" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" d="M12 8c-1.657 0-3 .895-3 2s1.343 2 3 2 3 .895 3 2-1.343 2-3 2m0-8c1.11 0 2.08.402 2.599 1M12 8V7m0 1v8m0 0v1m0-1c-1.11 0-2.08-.402-2.599-1M21 12a9 9 0 11-18 0 9 9 0 0118 0z" />
|
||||
</svg>
|
||||
</div>
|
||||
<h3 class="mb-1.5 font-semibold text-gray-900 dark:text-white">{{ t('home.painPoints.items.expensive.title') }}</h3>
|
||||
<p class="text-sm text-gray-600 dark:text-dark-400">{{ t('home.painPoints.items.expensive.desc') }}</p>
|
||||
</div>
|
||||
<!-- Pain Point 2: Complex -->
|
||||
<div class="rounded-xl border border-orange-200/50 bg-orange-50/50 p-5 dark:border-orange-900/30 dark:bg-orange-950/20">
|
||||
<div class="mb-3 flex h-10 w-10 items-center justify-center rounded-lg bg-orange-100 dark:bg-orange-900/30">
|
||||
<svg class="h-5 w-5 text-orange-500" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" d="M19 11H5m14 0a2 2 0 012 2v6a2 2 0 01-2 2H5a2 2 0 01-2-2v-6a2 2 0 012-2m14 0V9a2 2 0 00-2-2M5 11V9a2 2 0 012-2m0 0V5a2 2 0 012-2h6a2 2 0 012 2v2M7 7h10" />
|
||||
</svg>
|
||||
</div>
|
||||
<h3 class="mb-1.5 font-semibold text-gray-900 dark:text-white">{{ t('home.painPoints.items.complex.title') }}</h3>
|
||||
<p class="text-sm text-gray-600 dark:text-dark-400">{{ t('home.painPoints.items.complex.desc') }}</p>
|
||||
</div>
|
||||
<!-- Pain Point 3: Unstable -->
|
||||
<div class="rounded-xl border border-yellow-200/50 bg-yellow-50/50 p-5 dark:border-yellow-900/30 dark:bg-yellow-950/20">
|
||||
<div class="mb-3 flex h-10 w-10 items-center justify-center rounded-lg bg-yellow-100 dark:bg-yellow-900/30">
|
||||
<svg class="h-5 w-5 text-yellow-600" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" d="M12 9v2m0 4h.01m-6.938 4h13.856c1.54 0 2.502-1.667 1.732-3L13.732 4c-.77-1.333-2.694-1.333-3.464 0L3.34 16c-.77 1.333.192 3 1.732 3z" />
|
||||
</svg>
|
||||
</div>
|
||||
<h3 class="mb-1.5 font-semibold text-gray-900 dark:text-white">{{ t('home.painPoints.items.unstable.title') }}</h3>
|
||||
<p class="text-sm text-gray-600 dark:text-dark-400">{{ t('home.painPoints.items.unstable.desc') }}</p>
|
||||
</div>
|
||||
<!-- Pain Point 4: No Control -->
|
||||
<div class="rounded-xl border border-gray-200/50 bg-gray-50/50 p-5 dark:border-dark-700/50 dark:bg-dark-800/50">
|
||||
<div class="mb-3 flex h-10 w-10 items-center justify-center rounded-lg bg-gray-100 dark:bg-dark-700">
|
||||
<svg class="h-5 w-5 text-gray-500" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" d="M18.364 18.364A9 9 0 005.636 5.636m12.728 12.728A9 9 0 015.636 5.636m12.728 12.728L5.636 5.636" />
|
||||
</svg>
|
||||
</div>
|
||||
<h3 class="mb-1.5 font-semibold text-gray-900 dark:text-white">{{ t('home.painPoints.items.noControl.title') }}</h3>
|
||||
<p class="text-sm text-gray-600 dark:text-dark-400">{{ t('home.painPoints.items.noControl.desc') }}</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Solutions Section Title -->
|
||||
<div class="mb-8 text-center">
|
||||
<h2 class="mb-2 text-2xl font-bold text-gray-900 dark:text-white md:text-3xl">
|
||||
{{ t('home.solutions.title') }}
|
||||
</h2>
|
||||
<p class="text-gray-600 dark:text-dark-400">{{ t('home.solutions.subtitle') }}</p>
|
||||
</div>
|
||||
|
||||
<!-- Features Grid -->
|
||||
<div class="mb-12 grid gap-6 md:grid-cols-3">
|
||||
<!-- Feature 1: Unified Gateway -->
|
||||
@@ -369,6 +429,77 @@
|
||||
>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Comparison Table -->
|
||||
<div class="mb-16">
|
||||
<h2 class="mb-8 text-center text-2xl font-bold text-gray-900 dark:text-white md:text-3xl">
|
||||
{{ t('home.comparison.title') }}
|
||||
</h2>
|
||||
<div class="overflow-x-auto">
|
||||
<table class="w-full rounded-xl border border-gray-200/50 bg-white/60 backdrop-blur-sm dark:border-dark-700/50 dark:bg-dark-800/60">
|
||||
<thead>
|
||||
<tr class="border-b border-gray-200/50 dark:border-dark-700/50">
|
||||
<th class="px-6 py-4 text-left text-sm font-semibold text-gray-900 dark:text-white">{{ t('home.comparison.headers.feature') }}</th>
|
||||
<th class="px-6 py-4 text-center text-sm font-semibold text-gray-500 dark:text-dark-400">{{ t('home.comparison.headers.official') }}</th>
|
||||
<th class="px-6 py-4 text-center text-sm font-semibold text-primary-600 dark:text-primary-400">{{ t('home.comparison.headers.us') }}</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody class="divide-y divide-gray-200/50 dark:divide-dark-700/50">
|
||||
<tr>
|
||||
<td class="px-6 py-4 text-sm font-medium text-gray-900 dark:text-white">{{ t('home.comparison.items.pricing.feature') }}</td>
|
||||
<td class="px-6 py-4 text-center text-sm text-gray-500 dark:text-dark-400">{{ t('home.comparison.items.pricing.official') }}</td>
|
||||
<td class="px-6 py-4 text-center text-sm font-medium text-primary-600 dark:text-primary-400">{{ t('home.comparison.items.pricing.us') }}</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="px-6 py-4 text-sm font-medium text-gray-900 dark:text-white">{{ t('home.comparison.items.models.feature') }}</td>
|
||||
<td class="px-6 py-4 text-center text-sm text-gray-500 dark:text-dark-400">{{ t('home.comparison.items.models.official') }}</td>
|
||||
<td class="px-6 py-4 text-center text-sm font-medium text-primary-600 dark:text-primary-400">{{ t('home.comparison.items.models.us') }}</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="px-6 py-4 text-sm font-medium text-gray-900 dark:text-white">{{ t('home.comparison.items.management.feature') }}</td>
|
||||
<td class="px-6 py-4 text-center text-sm text-gray-500 dark:text-dark-400">{{ t('home.comparison.items.management.official') }}</td>
|
||||
<td class="px-6 py-4 text-center text-sm font-medium text-primary-600 dark:text-primary-400">{{ t('home.comparison.items.management.us') }}</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="px-6 py-4 text-sm font-medium text-gray-900 dark:text-white">{{ t('home.comparison.items.stability.feature') }}</td>
|
||||
<td class="px-6 py-4 text-center text-sm text-gray-500 dark:text-dark-400">{{ t('home.comparison.items.stability.official') }}</td>
|
||||
<td class="px-6 py-4 text-center text-sm font-medium text-primary-600 dark:text-primary-400">{{ t('home.comparison.items.stability.us') }}</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="px-6 py-4 text-sm font-medium text-gray-900 dark:text-white">{{ t('home.comparison.items.control.feature') }}</td>
|
||||
<td class="px-6 py-4 text-center text-sm text-gray-500 dark:text-dark-400">{{ t('home.comparison.items.control.official') }}</td>
|
||||
<td class="px-6 py-4 text-center text-sm font-medium text-primary-600 dark:text-primary-400">{{ t('home.comparison.items.control.us') }}</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- CTA Section -->
|
||||
<div class="mb-8 rounded-2xl bg-gradient-to-r from-primary-500 to-primary-600 p-8 text-center shadow-xl shadow-primary-500/20 md:p-12">
|
||||
<h2 class="mb-3 text-2xl font-bold text-white md:text-3xl">
|
||||
{{ t('home.cta.title') }}
|
||||
</h2>
|
||||
<p class="mb-6 text-primary-100">
|
||||
{{ t('home.cta.description') }}
|
||||
</p>
|
||||
<router-link
|
||||
v-if="!isAuthenticated"
|
||||
to="/register"
|
||||
class="inline-flex items-center gap-2 rounded-full bg-white px-8 py-3 font-semibold text-primary-600 shadow-lg transition-all hover:bg-gray-50 hover:shadow-xl"
|
||||
>
|
||||
{{ t('home.cta.button') }}
|
||||
<Icon name="arrowRight" size="md" :stroke-width="2" />
|
||||
</router-link>
|
||||
<router-link
|
||||
v-else
|
||||
:to="dashboardPath"
|
||||
class="inline-flex items-center gap-2 rounded-full bg-white px-8 py-3 font-semibold text-primary-600 shadow-lg transition-all hover:bg-gray-50 hover:shadow-xl"
|
||||
>
|
||||
{{ t('home.goToDashboard') }}
|
||||
<Icon name="arrowRight" size="md" :stroke-width="2" />
|
||||
</router-link>
|
||||
</div>
|
||||
</div>
|
||||
</main>
|
||||
|
||||
@@ -380,27 +511,20 @@
|
||||
<p class="text-sm text-gray-500 dark:text-dark-400">
|
||||
© {{ currentYear }} {{ siteName }}. {{ t('home.footer.allRightsReserved') }}
|
||||
</p>
|
||||
<div class="flex items-center gap-4">
|
||||
<a
|
||||
v-if="docUrl"
|
||||
:href="docUrl"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
class="text-sm text-gray-500 transition-colors hover:text-gray-700 dark:text-dark-400 dark:hover:text-white"
|
||||
>
|
||||
{{ t('home.docs') }}
|
||||
</a>
|
||||
<a
|
||||
:href="githubUrl"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
class="text-sm text-gray-500 transition-colors hover:text-gray-700 dark:text-dark-400 dark:hover:text-white"
|
||||
>
|
||||
GitHub
|
||||
</a>
|
||||
</div>
|
||||
<a
|
||||
v-if="docUrl"
|
||||
:href="docUrl"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
class="text-sm text-gray-500 transition-colors hover:text-gray-700 dark:text-dark-400 dark:hover:text-white"
|
||||
>
|
||||
{{ t('home.docs') }}
|
||||
</a>
|
||||
</div>
|
||||
</footer>
|
||||
|
||||
<!-- 微信客服悬浮按钮 -->
|
||||
<WechatServiceButton />
|
||||
</div>
|
||||
</template>
|
||||
|
||||
@@ -410,6 +534,7 @@ import { useI18n } from 'vue-i18n'
|
||||
import { useAuthStore, useAppStore } from '@/stores'
|
||||
import LocaleSwitcher from '@/components/common/LocaleSwitcher.vue'
|
||||
import Icon from '@/components/icons/Icon.vue'
|
||||
import WechatServiceButton from '@/components/common/WechatServiceButton.vue'
|
||||
|
||||
const { t } = useI18n()
|
||||
|
||||
@@ -419,7 +544,6 @@ const appStore = useAppStore()
|
||||
// Site settings - directly from appStore (already initialized from injected config)
|
||||
const siteName = computed(() => appStore.cachedPublicSettings?.site_name || appStore.siteName || 'Sub2API')
|
||||
const siteLogo = computed(() => appStore.cachedPublicSettings?.site_logo || appStore.siteLogo || '')
|
||||
const siteSubtitle = computed(() => appStore.cachedPublicSettings?.site_subtitle || 'AI API Gateway Platform')
|
||||
const docUrl = computed(() => appStore.cachedPublicSettings?.doc_url || appStore.docUrl || '')
|
||||
const homeContent = computed(() => appStore.cachedPublicSettings?.home_content || '')
|
||||
|
||||
@@ -432,9 +556,6 @@ const isHomeContentUrl = computed(() => {
|
||||
// Theme
|
||||
const isDark = ref(document.documentElement.classList.contains('dark'))
|
||||
|
||||
// GitHub URL
|
||||
const githubUrl = 'https://github.com/Wei-Shaw/sub2api'
|
||||
|
||||
// Auth state
|
||||
const isAuthenticated = computed(() => authStore.isAuthenticated)
|
||||
const isAdmin = computed(() => authStore.isAdmin)
|
||||
|
||||
@@ -56,7 +56,6 @@ interface SummaryRow {
|
||||
total_accounts: number
|
||||
available_accounts: number
|
||||
rate_limited_accounts: number
|
||||
scope_rate_limit_count?: Record<string, number>
|
||||
error_accounts: number
|
||||
// 并发统计
|
||||
total_concurrency: number
|
||||
@@ -122,7 +121,6 @@ const platformRows = computed((): SummaryRow[] => {
|
||||
total_accounts: totalAccounts,
|
||||
available_accounts: availableAccounts,
|
||||
rate_limited_accounts: safeNumber(avail.rate_limit_count),
|
||||
scope_rate_limit_count: avail.scope_rate_limit_count,
|
||||
error_accounts: safeNumber(avail.error_count),
|
||||
total_concurrency: totalConcurrency,
|
||||
used_concurrency: usedConcurrency,
|
||||
@@ -162,7 +160,6 @@ const groupRows = computed((): SummaryRow[] => {
|
||||
total_accounts: totalAccounts,
|
||||
available_accounts: availableAccounts,
|
||||
rate_limited_accounts: safeNumber(avail.rate_limit_count),
|
||||
scope_rate_limit_count: avail.scope_rate_limit_count,
|
||||
error_accounts: safeNumber(avail.error_count),
|
||||
total_concurrency: totalConcurrency,
|
||||
used_concurrency: usedConcurrency,
|
||||
@@ -329,15 +326,6 @@ function formatDuration(seconds: number): string {
|
||||
return `${hours}h`
|
||||
}
|
||||
|
||||
function formatScopeName(scope: string): string {
|
||||
const names: Record<string, string> = {
|
||||
claude: 'Claude',
|
||||
gemini_text: 'Gemini',
|
||||
gemini_image: 'Image'
|
||||
}
|
||||
return names[scope] || scope
|
||||
}
|
||||
|
||||
watch(
|
||||
() => realtimeEnabled.value,
|
||||
async (enabled) => {
|
||||
@@ -505,18 +493,6 @@ watch(
|
||||
{{ t('admin.ops.concurrency.rateLimited', { count: row.rate_limited_accounts }) }}
|
||||
</span>
|
||||
|
||||
<!-- Scope 限流 (仅 Antigravity) -->
|
||||
<template v-if="row.scope_rate_limit_count && Object.keys(row.scope_rate_limit_count).length > 0">
|
||||
<span
|
||||
v-for="(count, scope) in row.scope_rate_limit_count"
|
||||
:key="scope"
|
||||
class="rounded-full bg-orange-100 px-1.5 py-0.5 font-semibold text-orange-700 dark:bg-orange-900/30 dark:text-orange-400"
|
||||
:title="t('admin.ops.concurrency.scopeRateLimitedTooltip', { scope, count })"
|
||||
>
|
||||
{{ formatScopeName(scope as string) }} {{ count }}
|
||||
</span>
|
||||
</template>
|
||||
|
||||
<!-- 异常账号 -->
|
||||
<span
|
||||
v-if="row.error_accounts > 0"
|
||||
|
||||
127
stress_test_gemini_session.sh
Normal file
127
stress_test_gemini_session.sh
Normal file
@@ -0,0 +1,127 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Gemini 粘性会话压力测试脚本
|
||||
# 测试目标:验证不同会话分配不同账号,同一会话保持同一账号
|
||||
|
||||
BASE_URL="http://host.clicodeplus.com:8080"
|
||||
API_KEY="sk-32ad0a3197e528c840ea84f0dc6b2056dd3fead03526b5c605a60709bd408f7e"
|
||||
MODEL="gemini-2.5-flash"
|
||||
|
||||
# 创建临时目录存放结果
|
||||
RESULT_DIR="/tmp/gemini_stress_test_$(date +%s)"
|
||||
mkdir -p "$RESULT_DIR"
|
||||
|
||||
echo "=========================================="
|
||||
echo "Gemini 粘性会话压力测试"
|
||||
echo "结果目录: $RESULT_DIR"
|
||||
echo "=========================================="
|
||||
|
||||
# 函数:发送请求并记录
|
||||
send_request() {
|
||||
local session_id=$1
|
||||
local round=$2
|
||||
local system_prompt=$3
|
||||
local contents=$4
|
||||
local output_file="$RESULT_DIR/session_${session_id}_round_${round}.json"
|
||||
|
||||
local request_body=$(cat <<EOF
|
||||
{
|
||||
"systemInstruction": {
|
||||
"parts": [{"text": "$system_prompt"}]
|
||||
},
|
||||
"contents": $contents
|
||||
}
|
||||
EOF
|
||||
)
|
||||
|
||||
curl -s -X POST "${BASE_URL}/v1beta/models/${MODEL}:generateContent" \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "x-goog-api-key: ${API_KEY}" \
|
||||
-d "$request_body" > "$output_file" 2>&1
|
||||
|
||||
echo "[Session $session_id Round $round] 完成"
|
||||
}
|
||||
|
||||
# 会话1:数学计算器(累加序列)
|
||||
run_session_1() {
|
||||
local sys_prompt="你是一个数学计算器,只返回计算结果数字,不要任何解释"
|
||||
|
||||
# Round 1: 1+1=?
|
||||
send_request 1 1 "$sys_prompt" '[{"role":"user","parts":[{"text":"1+1=?"}]}]'
|
||||
|
||||
# Round 2: 继续 2+2=?(累加历史)
|
||||
send_request 1 2 "$sys_prompt" '[{"role":"user","parts":[{"text":"1+1=?"}]},{"role":"model","parts":[{"text":"2"}]},{"role":"user","parts":[{"text":"2+2=?"}]}]'
|
||||
|
||||
# Round 3: 继续 3+3=?
|
||||
send_request 1 3 "$sys_prompt" '[{"role":"user","parts":[{"text":"1+1=?"}]},{"role":"model","parts":[{"text":"2"}]},{"role":"user","parts":[{"text":"2+2=?"}]},{"role":"model","parts":[{"text":"4"}]},{"role":"user","parts":[{"text":"3+3=?"}]}]'
|
||||
|
||||
# Round 4: 批量计算 10+10, 20+20, 30+30
|
||||
send_request 1 4 "$sys_prompt" '[{"role":"user","parts":[{"text":"1+1=?"}]},{"role":"model","parts":[{"text":"2"}]},{"role":"user","parts":[{"text":"2+2=?"}]},{"role":"model","parts":[{"text":"4"}]},{"role":"user","parts":[{"text":"3+3=?"}]},{"role":"model","parts":[{"text":"6"}]},{"role":"user","parts":[{"text":"计算: 10+10=? 20+20=? 30+30=?"}]}]'
|
||||
}
|
||||
|
||||
# 会话2:英文翻译器(不同系统提示词 = 不同会话)
|
||||
run_session_2() {
|
||||
local sys_prompt="你是一个英文翻译器,将中文翻译成英文,只返回翻译结果"
|
||||
|
||||
send_request 2 1 "$sys_prompt" '[{"role":"user","parts":[{"text":"你好"}]}]'
|
||||
send_request 2 2 "$sys_prompt" '[{"role":"user","parts":[{"text":"你好"}]},{"role":"model","parts":[{"text":"Hello"}]},{"role":"user","parts":[{"text":"世界"}]}]'
|
||||
send_request 2 3 "$sys_prompt" '[{"role":"user","parts":[{"text":"你好"}]},{"role":"model","parts":[{"text":"Hello"}]},{"role":"user","parts":[{"text":"世界"}]},{"role":"model","parts":[{"text":"World"}]},{"role":"user","parts":[{"text":"早上好"}]}]'
|
||||
}
|
||||
|
||||
# 会话3:日文翻译器
|
||||
run_session_3() {
|
||||
local sys_prompt="你是一个日文翻译器,将中文翻译成日文,只返回翻译结果"
|
||||
|
||||
send_request 3 1 "$sys_prompt" '[{"role":"user","parts":[{"text":"你好"}]}]'
|
||||
send_request 3 2 "$sys_prompt" '[{"role":"user","parts":[{"text":"你好"}]},{"role":"model","parts":[{"text":"こんにちは"}]},{"role":"user","parts":[{"text":"谢谢"}]}]'
|
||||
send_request 3 3 "$sys_prompt" '[{"role":"user","parts":[{"text":"你好"}]},{"role":"model","parts":[{"text":"こんにちは"}]},{"role":"user","parts":[{"text":"谢谢"}]},{"role":"model","parts":[{"text":"ありがとう"}]},{"role":"user","parts":[{"text":"再见"}]}]'
|
||||
}
|
||||
|
||||
# 会话4:乘法计算器(另一个数学会话,但系统提示词不同)
|
||||
run_session_4() {
|
||||
local sys_prompt="你是一个乘法专用计算器,只计算乘法,返回数字结果"
|
||||
|
||||
send_request 4 1 "$sys_prompt" '[{"role":"user","parts":[{"text":"2*3=?"}]}]'
|
||||
send_request 4 2 "$sys_prompt" '[{"role":"user","parts":[{"text":"2*3=?"}]},{"role":"model","parts":[{"text":"6"}]},{"role":"user","parts":[{"text":"4*5=?"}]}]'
|
||||
send_request 4 3 "$sys_prompt" '[{"role":"user","parts":[{"text":"2*3=?"}]},{"role":"model","parts":[{"text":"6"}]},{"role":"user","parts":[{"text":"4*5=?"}]},{"role":"model","parts":[{"text":"20"}]},{"role":"user","parts":[{"text":"计算: 10*10=? 20*20=?"}]}]'
|
||||
}
|
||||
|
||||
# 会话5:诗人(完全不同的角色)
|
||||
run_session_5() {
|
||||
local sys_prompt="你是一位诗人,用简短的诗句回应每个话题,每次只写一句诗"
|
||||
|
||||
send_request 5 1 "$sys_prompt" '[{"role":"user","parts":[{"text":"春天"}]}]'
|
||||
send_request 5 2 "$sys_prompt" '[{"role":"user","parts":[{"text":"春天"}]},{"role":"model","parts":[{"text":"春风拂面花满枝"}]},{"role":"user","parts":[{"text":"夏天"}]}]'
|
||||
send_request 5 3 "$sys_prompt" '[{"role":"user","parts":[{"text":"春天"}]},{"role":"model","parts":[{"text":"春风拂面花满枝"}]},{"role":"user","parts":[{"text":"夏天"}]},{"role":"model","parts":[{"text":"蝉鸣蛙声伴荷香"}]},{"role":"user","parts":[{"text":"秋天"}]}]'
|
||||
}
|
||||
|
||||
echo ""
|
||||
echo "开始并发测试 5 个独立会话..."
|
||||
echo ""
|
||||
|
||||
# 并发运行所有会话
|
||||
run_session_1 &
|
||||
run_session_2 &
|
||||
run_session_3 &
|
||||
run_session_4 &
|
||||
run_session_5 &
|
||||
|
||||
# 等待所有后台任务完成
|
||||
wait
|
||||
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
echo "所有请求完成,结果保存在: $RESULT_DIR"
|
||||
echo "=========================================="
|
||||
|
||||
# 显示结果摘要
|
||||
echo ""
|
||||
echo "响应摘要:"
|
||||
for f in "$RESULT_DIR"/*.json; do
|
||||
filename=$(basename "$f")
|
||||
response=$(cat "$f" | head -c 200)
|
||||
echo "[$filename]: ${response}..."
|
||||
done
|
||||
|
||||
echo ""
|
||||
echo "请检查服务器日志确认账号分配情况"
|
||||
Reference in New Issue
Block a user