diff --git a/.github/workflows/backend-ci.yml b/.github/workflows/backend-ci.yml index d21d0684..342f48da 100644 --- a/.github/workflows/backend-ci.yml +++ b/.github/workflows/backend-ci.yml @@ -17,9 +17,10 @@ jobs: go-version-file: backend/go.mod check-latest: false cache: true + cache-dependency-path: backend/go.sum - name: Verify Go version run: | - go version | grep -q 'go1.25.7' + go version | grep -q 'go1.26.1' - name: Unit tests working-directory: backend run: make test-unit @@ -36,12 +37,13 @@ jobs: go-version-file: backend/go.mod check-latest: false cache: true + cache-dependency-path: backend/go.sum - name: Verify Go version run: | - go version | grep -q 'go1.25.7' + go version | grep -q 'go1.26.1' - name: golangci-lint uses: golangci/golangci-lint-action@v9 with: - version: v2.7 + version: v2.9 args: --timeout=30m working-directory: backend \ No newline at end of file diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index a1c6aa23..5c0524c8 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -115,7 +115,7 @@ jobs: - name: Verify Go version run: | - go version | grep -q 'go1.25.7' + go version | grep -q 'go1.26.1' # Docker setup for GoReleaser - name: Set up QEMU diff --git a/.github/workflows/security-scan.yml b/.github/workflows/security-scan.yml index db922509..cc5a90cf 100644 --- a/.github/workflows/security-scan.yml +++ b/.github/workflows/security-scan.yml @@ -23,7 +23,7 @@ jobs: cache-dependency-path: backend/go.sum - name: Verify Go version run: | - go version | grep -q 'go1.25.7' + go version | grep -q 'go1.26.1' - name: Run govulncheck working-directory: backend run: | diff --git a/.gitignore b/.gitignore index 297c1d6f..da112576 100644 --- a/.gitignore +++ b/.gitignore @@ -78,6 +78,7 @@ Desktop.ini # =================== tmp/ temp/ +logs/ *.tmp *.temp *.log @@ -128,8 +129,15 @@ deploy/docker-compose.override.yml vite.config.js docs/* .serena/ + +# =================== +# 压测工具 +# =================== +tools/loadtest/ +# Antigravity Manager +Antigravity-Manager/ +antigravity_projectid_fix.patch .codex/ frontend/coverage/ aicodex output/ - diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 00000000..8edfa58b --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,1392 @@ +# Sub2API 开发说明 + +## 版本管理策略 + +### 版本号规则 + +我们在官方版本号后面添加自己的小版本号: + +- 官方版本:`v0.1.68` +- 我们的版本:`v0.1.68.1`、`v0.1.68.2`(递增) + +### 分支策略 + +| 分支 | 说明 | +|------|------| +| `main` | 我们的主分支,包含所有定制功能 | +| `release/custom-X.Y.Z` | 基于官方 `vX.Y.Z` 的发布分支 | +| `upstream/main` | 上游官方仓库 | + +--- + +## 发布流程(基于新官方版本) + +当官方发布新版本(如 `v0.1.69`)时: + +### 1. 同步上游并创建发布分支 + +```bash +# 获取上游最新代码 +git fetch upstream --tags + +# 基于官方标签创建新的发布分支 +git checkout v0.1.69 -b release/custom-0.1.69 + +# 合并我们的 main 分支(包含所有定制功能) +git merge main --no-edit + +# 解决可能的冲突后继续 +``` + +### 2. 更新版本号并打标签 + +```bash +# 更新版本号文件 +echo "0.1.69.1" > backend/cmd/server/VERSION +git add backend/cmd/server/VERSION +git commit -m "chore: bump version to 0.1.69.1" + +# 打上我们自己的标签 +git tag v0.1.69.1 + +# 推送分支和标签 +git push origin release/custom-0.1.69 +git push origin v0.1.69.1 +``` + +### 3. 更新 main 分支 + +```bash +# 将发布分支合并回 main,保持 main 包含最新定制功能 +git checkout main +git merge release/custom-0.1.69 +git push origin main +``` + +--- + +## 热修复发布(在现有版本上修复) + +当需要在当前版本上发布修复时: + +```bash +# 在当前发布分支上修复 +git checkout release/custom-0.1.68 +# ... 进行修复 ... +git commit -m "fix: 修复描述" + +# 递增小版本号 +echo "0.1.68.2" > backend/cmd/server/VERSION +git add backend/cmd/server/VERSION +git commit -m "chore: bump version to 0.1.68.2" + +# 打标签并推送 +git tag v0.1.68.2 +git push origin release/custom-0.1.68 +git push origin v0.1.68.2 + +# 同步修复到 main +git checkout main +git cherry-pick +git push origin main +``` + +--- + +## 服务器部署流程 + +### 前置条件 + +- 本地已配置 SSH 别名 `clicodeplus` 连接到生产服务器(运行服务) +- 本地已配置 SSH 别名 `us-asaki-root` 连接到构建服务器(拉取代码、构建镜像) +- 生产服务器部署目录:`/root/sub2api`(正式)、`/root/sub2api-beta`(测试)、`/root/sub2api-star`(Star) +- 生产服务器使用 Docker Compose 部署 +- **镜像统一在构建服务器上构建**,避免生产服务器因编译占用 CPU/内存影响线上服务 + +### 服务器角色说明 + +| 服务器 | SSH 别名 | 职责 | +|--------|----------|------| +| 构建服务器 | `us-asaki-root` | 拉取代码、`docker build` 构建镜像 | +| 生产服务器 | `clicodeplus` | 加载镜像、运行服务、部署验证 | +| 数据库服务器 | `db-clicodeplus` | PostgreSQL 16 + Redis 7,所有环境共用 | + +> 数据库服务器运维手册:`db-clicodeplus:/root/README.md` + +### 部署环境说明 + +| 环境 | 目录(生产服务器) | 端口 | 数据库 | Redis DB | 容器名 | +|------|------|------|--------|----------|--------| +| 正式 | `/root/sub2api` | 8080 | `sub2api` | 0 | `sub2api` | +| Beta | `/root/sub2api-beta` | 8084 | `beta` | 2 | `sub2api-beta` | +| OpenAI | `/root/sub2api-openai` | 8083 | `openai` | 3 | `sub2api-openai` | +| Star | `/root/sub2api-star` | 8086 | `star` | 4 | `sub2api-star` | + +### 外部数据库与 Redis + +所有环境(正式、Beta、OpenAI、Star)共用 `db.clicodeplus.com` 上的 **PostgreSQL 16** 和 **Redis 7**,不使用容器内数据库或 Redis。 + +**PostgreSQL**(端口 5432,TLS 加密,scram-sha-256 认证): + +| 环境 | 用户名 | 数据库 | +|------|--------|--------| +| 正式 | `sub2api` | `sub2api` | +| Beta | `beta` | `beta` | +| OpenAI | `openai` | `openai` | +| Star | `star` | `star` | + +**Redis**(端口 6379,密码认证): + +| 环境 | DB | +|------|-----| +| 正式 | 0 | +| Beta | 2 | +| OpenAI | 3 | +| Star | 4 | + +**配置方式**: +- 数据库通过 `.env` 中的 `DATABASE_HOST`、`DATABASE_SSLMODE`、`POSTGRES_USER`、`POSTGRES_PASSWORD`、`POSTGRES_DB` 配置 +- Redis 通过 `docker-compose.override.yml` 覆盖 `REDIS_HOST`(因主 compose 文件硬编码为 `redis`),密码通过 `.env` 中的 `REDIS_PASSWORD` 配置 +- 各环境的 `docker-compose.override.yml` 已通过 `depends_on: !reset {}` 和 `redis: profiles: [disabled]` 去掉了对容器 Redis 的依赖 + +#### 数据库操作命令 + +通过 SSH 在服务器上执行数据库操作: + +```bash +# 正式环境 - 查询迁移记录 +ssh clicodeplus "source /root/sub2api/deploy/.env && PGPASSWORD=\"\$POSTGRES_PASSWORD\" psql -h \$DATABASE_HOST -U \$POSTGRES_USER -d \$POSTGRES_DB -c 'SELECT * FROM schema_migrations ORDER BY applied_at DESC LIMIT 5;'" + +# Beta 环境 - 查询迁移记录 +ssh clicodeplus "source /root/sub2api-beta/deploy/.env && PGPASSWORD=\"\$POSTGRES_PASSWORD\" psql -h \$DATABASE_HOST -U \$POSTGRES_USER -d \$POSTGRES_DB -c 'SELECT * FROM schema_migrations ORDER BY applied_at DESC LIMIT 5;'" + +# Beta 环境 - 清除指定迁移记录(重新执行迁移) +ssh clicodeplus "source /root/sub2api-beta/deploy/.env && PGPASSWORD=\"\$POSTGRES_PASSWORD\" psql -h \$DATABASE_HOST -U \$POSTGRES_USER -d \$POSTGRES_DB -c \"DELETE FROM schema_migrations WHERE filename LIKE '%049%';\"" + +# Beta 环境 - 更新账号数据 +ssh clicodeplus "source /root/sub2api-beta/deploy/.env && PGPASSWORD=\"\$POSTGRES_PASSWORD\" psql -h \$DATABASE_HOST -U \$POSTGRES_USER -d \$POSTGRES_DB -c \"UPDATE accounts SET credentials = credentials - 'model_mapping' WHERE platform = 'antigravity';\"" +``` + +> **注意**:使用 `source .env` 加载环境变量,避免在命令行中暴露密码。 + +### 部署步骤 + +**重要:每次部署都必须递增版本号!** + +#### 0. 递增版本号并推送(本地操作) + +每次部署前,先在本地递增小版本号并确保推送成功: + +```bash +# 查看当前版本号 +cat backend/cmd/server/VERSION +# 假设当前是 0.1.69.1 + +# 递增版本号 +echo "0.1.69.2" > backend/cmd/server/VERSION +git add backend/cmd/server/VERSION +git commit -m "chore: bump version to 0.1.69.2" +git push origin release/custom-0.1.69 + +# ⚠️ 确认推送成功(必须看到分支更新输出,不能有 rejected 错误) +``` + +> **检查点**:如果有其他未提交的改动,应先 commit 并 push,确保 release 分支上的所有代码都已推送到远程。 + +#### 1. 构建服务器拉取代码 + +```bash +# 拉取最新代码并切换分支 +ssh us-asaki-root "cd /root/sub2api && git fetch origin && git checkout -B release/custom-0.1.69 origin/release/custom-0.1.69" + +# ⚠️ 验证版本号与步骤 0 一致 +ssh us-asaki-root "cat /root/sub2api/backend/cmd/server/VERSION" +``` + +> **首次使用构建服务器?** 需要先初始化仓库,参见下方「构建服务器首次初始化」章节。 + +#### 2. 构建服务器构建镜像 + +```bash +ssh us-asaki-root "cd /root/sub2api && docker build --no-cache -t sub2api:latest -f Dockerfile ." + +# ⚠️ 必须看到构建成功输出,如果失败需要先排查问题 +``` + +> **常见构建问题**: +> - `buildx` 版本过旧导致 API 版本不兼容 → 更新 buildx:`curl -fsSL "https://github.com/docker/buildx/releases/latest/download/buildx-$(curl -fsSL https://api.github.com/repos/docker/buildx/releases/latest | grep tag_name | cut -d'"' -f4).linux-amd64" -o ~/.docker/cli-plugins/docker-buildx && chmod +x ~/.docker/cli-plugins/docker-buildx` +> - 磁盘空间不足 → `docker system prune -f` 清理无用镜像 + +#### 3. 传输镜像到生产服务器并加载 + +```bash +# 导出镜像 → 通过管道传输 → 生产服务器加载 +ssh us-asaki-root "docker save sub2api:latest" | ssh clicodeplus "docker load" + +# ⚠️ 必须看到 "Loaded image: sub2api:latest" 输出 +``` + +#### 4. 生产服务器同步代码、更新标签并重启 + +```bash +# 同步代码(用于版本号确认和 deploy 配置) +ssh clicodeplus "cd /root/sub2api && git fetch fork && git checkout -B release/custom-0.1.69 fork/release/custom-0.1.69" + +# 更新镜像标签并重启 +ssh clicodeplus "docker tag sub2api:latest weishaw/sub2api:latest" +ssh clicodeplus "cd /root/sub2api/deploy && docker compose up -d --force-recreate sub2api" +``` + +#### 5. 验证部署 + +```bash +# 查看启动日志 +ssh clicodeplus "docker logs sub2api --tail 20" + +# 确认版本号(必须与步骤 0 中设置的版本号一致) +ssh clicodeplus "cat /root/sub2api/backend/cmd/server/VERSION" + +# 检查容器状态(必须显示 healthy) +ssh clicodeplus "docker ps | grep sub2api" +``` + +--- + +### 构建服务器首次初始化 + +首次使用 `us-asaki-root` 作为构建服务器时,需要执行以下一次性操作: + +```bash +ssh us-asaki-root + +# 1) 克隆仓库 +cd /root +git clone https://github.com/touwaeriol/sub2api.git sub2api +cd sub2api + +# 2) 验证 Docker 和 buildx 版本 +docker version +docker buildx version +# 如果 buildx 版本过旧(< v0.14),执行更新: +# LATEST=$(curl -fsSL https://api.github.com/repos/docker/buildx/releases/latest | grep tag_name | cut -d'"' -f4) +# curl -fsSL "https://github.com/docker/buildx/releases/download/${LATEST}/buildx-${LATEST}.linux-amd64" -o ~/.docker/cli-plugins/docker-buildx +# chmod +x ~/.docker/cli-plugins/docker-buildx + +# 3) 验证构建能力 +docker build --no-cache -t sub2api:test -f Dockerfile . +docker rmi sub2api:test +``` + +--- + +## Beta 并行部署(不影响现网) + +目标:在同一台服务器上并行启动一个 beta 实例(例如端口 `8084`),**严禁改动/重启**现网实例(默认目录 `/root/sub2api`)。 + +### 设计原则 + +- **新目录**:beta 使用独立目录,例如 `/root/sub2api-beta`。 +- **敏感信息只放 `.env`**:beta 的数据库密码、JWT_SECRET 等只写入 `/root/sub2api-beta/deploy/.env`,不要提交到 git。 +- **独立 Compose Project**:通过 `docker compose -p sub2api-beta ...` 启动,确保 network/volume 隔离。 +- **独立端口**:通过 `.env` 的 `SERVER_PORT` 映射宿主机端口(例如 `8084:8080`)。 + +### 前置检查 + +```bash +# 1) 确保 8084 未被占用 +ssh clicodeplus "ss -ltnp | grep :8084 || echo '8084 is free'" + +# 2) 确认现网容器还在(只读检查) +ssh clicodeplus "docker ps --format 'table {{.Names}}\t{{.Image}}\t{{.Ports}}' | sed -n '1,200p'" +``` + +### 首次部署步骤 + +> **构建服务器说明**:正式和 beta 共用构建服务器上的 `/root/sub2api` 仓库,通过不同的镜像标签区分(`sub2api:latest` 用于正式,`sub2api:beta` 用于测试)。 + +```bash +# 1) 构建服务器构建 beta 镜像(共用 /root/sub2api 仓库,切到目标分支后打 beta 标签) +ssh us-asaki-root "cd /root/sub2api && git fetch origin && git checkout -B release/custom-0.1.71 origin/release/custom-0.1.71" +ssh us-asaki-root "cd /root/sub2api && docker build --no-cache -t sub2api:beta -f Dockerfile ." + +# ⚠️ 构建完成后如需恢复正式分支: +# ssh us-asaki-root "cd /root/sub2api && git checkout release/custom-<正式版本>" + +# 2) 传输镜像到生产服务器 +ssh us-asaki-root "docker save sub2api:beta" | ssh clicodeplus "docker load" +# ⚠️ 必须看到 "Loaded image: sub2api:beta" 输出 + +# 3) 在生产服务器上准备 beta 环境 +ssh clicodeplus + +# 克隆代码(仅用于 deploy 配置和版本号确认,不在此构建) +cd /root +git clone https://github.com/touwaeriol/sub2api.git sub2api-beta +cd /root/sub2api-beta +git checkout release/custom-0.1.71 + +# 4) 准备 beta 的 .env(敏感信息只写这里) +cd /root/sub2api-beta/deploy + +# 推荐:从现网 .env 复制,保证除 DB 名/用户/端口外完全一致 +cp -f /root/sub2api/deploy/.env ./.env + +# 仅修改以下三项(其他保持不变) +perl -pi -e 's/^SERVER_PORT=.*/SERVER_PORT=8084/' ./.env +perl -pi -e 's/^POSTGRES_USER=.*/POSTGRES_USER=beta/' ./.env +perl -pi -e 's/^POSTGRES_DB=.*/POSTGRES_DB=beta/' ./.env + +# 5) 写 compose override(避免与现网容器名冲突,镜像使用构建服务器传输的 sub2api:beta,Redis 使用外部服务) +cat > docker-compose.override.yml <<'YAML' +services: + sub2api: + image: sub2api:beta + container_name: sub2api-beta + environment: + - DATABASE_HOST=${DATABASE_HOST:-postgres} + - DATABASE_SSLMODE=${DATABASE_SSLMODE:-disable} + - REDIS_HOST=db.clicodeplus.com + depends_on: !reset {} + redis: + profiles: + - disabled +YAML + +# 6) 启动 beta(独立 project,确保不影响现网) +cd /root/sub2api-beta/deploy +docker compose -p sub2api-beta --env-file .env -f docker-compose.yml -f docker-compose.override.yml up -d + +# 7) 验证 beta +curl -fsS http://127.0.0.1:8084/health +docker logs sub2api-beta --tail 50 +``` + +### 数据库配置约定(beta) + +- 数据库地址/SSL/密码:与现网一致(从现网 `.env` 复制即可),均指向 `db.clicodeplus.com`。 +- 仅修改: + - `POSTGRES_USER=beta` + - `POSTGRES_DB=beta` + - `REDIS_DB=2` + +注意:需要数据库侧已存在 `beta` 用户与 `beta` 数据库,并授予权限;否则容器会启动失败并不断重启。 + +### 更新 beta(构建服务器构建 + 传输 + 仅重启 beta 容器) + +```bash +# 1) 构建服务器拉取代码并构建镜像(共用 /root/sub2api 仓库) +ssh us-asaki-root "cd /root/sub2api && git fetch origin && git checkout -B release/custom-0.1.71 origin/release/custom-0.1.71" +ssh us-asaki-root "cd /root/sub2api && docker build --no-cache -t sub2api:beta -f Dockerfile ." +# ⚠️ 必须看到构建成功输出 + +# 2) 传输镜像到生产服务器 +ssh us-asaki-root "docker save sub2api:beta" | ssh clicodeplus "docker load" +# ⚠️ 必须看到 "Loaded image: sub2api:beta" 输出 + +# 3) 生产服务器同步代码(用于版本号确认和 deploy 配置) +ssh clicodeplus "set -e; cd /root/sub2api-beta && git fetch --all --tags && git checkout -f release/custom-0.1.71 && git reset --hard origin/release/custom-0.1.71" + +# 4) 重启 beta 容器并验证 +ssh clicodeplus "cd /root/sub2api-beta/deploy && docker compose -p sub2api-beta --env-file .env -f docker-compose.yml -f docker-compose.override.yml up -d --no-deps --force-recreate sub2api" +ssh clicodeplus "sleep 5 && curl -fsS http://127.0.0.1:8084/health" +ssh clicodeplus "cat /root/sub2api-beta/backend/cmd/server/VERSION" +``` + +### 停止/回滚 beta(只影响 beta) + +```bash +ssh clicodeplus "cd /root/sub2api-beta/deploy && docker compose -p sub2api-beta -f docker-compose.yml -f docker-compose.override.yml down" +``` + +--- + +## 服务器首次部署 + +### 1. 构建服务器:克隆代码并配置远程仓库 + +```bash +ssh us-asaki-root +cd /root +git clone https://github.com/Wei-Shaw/sub2api.git +cd sub2api + +# 添加 fork 仓库 +git remote add fork https://github.com/touwaeriol/sub2api.git +``` + +### 2. 构建服务器:切换到定制分支并构建镜像 + +```bash +git fetch fork +git checkout -B release/custom-0.1.69 fork/release/custom-0.1.69 + +cd /root/sub2api +docker build -t sub2api:latest -f Dockerfile . +exit +``` + +### 3. 传输镜像到生产服务器 + +```bash +ssh us-asaki-root "docker save sub2api:latest" | ssh clicodeplus "docker load" +``` + +### 4. 生产服务器:克隆代码并配置环境 + +```bash +ssh clicodeplus +cd /root +git clone https://github.com/Wei-Shaw/sub2api.git +cd sub2api + +# 添加 fork 仓库 +git remote add fork https://github.com/touwaeriol/sub2api.git +git fetch fork +git checkout -B release/custom-0.1.69 fork/release/custom-0.1.69 + +# 配置环境变量 +cd deploy +cp .env.example .env +vim .env # 配置 DATABASE_HOST=db.clicodeplus.com, POSTGRES_PASSWORD, REDIS_PASSWORD, JWT_SECRET 等 + +# 创建 override 文件(Redis 指向外部服务,去掉容器 Redis 依赖) +cat > docker-compose.override.yml <<'YAML' +services: + sub2api: + environment: + - REDIS_HOST=db.clicodeplus.com + depends_on: !reset {} + redis: + profiles: + - disabled +YAML +``` + +### 5. 生产服务器:更新镜像标签并启动服务 + +```bash +docker tag sub2api:latest weishaw/sub2api:latest +cd /root/sub2api/deploy && docker compose up -d +``` + +### 6. 验证部署 + +```bash +# 查看应用日志 +docker logs sub2api --tail 50 + +# 检查健康状态 +curl http://localhost:8080/health + +# 确认版本号 +cat /root/sub2api/backend/cmd/server/VERSION +``` + +### 7. 常用运维命令 + +```bash +# 查看实时日志 +docker logs -f sub2api + +# 重启服务 +docker compose restart sub2api + +# 停止所有服务 +docker compose down + +# 停止并删除数据卷(慎用!会删除数据库数据) +docker compose down -v + +# 查看资源使用情况 +docker stats sub2api +``` + +--- + +## 定制功能说明 + +当前定制分支包含以下功能(相对于官方版本): + +### UI/UX 定制 + +| 功能 | 说明 | +|------|------| +| 首页优化 | 面向用户的价值主张设计 | +| 移除 GitHub 链接 | 用户菜单中不显示 GitHub 导航 | +| 微信客服按钮 | 首页悬浮微信客服入口 | +| 限流时间精确显示 | 账号限流时间显示精确到秒 | + +### Antigravity 平台增强 + +| 功能 | 说明 | +|------|------| +| Scope 级别限流 | 按配额域(claude/gemini_text/gemini_image)独立限流,避免整个账号被锁定 | +| 模型级别限流 | 按具体模型(如 claude-opus-4-5)独立限流,更精细的限流控制 | +| 限流预检查 | 调度时预检查账号/模型限流状态,避免选中已限流账号 | +| 秒级冷却时间 | 支持 429 响应的秒级精确冷却时间 | +| 身份注入优化 | 模型身份信息注入 + 静默边界防止身份泄露 | +| thoughtSignature 修复 | Gemini 3 函数调用 400 错误修复 | +| max_tokens 自动修正 | 自动修正 max_tokens <= budget_tokens 导致的 400 错误 | + +### 调度算法优化 + +| 功能 | 说明 | +|------|------| +| 分层过滤选择 | 调度算法从全排序改为分层过滤,提升性能 | +| LRU 随机选择 | 相同 LRU 时间时随机选择,避免账号集中 | +| 限流等待阈值配置化 | 可配置的限流等待阈值 | + +### 运维增强 + +| 功能 | 说明 | +|------|------| +| Scope 限流统计 | 运维界面展示 Antigravity 账号 scope 级别限流统计 | +| 账号限流状态显示 | 账号列表显示 scope 和模型级别限流状态 | +| 清除限流按钮增强 | 有 scope/模型限流时也显示清除限流按钮 | + +### 其他修复 + +| 功能 | 说明 | +|------|------| +| .gitattributes | 确保迁移文件使用 LF 换行符(解决 Windows 下 SQL 摘要不一致) | +| 部署配置优化 | DATABASE_HOST 和 DATABASE_SSLMODE 可通过 .env 配置 | + +--- + +## Admin API 接口文档 + +### ⚠️ API 操作流程规范 + +当收到操作正式环境 Web 界面的新需求,但文档中未记录对应 API 接口时,**必须按以下流程执行**: + +1. **探索接口**:通过代码库搜索路由定义(`backend/internal/server/routes/`)、Handler(`backend/internal/handler/admin/`)和请求结构体,确定正确的 API 端点、请求方法、请求体格式 +2. **更新文档**:将新发现的接口补充到本文档的 Admin API 接口文档章节中,包含端点、参数说明和 curl 示例 +3. **执行操作**:根据最新文档中记录的接口完成用户需求 + +> **目的**:避免每次遇到相同需求都重复探索代码库,确保 API 文档持续完善,后续操作可直接查阅文档执行。 + +--- + +### 认证方式 + +所有 Admin API 通过 `x-api-key` 请求头传递 Admin API Key 认证。 + +``` +x-api-key: admin-xxx +``` + +> **使用说明**:Admin API Key 统一存放在项目根目录 `.env` 文件的 `ADMIN_API_KEY` 变量中(该文件已被 `.gitignore` 排除,不会提交到代码库)。操作前先从 `.env` 读取密钥;若密钥失效(返回 401),应提示用户提供新的密钥并更新到 `.env` 中。Token 格式为 `admin-` + 64 位十六进制字符,在管理后台 `设置 > Admin API Key` 中生成。**请勿将实际 token 写入文档或代码中。** + +### 环境地址 + +| 环境 | 基础地址 | 说明 | +|------|----------|------| +| 正式 | `https://clicodeplus.com` | 生产环境 | +| Beta | `http://<服务器IP>:8084` | 仅内网访问 | +| OpenAI | `http://<服务器IP>:8083` | 仅内网访问 | +| Star | `https://hyntoken.com` | 独立环境 | + +> 以下接口文档中,`${BASE}` 代表环境基础地址,`${KEY}` 代表 `.env` 中的 `ADMIN_API_KEY`。操作前执行 `source .env` 或 `export KEY=$ADMIN_API_KEY` 加载。 + +--- + +### 1. 账号管理 + +#### 1.1 获取账号列表 + +``` +GET /api/v1/admin/accounts +``` + +**查询参数**: + +| 参数 | 类型 | 必填 | 说明 | +|------|------|------|------| +| `platform` | string | 否 | 平台筛选:`antigravity` / `anthropic` / `openai` / `gemini` | +| `type` | string | 否 | 账号类型:`oauth` / `api_key` / `cookie` | +| `status` | string | 否 | 状态:`active` / `disabled` / `error` | +| `search` | string | 否 | 搜索关键词(名称、备注) | +| `page` | int | 否 | 页码,默认 1 | +| `page_size` | int | 否 | 每页数量,默认 20 | + +```bash +curl -s "${BASE}/api/v1/admin/accounts?platform=antigravity&page=1&page_size=100" \ + -H "x-api-key: ${KEY}" +``` + +**响应**: +```json +{ + "code": 0, + "message": "success", + "data": { + "items": [{"id": 1, "name": "xxx@gmail.com", "platform": "antigravity", "status": "active", ...}], + "total": 66 + } +} +``` + +#### 1.2 获取账号详情 + +``` +GET /api/v1/admin/accounts/:id +``` + +```bash +curl -s "${BASE}/api/v1/admin/accounts/1" -H "x-api-key: ${KEY}" +``` + +#### 1.3 测试账号连接 + +``` +POST /api/v1/admin/accounts/:id/test +``` + +**请求体**(JSON,可选): + +| 字段 | 类型 | 必填 | 说明 | +|------|------|------|------| +| `model_id` | string | 否 | 指定测试模型,如 `claude-opus-4-6`;不传则使用默认模型 | + +**响应格式**:SSE(Server-Sent Events)流 + +```bash +curl -N -X POST "${BASE}/api/v1/admin/accounts/1/test" \ + -H "x-api-key: ${KEY}" \ + -H "Content-Type: application/json" \ + -d '{"model_id": "claude-opus-4-6"}' +``` + +**SSE 事件类型**: + +| type | 字段 | 说明 | +|------|------|------| +| `test_start` | `model` | 测试开始,返回测试模型名 | +| `content` | `text` | 模型响应内容(流式文本片段) | +| `test_end` | `success`, `error` | 测试结束,`success=true` 表示成功 | +| `error` | `text` | 错误信息 | + +#### 1.4 清除账号限流 + +``` +POST /api/v1/admin/accounts/:id/clear-rate-limit +``` + +```bash +curl -X POST "${BASE}/api/v1/admin/accounts/1/clear-rate-limit" \ + -H "x-api-key: ${KEY}" +``` + +#### 1.5 清除账号错误状态 + +``` +POST /api/v1/admin/accounts/:id/clear-error +``` + +```bash +curl -X POST "${BASE}/api/v1/admin/accounts/1/clear-error" \ + -H "x-api-key: ${KEY}" +``` + +#### 1.6 获取账号可用模型 + +``` +GET /api/v1/admin/accounts/:id/models +``` + +```bash +curl -s "${BASE}/api/v1/admin/accounts/1/models" -H "x-api-key: ${KEY}" +``` + +#### 1.7 刷新 OAuth Token + +``` +POST /api/v1/admin/accounts/:id/refresh +``` + +```bash +curl -X POST "${BASE}/api/v1/admin/accounts/1/refresh" -H "x-api-key: ${KEY}" +``` + +#### 1.8 刷新账号等级 + +``` +POST /api/v1/admin/accounts/:id/refresh-tier +``` + +```bash +curl -X POST "${BASE}/api/v1/admin/accounts/1/refresh-tier" -H "x-api-key: ${KEY}" +``` + +#### 1.9 获取账号统计 + +``` +GET /api/v1/admin/accounts/:id/stats +``` + +```bash +curl -s "${BASE}/api/v1/admin/accounts/1/stats" -H "x-api-key: ${KEY}" +``` + +#### 1.10 获取账号用量 + +``` +GET /api/v1/admin/accounts/:id/usage +``` + +```bash +curl -s "${BASE}/api/v1/admin/accounts/1/usage" -H "x-api-key: ${KEY}" +``` + +#### 1.11 更新单个账号 + +``` +PUT /api/v1/admin/accounts/:id +``` + +**请求体**(JSON,所有字段均为可选,仅传需要更新的字段): + +| 字段 | 类型 | 说明 | +|------|------|------| +| `name` | string | 账号名称 | +| `notes` | *string | 备注 | +| `type` | string | 类型:`oauth` / `setup-token` / `apikey` / `upstream` | +| `credentials` | object | 凭证信息 | +| `extra` | object | 额外配置 | +| `proxy_id` | *int64 | 代理 ID | +| `concurrency` | *int | 并发数 | +| `priority` | *int | 优先级(默认 50) | +| `rate_multiplier` | *float64 | 速率倍数 | +| `status` | string | 状态:`active` / `inactive` | +| `group_ids` | *[]int64 | 分组 ID 列表 | +| `expires_at` | *int64 | 过期时间戳 | +| `auto_pause_on_expired` | *bool | 过期后自动暂停 | + +> 使用指针类型(`*`)的字段可以区分"未提供"和"设置为零值"。 + +```bash +# 示例:更新账号优先级为 100 +curl -X PUT "${BASE}/api/v1/admin/accounts/1" \ + -H "x-api-key: ${KEY}" \ + -H "Content-Type: application/json" \ + -d '{"priority": 100}' +``` + +#### 1.12 批量更新账号 + +``` +POST /api/v1/admin/accounts/bulk-update +``` + +**请求体**(JSON): + +| 字段 | 类型 | 必填 | 说明 | +|------|------|------|------| +| `account_ids` | []int64 | **是** | 要更新的账号 ID 列表 | +| `priority` | *int | 否 | 优先级 | +| `concurrency` | *int | 否 | 并发数 | +| `rate_multiplier` | *float64 | 否 | 速率倍数 | +| `status` | string | 否 | 状态:`active` / `inactive` / `error` | +| `schedulable` | *bool | 否 | 是否可调度 | +| `group_ids` | *[]int64 | 否 | 分组 ID 列表 | +| `proxy_id` | *int64 | 否 | 代理 ID | +| `credentials` | object | 否 | 凭证信息(批量覆盖) | +| `extra` | object | 否 | 额外配置(批量覆盖) | + +```bash +# 示例:批量设置多个账号优先级为 100 +curl -X POST "${BASE}/api/v1/admin/accounts/bulk-update" \ + -H "x-api-key: ${KEY}" \ + -H "Content-Type: application/json" \ + -d '{"account_ids": [1, 2, 3], "priority": 100}' +``` + +#### 1.13 批量测试账号(脚本) + +批量测试指定平台所有账号的指定模型连通性: + +```bash +# 用户需提供:BASE(环境地址)、KEY(admin token)、MODEL(测试模型) +ACCOUNT_IDS=$(curl -s "${BASE}/api/v1/admin/accounts?platform=antigravity&page=1&page_size=100" \ + -H "x-api-key: ${KEY}" | python3 -c " +import json, sys +data = json.load(sys.stdin) +for item in data['data']['items']: + print(f\"{item['id']}|{item['name']}\") +") + +while IFS='|' read -r ID NAME; do + echo "测试账号 ID=${ID} (${NAME})..." + RESPONSE=$(curl -s --max-time 60 -N \ + -X POST "${BASE}/api/v1/admin/accounts/${ID}/test" \ + -H "x-api-key: ${KEY}" \ + -H "Content-Type: application/json" \ + -d "{\"model_id\": \"${MODEL}\"}" 2>&1) + if echo "$RESPONSE" | grep -q '"success":true'; then + echo " ✅ 成功" + elif echo "$RESPONSE" | grep -q '"type":"content"'; then + echo " ✅ 成功(有内容响应)" + else + ERROR_MSG=$(echo "$RESPONSE" | grep -o '"error":"[^"]*"' | tail -1) + echo " ❌ 失败: ${ERROR_MSG}" + fi +done <<< "$ACCOUNT_IDS" +``` + +--- + +### 2. 运维监控 + +#### 2.1 并发统计 + +``` +GET /api/v1/admin/ops/concurrency +``` + +```bash +curl -s "${BASE}/api/v1/admin/ops/concurrency" -H "x-api-key: ${KEY}" +``` + +#### 2.2 账号可用性 + +``` +GET /api/v1/admin/ops/account-availability +``` + +```bash +curl -s "${BASE}/api/v1/admin/ops/account-availability" -H "x-api-key: ${KEY}" +``` + +#### 2.3 实时流量摘要 + +``` +GET /api/v1/admin/ops/realtime-traffic +``` + +```bash +curl -s "${BASE}/api/v1/admin/ops/realtime-traffic" -H "x-api-key: ${KEY}" +``` + +#### 2.4 请求错误列表 + +``` +GET /api/v1/admin/ops/request-errors +``` + +**查询参数**:`page`、`page_size` + +```bash +curl -s "${BASE}/api/v1/admin/ops/request-errors?page=1&page_size=50" \ + -H "x-api-key: ${KEY}" +``` + +#### 2.5 上游错误列表 + +``` +GET /api/v1/admin/ops/upstream-errors +``` + +```bash +curl -s "${BASE}/api/v1/admin/ops/upstream-errors?page=1&page_size=50" \ + -H "x-api-key: ${KEY}" +``` + +#### 2.6 仪表板概览 + +``` +GET /api/v1/admin/ops/dashboard/overview +``` + +```bash +curl -s "${BASE}/api/v1/admin/ops/dashboard/overview" -H "x-api-key: ${KEY}" +``` + +--- + +### 3. 系统设置 + +#### 3.1 获取系统设置 + +``` +GET /api/v1/admin/settings +``` + +```bash +curl -s "${BASE}/api/v1/admin/settings" -H "x-api-key: ${KEY}" +``` + +#### 3.2 更新系统设置 + +``` +PUT /api/v1/admin/settings +``` + +```bash +curl -X PUT "${BASE}/api/v1/admin/settings" \ + -H "x-api-key: ${KEY}" \ + -H "Content-Type: application/json" \ + -d '{ ... }' +``` + +#### 3.3 Admin API Key 状态(脱敏) + +``` +GET /api/v1/admin/settings/admin-api-key +``` + +```bash +curl -s "${BASE}/api/v1/admin/settings/admin-api-key" -H "x-api-key: ${KEY}" +``` + +--- + +### 4. 用户管理 + +#### 4.1 用户列表 + +``` +GET /api/v1/admin/users +``` + +```bash +curl -s "${BASE}/api/v1/admin/users?page=1&page_size=20" -H "x-api-key: ${KEY}" +``` + +#### 4.2 用户详情 + +``` +GET /api/v1/admin/users/:id +``` + +```bash +curl -s "${BASE}/api/v1/admin/users/1" -H "x-api-key: ${KEY}" +``` + +#### 4.3 更新用户余额 + +``` +POST /api/v1/admin/users/:id/balance +``` + +```bash +curl -X POST "${BASE}/api/v1/admin/users/1/balance" \ + -H "x-api-key: ${KEY}" \ + -H "Content-Type: application/json" \ + -d '{"amount": 100, "reason": "充值"}' +``` + +--- + +### 5. 分组管理 + +#### 5.1 分组列表 + +``` +GET /api/v1/admin/groups +``` + +```bash +curl -s "${BASE}/api/v1/admin/groups" -H "x-api-key: ${KEY}" +``` + +#### 5.2 所有分组(不分页) + +``` +GET /api/v1/admin/groups/all +``` + +```bash +curl -s "${BASE}/api/v1/admin/groups/all" -H "x-api-key: ${KEY}" +``` + +--- + +## 注意事项 + +1. **前端必须打包进镜像**:使用 `docker build` 在构建服务器(`us-asaki-root`)上构建,Dockerfile 会自动编译前端并 embed 到后端二进制中,构建完成后通过 `docker save | docker load` 传输到生产服务器(`clicodeplus`) + +2. **镜像标签**:docker-compose.yml 使用 `weishaw/sub2api:latest`,本地构建后需要 `docker tag` 覆盖 + +3. **Windows 换行符问题**:已通过 `.gitattributes` 解决,确保 `*.sql` 文件始终使用 LF + +4. **版本号管理**:每次发布必须更新 `backend/cmd/server/VERSION` 并打标签 + +5. **合并冲突**:合并上游新版本时,重点关注以下文件可能的冲突: + - `backend/internal/service/antigravity_gateway_service.go` + - `backend/internal/service/gateway_service.go` + - `backend/internal/pkg/antigravity/request_transformer.go` + +--- + +## Go 代码规范 + +### 1. 函数设计 + +#### 单一职责原则 +- **函数行数**:单个函数常规不应超过 **30 行**,超过时应拆分为子函数。若某段逻辑确实不可拆分(如复杂的状态机、协议解析等),可以例外,但需添加注释说明原因 +- **嵌套层级**:避免超过 3 层嵌套,使用 early return 减少嵌套 + +```go +// ❌ 不推荐:深层嵌套 +func process(data []Item) { + for _, item := range data { + if item.Valid { + if item.Type == "A" { + if item.Status == "active" { + // 业务逻辑... + } + } + } + } +} + +// ✅ 推荐:early return +func process(data []Item) { + for _, item := range data { + if !item.Valid { + continue + } + if item.Type != "A" { + continue + } + if item.Status != "active" { + continue + } + // 业务逻辑... + } +} +``` + +#### 复杂逻辑提取 +将复杂的条件判断或处理逻辑提取为独立函数: + +```go +// ❌ 不推荐:内联复杂逻辑 +if resp.StatusCode == 429 || resp.StatusCode == 503 { + // 80+ 行处理逻辑... +} + +// ✅ 推荐:提取为独立函数 +result := handleRateLimitResponse(resp, params) +switch result.action { +case actionRetry: + continue +case actionBreak: + return result.resp, nil +} +``` + +### 2. 重复代码消除 + +#### 配置获取模式 +将重复的配置获取逻辑提取为方法: + +```go +// ❌ 不推荐:重复代码 +logBody := s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBody +maxBytes := 2048 +if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes > 0 { + maxBytes = s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes +} + +// ✅ 推荐:提取为方法 +func (s *Service) getLogConfig() (logBody bool, maxBytes int) { + maxBytes = 2048 + if s.settingService == nil || s.settingService.cfg == nil { + return false, maxBytes + } + cfg := s.settingService.cfg.Gateway + if cfg.LogUpstreamErrorBodyMaxBytes > 0 { + maxBytes = cfg.LogUpstreamErrorBodyMaxBytes + } + return cfg.LogUpstreamErrorBody, maxBytes +} +``` + +### 3. 常量管理 + +#### 避免魔法数字 +所有硬编码的数值都应定义为常量: + +```go +// ❌ 不推荐 +if retryDelay >= 10*time.Second { + resetAt := time.Now().Add(30 * time.Second) +} + +// ✅ 推荐 +const ( + rateLimitThreshold = 10 * time.Second + defaultRateLimitDuration = 30 * time.Second +) + +if retryDelay >= rateLimitThreshold { + resetAt := time.Now().Add(defaultRateLimitDuration) +} +``` + +#### 注释引用常量名 +在注释中引用常量名而非硬编码值: + +```go +// ❌ 不推荐 +// < 10s: 等待后重试 + +// ✅ 推荐 +// < rateLimitThreshold: 等待后重试 +``` + +### 4. 错误处理 + +#### 使用结构化日志 +优先使用 `slog` 进行结构化日志记录: + +```go +// ❌ 不推荐 +log.Printf("%s status=%d model_rate_limit_failed model=%s error=%v", prefix, statusCode, modelName, err) + +// ✅ 推荐 +slog.Error("failed to set model rate limit", + "prefix", prefix, + "status_code", statusCode, + "model", modelName, + "error", err, +) +``` + +### 5. 测试规范 + +#### Mock 函数签名同步 +修改函数签名时,必须同步更新所有测试中的 mock 函数: + +```go +// 如果修改了 handleError 签名 +handleError func(..., groupID int64, sessionHash string) *Result + +// 必须同步更新测试中的 mock +handleError: func(..., groupID int64, sessionHash string) *Result { + return nil +}, +``` + +#### 测试构建标签 +统一使用测试构建标签: + +```go +//go:build unit + +package service +``` + +### 6. 时间格式解析 + +#### 使用标准库 +优先使用 `time.ParseDuration`,支持所有 Go duration 格式: + +```go +// ❌ 不推荐:手动限制格式 +if !strings.HasSuffix(delay, "s") || strings.Contains(delay, "m") { + continue +} + +// ✅ 推荐:使用标准库 +dur, err := time.ParseDuration(delay) // 支持 "0.5s", "4m50s", "1h30m" 等 +``` + +### 7. 接口设计 + +#### 接口隔离原则 +定义最小化接口,只包含必需的方法: + +```go +// ❌ 不推荐:使用过于宽泛的接口 +type AccountRepository interface { + // 20+ 个方法... +} + +// ✅ 推荐:定义最小化接口 +type ModelRateLimiter interface { + SetModelRateLimit(ctx context.Context, id int64, modelKey string, resetAt time.Time) error +} +``` + +### 8. 并发安全 + +#### 共享数据保护 +访问可能被并发修改的数据时,确保线程安全: + +```go +// 如果 Account.Extra 可能被并发修改 +// 需要使用互斥锁或原子操作保护读取 +func (a *Account) GetRateLimitRemainingTime(model string) time.Duration { + a.mu.RLock() + defer a.mu.RUnlock() + // 读取 Extra 字段... +} +``` + +### 9. 命名规范 + +#### 一致的命名风格 +- 常量使用 camelCase:`rateLimitThreshold` +- 类型使用 PascalCase:`AntigravityQuotaScope` +- 同一概念使用统一命名:`Threshold` 或 `Limit`,不要混用 + +```go +// ❌ 不推荐:命名不一致 +antigravitySmartRetryMinWait // 使用 Min +antigravityRateLimitThreshold // 使用 Threshold + +// ✅ 推荐:统一风格 +antigravityMinRetryWait +antigravityRateLimitThreshold +``` + +### 10. 代码审查清单 + +在提交代码前,检查以下项目: + +- [ ] 函数是否超过 30 行?(不可拆分的逻辑除外,需注释说明) +- [ ] 嵌套是否超过 3 层? +- [ ] 是否有重复代码可以提取? +- [ ] 是否使用了魔法数字? +- [ ] Mock 函数签名是否与实际函数一致? +- [ ] 测试是否覆盖了新增逻辑? +- [ ] 日志是否包含足够的上下文信息? +- [ ] 是否考虑了并发安全? + +--- + +## CI 检查与发布门禁 + +### GitHub Actions 检查项 + +本项目有 4 个 CI 任务,**任何代码推送或发布前都必须全部通过**: + +| Workflow | Job | 说明 | 本地验证命令 | +|----------|-----|------|-------------| +| CI | `test` | 单元测试 + 集成测试 | `cd backend && make test-unit && make test-integration` | +| CI | `golangci-lint` | Go 代码静态检查(golangci-lint v2.7) | `cd backend && golangci-lint run --timeout=5m` | +| Security Scan | `backend-security` | govulncheck + gosec 安全扫描 | `cd backend && govulncheck ./... && gosec -severity high -confidence high ./...` | +| Security Scan | `frontend-security` | pnpm audit 前端依赖安全检查 | `cd frontend && pnpm audit --prod --audit-level=high` | + +### 向上游提交 PR + +PR 目标是上游官方仓库,**只包含通用功能改动**(bug fix、新功能、性能优化等)。 + +**以下文件禁止出现在 PR 中**(属于我们 fork 的定制化内容): +- `CLAUDE.md`、`AGENTS.md` — 我们的开发文档 +- `backend/cmd/server/VERSION` — 我们的版本号文件 +- UI 定制改动(GitHub 链接移除、微信客服按钮、首页定制等) +- 部署配置(`deploy/` 目录下的定制修改) + +**PR 流程**: +1. 从 `develop` 创建功能分支,只包含要提交给上游的改动 +2. 推送分支后,**等待 4 个 CI job 全部通过** +3. 确认通过后再创建 PR +4. 使用 `gh run list --repo touwaeriol/sub2api --branch ` 检查状态 + +### 自有分支推送(develop / main) + +推送到我们自己的 `develop` 或 `main` 分支时,包含所有改动(定制化 + 通用功能)。 + +**推送前必须在本地执行全部 CI 检查**(不要等 GitHub Actions): + +```bash +# 确保 Go 工具链可用(macOS homebrew) +export PATH="/opt/homebrew/bin:$HOME/go/bin:$PATH" + +# 1. 单元测试(必须) +cd backend && make test-unit + +# 2. 集成测试(推荐,需要 Docker) +make test-integration + +# 3. golangci-lint 静态检查(必须) +golangci-lint run --timeout=5m + +# 4. gofmt 格式检查(必须) +gofmt -l ./... +# 如果有输出,运行 gofmt -w 修复 +``` + +**推送后确认**: +1. 使用 `gh run list --repo touwaeriol/sub2api --branch ` 检查 GitHub Actions 状态 +2. 确认 CI 和 Security Scan 两个 workflow 的 4 个 job 全部绿色 ✅ +3. 任何 job 失败必须立即修复,**禁止在 CI 未通过的状态下继续后续操作** + +### 发布版本 + +1. 本地执行上述全部 CI 检查通过 +2. 递增 `backend/cmd/server/VERSION`,提交并推送 +3. 推送后确认 GitHub Actions 的 4 个 CI job 全部通过 +4. **CI 未通过时禁止部署** — 必须先修复问题 +5. 使用 `gh run list --repo touwaeriol/sub2api --limit 10` 确认状态 + +### 常见 CI 失败原因及修复 +- **gofmt**:struct 字段对齐不一致 → 运行 `gofmt -w ` 修复 +- **golangci-lint**:未使用的变量/导入 → 删除或使用 `_` 忽略 +- **test 失败**:mock 函数签名不一致 → 同步更新 mock +- **gosec**:安全漏洞 → 根据提示修复或添加例外 + +--- + +## PR 描述格式规范 + +所有 PR 描述使用中英文同步(先中文、后英文),包含以下三个部分: + +### 模板 + +```markdown +## 背景 / Background + +<一两句说明问题现状或触发原因> + + + +--- + +## 目的 / Purpose + +<本次改动要解决的问题或达到的目标> + + + +--- + +## 改动内容 / Changes + +### 后端 / Backend + +- **改动点 1**:说明 +- **改动点 2**:说明 + +--- + +- **Change 1**: description +- **Change 2**: description + +### 前端 / Frontend + +- **改动点 1**:说明 +- **改动点 2**:说明 + +--- + +- **Change 1**: description +- **Change 2**: description + +--- + +## 截图 / Screenshot(可选) + +ASCII 示意图或实际截图 +``` + +### 规范要点 + +- **标题**:使用 conventional commits 格式,如 `feat(scope): description` +- **中英文顺序**:同一段落先中文后英文,用空行分隔,不用 `---` 分割同段内容 +- **改动分类**:按 Backend / Frontend / Config 等模块分组,先列中文要点再列英文要点 +- **截图/示意图**:有 UI 变动时必须附上,可用 ASCII 示意布局 +- **目标分支**:提交到 `touwaeriol/sub2api` 的 `main` 分支 diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 00000000..b634af05 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,1337 @@ +# Sub2API 开发说明 + +## 版本管理策略 + +### 版本号规则 + +我们在官方版本号后面添加自己的小版本号: + +- 官方版本:`v0.1.68` +- 我们的版本:`v0.1.68.1`、`v0.1.68.2`(递增) + +### 分支策略 + +| 分支 | 说明 | +|------|------| +| `main` | 我们的主分支,包含所有定制功能 | +| `release/custom-X.Y.Z` | 基于官方 `vX.Y.Z` 的发布分支 | +| `upstream/main` | 上游官方仓库 | + +--- + +## 发布流程(基于新官方版本) + +当官方发布新版本(如 `v0.1.69`)时: + +### 1. 同步上游并创建发布分支 + +```bash +# 获取上游最新代码 +git fetch upstream --tags + +# 基于官方标签创建新的发布分支 +git checkout v0.1.69 -b release/custom-0.1.69 + +# 合并我们的 main 分支(包含所有定制功能) +git merge main --no-edit + +# 解决可能的冲突后继续 +``` + +### 2. 更新版本号并打标签 + +```bash +# 更新版本号文件 +echo "0.1.69.1" > backend/cmd/server/VERSION +git add backend/cmd/server/VERSION +git commit -m "chore: bump version to 0.1.69.1" + +# 打上我们自己的标签 +git tag v0.1.69.1 + +# 推送分支和标签 +git push origin release/custom-0.1.69 +git push origin v0.1.69.1 +``` + +### 3. 更新 main 分支 + +```bash +# 将发布分支合并回 main,保持 main 包含最新定制功能 +git checkout main +git merge release/custom-0.1.69 +git push origin main +``` + +--- + +## 热修复发布(在现有版本上修复) + +当需要在当前版本上发布修复时: + +```bash +# 在当前发布分支上修复 +git checkout release/custom-0.1.68 +# ... 进行修复 ... +git commit -m "fix: 修复描述" + +# 递增小版本号 +echo "0.1.68.2" > backend/cmd/server/VERSION +git add backend/cmd/server/VERSION +git commit -m "chore: bump version to 0.1.68.2" + +# 打标签并推送 +git tag v0.1.68.2 +git push origin release/custom-0.1.68 +git push origin v0.1.68.2 + +# 同步修复到 main +git checkout main +git cherry-pick +git push origin main +``` + +--- + +## 服务器部署流程 + +### 前置条件 + +- 本地已配置 SSH 别名 `clicodeplus` 连接到生产服务器(运行服务 + 构建镜像) +- 生产服务器部署目录:`/root/sub2api`(正式)、`/root/sub2api-beta`(测试)、`/root/sub2api-star`(Star) +- 生产服务器使用 Docker Compose 部署 +- **镜像在生产服务器本机构建**,使用资源限制的 `limited-builder` 构建器(3 核 CPU、4G 内存),避免构建占满服务器资源影响线上服务 + +### 服务器角色说明 + +| 服务器 | SSH 别名 | 职责 | +|--------|----------|------| +| 生产服务器 | `clicodeplus` | 拉取代码、构建镜像、运行服务、部署验证 | +| 数据库服务器 | `db-clicodeplus` | PostgreSQL 16 + Redis 7,所有环境共用 | + +> 数据库服务器运维手册:`db-clicodeplus:/root/README.md` + +### 构建器说明 + +生产服务器上配置了资源限制的 Docker buildx 构建器 `limited-builder`,**所有构建操作必须使用此构建器**: + +- **构建器名称**:`limited-builder` +- **驱动**:`docker-container`(独立容器运行 BuildKit) +- **资源限制**:3 核 CPU、4G 内存(服务器共 6 核 8G,预留一半给线上服务) +- **容器名**:`buildx_buildkit_limited-builder0` + +```bash +# 构建命令格式(必须指定 --builder) +ssh clicodeplus "cd /root/sub2api && docker buildx build --builder limited-builder --no-cache --load -t sub2api:latest -f Dockerfile ." + +# 查看构建器状态 +ssh clicodeplus "docker buildx inspect limited-builder" + +# 如果构建器容器被意外删除,重新创建: +ssh clicodeplus "docker buildx create --name limited-builder --driver docker-container --driver-opt 'default-load=true' && docker buildx inspect --builder limited-builder --bootstrap && docker update --cpus=3 --memory=4g --memory-swap=4g buildx_buildkit_limited-builder0" +``` + +### 部署环境说明 + +| 环境 | 目录(生产服务器) | 端口 | 数据库 | Redis DB | 容器名 | +|------|------|------|--------|----------|--------| +| 正式 | `/root/sub2api` | 8080 | `sub2api` | 0 | `sub2api` | +| Beta | `/root/sub2api-beta` | 8084 | `beta` | 2 | `sub2api-beta` | +| OpenAI | `/root/sub2api-openai` | 8083 | `openai` | 3 | `sub2api-openai` | +| Star | `/root/sub2api-star` | 8086 | `star` | 4 | `sub2api-star` | + +### 外部数据库与 Redis + +所有环境(正式、Beta、OpenAI、Star)共用 `db.clicodeplus.com` 上的 **PostgreSQL 16** 和 **Redis 7**,不使用容器内数据库或 Redis。 + +**PostgreSQL**(端口 5432,TLS 加密,scram-sha-256 认证): + +| 环境 | 用户名 | 数据库 | +|------|--------|--------| +| 正式 | `sub2api` | `sub2api` | +| Beta | `beta` | `beta` | +| OpenAI | `openai` | `openai` | +| Star | `star` | `star` | + +**Redis**(端口 6379,密码认证): + +| 环境 | DB | +|------|-----| +| 正式 | 0 | +| Beta | 2 | +| OpenAI | 3 | +| Star | 4 | + +**配置方式**: +- 数据库通过 `.env` 中的 `DATABASE_HOST`、`DATABASE_SSLMODE`、`POSTGRES_USER`、`POSTGRES_PASSWORD`、`POSTGRES_DB` 配置 +- Redis 通过 `docker-compose.override.yml` 覆盖 `REDIS_HOST`(因主 compose 文件硬编码为 `redis`),密码通过 `.env` 中的 `REDIS_PASSWORD` 配置 +- 各环境的 `docker-compose.override.yml` 已通过 `depends_on: !reset {}` 和 `redis: profiles: [disabled]` 去掉了对容器 Redis 的依赖 + +#### 数据库操作命令 + +通过 SSH 在服务器上执行数据库操作: + +```bash +# 正式环境 - 查询迁移记录 +ssh clicodeplus "source /root/sub2api/deploy/.env && PGPASSWORD=\"\$POSTGRES_PASSWORD\" psql -h \$DATABASE_HOST -U \$POSTGRES_USER -d \$POSTGRES_DB -c 'SELECT * FROM schema_migrations ORDER BY applied_at DESC LIMIT 5;'" + +# Beta 环境 - 查询迁移记录 +ssh clicodeplus "source /root/sub2api-beta/deploy/.env && PGPASSWORD=\"\$POSTGRES_PASSWORD\" psql -h \$DATABASE_HOST -U \$POSTGRES_USER -d \$POSTGRES_DB -c 'SELECT * FROM schema_migrations ORDER BY applied_at DESC LIMIT 5;'" + +# Beta 环境 - 清除指定迁移记录(重新执行迁移) +ssh clicodeplus "source /root/sub2api-beta/deploy/.env && PGPASSWORD=\"\$POSTGRES_PASSWORD\" psql -h \$DATABASE_HOST -U \$POSTGRES_USER -d \$POSTGRES_DB -c \"DELETE FROM schema_migrations WHERE filename LIKE '%049%';\"" + +# Beta 环境 - 更新账号数据 +ssh clicodeplus "source /root/sub2api-beta/deploy/.env && PGPASSWORD=\"\$POSTGRES_PASSWORD\" psql -h \$DATABASE_HOST -U \$POSTGRES_USER -d \$POSTGRES_DB -c \"UPDATE accounts SET credentials = credentials - 'model_mapping' WHERE platform = 'antigravity';\"" +``` + +> **注意**:使用 `source .env` 加载环境变量,避免在命令行中暴露密码。 + +### 部署步骤 + +**重要:每次部署都必须递增版本号!** + +#### 0. 递增版本号并推送(本地操作) + +每次部署前,先在本地递增小版本号并确保推送成功: + +```bash +# 查看当前版本号 +cat backend/cmd/server/VERSION +# 假设当前是 0.1.69.1 + +# 递增版本号 +echo "0.1.69.2" > backend/cmd/server/VERSION +git add backend/cmd/server/VERSION +git commit -m "chore: bump version to 0.1.69.2" +git push origin release/custom-0.1.69 + +# ⚠️ 确认推送成功(必须看到分支更新输出,不能有 rejected 错误) +``` + +> **检查点**:如果有其他未提交的改动,应先 commit 并 push,确保 release 分支上的所有代码都已推送到远程。 + +#### 1. 生产服务器拉取代码 + +```bash +# 拉取最新代码并切换分支 +ssh clicodeplus "cd /root/sub2api && git fetch fork && git checkout -B release/custom-0.1.69 fork/release/custom-0.1.69" + +# ⚠️ 验证版本号与步骤 0 一致 +ssh clicodeplus "cat /root/sub2api/backend/cmd/server/VERSION" +``` + +#### 2. 生产服务器构建镜像(使用 limited-builder) + +```bash +ssh clicodeplus "cd /root/sub2api && docker buildx build --builder limited-builder --no-cache --load -t sub2api:latest -f Dockerfile ." + +# ⚠️ 必须看到构建成功输出,如果失败需要先排查问题 +``` + +> **常见构建问题**: +> - 构建器未启动 → `docker buildx inspect --builder limited-builder --bootstrap` +> - 磁盘空间不足 → `docker system prune -f` 清理无用镜像 +> - 构建器被删除 → 参见上方「构建器说明」重新创建 + +#### 3. 更新镜像标签并重启 + +```bash +# 更新镜像标签并重启 +ssh clicodeplus "docker tag sub2api:latest weishaw/sub2api:latest" +ssh clicodeplus "cd /root/sub2api/deploy && docker compose up -d --force-recreate sub2api" +``` + +#### 4. 验证部署 + +```bash +# 查看启动日志 +ssh clicodeplus "docker logs sub2api --tail 20" + +# 确认版本号(必须与步骤 0 中设置的版本号一致) +ssh clicodeplus "cat /root/sub2api/backend/cmd/server/VERSION" + +# 检查容器状态(必须显示 healthy) +ssh clicodeplus "docker ps | grep sub2api" +``` + +--- + +## Beta 并行部署(不影响现网) + +目标:在同一台服务器上并行启动一个 beta 实例(例如端口 `8084`),**严禁改动/重启**现网实例(默认目录 `/root/sub2api`)。 + +### 设计原则 + +- **新目录**:beta 使用独立目录,例如 `/root/sub2api-beta`。 +- **敏感信息只放 `.env`**:beta 的数据库密码、JWT_SECRET 等只写入 `/root/sub2api-beta/deploy/.env`,不要提交到 git。 +- **独立 Compose Project**:通过 `docker compose -p sub2api-beta ...` 启动,确保 network/volume 隔离。 +- **独立端口**:通过 `.env` 的 `SERVER_PORT` 映射宿主机端口(例如 `8084:8080`)。 + +### 前置检查 + +```bash +# 1) 确保 8084 未被占用 +ssh clicodeplus "ss -ltnp | grep :8084 || echo '8084 is free'" + +# 2) 确认现网容器还在(只读检查) +ssh clicodeplus "docker ps --format 'table {{.Names}}\t{{.Image}}\t{{.Ports}}' | sed -n '1,200p'" +``` + +### 首次部署步骤 + +> **构建说明**:正式和 beta 通过不同的镜像标签区分(`sub2api:latest` 用于正式,`sub2api:beta` 用于测试),均在生产服务器本机使用 `limited-builder` 构建。 + +```bash +# 1) 在生产服务器上拉取代码并构建 beta 镜像 +ssh clicodeplus "cd /root/sub2api-beta && git fetch --all --tags && git checkout -f release/custom-0.1.71 && git reset --hard origin/release/custom-0.1.71" +ssh clicodeplus "cd /root/sub2api-beta && docker buildx build --builder limited-builder --no-cache --load -t sub2api:beta -f Dockerfile ." + +# 2) 在生产服务器上准备 beta 环境 +ssh clicodeplus + +# 克隆代码(仅用于 deploy 配置和版本号确认,不在此构建) +cd /root +git clone https://github.com/touwaeriol/sub2api.git sub2api-beta +cd /root/sub2api-beta +git checkout release/custom-0.1.71 + +# 4) 准备 beta 的 .env(敏感信息只写这里) +cd /root/sub2api-beta/deploy + +# 推荐:从现网 .env 复制,保证除 DB 名/用户/端口外完全一致 +cp -f /root/sub2api/deploy/.env ./.env + +# 仅修改以下三项(其他保持不变) +perl -pi -e 's/^SERVER_PORT=.*/SERVER_PORT=8084/' ./.env +perl -pi -e 's/^POSTGRES_USER=.*/POSTGRES_USER=beta/' ./.env +perl -pi -e 's/^POSTGRES_DB=.*/POSTGRES_DB=beta/' ./.env + +# 5) 写 compose override(避免与现网容器名冲突,镜像使用本机构建的 sub2api:beta,Redis 使用外部服务) +cat > docker-compose.override.yml <<'YAML' +services: + sub2api: + image: sub2api:beta + container_name: sub2api-beta + environment: + - DATABASE_HOST=${DATABASE_HOST:-postgres} + - DATABASE_SSLMODE=${DATABASE_SSLMODE:-disable} + - REDIS_HOST=db.clicodeplus.com + depends_on: !reset {} + redis: + profiles: + - disabled +YAML + +# 6) 启动 beta(独立 project,确保不影响现网) +cd /root/sub2api-beta/deploy +docker compose -p sub2api-beta --env-file .env -f docker-compose.yml -f docker-compose.override.yml up -d + +# 7) 验证 beta +curl -fsS http://127.0.0.1:8084/health +docker logs sub2api-beta --tail 50 +``` + +### 数据库配置约定(beta) + +- 数据库地址/SSL/密码:与现网一致(从现网 `.env` 复制即可),均指向 `db.clicodeplus.com`。 +- 仅修改: + - `POSTGRES_USER=beta` + - `POSTGRES_DB=beta` + - `REDIS_DB=2` + +注意:需要数据库侧已存在 `beta` 用户与 `beta` 数据库,并授予权限;否则容器会启动失败并不断重启。 + +### 更新 beta(本机构建 + 仅重启 beta 容器) + +```bash +# 1) 生产服务器拉取代码并构建镜像 +ssh clicodeplus "cd /root/sub2api-beta && git fetch --all --tags && git checkout -f release/custom-0.1.71 && git reset --hard origin/release/custom-0.1.71" +ssh clicodeplus "cd /root/sub2api-beta && docker buildx build --builder limited-builder --no-cache --load -t sub2api:beta -f Dockerfile ." +# ⚠️ 必须看到构建成功输出 + +# 2) 重启 beta 容器并验证 +ssh clicodeplus "cd /root/sub2api-beta/deploy && docker compose -p sub2api-beta --env-file .env -f docker-compose.yml -f docker-compose.override.yml up -d --no-deps --force-recreate sub2api" +ssh clicodeplus "sleep 5 && curl -fsS http://127.0.0.1:8084/health" +ssh clicodeplus "cat /root/sub2api-beta/backend/cmd/server/VERSION" +``` + +### 停止/回滚 beta(只影响 beta) + +```bash +ssh clicodeplus "cd /root/sub2api-beta/deploy && docker compose -p sub2api-beta -f docker-compose.yml -f docker-compose.override.yml down" +``` + +--- + +## 服务器首次部署 + +### 1. 生产服务器:克隆代码并配置环境 + +```bash +ssh clicodeplus +cd /root +git clone https://github.com/Wei-Shaw/sub2api.git +cd sub2api + +# 添加 fork 仓库 +git remote add fork https://github.com/touwaeriol/sub2api.git +git fetch fork +git checkout -B release/custom-0.1.69 fork/release/custom-0.1.69 + +# 配置环境变量 +cd deploy +cp .env.example .env +vim .env # 配置 DATABASE_HOST=db.clicodeplus.com, POSTGRES_PASSWORD, REDIS_PASSWORD, JWT_SECRET 等 + +# 创建 override 文件(Redis 指向外部服务,去掉容器 Redis 依赖) +cat > docker-compose.override.yml <<'YAML' +services: + sub2api: + environment: + - REDIS_HOST=db.clicodeplus.com + depends_on: !reset {} + redis: + profiles: + - disabled +YAML +``` + +### 2. 生产服务器:创建构建器并构建镜像 + +```bash +# 创建资源限制的构建器(首次执行一次即可) +docker buildx create --name limited-builder --driver docker-container --driver-opt "default-load=true" +docker buildx inspect --builder limited-builder --bootstrap +docker update --cpus=3 --memory=4g --memory-swap=4g buildx_buildkit_limited-builder0 + +# 构建镜像 +cd /root/sub2api +docker buildx build --builder limited-builder --no-cache --load -t sub2api:latest -f Dockerfile . + +# 更新镜像标签并启动 +docker tag sub2api:latest weishaw/sub2api:latest +cd /root/sub2api/deploy && docker compose up -d +``` + +### 3. 验证部署 + +```bash +# 查看应用日志 +docker logs sub2api --tail 50 + +# 检查健康状态 +curl http://localhost:8080/health + +# 确认版本号 +cat /root/sub2api/backend/cmd/server/VERSION +``` + +### 4. 常用运维命令 + +```bash +# 查看实时日志 +docker logs -f sub2api + +# 重启服务 +docker compose restart sub2api + +# 停止所有服务 +docker compose down + +# 停止并删除数据卷(慎用!会删除数据库数据) +docker compose down -v + +# 查看资源使用情况 +docker stats sub2api +``` + +--- + +## 定制功能说明 + +当前定制分支包含以下功能(相对于官方版本): + +### UI/UX 定制 + +| 功能 | 说明 | +|------|------| +| 首页优化 | 面向用户的价值主张设计 | +| 移除 GitHub 链接 | 用户菜单中不显示 GitHub 导航 | +| 微信客服按钮 | 首页悬浮微信客服入口 | +| 限流时间精确显示 | 账号限流时间显示精确到秒 | + +### Antigravity 平台增强 + +| 功能 | 说明 | +|------|------| +| Scope 级别限流 | 按配额域(claude/gemini_text/gemini_image)独立限流,避免整个账号被锁定 | +| 模型级别限流 | 按具体模型(如 claude-opus-4-5)独立限流,更精细的限流控制 | +| 限流预检查 | 调度时预检查账号/模型限流状态,避免选中已限流账号 | +| 秒级冷却时间 | 支持 429 响应的秒级精确冷却时间 | +| 身份注入优化 | 模型身份信息注入 + 静默边界防止身份泄露 | +| thoughtSignature 修复 | Gemini 3 函数调用 400 错误修复 | +| max_tokens 自动修正 | 自动修正 max_tokens <= budget_tokens 导致的 400 错误 | + +### 调度算法优化 + +| 功能 | 说明 | +|------|------| +| 分层过滤选择 | 调度算法从全排序改为分层过滤,提升性能 | +| LRU 随机选择 | 相同 LRU 时间时随机选择,避免账号集中 | +| 限流等待阈值配置化 | 可配置的限流等待阈值 | + +### 运维增强 + +| 功能 | 说明 | +|------|------| +| Scope 限流统计 | 运维界面展示 Antigravity 账号 scope 级别限流统计 | +| 账号限流状态显示 | 账号列表显示 scope 和模型级别限流状态 | +| 清除限流按钮增强 | 有 scope/模型限流时也显示清除限流按钮 | + +### 其他修复 + +| 功能 | 说明 | +|------|------| +| .gitattributes | 确保迁移文件使用 LF 换行符(解决 Windows 下 SQL 摘要不一致) | +| 部署配置优化 | DATABASE_HOST 和 DATABASE_SSLMODE 可通过 .env 配置 | + +--- + +## Admin API 接口文档 + +### ⚠️ API 操作流程规范 + +当收到操作正式环境 Web 界面的新需求,但文档中未记录对应 API 接口时,**必须按以下流程执行**: + +1. **探索接口**:通过代码库搜索路由定义(`backend/internal/server/routes/`)、Handler(`backend/internal/handler/admin/`)和请求结构体,确定正确的 API 端点、请求方法、请求体格式 +2. **更新文档**:将新发现的接口补充到本文档的 Admin API 接口文档章节中,包含端点、参数说明和 curl 示例 +3. **执行操作**:根据最新文档中记录的接口完成用户需求 + +> **目的**:避免每次遇到相同需求都重复探索代码库,确保 API 文档持续完善,后续操作可直接查阅文档执行。 + +--- + +### 认证方式 + +所有 Admin API 通过 `x-api-key` 请求头传递 Admin API Key 认证。 + +``` +x-api-key: admin-xxx +``` + +> **使用说明**:Admin API Key 统一存放在项目根目录 `.env` 文件的 `ADMIN_API_KEY` 变量中(该文件已被 `.gitignore` 排除,不会提交到代码库)。操作前先从 `.env` 读取密钥;若密钥失效(返回 401),应提示用户提供新的密钥并更新到 `.env` 中。Token 格式为 `admin-` + 64 位十六进制字符,在管理后台 `设置 > Admin API Key` 中生成。**请勿将实际 token 写入文档或代码中。** + +### 环境地址 + +| 环境 | 基础地址 | 说明 | +|------|----------|------| +| 正式 | `https://clicodeplus.com` | 生产环境 | +| Beta | `http://<服务器IP>:8084` | 仅内网访问 | +| OpenAI | `http://<服务器IP>:8083` | 仅内网访问 | +| Star | `https://hyntoken.com` | 独立环境 | + +> 以下接口文档中,`${BASE}` 代表环境基础地址,`${KEY}` 代表 `.env` 中的 `ADMIN_API_KEY`。操作前执行 `source .env` 或 `export KEY=$ADMIN_API_KEY` 加载。 + +--- + +### 1. 账号管理 + +#### 1.1 获取账号列表 + +``` +GET /api/v1/admin/accounts +``` + +**查询参数**: + +| 参数 | 类型 | 必填 | 说明 | +|------|------|------|------| +| `platform` | string | 否 | 平台筛选:`antigravity` / `anthropic` / `openai` / `gemini` | +| `type` | string | 否 | 账号类型:`oauth` / `api_key` / `cookie` | +| `status` | string | 否 | 状态:`active` / `disabled` / `error` | +| `search` | string | 否 | 搜索关键词(名称、备注) | +| `page` | int | 否 | 页码,默认 1 | +| `page_size` | int | 否 | 每页数量,默认 20 | + +```bash +curl -s "${BASE}/api/v1/admin/accounts?platform=antigravity&page=1&page_size=100" \ + -H "x-api-key: ${KEY}" +``` + +**响应**: +```json +{ + "code": 0, + "message": "success", + "data": { + "items": [{"id": 1, "name": "xxx@gmail.com", "platform": "antigravity", "status": "active", ...}], + "total": 66 + } +} +``` + +#### 1.2 获取账号详情 + +``` +GET /api/v1/admin/accounts/:id +``` + +```bash +curl -s "${BASE}/api/v1/admin/accounts/1" -H "x-api-key: ${KEY}" +``` + +#### 1.3 测试账号连接 + +``` +POST /api/v1/admin/accounts/:id/test +``` + +**请求体**(JSON,可选): + +| 字段 | 类型 | 必填 | 说明 | +|------|------|------|------| +| `model_id` | string | 否 | 指定测试模型,如 `claude-opus-4-6`;不传则使用默认模型 | + +**响应格式**:SSE(Server-Sent Events)流 + +```bash +curl -N -X POST "${BASE}/api/v1/admin/accounts/1/test" \ + -H "x-api-key: ${KEY}" \ + -H "Content-Type: application/json" \ + -d '{"model_id": "claude-opus-4-6"}' +``` + +**SSE 事件类型**: + +| type | 字段 | 说明 | +|------|------|------| +| `test_start` | `model` | 测试开始,返回测试模型名 | +| `content` | `text` | 模型响应内容(流式文本片段) | +| `test_end` | `success`, `error` | 测试结束,`success=true` 表示成功 | +| `error` | `text` | 错误信息 | + +#### 1.4 清除账号限流 + +``` +POST /api/v1/admin/accounts/:id/clear-rate-limit +``` + +```bash +curl -X POST "${BASE}/api/v1/admin/accounts/1/clear-rate-limit" \ + -H "x-api-key: ${KEY}" +``` + +#### 1.5 清除账号错误状态 + +``` +POST /api/v1/admin/accounts/:id/clear-error +``` + +```bash +curl -X POST "${BASE}/api/v1/admin/accounts/1/clear-error" \ + -H "x-api-key: ${KEY}" +``` + +#### 1.6 获取账号可用模型 + +``` +GET /api/v1/admin/accounts/:id/models +``` + +```bash +curl -s "${BASE}/api/v1/admin/accounts/1/models" -H "x-api-key: ${KEY}" +``` + +#### 1.7 刷新 OAuth Token + +``` +POST /api/v1/admin/accounts/:id/refresh +``` + +```bash +curl -X POST "${BASE}/api/v1/admin/accounts/1/refresh" -H "x-api-key: ${KEY}" +``` + +#### 1.8 刷新账号等级 + +``` +POST /api/v1/admin/accounts/:id/refresh-tier +``` + +```bash +curl -X POST "${BASE}/api/v1/admin/accounts/1/refresh-tier" -H "x-api-key: ${KEY}" +``` + +#### 1.9 获取账号统计 + +``` +GET /api/v1/admin/accounts/:id/stats +``` + +```bash +curl -s "${BASE}/api/v1/admin/accounts/1/stats" -H "x-api-key: ${KEY}" +``` + +#### 1.10 获取账号用量 + +``` +GET /api/v1/admin/accounts/:id/usage +``` + +```bash +curl -s "${BASE}/api/v1/admin/accounts/1/usage" -H "x-api-key: ${KEY}" +``` + +#### 1.11 更新单个账号 + +``` +PUT /api/v1/admin/accounts/:id +``` + +**请求体**(JSON,所有字段均为可选,仅传需要更新的字段): + +| 字段 | 类型 | 说明 | +|------|------|------| +| `name` | string | 账号名称 | +| `notes` | *string | 备注 | +| `type` | string | 类型:`oauth` / `setup-token` / `apikey` / `upstream` | +| `credentials` | object | 凭证信息 | +| `extra` | object | 额外配置 | +| `proxy_id` | *int64 | 代理 ID | +| `concurrency` | *int | 并发数 | +| `priority` | *int | 优先级(默认 50) | +| `rate_multiplier` | *float64 | 速率倍数 | +| `status` | string | 状态:`active` / `inactive` | +| `group_ids` | *[]int64 | 分组 ID 列表 | +| `expires_at` | *int64 | 过期时间戳 | +| `auto_pause_on_expired` | *bool | 过期后自动暂停 | + +> 使用指针类型(`*`)的字段可以区分"未提供"和"设置为零值"。 + +```bash +# 示例:更新账号优先级为 100 +curl -X PUT "${BASE}/api/v1/admin/accounts/1" \ + -H "x-api-key: ${KEY}" \ + -H "Content-Type: application/json" \ + -d '{"priority": 100}' +``` + +#### 1.12 批量更新账号 + +``` +POST /api/v1/admin/accounts/bulk-update +``` + +**请求体**(JSON): + +| 字段 | 类型 | 必填 | 说明 | +|------|------|------|------| +| `account_ids` | []int64 | **是** | 要更新的账号 ID 列表 | +| `priority` | *int | 否 | 优先级 | +| `concurrency` | *int | 否 | 并发数 | +| `rate_multiplier` | *float64 | 否 | 速率倍数 | +| `status` | string | 否 | 状态:`active` / `inactive` / `error` | +| `schedulable` | *bool | 否 | 是否可调度 | +| `group_ids` | *[]int64 | 否 | 分组 ID 列表 | +| `proxy_id` | *int64 | 否 | 代理 ID | +| `credentials` | object | 否 | 凭证信息(批量覆盖) | +| `extra` | object | 否 | 额外配置(批量覆盖) | + +```bash +# 示例:批量设置多个账号优先级为 100 +curl -X POST "${BASE}/api/v1/admin/accounts/bulk-update" \ + -H "x-api-key: ${KEY}" \ + -H "Content-Type: application/json" \ + -d '{"account_ids": [1, 2, 3], "priority": 100}' +``` + +#### 1.13 批量测试账号(脚本) + +批量测试指定平台所有账号的指定模型连通性: + +```bash +# 用户需提供:BASE(环境地址)、KEY(admin token)、MODEL(测试模型) +ACCOUNT_IDS=$(curl -s "${BASE}/api/v1/admin/accounts?platform=antigravity&page=1&page_size=100" \ + -H "x-api-key: ${KEY}" | python3 -c " +import json, sys +data = json.load(sys.stdin) +for item in data['data']['items']: + print(f\"{item['id']}|{item['name']}\") +") + +while IFS='|' read -r ID NAME; do + echo "测试账号 ID=${ID} (${NAME})..." + RESPONSE=$(curl -s --max-time 60 -N \ + -X POST "${BASE}/api/v1/admin/accounts/${ID}/test" \ + -H "x-api-key: ${KEY}" \ + -H "Content-Type: application/json" \ + -d "{\"model_id\": \"${MODEL}\"}" 2>&1) + if echo "$RESPONSE" | grep -q '"success":true'; then + echo " ✅ 成功" + elif echo "$RESPONSE" | grep -q '"type":"content"'; then + echo " ✅ 成功(有内容响应)" + else + ERROR_MSG=$(echo "$RESPONSE" | grep -o '"error":"[^"]*"' | tail -1) + echo " ❌ 失败: ${ERROR_MSG}" + fi +done <<< "$ACCOUNT_IDS" +``` + +--- + +### 2. 运维监控 + +#### 2.1 并发统计 + +``` +GET /api/v1/admin/ops/concurrency +``` + +```bash +curl -s "${BASE}/api/v1/admin/ops/concurrency" -H "x-api-key: ${KEY}" +``` + +#### 2.2 账号可用性 + +``` +GET /api/v1/admin/ops/account-availability +``` + +```bash +curl -s "${BASE}/api/v1/admin/ops/account-availability" -H "x-api-key: ${KEY}" +``` + +#### 2.3 实时流量摘要 + +``` +GET /api/v1/admin/ops/realtime-traffic +``` + +```bash +curl -s "${BASE}/api/v1/admin/ops/realtime-traffic" -H "x-api-key: ${KEY}" +``` + +#### 2.4 请求错误列表 + +``` +GET /api/v1/admin/ops/request-errors +``` + +**查询参数**:`page`、`page_size` + +```bash +curl -s "${BASE}/api/v1/admin/ops/request-errors?page=1&page_size=50" \ + -H "x-api-key: ${KEY}" +``` + +#### 2.5 上游错误列表 + +``` +GET /api/v1/admin/ops/upstream-errors +``` + +```bash +curl -s "${BASE}/api/v1/admin/ops/upstream-errors?page=1&page_size=50" \ + -H "x-api-key: ${KEY}" +``` + +#### 2.6 仪表板概览 + +``` +GET /api/v1/admin/ops/dashboard/overview +``` + +```bash +curl -s "${BASE}/api/v1/admin/ops/dashboard/overview" -H "x-api-key: ${KEY}" +``` + +--- + +### 3. 系统设置 + +#### 3.1 获取系统设置 + +``` +GET /api/v1/admin/settings +``` + +```bash +curl -s "${BASE}/api/v1/admin/settings" -H "x-api-key: ${KEY}" +``` + +#### 3.2 更新系统设置 + +``` +PUT /api/v1/admin/settings +``` + +```bash +curl -X PUT "${BASE}/api/v1/admin/settings" \ + -H "x-api-key: ${KEY}" \ + -H "Content-Type: application/json" \ + -d '{ ... }' +``` + +#### 3.3 Admin API Key 状态(脱敏) + +``` +GET /api/v1/admin/settings/admin-api-key +``` + +```bash +curl -s "${BASE}/api/v1/admin/settings/admin-api-key" -H "x-api-key: ${KEY}" +``` + +--- + +### 4. 用户管理 + +#### 4.1 用户列表 + +``` +GET /api/v1/admin/users +``` + +```bash +curl -s "${BASE}/api/v1/admin/users?page=1&page_size=20" -H "x-api-key: ${KEY}" +``` + +#### 4.2 用户详情 + +``` +GET /api/v1/admin/users/:id +``` + +```bash +curl -s "${BASE}/api/v1/admin/users/1" -H "x-api-key: ${KEY}" +``` + +#### 4.3 更新用户余额 + +``` +POST /api/v1/admin/users/:id/balance +``` + +```bash +curl -X POST "${BASE}/api/v1/admin/users/1/balance" \ + -H "x-api-key: ${KEY}" \ + -H "Content-Type: application/json" \ + -d '{"amount": 100, "reason": "充值"}' +``` + +--- + +### 5. 分组管理 + +#### 5.1 分组列表 + +``` +GET /api/v1/admin/groups +``` + +```bash +curl -s "${BASE}/api/v1/admin/groups" -H "x-api-key: ${KEY}" +``` + +#### 5.2 所有分组(不分页) + +``` +GET /api/v1/admin/groups/all +``` + +```bash +curl -s "${BASE}/api/v1/admin/groups/all" -H "x-api-key: ${KEY}" +``` + +--- + +## 注意事项 + +1. **前端必须打包进镜像**:使用 `docker buildx build --builder limited-builder` 在生产服务器(`clicodeplus`)本机构建,Dockerfile 会自动编译前端并 embed 到后端二进制中 + +2. **镜像标签**:docker-compose.yml 使用 `weishaw/sub2api:latest`,本地构建后需要 `docker tag` 覆盖 + +3. **Windows 换行符问题**:已通过 `.gitattributes` 解决,确保 `*.sql` 文件始终使用 LF + +4. **版本号管理**:每次发布必须更新 `backend/cmd/server/VERSION` 并打标签 + +5. **合并冲突**:合并上游新版本时,重点关注以下文件可能的冲突: + - `backend/internal/service/antigravity_gateway_service.go` + - `backend/internal/service/gateway_service.go` + - `backend/internal/pkg/antigravity/request_transformer.go` + +--- + +## Go 代码规范 + +### 1. 函数设计 + +#### 单一职责原则 +- **函数行数**:单个函数常规不应超过 **30 行**,超过时应拆分为子函数。若某段逻辑确实不可拆分(如复杂的状态机、协议解析等),可以例外,但需添加注释说明原因 +- **嵌套层级**:避免超过 3 层嵌套,使用 early return 减少嵌套 + +```go +// ❌ 不推荐:深层嵌套 +func process(data []Item) { + for _, item := range data { + if item.Valid { + if item.Type == "A" { + if item.Status == "active" { + // 业务逻辑... + } + } + } + } +} + +// ✅ 推荐:early return +func process(data []Item) { + for _, item := range data { + if !item.Valid { + continue + } + if item.Type != "A" { + continue + } + if item.Status != "active" { + continue + } + // 业务逻辑... + } +} +``` + +#### 复杂逻辑提取 +将复杂的条件判断或处理逻辑提取为独立函数: + +```go +// ❌ 不推荐:内联复杂逻辑 +if resp.StatusCode == 429 || resp.StatusCode == 503 { + // 80+ 行处理逻辑... +} + +// ✅ 推荐:提取为独立函数 +result := handleRateLimitResponse(resp, params) +switch result.action { +case actionRetry: + continue +case actionBreak: + return result.resp, nil +} +``` + +### 2. 重复代码消除 + +#### 配置获取模式 +将重复的配置获取逻辑提取为方法: + +```go +// ❌ 不推荐:重复代码 +logBody := s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBody +maxBytes := 2048 +if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes > 0 { + maxBytes = s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes +} + +// ✅ 推荐:提取为方法 +func (s *Service) getLogConfig() (logBody bool, maxBytes int) { + maxBytes = 2048 + if s.settingService == nil || s.settingService.cfg == nil { + return false, maxBytes + } + cfg := s.settingService.cfg.Gateway + if cfg.LogUpstreamErrorBodyMaxBytes > 0 { + maxBytes = cfg.LogUpstreamErrorBodyMaxBytes + } + return cfg.LogUpstreamErrorBody, maxBytes +} +``` + +### 3. 常量管理 + +#### 避免魔法数字 +所有硬编码的数值都应定义为常量: + +```go +// ❌ 不推荐 +if retryDelay >= 10*time.Second { + resetAt := time.Now().Add(30 * time.Second) +} + +// ✅ 推荐 +const ( + rateLimitThreshold = 10 * time.Second + defaultRateLimitDuration = 30 * time.Second +) + +if retryDelay >= rateLimitThreshold { + resetAt := time.Now().Add(defaultRateLimitDuration) +} +``` + +#### 注释引用常量名 +在注释中引用常量名而非硬编码值: + +```go +// ❌ 不推荐 +// < 10s: 等待后重试 + +// ✅ 推荐 +// < rateLimitThreshold: 等待后重试 +``` + +### 4. 错误处理 + +#### 使用结构化日志 +优先使用 `slog` 进行结构化日志记录: + +```go +// ❌ 不推荐 +log.Printf("%s status=%d model_rate_limit_failed model=%s error=%v", prefix, statusCode, modelName, err) + +// ✅ 推荐 +slog.Error("failed to set model rate limit", + "prefix", prefix, + "status_code", statusCode, + "model", modelName, + "error", err, +) +``` + +### 5. 测试规范 + +#### Mock 函数签名同步 +修改函数签名时,必须同步更新所有测试中的 mock 函数: + +```go +// 如果修改了 handleError 签名 +handleError func(..., groupID int64, sessionHash string) *Result + +// 必须同步更新测试中的 mock +handleError: func(..., groupID int64, sessionHash string) *Result { + return nil +}, +``` + +#### 测试构建标签 +统一使用测试构建标签: + +```go +//go:build unit + +package service +``` + +### 6. 时间格式解析 + +#### 使用标准库 +优先使用 `time.ParseDuration`,支持所有 Go duration 格式: + +```go +// ❌ 不推荐:手动限制格式 +if !strings.HasSuffix(delay, "s") || strings.Contains(delay, "m") { + continue +} + +// ✅ 推荐:使用标准库 +dur, err := time.ParseDuration(delay) // 支持 "0.5s", "4m50s", "1h30m" 等 +``` + +### 7. 接口设计 + +#### 接口隔离原则 +定义最小化接口,只包含必需的方法: + +```go +// ❌ 不推荐:使用过于宽泛的接口 +type AccountRepository interface { + // 20+ 个方法... +} + +// ✅ 推荐:定义最小化接口 +type ModelRateLimiter interface { + SetModelRateLimit(ctx context.Context, id int64, modelKey string, resetAt time.Time) error +} +``` + +### 8. 并发安全 + +#### 共享数据保护 +访问可能被并发修改的数据时,确保线程安全: + +```go +// 如果 Account.Extra 可能被并发修改 +// 需要使用互斥锁或原子操作保护读取 +func (a *Account) GetRateLimitRemainingTime(model string) time.Duration { + a.mu.RLock() + defer a.mu.RUnlock() + // 读取 Extra 字段... +} +``` + +### 9. 命名规范 + +#### 一致的命名风格 +- 常量使用 camelCase:`rateLimitThreshold` +- 类型使用 PascalCase:`AntigravityQuotaScope` +- 同一概念使用统一命名:`Threshold` 或 `Limit`,不要混用 + +```go +// ❌ 不推荐:命名不一致 +antigravitySmartRetryMinWait // 使用 Min +antigravityRateLimitThreshold // 使用 Threshold + +// ✅ 推荐:统一风格 +antigravityMinRetryWait +antigravityRateLimitThreshold +``` + +### 10. 代码审查清单 + +在提交代码前,检查以下项目: + +- [ ] 函数是否超过 30 行?(不可拆分的逻辑除外,需注释说明) +- [ ] 嵌套是否超过 3 层? +- [ ] 是否有重复代码可以提取? +- [ ] 是否使用了魔法数字? +- [ ] Mock 函数签名是否与实际函数一致? +- [ ] 测试是否覆盖了新增逻辑? +- [ ] 日志是否包含足够的上下文信息? +- [ ] 是否考虑了并发安全? + +--- + +## CI 检查与发布门禁 + +### GitHub Actions 检查项 + +本项目有 4 个 CI 任务,**任何代码推送或发布前都必须全部通过**: + +| Workflow | Job | 说明 | 本地验证命令 | +|----------|-----|------|-------------| +| CI | `test` | 单元测试 + 集成测试 | `cd backend && make test-unit && make test-integration` | +| CI | `golangci-lint` | Go 代码静态检查(golangci-lint v2.7) | `cd backend && golangci-lint run --timeout=5m` | +| Security Scan | `backend-security` | govulncheck + gosec 安全扫描 | `cd backend && govulncheck ./... && gosec -severity high -confidence high ./...` | +| Security Scan | `frontend-security` | pnpm audit 前端依赖安全检查 | `cd frontend && pnpm audit --prod --audit-level=high` | + +### 向上游提交 PR + +PR 目标是上游官方仓库,**只包含通用功能改动**(bug fix、新功能、性能优化等)。 + +**以下文件禁止出现在 PR 中**(属于我们 fork 的定制化内容): +- `CLAUDE.md`、`AGENTS.md` — 我们的开发文档 +- `backend/cmd/server/VERSION` — 我们的版本号文件 +- UI 定制改动(GitHub 链接移除、微信客服按钮、首页定制等) +- 部署配置(`deploy/` 目录下的定制修改) + +**PR 流程**: +1. 从 `develop` 创建功能分支,只包含要提交给上游的改动 +2. 推送分支后,**等待 4 个 CI job 全部通过** +3. 确认通过后再创建 PR +4. 使用 `gh run list --repo touwaeriol/sub2api --branch ` 检查状态 + +### 自有分支推送(develop / main) + +推送到我们自己的 `develop` 或 `main` 分支时,包含所有改动(定制化 + 通用功能)。 + +**推送前必须在本地执行全部 CI 检查**(不要等 GitHub Actions): + +```bash +# 确保 Go 工具链可用(macOS homebrew) +export PATH="/opt/homebrew/bin:$HOME/go/bin:$PATH" + +# 1. 单元测试(必须) +cd backend && make test-unit + +# 2. 集成测试(推荐,需要 Docker) +make test-integration + +# 3. golangci-lint 静态检查(必须) +golangci-lint run --timeout=5m + +# 4. gofmt 格式检查(必须) +gofmt -l ./... +# 如果有输出,运行 gofmt -w 修复 +``` + +**推送后确认**: +1. 使用 `gh run list --repo touwaeriol/sub2api --branch ` 检查 GitHub Actions 状态 +2. 确认 CI 和 Security Scan 两个 workflow 的 4 个 job 全部绿色 ✅ +3. 任何 job 失败必须立即修复,**禁止在 CI 未通过的状态下继续后续操作** + +### 发布版本 + +1. 本地执行上述全部 CI 检查通过 +2. 递增 `backend/cmd/server/VERSION`,提交并推送 +3. 推送后确认 GitHub Actions 的 4 个 CI job 全部通过 +4. **CI 未通过时禁止部署** — 必须先修复问题 +5. 使用 `gh run list --repo touwaeriol/sub2api --limit 10` 确认状态 + +### 常见 CI 失败原因及修复 +- **gofmt**:struct 字段对齐不一致 → 运行 `gofmt -w ` 修复 +- **golangci-lint**:未使用的变量/导入 → 删除或使用 `_` 忽略 +- **test 失败**:mock 函数签名不一致 → 同步更新 mock +- **gosec**:安全漏洞 → 根据提示修复或添加例外 + +--- + +## PR 描述格式规范 + +所有 PR 描述使用中英文同步(先中文、后英文),包含以下三个部分: + +### 模板 + +```markdown +## 背景 / Background + +<一两句说明问题现状或触发原因> + + + +--- + +## 目的 / Purpose + +<本次改动要解决的问题或达到的目标> + + + +--- + +## 改动内容 / Changes + +### 后端 / Backend + +- **改动点 1**:说明 +- **改动点 2**:说明 + +--- + +- **Change 1**: description +- **Change 2**: description + +### 前端 / Frontend + +- **改动点 1**:说明 +- **改动点 2**:说明 + +--- + +- **Change 1**: description +- **Change 2**: description + +--- + +## 截图 / Screenshot(可选) + +ASCII 示意图或实际截图 +``` + +### 规范要点 + +- **标题**:使用 conventional commits 格式,如 `feat(scope): description` +- **中英文顺序**:同一段落先中文后英文,用空行分隔,不用 `---` 分割同段内容 +- **改动分类**:按 Backend / Frontend / Config 等模块分组,先列中文要点再列英文要点 +- **截图/示意图**:有 UI 变动时必须附上,可用 ASCII 示意布局 +- **目标分支**:提交到 `touwaeriol/sub2api` 的 `main` 分支 diff --git a/Dockerfile b/Dockerfile index 1493e8a7..8517f2fa 100644 --- a/Dockerfile +++ b/Dockerfile @@ -7,7 +7,7 @@ # ============================================================================= ARG NODE_IMAGE=node:24-alpine -ARG GOLANG_IMAGE=golang:1.25.7-alpine +ARG GOLANG_IMAGE=golang:1.26.1-alpine ARG ALPINE_IMAGE=alpine:3.21 ARG GOPROXY=https://goproxy.cn,direct ARG GOSUMDB=sum.golang.google.cn diff --git a/README.md b/README.md index 1e2f2290..c83bd27e 100644 --- a/README.md +++ b/README.md @@ -150,14 +150,14 @@ mkdir -p sub2api-deploy && cd sub2api-deploy curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/docker-deploy.sh | bash # Start services -docker-compose -f docker-compose.local.yml up -d +docker-compose up -d # View logs -docker-compose -f docker-compose.local.yml logs -f sub2api +docker-compose logs -f sub2api ``` **What the script does:** -- Downloads `docker-compose.local.yml` and `.env.example` +- Downloads `docker-compose.local.yml` (saved as `docker-compose.yml`) and `.env.example` - Generates secure credentials (JWT_SECRET, TOTP_ENCRYPTION_KEY, POSTGRES_PASSWORD) - Creates `.env` file with auto-generated secrets - Creates data directories (uses local directories for easy backup/migration) @@ -522,6 +522,28 @@ sub2api/ └── install.sh # One-click installation script ``` +## Disclaimer + +> **Please read carefully before using this project:** +> +> :rotating_light: **Terms of Service Risk**: Using this project may violate Anthropic's Terms of Service. Please read Anthropic's user agreement carefully before use. All risks arising from the use of this project are borne solely by the user. +> +> :book: **Disclaimer**: This project is for technical learning and research purposes only. The author assumes no responsibility for account suspension, service interruption, or any other losses caused by the use of this project. + +--- + +## Star History + + + + + + Star History Chart + + + +--- + ## License MIT License diff --git a/README_CN.md b/README_CN.md index 316cab94..a5ad8a94 100644 --- a/README_CN.md +++ b/README_CN.md @@ -154,14 +154,14 @@ mkdir -p sub2api-deploy && cd sub2api-deploy curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/docker-deploy.sh | bash # 启动服务 -docker-compose -f docker-compose.local.yml up -d +docker-compose up -d # 查看日志 -docker-compose -f docker-compose.local.yml logs -f sub2api +docker-compose logs -f sub2api ``` **脚本功能:** -- 下载 `docker-compose.local.yml` 和 `.env.example` +- 下载 `docker-compose.local.yml`(本地保存为 `docker-compose.yml`)和 `.env.example` - 自动生成安全凭证(JWT_SECRET、TOTP_ENCRYPTION_KEY、POSTGRES_PASSWORD) - 创建 `.env` 文件并填充自动生成的密钥 - 创建数据目录(使用本地目录,便于备份和迁移) @@ -588,6 +588,28 @@ sub2api/ └── install.sh # 一键安装脚本 ``` +## 免责声明 + +> **使用本项目前请仔细阅读:** +> +> :rotating_light: **服务条款风险**: 使用本项目可能违反 Anthropic 的服务条款。请在使用前仔细阅读 Anthropic 的用户协议,使用本项目的一切风险由用户自行承担。 +> +> :book: **免责声明**: 本项目仅供技术学习和研究使用,作者不对因使用本项目导致的账户封禁、服务中断或其他损失承担任何责任。 + +--- + +## Star History + + + + + + Star History Chart + + + +--- + ## 许可证 MIT License diff --git a/backend/.golangci.yml b/backend/.golangci.yml index 68b76751..92ba3916 100644 --- a/backend/.golangci.yml +++ b/backend/.golangci.yml @@ -93,20 +93,13 @@ linters: check-escaping-errors: true staticcheck: # https://staticcheck.dev/docs/configuration/options/#dot_import_whitelist - # Default: ["github.com/mmcloughlin/avo/build", "github.com/mmcloughlin/avo/operand", "github.com/mmcloughlin/avo/reg"] dot-import-whitelist: - fmt # https://staticcheck.dev/docs/configuration/options/#initialisms - # Default: ["ACL", "API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "QPS", "RAM", "RPC", "SLA", "SMTP", "SQL", "SSH", "TCP", "TLS", "TTL", "UDP", "UI", "GID", "UID", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XMPP", "XSRF", "XSS", "SIP", "RTP", "AMQP", "DB", "TS"] initialisms: [ "ACL", "API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "QPS", "RAM", "RPC", "SLA", "SMTP", "SQL", "SSH", "TCP", "TLS", "TTL", "UDP", "UI", "GID", "UID", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XMPP", "XSRF", "XSS", "SIP", "RTP", "AMQP", "DB", "TS" ] # https://staticcheck.dev/docs/configuration/options/#http_status_code_whitelist - # Default: ["200", "400", "404", "500"] http-status-code-whitelist: [ "200", "400", "404", "500" ] - # SAxxxx checks in https://staticcheck.dev/docs/configuration/options/#checks - # Example (to disable some checks): [ "all", "-SA1000", "-SA1001"] - # Run `GL_DEBUG=staticcheck golangci-lint run --enable=staticcheck` to see all available checks and enabled by config checks. - # Default: ["all", "-ST1000", "-ST1003", "-ST1016", "-ST1020", "-ST1021", "-ST1022"] - # Temporarily disable style checks to allow CI to pass + # "all" enables every SA/ST/S/QF check; only list the ones to disable. checks: - all - -ST1000 # Package comment format @@ -114,489 +107,19 @@ linters: - -ST1020 # Comment on exported method format - -ST1021 # Comment on exported type format - -ST1022 # Comment on exported variable format - # Invalid regular expression. - # https://staticcheck.dev/docs/checks/#SA1000 - - SA1000 - # Invalid template. - # https://staticcheck.dev/docs/checks/#SA1001 - - SA1001 - # Invalid format in 'time.Parse'. - # https://staticcheck.dev/docs/checks/#SA1002 - - SA1002 - # Unsupported argument to functions in 'encoding/binary'. - # https://staticcheck.dev/docs/checks/#SA1003 - - SA1003 - # Suspiciously small untyped constant in 'time.Sleep'. - # https://staticcheck.dev/docs/checks/#SA1004 - - SA1004 - # Invalid first argument to 'exec.Command'. - # https://staticcheck.dev/docs/checks/#SA1005 - - SA1005 - # 'Printf' with dynamic first argument and no further arguments. - # https://staticcheck.dev/docs/checks/#SA1006 - - SA1006 - # Invalid URL in 'net/url.Parse'. - # https://staticcheck.dev/docs/checks/#SA1007 - - SA1007 - # Non-canonical key in 'http.Header' map. - # https://staticcheck.dev/docs/checks/#SA1008 - - SA1008 - # '(*regexp.Regexp).FindAll' called with 'n == 0', which will always return zero results. - # https://staticcheck.dev/docs/checks/#SA1010 - - SA1010 - # Various methods in the "strings" package expect valid UTF-8, but invalid input is provided. - # https://staticcheck.dev/docs/checks/#SA1011 - - SA1011 - # A nil 'context.Context' is being passed to a function, consider using 'context.TODO' instead. - # https://staticcheck.dev/docs/checks/#SA1012 - - SA1012 - # 'io.Seeker.Seek' is being called with the whence constant as the first argument, but it should be the second. - # https://staticcheck.dev/docs/checks/#SA1013 - - SA1013 - # Non-pointer value passed to 'Unmarshal' or 'Decode'. - # https://staticcheck.dev/docs/checks/#SA1014 - - SA1014 - # Using 'time.Tick' in a way that will leak. Consider using 'time.NewTicker', and only use 'time.Tick' in tests, commands and endless functions. - # https://staticcheck.dev/docs/checks/#SA1015 - - SA1015 - # Trapping a signal that cannot be trapped. - # https://staticcheck.dev/docs/checks/#SA1016 - - SA1016 - # Channels used with 'os/signal.Notify' should be buffered. - # https://staticcheck.dev/docs/checks/#SA1017 - - SA1017 - # 'strings.Replace' called with 'n == 0', which does nothing. - # https://staticcheck.dev/docs/checks/#SA1018 - - SA1018 - # Using a deprecated function, variable, constant or field. - # https://staticcheck.dev/docs/checks/#SA1019 - - SA1019 - # Using an invalid host:port pair with a 'net.Listen'-related function. - # https://staticcheck.dev/docs/checks/#SA1020 - - SA1020 - # Using 'bytes.Equal' to compare two 'net.IP'. - # https://staticcheck.dev/docs/checks/#SA1021 - - SA1021 - # Modifying the buffer in an 'io.Writer' implementation. - # https://staticcheck.dev/docs/checks/#SA1023 - - SA1023 - # A string cutset contains duplicate characters. - # https://staticcheck.dev/docs/checks/#SA1024 - - SA1024 - # It is not possible to use '(*time.Timer).Reset''s return value correctly. - # https://staticcheck.dev/docs/checks/#SA1025 - - SA1025 - # Cannot marshal channels or functions. - # https://staticcheck.dev/docs/checks/#SA1026 - - SA1026 - # Atomic access to 64-bit variable must be 64-bit aligned. - # https://staticcheck.dev/docs/checks/#SA1027 - - SA1027 - # 'sort.Slice' can only be used on slices. - # https://staticcheck.dev/docs/checks/#SA1028 - - SA1028 - # Inappropriate key in call to 'context.WithValue'. - # https://staticcheck.dev/docs/checks/#SA1029 - - SA1029 - # Invalid argument in call to a 'strconv' function. - # https://staticcheck.dev/docs/checks/#SA1030 - - SA1030 - # Overlapping byte slices passed to an encoder. - # https://staticcheck.dev/docs/checks/#SA1031 - - SA1031 - # Wrong order of arguments to 'errors.Is'. - # https://staticcheck.dev/docs/checks/#SA1032 - - SA1032 - # 'sync.WaitGroup.Add' called inside the goroutine, leading to a race condition. - # https://staticcheck.dev/docs/checks/#SA2000 - - SA2000 - # Empty critical section, did you mean to defer the unlock?. - # https://staticcheck.dev/docs/checks/#SA2001 - - SA2001 - # Called 'testing.T.FailNow' or 'SkipNow' in a goroutine, which isn't allowed. - # https://staticcheck.dev/docs/checks/#SA2002 - - SA2002 - # Deferred 'Lock' right after locking, likely meant to defer 'Unlock' instead. - # https://staticcheck.dev/docs/checks/#SA2003 - - SA2003 - # 'TestMain' doesn't call 'os.Exit', hiding test failures. - # https://staticcheck.dev/docs/checks/#SA3000 - - SA3000 - # Assigning to 'b.N' in benchmarks distorts the results. - # https://staticcheck.dev/docs/checks/#SA3001 - - SA3001 - # Binary operator has identical expressions on both sides. - # https://staticcheck.dev/docs/checks/#SA4000 - - SA4000 - # '&*x' gets simplified to 'x', it does not copy 'x'. - # https://staticcheck.dev/docs/checks/#SA4001 - - SA4001 - # Comparing unsigned values against negative values is pointless. - # https://staticcheck.dev/docs/checks/#SA4003 - - SA4003 - # The loop exits unconditionally after one iteration. - # https://staticcheck.dev/docs/checks/#SA4004 - - SA4004 - # Field assignment that will never be observed. Did you mean to use a pointer receiver?. - # https://staticcheck.dev/docs/checks/#SA4005 - - SA4005 - # A value assigned to a variable is never read before being overwritten. Forgotten error check or dead code?. - # https://staticcheck.dev/docs/checks/#SA4006 - - SA4006 - # The variable in the loop condition never changes, are you incrementing the wrong variable?. - # https://staticcheck.dev/docs/checks/#SA4008 - - SA4008 - # A function argument is overwritten before its first use. - # https://staticcheck.dev/docs/checks/#SA4009 - - SA4009 - # The result of 'append' will never be observed anywhere. - # https://staticcheck.dev/docs/checks/#SA4010 - - SA4010 - # Break statement with no effect. Did you mean to break out of an outer loop?. - # https://staticcheck.dev/docs/checks/#SA4011 - - SA4011 - # Comparing a value against NaN even though no value is equal to NaN. - # https://staticcheck.dev/docs/checks/#SA4012 - - SA4012 - # Negating a boolean twice ('!!b') is the same as writing 'b'. This is either redundant, or a typo. - # https://staticcheck.dev/docs/checks/#SA4013 - - SA4013 - # An if/else if chain has repeated conditions and no side-effects; if the condition didn't match the first time, it won't match the second time, either. - # https://staticcheck.dev/docs/checks/#SA4014 - - SA4014 - # Calling functions like 'math.Ceil' on floats converted from integers doesn't do anything useful. - # https://staticcheck.dev/docs/checks/#SA4015 - - SA4015 - # Certain bitwise operations, such as 'x ^ 0', do not do anything useful. - # https://staticcheck.dev/docs/checks/#SA4016 - - SA4016 - # Discarding the return values of a function without side effects, making the call pointless. - # https://staticcheck.dev/docs/checks/#SA4017 - - SA4017 - # Self-assignment of variables. - # https://staticcheck.dev/docs/checks/#SA4018 - - SA4018 - # Multiple, identical build constraints in the same file. - # https://staticcheck.dev/docs/checks/#SA4019 - - SA4019 - # Unreachable case clause in a type switch. - # https://staticcheck.dev/docs/checks/#SA4020 - - SA4020 - # "x = append(y)" is equivalent to "x = y". - # https://staticcheck.dev/docs/checks/#SA4021 - - SA4021 - # Comparing the address of a variable against nil. - # https://staticcheck.dev/docs/checks/#SA4022 - - SA4022 - # Impossible comparison of interface value with untyped nil. - # https://staticcheck.dev/docs/checks/#SA4023 - - SA4023 - # Checking for impossible return value from a builtin function. - # https://staticcheck.dev/docs/checks/#SA4024 - - SA4024 - # Integer division of literals that results in zero. - # https://staticcheck.dev/docs/checks/#SA4025 - - SA4025 - # Go constants cannot express negative zero. - # https://staticcheck.dev/docs/checks/#SA4026 - - SA4026 - # '(*net/url.URL).Query' returns a copy, modifying it doesn't change the URL. - # https://staticcheck.dev/docs/checks/#SA4027 - - SA4027 - # 'x % 1' is always zero. - # https://staticcheck.dev/docs/checks/#SA4028 - - SA4028 - # Ineffective attempt at sorting slice. - # https://staticcheck.dev/docs/checks/#SA4029 - - SA4029 - # Ineffective attempt at generating random number. - # https://staticcheck.dev/docs/checks/#SA4030 - - SA4030 - # Checking never-nil value against nil. - # https://staticcheck.dev/docs/checks/#SA4031 - - SA4031 - # Comparing 'runtime.GOOS' or 'runtime.GOARCH' against impossible value. - # https://staticcheck.dev/docs/checks/#SA4032 - - SA4032 - # Assignment to nil map. - # https://staticcheck.dev/docs/checks/#SA5000 - - SA5000 - # Deferring 'Close' before checking for a possible error. - # https://staticcheck.dev/docs/checks/#SA5001 - - SA5001 - # The empty for loop ("for {}") spins and can block the scheduler. - # https://staticcheck.dev/docs/checks/#SA5002 - - SA5002 - # Defers in infinite loops will never execute. - # https://staticcheck.dev/docs/checks/#SA5003 - - SA5003 - # "for { select { ..." with an empty default branch spins. - # https://staticcheck.dev/docs/checks/#SA5004 - - SA5004 - # The finalizer references the finalized object, preventing garbage collection. - # https://staticcheck.dev/docs/checks/#SA5005 - - SA5005 - # Infinite recursive call. - # https://staticcheck.dev/docs/checks/#SA5007 - - SA5007 - # Invalid struct tag. - # https://staticcheck.dev/docs/checks/#SA5008 - - SA5008 - # Invalid Printf call. - # https://staticcheck.dev/docs/checks/#SA5009 - - SA5009 - # Impossible type assertion. - # https://staticcheck.dev/docs/checks/#SA5010 - - SA5010 - # Possible nil pointer dereference. - # https://staticcheck.dev/docs/checks/#SA5011 - - SA5011 - # Passing odd-sized slice to function expecting even size. - # https://staticcheck.dev/docs/checks/#SA5012 - - SA5012 - # Using 'regexp.Match' or related in a loop, should use 'regexp.Compile'. - # https://staticcheck.dev/docs/checks/#SA6000 - - SA6000 - # Missing an optimization opportunity when indexing maps by byte slices. - # https://staticcheck.dev/docs/checks/#SA6001 - - SA6001 - # Storing non-pointer values in 'sync.Pool' allocates memory. - # https://staticcheck.dev/docs/checks/#SA6002 - - SA6002 - # Converting a string to a slice of runes before ranging over it. - # https://staticcheck.dev/docs/checks/#SA6003 - - SA6003 - # Inefficient string comparison with 'strings.ToLower' or 'strings.ToUpper'. - # https://staticcheck.dev/docs/checks/#SA6005 - - SA6005 - # Using io.WriteString to write '[]byte'. - # https://staticcheck.dev/docs/checks/#SA6006 - - SA6006 - # Defers in range loops may not run when you expect them to. - # https://staticcheck.dev/docs/checks/#SA9001 - - SA9001 - # Using a non-octal 'os.FileMode' that looks like it was meant to be in octal. - # https://staticcheck.dev/docs/checks/#SA9002 - - SA9002 - # Empty body in an if or else branch. - # https://staticcheck.dev/docs/checks/#SA9003 - - SA9003 - # Only the first constant has an explicit type. - # https://staticcheck.dev/docs/checks/#SA9004 - - SA9004 - # Trying to marshal a struct with no public fields nor custom marshaling. - # https://staticcheck.dev/docs/checks/#SA9005 - - SA9005 - # Dubious bit shifting of a fixed size integer value. - # https://staticcheck.dev/docs/checks/#SA9006 - - SA9006 - # Deleting a directory that shouldn't be deleted. - # https://staticcheck.dev/docs/checks/#SA9007 - - SA9007 - # 'else' branch of a type assertion is probably not reading the right value. - # https://staticcheck.dev/docs/checks/#SA9008 - - SA9008 - # Ineffectual Go compiler directive. - # https://staticcheck.dev/docs/checks/#SA9009 - - SA9009 - # NOTE: ST1000, ST1001, ST1003, ST1020, ST1021, ST1022 are disabled above - # Incorrectly formatted error string. - # https://staticcheck.dev/docs/checks/#ST1005 - - ST1005 - # Poorly chosen receiver name. - # https://staticcheck.dev/docs/checks/#ST1006 - - ST1006 - # A function's error value should be its last return value. - # https://staticcheck.dev/docs/checks/#ST1008 - - ST1008 - # Poorly chosen name for variable of type 'time.Duration'. - # https://staticcheck.dev/docs/checks/#ST1011 - - ST1011 - # Poorly chosen name for error variable. - # https://staticcheck.dev/docs/checks/#ST1012 - - ST1012 - # Should use constants for HTTP error codes, not magic numbers. - # https://staticcheck.dev/docs/checks/#ST1013 - - ST1013 - # A switch's default case should be the first or last case. - # https://staticcheck.dev/docs/checks/#ST1015 - - ST1015 - # Use consistent method receiver names. - # https://staticcheck.dev/docs/checks/#ST1016 - - ST1016 - # Don't use Yoda conditions. - # https://staticcheck.dev/docs/checks/#ST1017 - - ST1017 - # Avoid zero-width and control characters in string literals. - # https://staticcheck.dev/docs/checks/#ST1018 - - ST1018 - # Importing the same package multiple times. - # https://staticcheck.dev/docs/checks/#ST1019 - - ST1019 - # NOTE: ST1020, ST1021, ST1022 removed (disabled above) - # Redundant type in variable declaration. - # https://staticcheck.dev/docs/checks/#ST1023 - - ST1023 - # Use plain channel send or receive instead of single-case select. - # https://staticcheck.dev/docs/checks/#S1000 - - S1000 - # Replace for loop with call to copy. - # https://staticcheck.dev/docs/checks/#S1001 - - S1001 - # Omit comparison with boolean constant. - # https://staticcheck.dev/docs/checks/#S1002 - - S1002 - # Replace call to 'strings.Index' with 'strings.Contains'. - # https://staticcheck.dev/docs/checks/#S1003 - - S1003 - # Replace call to 'bytes.Compare' with 'bytes.Equal'. - # https://staticcheck.dev/docs/checks/#S1004 - - S1004 - # Drop unnecessary use of the blank identifier. - # https://staticcheck.dev/docs/checks/#S1005 - - S1005 - # Use "for { ... }" for infinite loops. - # https://staticcheck.dev/docs/checks/#S1006 - - S1006 - # Simplify regular expression by using raw string literal. - # https://staticcheck.dev/docs/checks/#S1007 - - S1007 - # Simplify returning boolean expression. - # https://staticcheck.dev/docs/checks/#S1008 - - S1008 - # Omit redundant nil check on slices, maps, and channels. - # https://staticcheck.dev/docs/checks/#S1009 - - S1009 - # Omit default slice index. - # https://staticcheck.dev/docs/checks/#S1010 - - S1010 - # Use a single 'append' to concatenate two slices. - # https://staticcheck.dev/docs/checks/#S1011 - - S1011 - # Replace 'time.Now().Sub(x)' with 'time.Since(x)'. - # https://staticcheck.dev/docs/checks/#S1012 - - S1012 - # Use a type conversion instead of manually copying struct fields. - # https://staticcheck.dev/docs/checks/#S1016 - - S1016 - # Replace manual trimming with 'strings.TrimPrefix'. - # https://staticcheck.dev/docs/checks/#S1017 - - S1017 - # Use "copy" for sliding elements. - # https://staticcheck.dev/docs/checks/#S1018 - - S1018 - # Simplify "make" call by omitting redundant arguments. - # https://staticcheck.dev/docs/checks/#S1019 - - S1019 - # Omit redundant nil check in type assertion. - # https://staticcheck.dev/docs/checks/#S1020 - - S1020 - # Merge variable declaration and assignment. - # https://staticcheck.dev/docs/checks/#S1021 - - S1021 - # Omit redundant control flow. - # https://staticcheck.dev/docs/checks/#S1023 - - S1023 - # Replace 'x.Sub(time.Now())' with 'time.Until(x)'. - # https://staticcheck.dev/docs/checks/#S1024 - - S1024 - # Don't use 'fmt.Sprintf("%s", x)' unnecessarily. - # https://staticcheck.dev/docs/checks/#S1025 - - S1025 - # Simplify error construction with 'fmt.Errorf'. - # https://staticcheck.dev/docs/checks/#S1028 - - S1028 - # Range over the string directly. - # https://staticcheck.dev/docs/checks/#S1029 - - S1029 - # Use 'bytes.Buffer.String' or 'bytes.Buffer.Bytes'. - # https://staticcheck.dev/docs/checks/#S1030 - - S1030 - # Omit redundant nil check around loop. - # https://staticcheck.dev/docs/checks/#S1031 - - S1031 - # Use 'sort.Ints(x)', 'sort.Float64s(x)', and 'sort.Strings(x)'. - # https://staticcheck.dev/docs/checks/#S1032 - - S1032 - # Unnecessary guard around call to "delete". - # https://staticcheck.dev/docs/checks/#S1033 - - S1033 - # Use result of type assertion to simplify cases. - # https://staticcheck.dev/docs/checks/#S1034 - - S1034 - # Redundant call to 'net/http.CanonicalHeaderKey' in method call on 'net/http.Header'. - # https://staticcheck.dev/docs/checks/#S1035 - - S1035 - # Unnecessary guard around map access. - # https://staticcheck.dev/docs/checks/#S1036 - - S1036 - # Elaborate way of sleeping. - # https://staticcheck.dev/docs/checks/#S1037 - - S1037 - # Unnecessarily complex way of printing formatted string. - # https://staticcheck.dev/docs/checks/#S1038 - - S1038 - # Unnecessary use of 'fmt.Sprint'. - # https://staticcheck.dev/docs/checks/#S1039 - - S1039 - # Type assertion to current type. - # https://staticcheck.dev/docs/checks/#S1040 - - S1040 - # Apply De Morgan's law. - # https://staticcheck.dev/docs/checks/#QF1001 - - QF1001 - # Convert untagged switch to tagged switch. - # https://staticcheck.dev/docs/checks/#QF1002 - - QF1002 - # Convert if/else-if chain to tagged switch. - # https://staticcheck.dev/docs/checks/#QF1003 - - QF1003 - # Use 'strings.ReplaceAll' instead of 'strings.Replace' with 'n == -1'. - # https://staticcheck.dev/docs/checks/#QF1004 - - QF1004 - # Expand call to 'math.Pow'. - # https://staticcheck.dev/docs/checks/#QF1005 - - QF1005 - # Lift 'if'+'break' into loop condition. - # https://staticcheck.dev/docs/checks/#QF1006 - - QF1006 - # Merge conditional assignment into variable declaration. - # https://staticcheck.dev/docs/checks/#QF1007 - - QF1007 - # Omit embedded fields from selector expression. - # https://staticcheck.dev/docs/checks/#QF1008 - - QF1008 - # Use 'time.Time.Equal' instead of '==' operator. - # https://staticcheck.dev/docs/checks/#QF1009 - - QF1009 - # Convert slice of bytes to string when printing it. - # https://staticcheck.dev/docs/checks/#QF1010 - - QF1010 - # Omit redundant type from variable declaration. - # https://staticcheck.dev/docs/checks/#QF1011 - - QF1011 - # Use 'fmt.Fprintf(x, ...)' instead of 'x.Write(fmt.Sprintf(...))'. - # https://staticcheck.dev/docs/checks/#QF1012 - - QF1012 unused: - # Mark all struct fields that have been written to as used. # Default: true - field-writes-are-uses: false - # Treat IncDec statement (e.g. `i++` or `i--`) as both read and write operation instead of just write. + field-writes-are-uses: true # Default: false post-statements-are-reads: true - # Mark all exported fields as used. - # default: true - exported-fields-are-used: false - # Mark all function parameters as used. - # default: true - parameters-are-used: true - # Mark all local variables as used. - # default: true - local-variables-are-used: false - # Mark all identifiers inside generated files as used. # Default: true - generated-is-used: false + exported-fields-are-used: true + # Default: true + parameters-are-used: true + # Default: true + local-variables-are-used: false + # Default: true — must be true, ent generates 130K+ lines of code + generated-is-used: true formatters: enable: diff --git a/backend/cmd/jwtgen/main.go b/backend/cmd/jwtgen/main.go index bc001693..7eabde62 100644 --- a/backend/cmd/jwtgen/main.go +++ b/backend/cmd/jwtgen/main.go @@ -33,7 +33,7 @@ func main() { }() userRepo := repository.NewUserRepository(client, sqlDB) - authService := service.NewAuthService(userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil) + authService := service.NewAuthService(client, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION index 32844913..5a4e9c9a 100644 --- a/backend/cmd/server/VERSION +++ b/backend/cmd/server/VERSION @@ -1 +1 @@ -0.1.88 \ No newline at end of file +0.1.96.1 diff --git a/backend/cmd/server/wire.go b/backend/cmd/server/wire.go index cbf89ba3..80364bf2 100644 --- a/backend/cmd/server/wire.go +++ b/backend/cmd/server/wire.go @@ -86,6 +86,7 @@ func provideCleanup( geminiOAuth *service.GeminiOAuthService, antigravityOAuth *service.AntigravityOAuthService, openAIGateway *service.OpenAIGatewayService, + scheduledTestRunner *service.ScheduledTestRunnerService, ) func() { return func() { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) @@ -216,6 +217,12 @@ func provideCleanup( } return nil }}, + {"ScheduledTestRunnerService", func() error { + if scheduledTestRunner != nil { + scheduledTestRunner.Stop() + } + return nil + }}, } infraSteps := []cleanupStep{ diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 8e7aefe1..034c70ec 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -67,7 +67,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService) promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator) subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig) - authService := service.NewAuthService(userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService) + authService := service.NewAuthService(client, userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService) userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator, billingCache) redeemCache := repository.NewRedeemCache(redisClient) redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator) @@ -104,7 +104,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { proxyRepository := repository.NewProxyRepository(client, db) proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig) proxyLatencyCache := repository.NewProxyLatencyCache(redisClient) - adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService) + adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService, userSubscriptionRepository) concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig) concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig) adminUserHandler := admin.NewUserHandler(adminService, concurrencyService) @@ -162,9 +162,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { deferredService := service.ProvideDeferredService(accountRepository, timingWheelService) claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService) digestSessionStore := service.NewDigestSessionStore() - gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore) + gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService) openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService) - openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider) + openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider) geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig) opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository) opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink) @@ -195,7 +195,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { errorPassthroughService := service.NewErrorPassthroughService(errorPassthroughRepository, errorPassthroughCache) errorPassthroughHandler := admin.NewErrorPassthroughHandler(errorPassthroughService) adminAPIKeyHandler := admin.NewAdminAPIKeyHandler(adminService) - adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, adminAPIKeyHandler) + scheduledTestPlanRepository := repository.NewScheduledTestPlanRepository(db) + scheduledTestResultRepository := repository.NewScheduledTestResultRepository(db) + scheduledTestService := service.ProvideScheduledTestService(scheduledTestPlanRepository, scheduledTestResultRepository) + scheduledTestHandler := admin.NewScheduledTestHandler(scheduledTestService) + adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, adminAPIKeyHandler, scheduledTestHandler) usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig) userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient) userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig) @@ -225,7 +229,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache) accountExpiryService := service.ProvideAccountExpiryService(accountRepository) subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository) - v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService) + scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig) + v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService) application := &Application{ Server: httpServer, Cleanup: v, @@ -273,6 +278,7 @@ func provideCleanup( geminiOAuth *service.GeminiOAuthService, antigravityOAuth *service.AntigravityOAuthService, openAIGateway *service.OpenAIGatewayService, + scheduledTestRunner *service.ScheduledTestRunnerService, ) func() { return func() { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) @@ -402,6 +408,12 @@ func provideCleanup( } return nil }}, + {"ScheduledTestRunnerService", func() error { + if scheduledTestRunner != nil { + scheduledTestRunner.Stop() + } + return nil + }}, } infraSteps := []cleanupStep{ diff --git a/backend/cmd/server/wire_gen_test.go b/backend/cmd/server/wire_gen_test.go index 373bfd88..8e203cb8 100644 --- a/backend/cmd/server/wire_gen_test.go +++ b/backend/cmd/server/wire_gen_test.go @@ -74,6 +74,7 @@ func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) { geminiOAuthSvc, antigravityOAuthSvc, nil, // openAIGateway + nil, // scheduledTestRunner ) require.NotPanics(t, func() { diff --git a/backend/ent/account.go b/backend/ent/account.go index c77002b3..2dbfc3a2 100644 --- a/backend/ent/account.go +++ b/backend/ent/account.go @@ -41,6 +41,8 @@ type Account struct { ProxyID *int64 `json:"proxy_id,omitempty"` // Concurrency holds the value of the "concurrency" field. Concurrency int `json:"concurrency,omitempty"` + // LoadFactor holds the value of the "load_factor" field. + LoadFactor *int `json:"load_factor,omitempty"` // Priority holds the value of the "priority" field. Priority int `json:"priority,omitempty"` // RateMultiplier holds the value of the "rate_multiplier" field. @@ -143,7 +145,7 @@ func (*Account) scanValues(columns []string) ([]any, error) { values[i] = new(sql.NullBool) case account.FieldRateMultiplier: values[i] = new(sql.NullFloat64) - case account.FieldID, account.FieldProxyID, account.FieldConcurrency, account.FieldPriority: + case account.FieldID, account.FieldProxyID, account.FieldConcurrency, account.FieldLoadFactor, account.FieldPriority: values[i] = new(sql.NullInt64) case account.FieldName, account.FieldNotes, account.FieldPlatform, account.FieldType, account.FieldStatus, account.FieldErrorMessage, account.FieldTempUnschedulableReason, account.FieldSessionWindowStatus: values[i] = new(sql.NullString) @@ -243,6 +245,13 @@ func (_m *Account) assignValues(columns []string, values []any) error { } else if value.Valid { _m.Concurrency = int(value.Int64) } + case account.FieldLoadFactor: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field load_factor", values[i]) + } else if value.Valid { + _m.LoadFactor = new(int) + *_m.LoadFactor = int(value.Int64) + } case account.FieldPriority: if value, ok := values[i].(*sql.NullInt64); !ok { return fmt.Errorf("unexpected type %T for field priority", values[i]) @@ -445,6 +454,11 @@ func (_m *Account) String() string { builder.WriteString("concurrency=") builder.WriteString(fmt.Sprintf("%v", _m.Concurrency)) builder.WriteString(", ") + if v := _m.LoadFactor; v != nil { + builder.WriteString("load_factor=") + builder.WriteString(fmt.Sprintf("%v", *v)) + } + builder.WriteString(", ") builder.WriteString("priority=") builder.WriteString(fmt.Sprintf("%v", _m.Priority)) builder.WriteString(", ") diff --git a/backend/ent/account/account.go b/backend/ent/account/account.go index 1fc34620..4c134649 100644 --- a/backend/ent/account/account.go +++ b/backend/ent/account/account.go @@ -37,6 +37,8 @@ const ( FieldProxyID = "proxy_id" // FieldConcurrency holds the string denoting the concurrency field in the database. FieldConcurrency = "concurrency" + // FieldLoadFactor holds the string denoting the load_factor field in the database. + FieldLoadFactor = "load_factor" // FieldPriority holds the string denoting the priority field in the database. FieldPriority = "priority" // FieldRateMultiplier holds the string denoting the rate_multiplier field in the database. @@ -121,6 +123,7 @@ var Columns = []string{ FieldExtra, FieldProxyID, FieldConcurrency, + FieldLoadFactor, FieldPriority, FieldRateMultiplier, FieldStatus, @@ -250,6 +253,11 @@ func ByConcurrency(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldConcurrency, opts...).ToFunc() } +// ByLoadFactor orders the results by the load_factor field. +func ByLoadFactor(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldLoadFactor, opts...).ToFunc() +} + // ByPriority orders the results by the priority field. func ByPriority(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldPriority, opts...).ToFunc() diff --git a/backend/ent/account/where.go b/backend/ent/account/where.go index 54db1dcb..3749b45c 100644 --- a/backend/ent/account/where.go +++ b/backend/ent/account/where.go @@ -100,6 +100,11 @@ func Concurrency(v int) predicate.Account { return predicate.Account(sql.FieldEQ(FieldConcurrency, v)) } +// LoadFactor applies equality check predicate on the "load_factor" field. It's identical to LoadFactorEQ. +func LoadFactor(v int) predicate.Account { + return predicate.Account(sql.FieldEQ(FieldLoadFactor, v)) +} + // Priority applies equality check predicate on the "priority" field. It's identical to PriorityEQ. func Priority(v int) predicate.Account { return predicate.Account(sql.FieldEQ(FieldPriority, v)) @@ -650,6 +655,56 @@ func ConcurrencyLTE(v int) predicate.Account { return predicate.Account(sql.FieldLTE(FieldConcurrency, v)) } +// LoadFactorEQ applies the EQ predicate on the "load_factor" field. +func LoadFactorEQ(v int) predicate.Account { + return predicate.Account(sql.FieldEQ(FieldLoadFactor, v)) +} + +// LoadFactorNEQ applies the NEQ predicate on the "load_factor" field. +func LoadFactorNEQ(v int) predicate.Account { + return predicate.Account(sql.FieldNEQ(FieldLoadFactor, v)) +} + +// LoadFactorIn applies the In predicate on the "load_factor" field. +func LoadFactorIn(vs ...int) predicate.Account { + return predicate.Account(sql.FieldIn(FieldLoadFactor, vs...)) +} + +// LoadFactorNotIn applies the NotIn predicate on the "load_factor" field. +func LoadFactorNotIn(vs ...int) predicate.Account { + return predicate.Account(sql.FieldNotIn(FieldLoadFactor, vs...)) +} + +// LoadFactorGT applies the GT predicate on the "load_factor" field. +func LoadFactorGT(v int) predicate.Account { + return predicate.Account(sql.FieldGT(FieldLoadFactor, v)) +} + +// LoadFactorGTE applies the GTE predicate on the "load_factor" field. +func LoadFactorGTE(v int) predicate.Account { + return predicate.Account(sql.FieldGTE(FieldLoadFactor, v)) +} + +// LoadFactorLT applies the LT predicate on the "load_factor" field. +func LoadFactorLT(v int) predicate.Account { + return predicate.Account(sql.FieldLT(FieldLoadFactor, v)) +} + +// LoadFactorLTE applies the LTE predicate on the "load_factor" field. +func LoadFactorLTE(v int) predicate.Account { + return predicate.Account(sql.FieldLTE(FieldLoadFactor, v)) +} + +// LoadFactorIsNil applies the IsNil predicate on the "load_factor" field. +func LoadFactorIsNil() predicate.Account { + return predicate.Account(sql.FieldIsNull(FieldLoadFactor)) +} + +// LoadFactorNotNil applies the NotNil predicate on the "load_factor" field. +func LoadFactorNotNil() predicate.Account { + return predicate.Account(sql.FieldNotNull(FieldLoadFactor)) +} + // PriorityEQ applies the EQ predicate on the "priority" field. func PriorityEQ(v int) predicate.Account { return predicate.Account(sql.FieldEQ(FieldPriority, v)) diff --git a/backend/ent/account_create.go b/backend/ent/account_create.go index 963ffee8..d6046c79 100644 --- a/backend/ent/account_create.go +++ b/backend/ent/account_create.go @@ -139,6 +139,20 @@ func (_c *AccountCreate) SetNillableConcurrency(v *int) *AccountCreate { return _c } +// SetLoadFactor sets the "load_factor" field. +func (_c *AccountCreate) SetLoadFactor(v int) *AccountCreate { + _c.mutation.SetLoadFactor(v) + return _c +} + +// SetNillableLoadFactor sets the "load_factor" field if the given value is not nil. +func (_c *AccountCreate) SetNillableLoadFactor(v *int) *AccountCreate { + if v != nil { + _c.SetLoadFactor(*v) + } + return _c +} + // SetPriority sets the "priority" field. func (_c *AccountCreate) SetPriority(v int) *AccountCreate { _c.mutation.SetPriority(v) @@ -623,6 +637,10 @@ func (_c *AccountCreate) createSpec() (*Account, *sqlgraph.CreateSpec) { _spec.SetField(account.FieldConcurrency, field.TypeInt, value) _node.Concurrency = value } + if value, ok := _c.mutation.LoadFactor(); ok { + _spec.SetField(account.FieldLoadFactor, field.TypeInt, value) + _node.LoadFactor = &value + } if value, ok := _c.mutation.Priority(); ok { _spec.SetField(account.FieldPriority, field.TypeInt, value) _node.Priority = value @@ -936,6 +954,30 @@ func (u *AccountUpsert) AddConcurrency(v int) *AccountUpsert { return u } +// SetLoadFactor sets the "load_factor" field. +func (u *AccountUpsert) SetLoadFactor(v int) *AccountUpsert { + u.Set(account.FieldLoadFactor, v) + return u +} + +// UpdateLoadFactor sets the "load_factor" field to the value that was provided on create. +func (u *AccountUpsert) UpdateLoadFactor() *AccountUpsert { + u.SetExcluded(account.FieldLoadFactor) + return u +} + +// AddLoadFactor adds v to the "load_factor" field. +func (u *AccountUpsert) AddLoadFactor(v int) *AccountUpsert { + u.Add(account.FieldLoadFactor, v) + return u +} + +// ClearLoadFactor clears the value of the "load_factor" field. +func (u *AccountUpsert) ClearLoadFactor() *AccountUpsert { + u.SetNull(account.FieldLoadFactor) + return u +} + // SetPriority sets the "priority" field. func (u *AccountUpsert) SetPriority(v int) *AccountUpsert { u.Set(account.FieldPriority, v) @@ -1419,6 +1461,34 @@ func (u *AccountUpsertOne) UpdateConcurrency() *AccountUpsertOne { }) } +// SetLoadFactor sets the "load_factor" field. +func (u *AccountUpsertOne) SetLoadFactor(v int) *AccountUpsertOne { + return u.Update(func(s *AccountUpsert) { + s.SetLoadFactor(v) + }) +} + +// AddLoadFactor adds v to the "load_factor" field. +func (u *AccountUpsertOne) AddLoadFactor(v int) *AccountUpsertOne { + return u.Update(func(s *AccountUpsert) { + s.AddLoadFactor(v) + }) +} + +// UpdateLoadFactor sets the "load_factor" field to the value that was provided on create. +func (u *AccountUpsertOne) UpdateLoadFactor() *AccountUpsertOne { + return u.Update(func(s *AccountUpsert) { + s.UpdateLoadFactor() + }) +} + +// ClearLoadFactor clears the value of the "load_factor" field. +func (u *AccountUpsertOne) ClearLoadFactor() *AccountUpsertOne { + return u.Update(func(s *AccountUpsert) { + s.ClearLoadFactor() + }) +} + // SetPriority sets the "priority" field. func (u *AccountUpsertOne) SetPriority(v int) *AccountUpsertOne { return u.Update(func(s *AccountUpsert) { @@ -2113,6 +2183,34 @@ func (u *AccountUpsertBulk) UpdateConcurrency() *AccountUpsertBulk { }) } +// SetLoadFactor sets the "load_factor" field. +func (u *AccountUpsertBulk) SetLoadFactor(v int) *AccountUpsertBulk { + return u.Update(func(s *AccountUpsert) { + s.SetLoadFactor(v) + }) +} + +// AddLoadFactor adds v to the "load_factor" field. +func (u *AccountUpsertBulk) AddLoadFactor(v int) *AccountUpsertBulk { + return u.Update(func(s *AccountUpsert) { + s.AddLoadFactor(v) + }) +} + +// UpdateLoadFactor sets the "load_factor" field to the value that was provided on create. +func (u *AccountUpsertBulk) UpdateLoadFactor() *AccountUpsertBulk { + return u.Update(func(s *AccountUpsert) { + s.UpdateLoadFactor() + }) +} + +// ClearLoadFactor clears the value of the "load_factor" field. +func (u *AccountUpsertBulk) ClearLoadFactor() *AccountUpsertBulk { + return u.Update(func(s *AccountUpsert) { + s.ClearLoadFactor() + }) +} + // SetPriority sets the "priority" field. func (u *AccountUpsertBulk) SetPriority(v int) *AccountUpsertBulk { return u.Update(func(s *AccountUpsert) { diff --git a/backend/ent/account_update.go b/backend/ent/account_update.go index 875888e0..6f443c65 100644 --- a/backend/ent/account_update.go +++ b/backend/ent/account_update.go @@ -172,6 +172,33 @@ func (_u *AccountUpdate) AddConcurrency(v int) *AccountUpdate { return _u } +// SetLoadFactor sets the "load_factor" field. +func (_u *AccountUpdate) SetLoadFactor(v int) *AccountUpdate { + _u.mutation.ResetLoadFactor() + _u.mutation.SetLoadFactor(v) + return _u +} + +// SetNillableLoadFactor sets the "load_factor" field if the given value is not nil. +func (_u *AccountUpdate) SetNillableLoadFactor(v *int) *AccountUpdate { + if v != nil { + _u.SetLoadFactor(*v) + } + return _u +} + +// AddLoadFactor adds value to the "load_factor" field. +func (_u *AccountUpdate) AddLoadFactor(v int) *AccountUpdate { + _u.mutation.AddLoadFactor(v) + return _u +} + +// ClearLoadFactor clears the value of the "load_factor" field. +func (_u *AccountUpdate) ClearLoadFactor() *AccountUpdate { + _u.mutation.ClearLoadFactor() + return _u +} + // SetPriority sets the "priority" field. func (_u *AccountUpdate) SetPriority(v int) *AccountUpdate { _u.mutation.ResetPriority() @@ -684,6 +711,15 @@ func (_u *AccountUpdate) sqlSave(ctx context.Context) (_node int, err error) { if value, ok := _u.mutation.AddedConcurrency(); ok { _spec.AddField(account.FieldConcurrency, field.TypeInt, value) } + if value, ok := _u.mutation.LoadFactor(); ok { + _spec.SetField(account.FieldLoadFactor, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedLoadFactor(); ok { + _spec.AddField(account.FieldLoadFactor, field.TypeInt, value) + } + if _u.mutation.LoadFactorCleared() { + _spec.ClearField(account.FieldLoadFactor, field.TypeInt) + } if value, ok := _u.mutation.Priority(); ok { _spec.SetField(account.FieldPriority, field.TypeInt, value) } @@ -1063,6 +1099,33 @@ func (_u *AccountUpdateOne) AddConcurrency(v int) *AccountUpdateOne { return _u } +// SetLoadFactor sets the "load_factor" field. +func (_u *AccountUpdateOne) SetLoadFactor(v int) *AccountUpdateOne { + _u.mutation.ResetLoadFactor() + _u.mutation.SetLoadFactor(v) + return _u +} + +// SetNillableLoadFactor sets the "load_factor" field if the given value is not nil. +func (_u *AccountUpdateOne) SetNillableLoadFactor(v *int) *AccountUpdateOne { + if v != nil { + _u.SetLoadFactor(*v) + } + return _u +} + +// AddLoadFactor adds value to the "load_factor" field. +func (_u *AccountUpdateOne) AddLoadFactor(v int) *AccountUpdateOne { + _u.mutation.AddLoadFactor(v) + return _u +} + +// ClearLoadFactor clears the value of the "load_factor" field. +func (_u *AccountUpdateOne) ClearLoadFactor() *AccountUpdateOne { + _u.mutation.ClearLoadFactor() + return _u +} + // SetPriority sets the "priority" field. func (_u *AccountUpdateOne) SetPriority(v int) *AccountUpdateOne { _u.mutation.ResetPriority() @@ -1605,6 +1668,15 @@ func (_u *AccountUpdateOne) sqlSave(ctx context.Context) (_node *Account, err er if value, ok := _u.mutation.AddedConcurrency(); ok { _spec.AddField(account.FieldConcurrency, field.TypeInt, value) } + if value, ok := _u.mutation.LoadFactor(); ok { + _spec.SetField(account.FieldLoadFactor, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedLoadFactor(); ok { + _spec.AddField(account.FieldLoadFactor, field.TypeInt, value) + } + if _u.mutation.LoadFactorCleared() { + _spec.ClearField(account.FieldLoadFactor, field.TypeInt) + } if value, ok := _u.mutation.Priority(); ok { _spec.SetField(account.FieldPriority, field.TypeInt, value) } diff --git a/backend/ent/announcement.go b/backend/ent/announcement.go index 93d7a375..6c5b21da 100644 --- a/backend/ent/announcement.go +++ b/backend/ent/announcement.go @@ -25,6 +25,8 @@ type Announcement struct { Content string `json:"content,omitempty"` // 状态: draft, active, archived Status string `json:"status,omitempty"` + // 通知模式: silent(仅铃铛), popup(弹窗提醒) + NotifyMode string `json:"notify_mode,omitempty"` // 展示条件(JSON 规则) Targeting domain.AnnouncementTargeting `json:"targeting,omitempty"` // 开始展示时间(为空表示立即生效) @@ -72,7 +74,7 @@ func (*Announcement) scanValues(columns []string) ([]any, error) { values[i] = new([]byte) case announcement.FieldID, announcement.FieldCreatedBy, announcement.FieldUpdatedBy: values[i] = new(sql.NullInt64) - case announcement.FieldTitle, announcement.FieldContent, announcement.FieldStatus: + case announcement.FieldTitle, announcement.FieldContent, announcement.FieldStatus, announcement.FieldNotifyMode: values[i] = new(sql.NullString) case announcement.FieldStartsAt, announcement.FieldEndsAt, announcement.FieldCreatedAt, announcement.FieldUpdatedAt: values[i] = new(sql.NullTime) @@ -115,6 +117,12 @@ func (_m *Announcement) assignValues(columns []string, values []any) error { } else if value.Valid { _m.Status = value.String } + case announcement.FieldNotifyMode: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field notify_mode", values[i]) + } else if value.Valid { + _m.NotifyMode = value.String + } case announcement.FieldTargeting: if value, ok := values[i].(*[]byte); !ok { return fmt.Errorf("unexpected type %T for field targeting", values[i]) @@ -213,6 +221,9 @@ func (_m *Announcement) String() string { builder.WriteString("status=") builder.WriteString(_m.Status) builder.WriteString(", ") + builder.WriteString("notify_mode=") + builder.WriteString(_m.NotifyMode) + builder.WriteString(", ") builder.WriteString("targeting=") builder.WriteString(fmt.Sprintf("%v", _m.Targeting)) builder.WriteString(", ") diff --git a/backend/ent/announcement/announcement.go b/backend/ent/announcement/announcement.go index 4f34ee05..71ba25ff 100644 --- a/backend/ent/announcement/announcement.go +++ b/backend/ent/announcement/announcement.go @@ -20,6 +20,8 @@ const ( FieldContent = "content" // FieldStatus holds the string denoting the status field in the database. FieldStatus = "status" + // FieldNotifyMode holds the string denoting the notify_mode field in the database. + FieldNotifyMode = "notify_mode" // FieldTargeting holds the string denoting the targeting field in the database. FieldTargeting = "targeting" // FieldStartsAt holds the string denoting the starts_at field in the database. @@ -53,6 +55,7 @@ var Columns = []string{ FieldTitle, FieldContent, FieldStatus, + FieldNotifyMode, FieldTargeting, FieldStartsAt, FieldEndsAt, @@ -81,6 +84,10 @@ var ( DefaultStatus string // StatusValidator is a validator for the "status" field. It is called by the builders before save. StatusValidator func(string) error + // DefaultNotifyMode holds the default value on creation for the "notify_mode" field. + DefaultNotifyMode string + // NotifyModeValidator is a validator for the "notify_mode" field. It is called by the builders before save. + NotifyModeValidator func(string) error // DefaultCreatedAt holds the default value on creation for the "created_at" field. DefaultCreatedAt func() time.Time // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. @@ -112,6 +119,11 @@ func ByStatus(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldStatus, opts...).ToFunc() } +// ByNotifyMode orders the results by the notify_mode field. +func ByNotifyMode(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldNotifyMode, opts...).ToFunc() +} + // ByStartsAt orders the results by the starts_at field. func ByStartsAt(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldStartsAt, opts...).ToFunc() diff --git a/backend/ent/announcement/where.go b/backend/ent/announcement/where.go index d3cad2a5..2eea5f0b 100644 --- a/backend/ent/announcement/where.go +++ b/backend/ent/announcement/where.go @@ -70,6 +70,11 @@ func Status(v string) predicate.Announcement { return predicate.Announcement(sql.FieldEQ(FieldStatus, v)) } +// NotifyMode applies equality check predicate on the "notify_mode" field. It's identical to NotifyModeEQ. +func NotifyMode(v string) predicate.Announcement { + return predicate.Announcement(sql.FieldEQ(FieldNotifyMode, v)) +} + // StartsAt applies equality check predicate on the "starts_at" field. It's identical to StartsAtEQ. func StartsAt(v time.Time) predicate.Announcement { return predicate.Announcement(sql.FieldEQ(FieldStartsAt, v)) @@ -295,6 +300,71 @@ func StatusContainsFold(v string) predicate.Announcement { return predicate.Announcement(sql.FieldContainsFold(FieldStatus, v)) } +// NotifyModeEQ applies the EQ predicate on the "notify_mode" field. +func NotifyModeEQ(v string) predicate.Announcement { + return predicate.Announcement(sql.FieldEQ(FieldNotifyMode, v)) +} + +// NotifyModeNEQ applies the NEQ predicate on the "notify_mode" field. +func NotifyModeNEQ(v string) predicate.Announcement { + return predicate.Announcement(sql.FieldNEQ(FieldNotifyMode, v)) +} + +// NotifyModeIn applies the In predicate on the "notify_mode" field. +func NotifyModeIn(vs ...string) predicate.Announcement { + return predicate.Announcement(sql.FieldIn(FieldNotifyMode, vs...)) +} + +// NotifyModeNotIn applies the NotIn predicate on the "notify_mode" field. +func NotifyModeNotIn(vs ...string) predicate.Announcement { + return predicate.Announcement(sql.FieldNotIn(FieldNotifyMode, vs...)) +} + +// NotifyModeGT applies the GT predicate on the "notify_mode" field. +func NotifyModeGT(v string) predicate.Announcement { + return predicate.Announcement(sql.FieldGT(FieldNotifyMode, v)) +} + +// NotifyModeGTE applies the GTE predicate on the "notify_mode" field. +func NotifyModeGTE(v string) predicate.Announcement { + return predicate.Announcement(sql.FieldGTE(FieldNotifyMode, v)) +} + +// NotifyModeLT applies the LT predicate on the "notify_mode" field. +func NotifyModeLT(v string) predicate.Announcement { + return predicate.Announcement(sql.FieldLT(FieldNotifyMode, v)) +} + +// NotifyModeLTE applies the LTE predicate on the "notify_mode" field. +func NotifyModeLTE(v string) predicate.Announcement { + return predicate.Announcement(sql.FieldLTE(FieldNotifyMode, v)) +} + +// NotifyModeContains applies the Contains predicate on the "notify_mode" field. +func NotifyModeContains(v string) predicate.Announcement { + return predicate.Announcement(sql.FieldContains(FieldNotifyMode, v)) +} + +// NotifyModeHasPrefix applies the HasPrefix predicate on the "notify_mode" field. +func NotifyModeHasPrefix(v string) predicate.Announcement { + return predicate.Announcement(sql.FieldHasPrefix(FieldNotifyMode, v)) +} + +// NotifyModeHasSuffix applies the HasSuffix predicate on the "notify_mode" field. +func NotifyModeHasSuffix(v string) predicate.Announcement { + return predicate.Announcement(sql.FieldHasSuffix(FieldNotifyMode, v)) +} + +// NotifyModeEqualFold applies the EqualFold predicate on the "notify_mode" field. +func NotifyModeEqualFold(v string) predicate.Announcement { + return predicate.Announcement(sql.FieldEqualFold(FieldNotifyMode, v)) +} + +// NotifyModeContainsFold applies the ContainsFold predicate on the "notify_mode" field. +func NotifyModeContainsFold(v string) predicate.Announcement { + return predicate.Announcement(sql.FieldContainsFold(FieldNotifyMode, v)) +} + // TargetingIsNil applies the IsNil predicate on the "targeting" field. func TargetingIsNil() predicate.Announcement { return predicate.Announcement(sql.FieldIsNull(FieldTargeting)) diff --git a/backend/ent/announcement_create.go b/backend/ent/announcement_create.go index 151d4c11..d9029792 100644 --- a/backend/ent/announcement_create.go +++ b/backend/ent/announcement_create.go @@ -50,6 +50,20 @@ func (_c *AnnouncementCreate) SetNillableStatus(v *string) *AnnouncementCreate { return _c } +// SetNotifyMode sets the "notify_mode" field. +func (_c *AnnouncementCreate) SetNotifyMode(v string) *AnnouncementCreate { + _c.mutation.SetNotifyMode(v) + return _c +} + +// SetNillableNotifyMode sets the "notify_mode" field if the given value is not nil. +func (_c *AnnouncementCreate) SetNillableNotifyMode(v *string) *AnnouncementCreate { + if v != nil { + _c.SetNotifyMode(*v) + } + return _c +} + // SetTargeting sets the "targeting" field. func (_c *AnnouncementCreate) SetTargeting(v domain.AnnouncementTargeting) *AnnouncementCreate { _c.mutation.SetTargeting(v) @@ -202,6 +216,10 @@ func (_c *AnnouncementCreate) defaults() { v := announcement.DefaultStatus _c.mutation.SetStatus(v) } + if _, ok := _c.mutation.NotifyMode(); !ok { + v := announcement.DefaultNotifyMode + _c.mutation.SetNotifyMode(v) + } if _, ok := _c.mutation.CreatedAt(); !ok { v := announcement.DefaultCreatedAt() _c.mutation.SetCreatedAt(v) @@ -238,6 +256,14 @@ func (_c *AnnouncementCreate) check() error { return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "Announcement.status": %w`, err)} } } + if _, ok := _c.mutation.NotifyMode(); !ok { + return &ValidationError{Name: "notify_mode", err: errors.New(`ent: missing required field "Announcement.notify_mode"`)} + } + if v, ok := _c.mutation.NotifyMode(); ok { + if err := announcement.NotifyModeValidator(v); err != nil { + return &ValidationError{Name: "notify_mode", err: fmt.Errorf(`ent: validator failed for field "Announcement.notify_mode": %w`, err)} + } + } if _, ok := _c.mutation.CreatedAt(); !ok { return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "Announcement.created_at"`)} } @@ -283,6 +309,10 @@ func (_c *AnnouncementCreate) createSpec() (*Announcement, *sqlgraph.CreateSpec) _spec.SetField(announcement.FieldStatus, field.TypeString, value) _node.Status = value } + if value, ok := _c.mutation.NotifyMode(); ok { + _spec.SetField(announcement.FieldNotifyMode, field.TypeString, value) + _node.NotifyMode = value + } if value, ok := _c.mutation.Targeting(); ok { _spec.SetField(announcement.FieldTargeting, field.TypeJSON, value) _node.Targeting = value @@ -415,6 +445,18 @@ func (u *AnnouncementUpsert) UpdateStatus() *AnnouncementUpsert { return u } +// SetNotifyMode sets the "notify_mode" field. +func (u *AnnouncementUpsert) SetNotifyMode(v string) *AnnouncementUpsert { + u.Set(announcement.FieldNotifyMode, v) + return u +} + +// UpdateNotifyMode sets the "notify_mode" field to the value that was provided on create. +func (u *AnnouncementUpsert) UpdateNotifyMode() *AnnouncementUpsert { + u.SetExcluded(announcement.FieldNotifyMode) + return u +} + // SetTargeting sets the "targeting" field. func (u *AnnouncementUpsert) SetTargeting(v domain.AnnouncementTargeting) *AnnouncementUpsert { u.Set(announcement.FieldTargeting, v) @@ -616,6 +658,20 @@ func (u *AnnouncementUpsertOne) UpdateStatus() *AnnouncementUpsertOne { }) } +// SetNotifyMode sets the "notify_mode" field. +func (u *AnnouncementUpsertOne) SetNotifyMode(v string) *AnnouncementUpsertOne { + return u.Update(func(s *AnnouncementUpsert) { + s.SetNotifyMode(v) + }) +} + +// UpdateNotifyMode sets the "notify_mode" field to the value that was provided on create. +func (u *AnnouncementUpsertOne) UpdateNotifyMode() *AnnouncementUpsertOne { + return u.Update(func(s *AnnouncementUpsert) { + s.UpdateNotifyMode() + }) +} + // SetTargeting sets the "targeting" field. func (u *AnnouncementUpsertOne) SetTargeting(v domain.AnnouncementTargeting) *AnnouncementUpsertOne { return u.Update(func(s *AnnouncementUpsert) { @@ -1002,6 +1058,20 @@ func (u *AnnouncementUpsertBulk) UpdateStatus() *AnnouncementUpsertBulk { }) } +// SetNotifyMode sets the "notify_mode" field. +func (u *AnnouncementUpsertBulk) SetNotifyMode(v string) *AnnouncementUpsertBulk { + return u.Update(func(s *AnnouncementUpsert) { + s.SetNotifyMode(v) + }) +} + +// UpdateNotifyMode sets the "notify_mode" field to the value that was provided on create. +func (u *AnnouncementUpsertBulk) UpdateNotifyMode() *AnnouncementUpsertBulk { + return u.Update(func(s *AnnouncementUpsert) { + s.UpdateNotifyMode() + }) +} + // SetTargeting sets the "targeting" field. func (u *AnnouncementUpsertBulk) SetTargeting(v domain.AnnouncementTargeting) *AnnouncementUpsertBulk { return u.Update(func(s *AnnouncementUpsert) { diff --git a/backend/ent/announcement_update.go b/backend/ent/announcement_update.go index 702d0817..f93f4f0e 100644 --- a/backend/ent/announcement_update.go +++ b/backend/ent/announcement_update.go @@ -72,6 +72,20 @@ func (_u *AnnouncementUpdate) SetNillableStatus(v *string) *AnnouncementUpdate { return _u } +// SetNotifyMode sets the "notify_mode" field. +func (_u *AnnouncementUpdate) SetNotifyMode(v string) *AnnouncementUpdate { + _u.mutation.SetNotifyMode(v) + return _u +} + +// SetNillableNotifyMode sets the "notify_mode" field if the given value is not nil. +func (_u *AnnouncementUpdate) SetNillableNotifyMode(v *string) *AnnouncementUpdate { + if v != nil { + _u.SetNotifyMode(*v) + } + return _u +} + // SetTargeting sets the "targeting" field. func (_u *AnnouncementUpdate) SetTargeting(v domain.AnnouncementTargeting) *AnnouncementUpdate { _u.mutation.SetTargeting(v) @@ -286,6 +300,11 @@ func (_u *AnnouncementUpdate) check() error { return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "Announcement.status": %w`, err)} } } + if v, ok := _u.mutation.NotifyMode(); ok { + if err := announcement.NotifyModeValidator(v); err != nil { + return &ValidationError{Name: "notify_mode", err: fmt.Errorf(`ent: validator failed for field "Announcement.notify_mode": %w`, err)} + } + } return nil } @@ -310,6 +329,9 @@ func (_u *AnnouncementUpdate) sqlSave(ctx context.Context) (_node int, err error if value, ok := _u.mutation.Status(); ok { _spec.SetField(announcement.FieldStatus, field.TypeString, value) } + if value, ok := _u.mutation.NotifyMode(); ok { + _spec.SetField(announcement.FieldNotifyMode, field.TypeString, value) + } if value, ok := _u.mutation.Targeting(); ok { _spec.SetField(announcement.FieldTargeting, field.TypeJSON, value) } @@ -456,6 +478,20 @@ func (_u *AnnouncementUpdateOne) SetNillableStatus(v *string) *AnnouncementUpdat return _u } +// SetNotifyMode sets the "notify_mode" field. +func (_u *AnnouncementUpdateOne) SetNotifyMode(v string) *AnnouncementUpdateOne { + _u.mutation.SetNotifyMode(v) + return _u +} + +// SetNillableNotifyMode sets the "notify_mode" field if the given value is not nil. +func (_u *AnnouncementUpdateOne) SetNillableNotifyMode(v *string) *AnnouncementUpdateOne { + if v != nil { + _u.SetNotifyMode(*v) + } + return _u +} + // SetTargeting sets the "targeting" field. func (_u *AnnouncementUpdateOne) SetTargeting(v domain.AnnouncementTargeting) *AnnouncementUpdateOne { _u.mutation.SetTargeting(v) @@ -683,6 +719,11 @@ func (_u *AnnouncementUpdateOne) check() error { return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "Announcement.status": %w`, err)} } } + if v, ok := _u.mutation.NotifyMode(); ok { + if err := announcement.NotifyModeValidator(v); err != nil { + return &ValidationError{Name: "notify_mode", err: fmt.Errorf(`ent: validator failed for field "Announcement.notify_mode": %w`, err)} + } + } return nil } @@ -724,6 +765,9 @@ func (_u *AnnouncementUpdateOne) sqlSave(ctx context.Context) (_node *Announceme if value, ok := _u.mutation.Status(); ok { _spec.SetField(announcement.FieldStatus, field.TypeString, value) } + if value, ok := _u.mutation.NotifyMode(); ok { + _spec.SetField(announcement.FieldNotifyMode, field.TypeString, value) + } if value, ok := _u.mutation.Targeting(); ok { _spec.SetField(announcement.FieldTargeting, field.TypeJSON, value) } diff --git a/backend/ent/group.go b/backend/ent/group.go index 76c3cae2..7ed49905 100644 --- a/backend/ent/group.go +++ b/backend/ent/group.go @@ -62,22 +62,28 @@ type Group struct { SoraVideoPricePerRequestHd *float64 `json:"sora_video_price_per_request_hd,omitempty"` // SoraStorageQuotaBytes holds the value of the "sora_storage_quota_bytes" field. SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes,omitempty"` - // 是否仅允许 Claude Code 客户端 + // allow Claude Code client only ClaudeCodeOnly bool `json:"claude_code_only,omitempty"` - // 非 Claude Code 请求降级使用的分组 ID + // fallback group for non-Claude-Code requests FallbackGroupID *int64 `json:"fallback_group_id,omitempty"` - // 无效请求兜底使用的分组 ID + // fallback group for invalid request FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request,omitempty"` - // 模型路由配置:模型模式 -> 优先账号ID列表 + // model routing config: pattern -> account ids ModelRouting map[string][]int64 `json:"model_routing,omitempty"` - // 是否启用模型路由配置 + // whether model routing is enabled ModelRoutingEnabled bool `json:"model_routing_enabled,omitempty"` - // 是否注入 MCP XML 调用协议提示词(仅 antigravity 平台) + // whether MCP XML prompt injection is enabled McpXMLInject bool `json:"mcp_xml_inject,omitempty"` - // 支持的模型系列:claude, gemini_text, gemini_image + // supported model scopes: claude, gemini_text, gemini_image SupportedModelScopes []string `json:"supported_model_scopes,omitempty"` - // 分组显示排序,数值越小越靠前 + // group display order, lower comes first SortOrder int `json:"sort_order,omitempty"` + // 是否允许 /v1/messages 调度到此 OpenAI 分组 + AllowMessagesDispatch bool `json:"allow_messages_dispatch,omitempty"` + // 默认映射模型 ID,当账号级映射找不到时使用此值 + DefaultMappedModel string `json:"default_mapped_model,omitempty"` + // simulate claude usage as claude-max style (1h cache write) + SimulateClaudeMaxEnabled bool `json:"simulate_claude_max_enabled,omitempty"` // Edges holds the relations/edges for other nodes in the graph. // The values are being populated by the GroupQuery when eager-loading is set. Edges GroupEdges `json:"edges"` @@ -186,13 +192,13 @@ func (*Group) scanValues(columns []string) ([]any, error) { switch columns[i] { case group.FieldModelRouting, group.FieldSupportedModelScopes: values[i] = new([]byte) - case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject: + case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject, group.FieldAllowMessagesDispatch, group.FieldSimulateClaudeMaxEnabled: values[i] = new(sql.NullBool) case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k, group.FieldSoraImagePrice360, group.FieldSoraImagePrice540, group.FieldSoraVideoPricePerRequest, group.FieldSoraVideoPricePerRequestHd: values[i] = new(sql.NullFloat64) case group.FieldID, group.FieldDefaultValidityDays, group.FieldSoraStorageQuotaBytes, group.FieldFallbackGroupID, group.FieldFallbackGroupIDOnInvalidRequest, group.FieldSortOrder: values[i] = new(sql.NullInt64) - case group.FieldName, group.FieldDescription, group.FieldStatus, group.FieldPlatform, group.FieldSubscriptionType: + case group.FieldName, group.FieldDescription, group.FieldStatus, group.FieldPlatform, group.FieldSubscriptionType, group.FieldDefaultMappedModel: values[i] = new(sql.NullString) case group.FieldCreatedAt, group.FieldUpdatedAt, group.FieldDeletedAt: values[i] = new(sql.NullTime) @@ -415,6 +421,24 @@ func (_m *Group) assignValues(columns []string, values []any) error { } else if value.Valid { _m.SortOrder = int(value.Int64) } + case group.FieldAllowMessagesDispatch: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field allow_messages_dispatch", values[i]) + } else if value.Valid { + _m.AllowMessagesDispatch = value.Bool + } + case group.FieldDefaultMappedModel: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field default_mapped_model", values[i]) + } else if value.Valid { + _m.DefaultMappedModel = value.String + } + case group.FieldSimulateClaudeMaxEnabled: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field simulate_claude_max_enabled", values[i]) + } else if value.Valid { + _m.SimulateClaudeMaxEnabled = value.Bool + } default: _m.selectValues.Set(columns[i], values[i]) } @@ -608,6 +632,15 @@ func (_m *Group) String() string { builder.WriteString(", ") builder.WriteString("sort_order=") builder.WriteString(fmt.Sprintf("%v", _m.SortOrder)) + builder.WriteString(", ") + builder.WriteString("allow_messages_dispatch=") + builder.WriteString(fmt.Sprintf("%v", _m.AllowMessagesDispatch)) + builder.WriteString(", ") + builder.WriteString("default_mapped_model=") + builder.WriteString(_m.DefaultMappedModel) + builder.WriteString(", ") + builder.WriteString("simulate_claude_max_enabled=") + builder.WriteString(fmt.Sprintf("%v", _m.SimulateClaudeMaxEnabled)) builder.WriteByte(')') return builder.String() } diff --git a/backend/ent/group/group.go b/backend/ent/group/group.go index 6ac4eea1..970c7a85 100644 --- a/backend/ent/group/group.go +++ b/backend/ent/group/group.go @@ -75,6 +75,12 @@ const ( FieldSupportedModelScopes = "supported_model_scopes" // FieldSortOrder holds the string denoting the sort_order field in the database. FieldSortOrder = "sort_order" + // FieldAllowMessagesDispatch holds the string denoting the allow_messages_dispatch field in the database. + FieldAllowMessagesDispatch = "allow_messages_dispatch" + // FieldDefaultMappedModel holds the string denoting the default_mapped_model field in the database. + FieldDefaultMappedModel = "default_mapped_model" + // FieldSimulateClaudeMaxEnabled holds the string denoting the simulate_claude_max_enabled field in the database. + FieldSimulateClaudeMaxEnabled = "simulate_claude_max_enabled" // EdgeAPIKeys holds the string denoting the api_keys edge name in mutations. EdgeAPIKeys = "api_keys" // EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations. @@ -180,6 +186,9 @@ var Columns = []string{ FieldMcpXMLInject, FieldSupportedModelScopes, FieldSortOrder, + FieldAllowMessagesDispatch, + FieldDefaultMappedModel, + FieldSimulateClaudeMaxEnabled, } var ( @@ -247,6 +256,14 @@ var ( DefaultSupportedModelScopes []string // DefaultSortOrder holds the default value on creation for the "sort_order" field. DefaultSortOrder int + // DefaultAllowMessagesDispatch holds the default value on creation for the "allow_messages_dispatch" field. + DefaultAllowMessagesDispatch bool + // DefaultDefaultMappedModel holds the default value on creation for the "default_mapped_model" field. + DefaultDefaultMappedModel string + // DefaultMappedModelValidator is a validator for the "default_mapped_model" field. It is called by the builders before save. + DefaultMappedModelValidator func(string) error + // DefaultSimulateClaudeMaxEnabled holds the default value on creation for the "simulate_claude_max_enabled" field. + DefaultSimulateClaudeMaxEnabled bool ) // OrderOption defines the ordering options for the Group queries. @@ -397,6 +414,21 @@ func BySortOrder(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldSortOrder, opts...).ToFunc() } +// ByAllowMessagesDispatch orders the results by the allow_messages_dispatch field. +func ByAllowMessagesDispatch(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldAllowMessagesDispatch, opts...).ToFunc() +} + +// ByDefaultMappedModel orders the results by the default_mapped_model field. +func ByDefaultMappedModel(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldDefaultMappedModel, opts...).ToFunc() +} + +// BySimulateClaudeMaxEnabled orders the results by the simulate_claude_max_enabled field. +func BySimulateClaudeMaxEnabled(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSimulateClaudeMaxEnabled, opts...).ToFunc() +} + // ByAPIKeysCount orders the results by api_keys count. func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption { return func(s *sql.Selector) { diff --git a/backend/ent/group/where.go b/backend/ent/group/where.go index 4cf65d0f..62c91d5a 100644 --- a/backend/ent/group/where.go +++ b/backend/ent/group/where.go @@ -195,6 +195,21 @@ func SortOrder(v int) predicate.Group { return predicate.Group(sql.FieldEQ(FieldSortOrder, v)) } +// AllowMessagesDispatch applies equality check predicate on the "allow_messages_dispatch" field. It's identical to AllowMessagesDispatchEQ. +func AllowMessagesDispatch(v bool) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldAllowMessagesDispatch, v)) +} + +// DefaultMappedModel applies equality check predicate on the "default_mapped_model" field. It's identical to DefaultMappedModelEQ. +func DefaultMappedModel(v string) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldDefaultMappedModel, v)) +} + +// SimulateClaudeMaxEnabled applies equality check predicate on the "simulate_claude_max_enabled" field. It's identical to SimulateClaudeMaxEnabledEQ. +func SimulateClaudeMaxEnabled(v bool) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldSimulateClaudeMaxEnabled, v)) +} + // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.Group { return predicate.Group(sql.FieldEQ(FieldCreatedAt, v)) @@ -1470,6 +1485,91 @@ func SortOrderLTE(v int) predicate.Group { return predicate.Group(sql.FieldLTE(FieldSortOrder, v)) } +// AllowMessagesDispatchEQ applies the EQ predicate on the "allow_messages_dispatch" field. +func AllowMessagesDispatchEQ(v bool) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldAllowMessagesDispatch, v)) +} + +// AllowMessagesDispatchNEQ applies the NEQ predicate on the "allow_messages_dispatch" field. +func AllowMessagesDispatchNEQ(v bool) predicate.Group { + return predicate.Group(sql.FieldNEQ(FieldAllowMessagesDispatch, v)) +} + +// DefaultMappedModelEQ applies the EQ predicate on the "default_mapped_model" field. +func DefaultMappedModelEQ(v string) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldDefaultMappedModel, v)) +} + +// DefaultMappedModelNEQ applies the NEQ predicate on the "default_mapped_model" field. +func DefaultMappedModelNEQ(v string) predicate.Group { + return predicate.Group(sql.FieldNEQ(FieldDefaultMappedModel, v)) +} + +// DefaultMappedModelIn applies the In predicate on the "default_mapped_model" field. +func DefaultMappedModelIn(vs ...string) predicate.Group { + return predicate.Group(sql.FieldIn(FieldDefaultMappedModel, vs...)) +} + +// DefaultMappedModelNotIn applies the NotIn predicate on the "default_mapped_model" field. +func DefaultMappedModelNotIn(vs ...string) predicate.Group { + return predicate.Group(sql.FieldNotIn(FieldDefaultMappedModel, vs...)) +} + +// DefaultMappedModelGT applies the GT predicate on the "default_mapped_model" field. +func DefaultMappedModelGT(v string) predicate.Group { + return predicate.Group(sql.FieldGT(FieldDefaultMappedModel, v)) +} + +// DefaultMappedModelGTE applies the GTE predicate on the "default_mapped_model" field. +func DefaultMappedModelGTE(v string) predicate.Group { + return predicate.Group(sql.FieldGTE(FieldDefaultMappedModel, v)) +} + +// DefaultMappedModelLT applies the LT predicate on the "default_mapped_model" field. +func DefaultMappedModelLT(v string) predicate.Group { + return predicate.Group(sql.FieldLT(FieldDefaultMappedModel, v)) +} + +// DefaultMappedModelLTE applies the LTE predicate on the "default_mapped_model" field. +func DefaultMappedModelLTE(v string) predicate.Group { + return predicate.Group(sql.FieldLTE(FieldDefaultMappedModel, v)) +} + +// DefaultMappedModelContains applies the Contains predicate on the "default_mapped_model" field. +func DefaultMappedModelContains(v string) predicate.Group { + return predicate.Group(sql.FieldContains(FieldDefaultMappedModel, v)) +} + +// DefaultMappedModelHasPrefix applies the HasPrefix predicate on the "default_mapped_model" field. +func DefaultMappedModelHasPrefix(v string) predicate.Group { + return predicate.Group(sql.FieldHasPrefix(FieldDefaultMappedModel, v)) +} + +// DefaultMappedModelHasSuffix applies the HasSuffix predicate on the "default_mapped_model" field. +func DefaultMappedModelHasSuffix(v string) predicate.Group { + return predicate.Group(sql.FieldHasSuffix(FieldDefaultMappedModel, v)) +} + +// DefaultMappedModelEqualFold applies the EqualFold predicate on the "default_mapped_model" field. +func DefaultMappedModelEqualFold(v string) predicate.Group { + return predicate.Group(sql.FieldEqualFold(FieldDefaultMappedModel, v)) +} + +// DefaultMappedModelContainsFold applies the ContainsFold predicate on the "default_mapped_model" field. +func DefaultMappedModelContainsFold(v string) predicate.Group { + return predicate.Group(sql.FieldContainsFold(FieldDefaultMappedModel, v)) +} + +// SimulateClaudeMaxEnabledEQ applies the EQ predicate on the "simulate_claude_max_enabled" field. +func SimulateClaudeMaxEnabledEQ(v bool) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldSimulateClaudeMaxEnabled, v)) +} + +// SimulateClaudeMaxEnabledNEQ applies the NEQ predicate on the "simulate_claude_max_enabled" field. +func SimulateClaudeMaxEnabledNEQ(v bool) predicate.Group { + return predicate.Group(sql.FieldNEQ(FieldSimulateClaudeMaxEnabled, v)) +} + // HasAPIKeys applies the HasEdge predicate on the "api_keys" edge. func HasAPIKeys() predicate.Group { return predicate.Group(func(s *sql.Selector) { diff --git a/backend/ent/group_create.go b/backend/ent/group_create.go index 0ce5f959..9418b02f 100644 --- a/backend/ent/group_create.go +++ b/backend/ent/group_create.go @@ -424,6 +424,48 @@ func (_c *GroupCreate) SetNillableSortOrder(v *int) *GroupCreate { return _c } +// SetAllowMessagesDispatch sets the "allow_messages_dispatch" field. +func (_c *GroupCreate) SetAllowMessagesDispatch(v bool) *GroupCreate { + _c.mutation.SetAllowMessagesDispatch(v) + return _c +} + +// SetNillableAllowMessagesDispatch sets the "allow_messages_dispatch" field if the given value is not nil. +func (_c *GroupCreate) SetNillableAllowMessagesDispatch(v *bool) *GroupCreate { + if v != nil { + _c.SetAllowMessagesDispatch(*v) + } + return _c +} + +// SetDefaultMappedModel sets the "default_mapped_model" field. +func (_c *GroupCreate) SetDefaultMappedModel(v string) *GroupCreate { + _c.mutation.SetDefaultMappedModel(v) + return _c +} + +// SetNillableDefaultMappedModel sets the "default_mapped_model" field if the given value is not nil. +func (_c *GroupCreate) SetNillableDefaultMappedModel(v *string) *GroupCreate { + if v != nil { + _c.SetDefaultMappedModel(*v) + } + return _c +} + +// SetSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field. +func (_c *GroupCreate) SetSimulateClaudeMaxEnabled(v bool) *GroupCreate { + _c.mutation.SetSimulateClaudeMaxEnabled(v) + return _c +} + +// SetNillableSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field if the given value is not nil. +func (_c *GroupCreate) SetNillableSimulateClaudeMaxEnabled(v *bool) *GroupCreate { + if v != nil { + _c.SetSimulateClaudeMaxEnabled(*v) + } + return _c +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_c *GroupCreate) AddAPIKeyIDs(ids ...int64) *GroupCreate { _c.mutation.AddAPIKeyIDs(ids...) @@ -613,6 +655,18 @@ func (_c *GroupCreate) defaults() error { v := group.DefaultSortOrder _c.mutation.SetSortOrder(v) } + if _, ok := _c.mutation.AllowMessagesDispatch(); !ok { + v := group.DefaultAllowMessagesDispatch + _c.mutation.SetAllowMessagesDispatch(v) + } + if _, ok := _c.mutation.DefaultMappedModel(); !ok { + v := group.DefaultDefaultMappedModel + _c.mutation.SetDefaultMappedModel(v) + } + if _, ok := _c.mutation.SimulateClaudeMaxEnabled(); !ok { + v := group.DefaultSimulateClaudeMaxEnabled + _c.mutation.SetSimulateClaudeMaxEnabled(v) + } return nil } @@ -683,6 +737,20 @@ func (_c *GroupCreate) check() error { if _, ok := _c.mutation.SortOrder(); !ok { return &ValidationError{Name: "sort_order", err: errors.New(`ent: missing required field "Group.sort_order"`)} } + if _, ok := _c.mutation.AllowMessagesDispatch(); !ok { + return &ValidationError{Name: "allow_messages_dispatch", err: errors.New(`ent: missing required field "Group.allow_messages_dispatch"`)} + } + if _, ok := _c.mutation.DefaultMappedModel(); !ok { + return &ValidationError{Name: "default_mapped_model", err: errors.New(`ent: missing required field "Group.default_mapped_model"`)} + } + if v, ok := _c.mutation.DefaultMappedModel(); ok { + if err := group.DefaultMappedModelValidator(v); err != nil { + return &ValidationError{Name: "default_mapped_model", err: fmt.Errorf(`ent: validator failed for field "Group.default_mapped_model": %w`, err)} + } + } + if _, ok := _c.mutation.SimulateClaudeMaxEnabled(); !ok { + return &ValidationError{Name: "simulate_claude_max_enabled", err: errors.New(`ent: missing required field "Group.simulate_claude_max_enabled"`)} + } return nil } @@ -830,6 +898,18 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) { _spec.SetField(group.FieldSortOrder, field.TypeInt, value) _node.SortOrder = value } + if value, ok := _c.mutation.AllowMessagesDispatch(); ok { + _spec.SetField(group.FieldAllowMessagesDispatch, field.TypeBool, value) + _node.AllowMessagesDispatch = value + } + if value, ok := _c.mutation.DefaultMappedModel(); ok { + _spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value) + _node.DefaultMappedModel = value + } + if value, ok := _c.mutation.SimulateClaudeMaxEnabled(); ok { + _spec.SetField(group.FieldSimulateClaudeMaxEnabled, field.TypeBool, value) + _node.SimulateClaudeMaxEnabled = value + } if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -1520,6 +1600,42 @@ func (u *GroupUpsert) AddSortOrder(v int) *GroupUpsert { return u } +// SetAllowMessagesDispatch sets the "allow_messages_dispatch" field. +func (u *GroupUpsert) SetAllowMessagesDispatch(v bool) *GroupUpsert { + u.Set(group.FieldAllowMessagesDispatch, v) + return u +} + +// UpdateAllowMessagesDispatch sets the "allow_messages_dispatch" field to the value that was provided on create. +func (u *GroupUpsert) UpdateAllowMessagesDispatch() *GroupUpsert { + u.SetExcluded(group.FieldAllowMessagesDispatch) + return u +} + +// SetDefaultMappedModel sets the "default_mapped_model" field. +func (u *GroupUpsert) SetDefaultMappedModel(v string) *GroupUpsert { + u.Set(group.FieldDefaultMappedModel, v) + return u +} + +// UpdateDefaultMappedModel sets the "default_mapped_model" field to the value that was provided on create. +func (u *GroupUpsert) UpdateDefaultMappedModel() *GroupUpsert { + u.SetExcluded(group.FieldDefaultMappedModel) + return u +} + +// SetSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field. +func (u *GroupUpsert) SetSimulateClaudeMaxEnabled(v bool) *GroupUpsert { + u.Set(group.FieldSimulateClaudeMaxEnabled, v) + return u +} + +// UpdateSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field to the value that was provided on create. +func (u *GroupUpsert) UpdateSimulateClaudeMaxEnabled() *GroupUpsert { + u.SetExcluded(group.FieldSimulateClaudeMaxEnabled) + return u +} + // UpdateNewValues updates the mutable fields using the new values that were set on create. // Using this option is equivalent to using: // @@ -2188,6 +2304,48 @@ func (u *GroupUpsertOne) UpdateSortOrder() *GroupUpsertOne { }) } +// SetAllowMessagesDispatch sets the "allow_messages_dispatch" field. +func (u *GroupUpsertOne) SetAllowMessagesDispatch(v bool) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetAllowMessagesDispatch(v) + }) +} + +// UpdateAllowMessagesDispatch sets the "allow_messages_dispatch" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateAllowMessagesDispatch() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateAllowMessagesDispatch() + }) +} + +// SetDefaultMappedModel sets the "default_mapped_model" field. +func (u *GroupUpsertOne) SetDefaultMappedModel(v string) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetDefaultMappedModel(v) + }) +} + +// UpdateDefaultMappedModel sets the "default_mapped_model" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateDefaultMappedModel() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateDefaultMappedModel() + }) +} + +// SetSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field. +func (u *GroupUpsertOne) SetSimulateClaudeMaxEnabled(v bool) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetSimulateClaudeMaxEnabled(v) + }) +} + +// UpdateSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateSimulateClaudeMaxEnabled() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateSimulateClaudeMaxEnabled() + }) +} + // Exec executes the query. func (u *GroupUpsertOne) Exec(ctx context.Context) error { if len(u.create.conflict) == 0 { @@ -3022,6 +3180,48 @@ func (u *GroupUpsertBulk) UpdateSortOrder() *GroupUpsertBulk { }) } +// SetAllowMessagesDispatch sets the "allow_messages_dispatch" field. +func (u *GroupUpsertBulk) SetAllowMessagesDispatch(v bool) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetAllowMessagesDispatch(v) + }) +} + +// UpdateAllowMessagesDispatch sets the "allow_messages_dispatch" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateAllowMessagesDispatch() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateAllowMessagesDispatch() + }) +} + +// SetDefaultMappedModel sets the "default_mapped_model" field. +func (u *GroupUpsertBulk) SetDefaultMappedModel(v string) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetDefaultMappedModel(v) + }) +} + +// UpdateDefaultMappedModel sets the "default_mapped_model" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateDefaultMappedModel() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateDefaultMappedModel() + }) +} + +// SetSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field. +func (u *GroupUpsertBulk) SetSimulateClaudeMaxEnabled(v bool) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetSimulateClaudeMaxEnabled(v) + }) +} + +// UpdateSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateSimulateClaudeMaxEnabled() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateSimulateClaudeMaxEnabled() + }) +} + // Exec executes the query. func (u *GroupUpsertBulk) Exec(ctx context.Context) error { if u.create.err != nil { diff --git a/backend/ent/group_update.go b/backend/ent/group_update.go index 85575292..75955f7d 100644 --- a/backend/ent/group_update.go +++ b/backend/ent/group_update.go @@ -625,6 +625,48 @@ func (_u *GroupUpdate) AddSortOrder(v int) *GroupUpdate { return _u } +// SetAllowMessagesDispatch sets the "allow_messages_dispatch" field. +func (_u *GroupUpdate) SetAllowMessagesDispatch(v bool) *GroupUpdate { + _u.mutation.SetAllowMessagesDispatch(v) + return _u +} + +// SetNillableAllowMessagesDispatch sets the "allow_messages_dispatch" field if the given value is not nil. +func (_u *GroupUpdate) SetNillableAllowMessagesDispatch(v *bool) *GroupUpdate { + if v != nil { + _u.SetAllowMessagesDispatch(*v) + } + return _u +} + +// SetDefaultMappedModel sets the "default_mapped_model" field. +func (_u *GroupUpdate) SetDefaultMappedModel(v string) *GroupUpdate { + _u.mutation.SetDefaultMappedModel(v) + return _u +} + +// SetNillableDefaultMappedModel sets the "default_mapped_model" field if the given value is not nil. +func (_u *GroupUpdate) SetNillableDefaultMappedModel(v *string) *GroupUpdate { + if v != nil { + _u.SetDefaultMappedModel(*v) + } + return _u +} + +// SetSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field. +func (_u *GroupUpdate) SetSimulateClaudeMaxEnabled(v bool) *GroupUpdate { + _u.mutation.SetSimulateClaudeMaxEnabled(v) + return _u +} + +// SetNillableSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field if the given value is not nil. +func (_u *GroupUpdate) SetNillableSimulateClaudeMaxEnabled(v *bool) *GroupUpdate { + if v != nil { + _u.SetSimulateClaudeMaxEnabled(*v) + } + return _u +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_u *GroupUpdate) AddAPIKeyIDs(ids ...int64) *GroupUpdate { _u.mutation.AddAPIKeyIDs(ids...) @@ -910,6 +952,11 @@ func (_u *GroupUpdate) check() error { return &ValidationError{Name: "subscription_type", err: fmt.Errorf(`ent: validator failed for field "Group.subscription_type": %w`, err)} } } + if v, ok := _u.mutation.DefaultMappedModel(); ok { + if err := group.DefaultMappedModelValidator(v); err != nil { + return &ValidationError{Name: "default_mapped_model", err: fmt.Errorf(`ent: validator failed for field "Group.default_mapped_model": %w`, err)} + } + } return nil } @@ -1110,6 +1157,15 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) { if value, ok := _u.mutation.AddedSortOrder(); ok { _spec.AddField(group.FieldSortOrder, field.TypeInt, value) } + if value, ok := _u.mutation.AllowMessagesDispatch(); ok { + _spec.SetField(group.FieldAllowMessagesDispatch, field.TypeBool, value) + } + if value, ok := _u.mutation.DefaultMappedModel(); ok { + _spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value) + } + if value, ok := _u.mutation.SimulateClaudeMaxEnabled(); ok { + _spec.SetField(group.FieldSimulateClaudeMaxEnabled, field.TypeBool, value) + } if _u.mutation.APIKeysCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -2014,6 +2070,48 @@ func (_u *GroupUpdateOne) AddSortOrder(v int) *GroupUpdateOne { return _u } +// SetAllowMessagesDispatch sets the "allow_messages_dispatch" field. +func (_u *GroupUpdateOne) SetAllowMessagesDispatch(v bool) *GroupUpdateOne { + _u.mutation.SetAllowMessagesDispatch(v) + return _u +} + +// SetNillableAllowMessagesDispatch sets the "allow_messages_dispatch" field if the given value is not nil. +func (_u *GroupUpdateOne) SetNillableAllowMessagesDispatch(v *bool) *GroupUpdateOne { + if v != nil { + _u.SetAllowMessagesDispatch(*v) + } + return _u +} + +// SetDefaultMappedModel sets the "default_mapped_model" field. +func (_u *GroupUpdateOne) SetDefaultMappedModel(v string) *GroupUpdateOne { + _u.mutation.SetDefaultMappedModel(v) + return _u +} + +// SetNillableDefaultMappedModel sets the "default_mapped_model" field if the given value is not nil. +func (_u *GroupUpdateOne) SetNillableDefaultMappedModel(v *string) *GroupUpdateOne { + if v != nil { + _u.SetDefaultMappedModel(*v) + } + return _u +} + +// SetSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field. +func (_u *GroupUpdateOne) SetSimulateClaudeMaxEnabled(v bool) *GroupUpdateOne { + _u.mutation.SetSimulateClaudeMaxEnabled(v) + return _u +} + +// SetNillableSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field if the given value is not nil. +func (_u *GroupUpdateOne) SetNillableSimulateClaudeMaxEnabled(v *bool) *GroupUpdateOne { + if v != nil { + _u.SetSimulateClaudeMaxEnabled(*v) + } + return _u +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_u *GroupUpdateOne) AddAPIKeyIDs(ids ...int64) *GroupUpdateOne { _u.mutation.AddAPIKeyIDs(ids...) @@ -2312,6 +2410,11 @@ func (_u *GroupUpdateOne) check() error { return &ValidationError{Name: "subscription_type", err: fmt.Errorf(`ent: validator failed for field "Group.subscription_type": %w`, err)} } } + if v, ok := _u.mutation.DefaultMappedModel(); ok { + if err := group.DefaultMappedModelValidator(v); err != nil { + return &ValidationError{Name: "default_mapped_model", err: fmt.Errorf(`ent: validator failed for field "Group.default_mapped_model": %w`, err)} + } + } return nil } @@ -2529,6 +2632,15 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error) if value, ok := _u.mutation.AddedSortOrder(); ok { _spec.AddField(group.FieldSortOrder, field.TypeInt, value) } + if value, ok := _u.mutation.AllowMessagesDispatch(); ok { + _spec.SetField(group.FieldAllowMessagesDispatch, field.TypeBool, value) + } + if value, ok := _u.mutation.DefaultMappedModel(); ok { + _spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value) + } + if value, ok := _u.mutation.SimulateClaudeMaxEnabled(); ok { + _spec.SetField(group.FieldSimulateClaudeMaxEnabled, field.TypeBool, value) + } if _u.mutation.APIKeysCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go index 85e94072..03c66cd1 100644 --- a/backend/ent/migrate/schema.go +++ b/backend/ent/migrate/schema.go @@ -106,6 +106,7 @@ var ( {Name: "credentials", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}}, {Name: "extra", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}}, {Name: "concurrency", Type: field.TypeInt, Default: 3}, + {Name: "load_factor", Type: field.TypeInt, Nullable: true}, {Name: "priority", Type: field.TypeInt, Default: 50}, {Name: "rate_multiplier", Type: field.TypeFloat64, Default: 1, SchemaType: map[string]string{"postgres": "decimal(10,4)"}}, {Name: "status", Type: field.TypeString, Size: 20, Default: "active"}, @@ -132,7 +133,7 @@ var ( ForeignKeys: []*schema.ForeignKey{ { Symbol: "accounts_proxies_proxy", - Columns: []*schema.Column{AccountsColumns[27]}, + Columns: []*schema.Column{AccountsColumns[28]}, RefColumns: []*schema.Column{ProxiesColumns[0]}, OnDelete: schema.SetNull, }, @@ -151,52 +152,52 @@ var ( { Name: "account_status", Unique: false, - Columns: []*schema.Column{AccountsColumns[13]}, + Columns: []*schema.Column{AccountsColumns[14]}, }, { Name: "account_proxy_id", Unique: false, - Columns: []*schema.Column{AccountsColumns[27]}, + Columns: []*schema.Column{AccountsColumns[28]}, }, { Name: "account_priority", Unique: false, - Columns: []*schema.Column{AccountsColumns[11]}, + Columns: []*schema.Column{AccountsColumns[12]}, }, { Name: "account_last_used_at", Unique: false, - Columns: []*schema.Column{AccountsColumns[15]}, + Columns: []*schema.Column{AccountsColumns[16]}, }, { Name: "account_schedulable", Unique: false, - Columns: []*schema.Column{AccountsColumns[18]}, + Columns: []*schema.Column{AccountsColumns[19]}, }, { Name: "account_rate_limited_at", Unique: false, - Columns: []*schema.Column{AccountsColumns[19]}, + Columns: []*schema.Column{AccountsColumns[20]}, }, { Name: "account_rate_limit_reset_at", Unique: false, - Columns: []*schema.Column{AccountsColumns[20]}, + Columns: []*schema.Column{AccountsColumns[21]}, }, { Name: "account_overload_until", Unique: false, - Columns: []*schema.Column{AccountsColumns[21]}, + Columns: []*schema.Column{AccountsColumns[22]}, }, { Name: "account_platform_priority", Unique: false, - Columns: []*schema.Column{AccountsColumns[6], AccountsColumns[11]}, + Columns: []*schema.Column{AccountsColumns[6], AccountsColumns[12]}, }, { Name: "account_priority_status", Unique: false, - Columns: []*schema.Column{AccountsColumns[11], AccountsColumns[13]}, + Columns: []*schema.Column{AccountsColumns[12], AccountsColumns[14]}, }, { Name: "account_deleted_at", @@ -250,6 +251,7 @@ var ( {Name: "title", Type: field.TypeString, Size: 200}, {Name: "content", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}}, {Name: "status", Type: field.TypeString, Size: 20, Default: "draft"}, + {Name: "notify_mode", Type: field.TypeString, Size: 20, Default: "silent"}, {Name: "targeting", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}}, {Name: "starts_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, {Name: "ends_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, @@ -272,17 +274,17 @@ var ( { Name: "announcement_created_at", Unique: false, - Columns: []*schema.Column{AnnouncementsColumns[9]}, + Columns: []*schema.Column{AnnouncementsColumns[10]}, }, { Name: "announcement_starts_at", Unique: false, - Columns: []*schema.Column{AnnouncementsColumns[5]}, + Columns: []*schema.Column{AnnouncementsColumns[6]}, }, { Name: "announcement_ends_at", Unique: false, - Columns: []*schema.Column{AnnouncementsColumns[6]}, + Columns: []*schema.Column{AnnouncementsColumns[7]}, }, }, } @@ -406,6 +408,9 @@ var ( {Name: "mcp_xml_inject", Type: field.TypeBool, Default: true}, {Name: "supported_model_scopes", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}}, {Name: "sort_order", Type: field.TypeInt, Default: 0}, + {Name: "allow_messages_dispatch", Type: field.TypeBool, Default: false}, + {Name: "default_mapped_model", Type: field.TypeString, Size: 100, Default: ""}, + {Name: "simulate_claude_max_enabled", Type: field.TypeBool, Default: false}, } // GroupsTable holds the schema information for the "groups" table. GroupsTable = &schema.Table{ diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go index 85e2ea71..8177d14d 100644 --- a/backend/ent/mutation.go +++ b/backend/ent/mutation.go @@ -2260,6 +2260,8 @@ type AccountMutation struct { extra *map[string]interface{} concurrency *int addconcurrency *int + load_factor *int + addload_factor *int priority *int addpriority *int rate_multiplier *float64 @@ -2845,6 +2847,76 @@ func (m *AccountMutation) ResetConcurrency() { m.addconcurrency = nil } +// SetLoadFactor sets the "load_factor" field. +func (m *AccountMutation) SetLoadFactor(i int) { + m.load_factor = &i + m.addload_factor = nil +} + +// LoadFactor returns the value of the "load_factor" field in the mutation. +func (m *AccountMutation) LoadFactor() (r int, exists bool) { + v := m.load_factor + if v == nil { + return + } + return *v, true +} + +// OldLoadFactor returns the old "load_factor" field's value of the Account entity. +// If the Account object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AccountMutation) OldLoadFactor(ctx context.Context) (v *int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldLoadFactor is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldLoadFactor requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldLoadFactor: %w", err) + } + return oldValue.LoadFactor, nil +} + +// AddLoadFactor adds i to the "load_factor" field. +func (m *AccountMutation) AddLoadFactor(i int) { + if m.addload_factor != nil { + *m.addload_factor += i + } else { + m.addload_factor = &i + } +} + +// AddedLoadFactor returns the value that was added to the "load_factor" field in this mutation. +func (m *AccountMutation) AddedLoadFactor() (r int, exists bool) { + v := m.addload_factor + if v == nil { + return + } + return *v, true +} + +// ClearLoadFactor clears the value of the "load_factor" field. +func (m *AccountMutation) ClearLoadFactor() { + m.load_factor = nil + m.addload_factor = nil + m.clearedFields[account.FieldLoadFactor] = struct{}{} +} + +// LoadFactorCleared returns if the "load_factor" field was cleared in this mutation. +func (m *AccountMutation) LoadFactorCleared() bool { + _, ok := m.clearedFields[account.FieldLoadFactor] + return ok +} + +// ResetLoadFactor resets all changes to the "load_factor" field. +func (m *AccountMutation) ResetLoadFactor() { + m.load_factor = nil + m.addload_factor = nil + delete(m.clearedFields, account.FieldLoadFactor) +} + // SetPriority sets the "priority" field. func (m *AccountMutation) SetPriority(i int) { m.priority = &i @@ -3773,7 +3845,7 @@ func (m *AccountMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *AccountMutation) Fields() []string { - fields := make([]string, 0, 27) + fields := make([]string, 0, 28) if m.created_at != nil { fields = append(fields, account.FieldCreatedAt) } @@ -3807,6 +3879,9 @@ func (m *AccountMutation) Fields() []string { if m.concurrency != nil { fields = append(fields, account.FieldConcurrency) } + if m.load_factor != nil { + fields = append(fields, account.FieldLoadFactor) + } if m.priority != nil { fields = append(fields, account.FieldPriority) } @@ -3885,6 +3960,8 @@ func (m *AccountMutation) Field(name string) (ent.Value, bool) { return m.ProxyID() case account.FieldConcurrency: return m.Concurrency() + case account.FieldLoadFactor: + return m.LoadFactor() case account.FieldPriority: return m.Priority() case account.FieldRateMultiplier: @@ -3948,6 +4025,8 @@ func (m *AccountMutation) OldField(ctx context.Context, name string) (ent.Value, return m.OldProxyID(ctx) case account.FieldConcurrency: return m.OldConcurrency(ctx) + case account.FieldLoadFactor: + return m.OldLoadFactor(ctx) case account.FieldPriority: return m.OldPriority(ctx) case account.FieldRateMultiplier: @@ -4066,6 +4145,13 @@ func (m *AccountMutation) SetField(name string, value ent.Value) error { } m.SetConcurrency(v) return nil + case account.FieldLoadFactor: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetLoadFactor(v) + return nil case account.FieldPriority: v, ok := value.(int) if !ok { @@ -4189,6 +4275,9 @@ func (m *AccountMutation) AddedFields() []string { if m.addconcurrency != nil { fields = append(fields, account.FieldConcurrency) } + if m.addload_factor != nil { + fields = append(fields, account.FieldLoadFactor) + } if m.addpriority != nil { fields = append(fields, account.FieldPriority) } @@ -4205,6 +4294,8 @@ func (m *AccountMutation) AddedField(name string) (ent.Value, bool) { switch name { case account.FieldConcurrency: return m.AddedConcurrency() + case account.FieldLoadFactor: + return m.AddedLoadFactor() case account.FieldPriority: return m.AddedPriority() case account.FieldRateMultiplier: @@ -4225,6 +4316,13 @@ func (m *AccountMutation) AddField(name string, value ent.Value) error { } m.AddConcurrency(v) return nil + case account.FieldLoadFactor: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddLoadFactor(v) + return nil case account.FieldPriority: v, ok := value.(int) if !ok { @@ -4256,6 +4354,9 @@ func (m *AccountMutation) ClearedFields() []string { if m.FieldCleared(account.FieldProxyID) { fields = append(fields, account.FieldProxyID) } + if m.FieldCleared(account.FieldLoadFactor) { + fields = append(fields, account.FieldLoadFactor) + } if m.FieldCleared(account.FieldErrorMessage) { fields = append(fields, account.FieldErrorMessage) } @@ -4312,6 +4413,9 @@ func (m *AccountMutation) ClearField(name string) error { case account.FieldProxyID: m.ClearProxyID() return nil + case account.FieldLoadFactor: + m.ClearLoadFactor() + return nil case account.FieldErrorMessage: m.ClearErrorMessage() return nil @@ -4386,6 +4490,9 @@ func (m *AccountMutation) ResetField(name string) error { case account.FieldConcurrency: m.ResetConcurrency() return nil + case account.FieldLoadFactor: + m.ResetLoadFactor() + return nil case account.FieldPriority: m.ResetPriority() return nil @@ -5060,6 +5167,7 @@ type AnnouncementMutation struct { title *string content *string status *string + notify_mode *string targeting *domain.AnnouncementTargeting starts_at *time.Time ends_at *time.Time @@ -5284,6 +5392,42 @@ func (m *AnnouncementMutation) ResetStatus() { m.status = nil } +// SetNotifyMode sets the "notify_mode" field. +func (m *AnnouncementMutation) SetNotifyMode(s string) { + m.notify_mode = &s +} + +// NotifyMode returns the value of the "notify_mode" field in the mutation. +func (m *AnnouncementMutation) NotifyMode() (r string, exists bool) { + v := m.notify_mode + if v == nil { + return + } + return *v, true +} + +// OldNotifyMode returns the old "notify_mode" field's value of the Announcement entity. +// If the Announcement object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AnnouncementMutation) OldNotifyMode(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldNotifyMode is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldNotifyMode requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldNotifyMode: %w", err) + } + return oldValue.NotifyMode, nil +} + +// ResetNotifyMode resets all changes to the "notify_mode" field. +func (m *AnnouncementMutation) ResetNotifyMode() { + m.notify_mode = nil +} + // SetTargeting sets the "targeting" field. func (m *AnnouncementMutation) SetTargeting(dt domain.AnnouncementTargeting) { m.targeting = &dt @@ -5731,7 +5875,7 @@ func (m *AnnouncementMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *AnnouncementMutation) Fields() []string { - fields := make([]string, 0, 10) + fields := make([]string, 0, 11) if m.title != nil { fields = append(fields, announcement.FieldTitle) } @@ -5741,6 +5885,9 @@ func (m *AnnouncementMutation) Fields() []string { if m.status != nil { fields = append(fields, announcement.FieldStatus) } + if m.notify_mode != nil { + fields = append(fields, announcement.FieldNotifyMode) + } if m.targeting != nil { fields = append(fields, announcement.FieldTargeting) } @@ -5776,6 +5923,8 @@ func (m *AnnouncementMutation) Field(name string) (ent.Value, bool) { return m.Content() case announcement.FieldStatus: return m.Status() + case announcement.FieldNotifyMode: + return m.NotifyMode() case announcement.FieldTargeting: return m.Targeting() case announcement.FieldStartsAt: @@ -5805,6 +5954,8 @@ func (m *AnnouncementMutation) OldField(ctx context.Context, name string) (ent.V return m.OldContent(ctx) case announcement.FieldStatus: return m.OldStatus(ctx) + case announcement.FieldNotifyMode: + return m.OldNotifyMode(ctx) case announcement.FieldTargeting: return m.OldTargeting(ctx) case announcement.FieldStartsAt: @@ -5849,6 +6000,13 @@ func (m *AnnouncementMutation) SetField(name string, value ent.Value) error { } m.SetStatus(v) return nil + case announcement.FieldNotifyMode: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetNotifyMode(v) + return nil case announcement.FieldTargeting: v, ok := value.(domain.AnnouncementTargeting) if !ok { @@ -6016,6 +6174,9 @@ func (m *AnnouncementMutation) ResetField(name string) error { case announcement.FieldStatus: m.ResetStatus() return nil + case announcement.FieldNotifyMode: + m.ResetNotifyMode() + return nil case announcement.FieldTargeting: m.ResetTargeting() return nil @@ -8089,6 +8250,9 @@ type GroupMutation struct { appendsupported_model_scopes []string sort_order *int addsort_order *int + allow_messages_dispatch *bool + default_mapped_model *string + simulate_claude_max_enabled *bool clearedFields map[string]struct{} api_keys map[int64]struct{} removedapi_keys map[int64]struct{} @@ -9833,6 +9997,114 @@ func (m *GroupMutation) ResetSortOrder() { m.addsort_order = nil } +// SetAllowMessagesDispatch sets the "allow_messages_dispatch" field. +func (m *GroupMutation) SetAllowMessagesDispatch(b bool) { + m.allow_messages_dispatch = &b +} + +// AllowMessagesDispatch returns the value of the "allow_messages_dispatch" field in the mutation. +func (m *GroupMutation) AllowMessagesDispatch() (r bool, exists bool) { + v := m.allow_messages_dispatch + if v == nil { + return + } + return *v, true +} + +// OldAllowMessagesDispatch returns the old "allow_messages_dispatch" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldAllowMessagesDispatch(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldAllowMessagesDispatch is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldAllowMessagesDispatch requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldAllowMessagesDispatch: %w", err) + } + return oldValue.AllowMessagesDispatch, nil +} + +// ResetAllowMessagesDispatch resets all changes to the "allow_messages_dispatch" field. +func (m *GroupMutation) ResetAllowMessagesDispatch() { + m.allow_messages_dispatch = nil +} + +// SetDefaultMappedModel sets the "default_mapped_model" field. +func (m *GroupMutation) SetDefaultMappedModel(s string) { + m.default_mapped_model = &s +} + +// DefaultMappedModel returns the value of the "default_mapped_model" field in the mutation. +func (m *GroupMutation) DefaultMappedModel() (r string, exists bool) { + v := m.default_mapped_model + if v == nil { + return + } + return *v, true +} + +// OldDefaultMappedModel returns the old "default_mapped_model" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldDefaultMappedModel(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldDefaultMappedModel is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldDefaultMappedModel requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldDefaultMappedModel: %w", err) + } + return oldValue.DefaultMappedModel, nil +} + +// ResetDefaultMappedModel resets all changes to the "default_mapped_model" field. +func (m *GroupMutation) ResetDefaultMappedModel() { + m.default_mapped_model = nil +} + +// SetSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field. +func (m *GroupMutation) SetSimulateClaudeMaxEnabled(b bool) { + m.simulate_claude_max_enabled = &b +} + +// SimulateClaudeMaxEnabled returns the value of the "simulate_claude_max_enabled" field in the mutation. +func (m *GroupMutation) SimulateClaudeMaxEnabled() (r bool, exists bool) { + v := m.simulate_claude_max_enabled + if v == nil { + return + } + return *v, true +} + +// OldSimulateClaudeMaxEnabled returns the old "simulate_claude_max_enabled" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldSimulateClaudeMaxEnabled(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSimulateClaudeMaxEnabled is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSimulateClaudeMaxEnabled requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSimulateClaudeMaxEnabled: %w", err) + } + return oldValue.SimulateClaudeMaxEnabled, nil +} + +// ResetSimulateClaudeMaxEnabled resets all changes to the "simulate_claude_max_enabled" field. +func (m *GroupMutation) ResetSimulateClaudeMaxEnabled() { + m.simulate_claude_max_enabled = nil +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids. func (m *GroupMutation) AddAPIKeyIDs(ids ...int64) { if m.api_keys == nil { @@ -10191,7 +10463,7 @@ func (m *GroupMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *GroupMutation) Fields() []string { - fields := make([]string, 0, 30) + fields := make([]string, 0, 33) if m.created_at != nil { fields = append(fields, group.FieldCreatedAt) } @@ -10282,6 +10554,15 @@ func (m *GroupMutation) Fields() []string { if m.sort_order != nil { fields = append(fields, group.FieldSortOrder) } + if m.allow_messages_dispatch != nil { + fields = append(fields, group.FieldAllowMessagesDispatch) + } + if m.default_mapped_model != nil { + fields = append(fields, group.FieldDefaultMappedModel) + } + if m.simulate_claude_max_enabled != nil { + fields = append(fields, group.FieldSimulateClaudeMaxEnabled) + } return fields } @@ -10350,6 +10631,12 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) { return m.SupportedModelScopes() case group.FieldSortOrder: return m.SortOrder() + case group.FieldAllowMessagesDispatch: + return m.AllowMessagesDispatch() + case group.FieldDefaultMappedModel: + return m.DefaultMappedModel() + case group.FieldSimulateClaudeMaxEnabled: + return m.SimulateClaudeMaxEnabled() } return nil, false } @@ -10419,6 +10706,12 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e return m.OldSupportedModelScopes(ctx) case group.FieldSortOrder: return m.OldSortOrder(ctx) + case group.FieldAllowMessagesDispatch: + return m.OldAllowMessagesDispatch(ctx) + case group.FieldDefaultMappedModel: + return m.OldDefaultMappedModel(ctx) + case group.FieldSimulateClaudeMaxEnabled: + return m.OldSimulateClaudeMaxEnabled(ctx) } return nil, fmt.Errorf("unknown Group field %s", name) } @@ -10638,6 +10931,27 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error { } m.SetSortOrder(v) return nil + case group.FieldAllowMessagesDispatch: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetAllowMessagesDispatch(v) + return nil + case group.FieldDefaultMappedModel: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetDefaultMappedModel(v) + return nil + case group.FieldSimulateClaudeMaxEnabled: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSimulateClaudeMaxEnabled(v) + return nil } return fmt.Errorf("unknown Group field %s", name) } @@ -11065,6 +11379,15 @@ func (m *GroupMutation) ResetField(name string) error { case group.FieldSortOrder: m.ResetSortOrder() return nil + case group.FieldAllowMessagesDispatch: + m.ResetAllowMessagesDispatch() + return nil + case group.FieldDefaultMappedModel: + m.ResetDefaultMappedModel() + return nil + case group.FieldSimulateClaudeMaxEnabled: + m.ResetSimulateClaudeMaxEnabled() + return nil } return fmt.Errorf("unknown Group field %s", name) } diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go index 2c7467f6..ff8a655b 100644 --- a/backend/ent/runtime/runtime.go +++ b/backend/ent/runtime/runtime.go @@ -212,29 +212,29 @@ func init() { // account.DefaultConcurrency holds the default value on creation for the concurrency field. account.DefaultConcurrency = accountDescConcurrency.Default.(int) // accountDescPriority is the schema descriptor for priority field. - accountDescPriority := accountFields[8].Descriptor() + accountDescPriority := accountFields[9].Descriptor() // account.DefaultPriority holds the default value on creation for the priority field. account.DefaultPriority = accountDescPriority.Default.(int) // accountDescRateMultiplier is the schema descriptor for rate_multiplier field. - accountDescRateMultiplier := accountFields[9].Descriptor() + accountDescRateMultiplier := accountFields[10].Descriptor() // account.DefaultRateMultiplier holds the default value on creation for the rate_multiplier field. account.DefaultRateMultiplier = accountDescRateMultiplier.Default.(float64) // accountDescStatus is the schema descriptor for status field. - accountDescStatus := accountFields[10].Descriptor() + accountDescStatus := accountFields[11].Descriptor() // account.DefaultStatus holds the default value on creation for the status field. account.DefaultStatus = accountDescStatus.Default.(string) // account.StatusValidator is a validator for the "status" field. It is called by the builders before save. account.StatusValidator = accountDescStatus.Validators[0].(func(string) error) // accountDescAutoPauseOnExpired is the schema descriptor for auto_pause_on_expired field. - accountDescAutoPauseOnExpired := accountFields[14].Descriptor() + accountDescAutoPauseOnExpired := accountFields[15].Descriptor() // account.DefaultAutoPauseOnExpired holds the default value on creation for the auto_pause_on_expired field. account.DefaultAutoPauseOnExpired = accountDescAutoPauseOnExpired.Default.(bool) // accountDescSchedulable is the schema descriptor for schedulable field. - accountDescSchedulable := accountFields[15].Descriptor() + accountDescSchedulable := accountFields[16].Descriptor() // account.DefaultSchedulable holds the default value on creation for the schedulable field. account.DefaultSchedulable = accountDescSchedulable.Default.(bool) // accountDescSessionWindowStatus is the schema descriptor for session_window_status field. - accountDescSessionWindowStatus := accountFields[23].Descriptor() + accountDescSessionWindowStatus := accountFields[24].Descriptor() // account.SessionWindowStatusValidator is a validator for the "session_window_status" field. It is called by the builders before save. account.SessionWindowStatusValidator = accountDescSessionWindowStatus.Validators[0].(func(string) error) accountgroupFields := schema.AccountGroup{}.Fields() @@ -277,12 +277,18 @@ func init() { announcement.DefaultStatus = announcementDescStatus.Default.(string) // announcement.StatusValidator is a validator for the "status" field. It is called by the builders before save. announcement.StatusValidator = announcementDescStatus.Validators[0].(func(string) error) + // announcementDescNotifyMode is the schema descriptor for notify_mode field. + announcementDescNotifyMode := announcementFields[3].Descriptor() + // announcement.DefaultNotifyMode holds the default value on creation for the notify_mode field. + announcement.DefaultNotifyMode = announcementDescNotifyMode.Default.(string) + // announcement.NotifyModeValidator is a validator for the "notify_mode" field. It is called by the builders before save. + announcement.NotifyModeValidator = announcementDescNotifyMode.Validators[0].(func(string) error) // announcementDescCreatedAt is the schema descriptor for created_at field. - announcementDescCreatedAt := announcementFields[8].Descriptor() + announcementDescCreatedAt := announcementFields[9].Descriptor() // announcement.DefaultCreatedAt holds the default value on creation for the created_at field. announcement.DefaultCreatedAt = announcementDescCreatedAt.Default.(func() time.Time) // announcementDescUpdatedAt is the schema descriptor for updated_at field. - announcementDescUpdatedAt := announcementFields[9].Descriptor() + announcementDescUpdatedAt := announcementFields[10].Descriptor() // announcement.DefaultUpdatedAt holds the default value on creation for the updated_at field. announcement.DefaultUpdatedAt = announcementDescUpdatedAt.Default.(func() time.Time) // announcement.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. @@ -447,6 +453,20 @@ func init() { groupDescSortOrder := groupFields[26].Descriptor() // group.DefaultSortOrder holds the default value on creation for the sort_order field. group.DefaultSortOrder = groupDescSortOrder.Default.(int) + // groupDescAllowMessagesDispatch is the schema descriptor for allow_messages_dispatch field. + groupDescAllowMessagesDispatch := groupFields[27].Descriptor() + // group.DefaultAllowMessagesDispatch holds the default value on creation for the allow_messages_dispatch field. + group.DefaultAllowMessagesDispatch = groupDescAllowMessagesDispatch.Default.(bool) + // groupDescDefaultMappedModel is the schema descriptor for default_mapped_model field. + groupDescDefaultMappedModel := groupFields[28].Descriptor() + // group.DefaultDefaultMappedModel holds the default value on creation for the default_mapped_model field. + group.DefaultDefaultMappedModel = groupDescDefaultMappedModel.Default.(string) + // group.DefaultMappedModelValidator is a validator for the "default_mapped_model" field. It is called by the builders before save. + group.DefaultMappedModelValidator = groupDescDefaultMappedModel.Validators[0].(func(string) error) + // groupDescSimulateClaudeMaxEnabled is the schema descriptor for simulate_claude_max_enabled field. + groupDescSimulateClaudeMaxEnabled := groupFields[29].Descriptor() + // group.DefaultSimulateClaudeMaxEnabled holds the default value on creation for the simulate_claude_max_enabled field. + group.DefaultSimulateClaudeMaxEnabled = groupDescSimulateClaudeMaxEnabled.Default.(bool) idempotencyrecordMixin := schema.IdempotencyRecord{}.Mixin() idempotencyrecordMixinFields0 := idempotencyrecordMixin[0].Fields() _ = idempotencyrecordMixinFields0 diff --git a/backend/ent/schema/account.go b/backend/ent/schema/account.go index 443f9e09..5616d399 100644 --- a/backend/ent/schema/account.go +++ b/backend/ent/schema/account.go @@ -97,6 +97,8 @@ func (Account) Fields() []ent.Field { field.Int("concurrency"). Default(3), + field.Int("load_factor").Optional().Nillable(), + // priority: 账户优先级,数值越小优先级越高 // 调度器会优先使用高优先级的账户 field.Int("priority"). diff --git a/backend/ent/schema/announcement.go b/backend/ent/schema/announcement.go index 1568778f..14159fc3 100644 --- a/backend/ent/schema/announcement.go +++ b/backend/ent/schema/announcement.go @@ -41,6 +41,10 @@ func (Announcement) Fields() []ent.Field { MaxLen(20). Default(domain.AnnouncementStatusDraft). Comment("状态: draft, active, archived"), + field.String("notify_mode"). + MaxLen(20). + Default(domain.AnnouncementNotifyModeSilent). + Comment("通知模式: silent(仅铃铛), popup(弹窗提醒)"), field.JSON("targeting", domain.AnnouncementTargeting{}). Optional(). SchemaType(map[string]string{dialect.Postgres: "jsonb"}). diff --git a/backend/ent/schema/group.go b/backend/ent/schema/group.go index 3fcf8674..0842a0f8 100644 --- a/backend/ent/schema/group.go +++ b/backend/ent/schema/group.go @@ -33,8 +33,6 @@ func (Group) Mixin() []ent.Mixin { func (Group) Fields() []ent.Field { return []ent.Field{ - // 唯一约束通过部分索引实现(WHERE deleted_at IS NULL),支持软删除后重用 - // 见迁移文件 016_soft_delete_partial_unique_indexes.sql field.String("name"). MaxLen(100). NotEmpty(), @@ -51,7 +49,6 @@ func (Group) Fields() []ent.Field { MaxLen(20). Default(domain.StatusActive), - // Subscription-related fields (added by migration 003) field.String("platform"). MaxLen(50). Default(domain.PlatformAnthropic), @@ -73,7 +70,6 @@ func (Group) Fields() []ent.Field { field.Int("default_validity_days"). Default(30), - // 图片生成计费配置(antigravity 和 gemini 平台使用) field.Float("image_price_1k"). Optional(). Nillable(). @@ -87,7 +83,6 @@ func (Group) Fields() []ent.Field { Nillable(). SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}), - // Sora 按次计费配置(阶段 1) field.Float("sora_image_price_360"). Optional(). Nillable(). @@ -109,45 +104,50 @@ func (Group) Fields() []ent.Field { field.Int64("sora_storage_quota_bytes"). Default(0), - // Claude Code 客户端限制 (added by migration 029) field.Bool("claude_code_only"). Default(false). - Comment("是否仅允许 Claude Code 客户端"), + Comment("allow Claude Code client only"), field.Int64("fallback_group_id"). Optional(). Nillable(). - Comment("非 Claude Code 请求降级使用的分组 ID"), + Comment("fallback group for non-Claude-Code requests"), field.Int64("fallback_group_id_on_invalid_request"). Optional(). Nillable(). - Comment("无效请求兜底使用的分组 ID"), + Comment("fallback group for invalid request"), - // 模型路由配置 (added by migration 040) field.JSON("model_routing", map[string][]int64{}). Optional(). SchemaType(map[string]string{dialect.Postgres: "jsonb"}). - Comment("模型路由配置:模型模式 -> 优先账号ID列表"), - - // 模型路由开关 (added by migration 041) + Comment("model routing config: pattern -> account ids"), field.Bool("model_routing_enabled"). Default(false). - Comment("是否启用模型路由配置"), + Comment("whether model routing is enabled"), - // MCP XML 协议注入开关 (added by migration 042) field.Bool("mcp_xml_inject"). Default(true). - Comment("是否注入 MCP XML 调用协议提示词(仅 antigravity 平台)"), + Comment("whether MCP XML prompt injection is enabled"), - // 支持的模型系列 (added by migration 046) field.JSON("supported_model_scopes", []string{}). Default([]string{"claude", "gemini_text", "gemini_image"}). SchemaType(map[string]string{dialect.Postgres: "jsonb"}). - Comment("支持的模型系列:claude, gemini_text, gemini_image"), + Comment("supported model scopes: claude, gemini_text, gemini_image"), - // 分组排序 (added by migration 052) field.Int("sort_order"). Default(0). - Comment("分组显示排序,数值越小越靠前"), + Comment("group display order, lower comes first"), + + // OpenAI Messages 调度配置 (added by migration 069) + field.Bool("allow_messages_dispatch"). + Default(false). + Comment("是否允许 /v1/messages 调度到此 OpenAI 分组"), + field.String("default_mapped_model"). + MaxLen(100). + Default(""). + Comment("默认映射模型 ID,当账号级映射找不到时使用此值"), + field.Bool("simulate_claude_max_enabled"). + Default(false). + Comment("simulate claude usage as claude-max style (1h cache write)"), } } @@ -163,14 +163,11 @@ func (Group) Edges() []ent.Edge { edge.From("allowed_users", User.Type). Ref("allowed_groups"). Through("user_allowed_groups", UserAllowedGroup.Type), - // 注意:fallback_group_id 直接作为字段使用,不定义 edge - // 这样允许多个分组指向同一个降级分组(M2O 关系) } } func (Group) Indexes() []ent.Index { return []ent.Index{ - // name 字段已在 Fields() 中声明 Unique(),无需重复索引 index.Fields("status"), index.Fields("platform"), index.Fields("subscription_type"), diff --git a/backend/go.mod b/backend/go.mod index ab76258a..267fcf60 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -1,12 +1,13 @@ module github.com/Wei-Shaw/sub2api -go 1.25.7 +go 1.26.1 require ( entgo.io/ent v0.14.5 github.com/DATA-DOG/go-sqlmock v1.5.2 github.com/DouDOU-start/go-sora2api v1.1.0 github.com/alitto/pond/v2 v2.6.2 + github.com/aws/aws-sdk-go-v2 v1.41.2 github.com/aws/aws-sdk-go-v2/config v1.32.10 github.com/aws/aws-sdk-go-v2/credentials v1.19.10 github.com/aws/aws-sdk-go-v2/service/s3 v1.96.2 @@ -38,8 +39,6 @@ require ( golang.org/x/net v0.49.0 golang.org/x/sync v0.19.0 golang.org/x/term v0.40.0 - google.golang.org/grpc v1.75.1 - google.golang.org/protobuf v1.36.10 gopkg.in/natefinch/lumberjack.v2 v2.2.1 gopkg.in/yaml.v3 v3.0.1 modernc.org/sqlite v1.44.3 @@ -53,7 +52,6 @@ require ( github.com/agext/levenshtein v1.2.3 // indirect github.com/andybalholm/brotli v1.2.0 // indirect github.com/apparentlymart/go-textseg/v15 v15.0.0 // indirect - github.com/aws/aws-sdk-go-v2 v1.41.2 // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5 // indirect github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.18 // indirect github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.18 // indirect @@ -89,6 +87,7 @@ require ( github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/distribution/reference v0.6.0 // indirect + github.com/dlclark/regexp2 v1.10.0 // indirect github.com/docker/docker v28.5.1+incompatible // indirect github.com/docker/go-connections v0.6.0 // indirect github.com/docker/go-units v0.5.0 // indirect @@ -109,7 +108,6 @@ require ( github.com/goccy/go-json v0.10.2 // indirect github.com/google/go-cmp v0.7.0 // indirect github.com/google/go-querystring v1.1.0 // indirect - github.com/google/subcommands v1.2.0 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/hashicorp/hcl/v2 v2.18.1 // indirect @@ -140,6 +138,8 @@ require ( github.com/opencontainers/image-spec v1.1.1 // indirect github.com/pelletier/go-toml/v2 v2.2.2 // indirect github.com/pkg/errors v0.9.1 // indirect + github.com/pkoukk/tiktoken-go v0.1.8 // indirect + github.com/pkoukk/tiktoken-go-loader v0.0.2 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect github.com/quic-go/qpack v0.6.0 // indirect @@ -169,6 +169,7 @@ require ( go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 // indirect go.opentelemetry.io/otel v1.37.0 // indirect go.opentelemetry.io/otel/metric v1.37.0 // indirect + go.opentelemetry.io/otel/sdk v1.37.0 // indirect go.opentelemetry.io/otel/trace v1.37.0 // indirect go.uber.org/atomic v1.10.0 // indirect go.uber.org/automaxprocs v1.6.0 // indirect @@ -178,8 +179,8 @@ require ( golang.org/x/mod v0.32.0 // indirect golang.org/x/sys v0.41.0 // indirect golang.org/x/text v0.34.0 // indirect - golang.org/x/tools v0.41.0 // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20250929231259-57b25ae835d4 // indirect + google.golang.org/grpc v1.75.1 // indirect + google.golang.org/protobuf v1.36.10 // indirect gopkg.in/ini.v1 v1.67.0 // indirect modernc.org/libc v1.67.6 // indirect modernc.org/mathutil v1.7.1 // indirect diff --git a/backend/go.sum b/backend/go.sum index 32e389a7..965f7442 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -124,6 +124,8 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/r github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk= github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E= +github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0= +github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/docker/docker v28.5.1+incompatible h1:Bm8DchhSD2J6PsFzxC35TZo4TLGR2PdW/E69rU45NhM= github.com/docker/docker v28.5.1+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= github.com/docker/go-connections v0.6.0 h1:LlMG9azAe1TqfR7sO+NJttz1gy6KO7VJBh+pMmjSD94= @@ -171,8 +173,6 @@ github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= -github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= -github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= @@ -182,7 +182,6 @@ github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17 github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs= github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA= -github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE= github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= @@ -203,6 +202,8 @@ github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4= github.com/icholy/digest v1.1.0/go.mod h1:QNrsSGQ5v7v9cReDI0+eyjsXGUoRSUZQHeQ5C4XLa0Y= github.com/imroc/req/v3 v3.57.0 h1:LMTUjNRUybUkTPn8oJDq8Kg3JRBOBTcnDhKu7mzupKI= github.com/imroc/req/v3 v3.57.0/go.mod h1:JL62ey1nvSLq81HORNcosvlf7SxZStONNqOprg0Pz00= +github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= +github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= @@ -285,6 +286,10 @@ github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6 github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkoukk/tiktoken-go v0.1.8 h1:85ENo+3FpWgAACBaEUVp+lctuTcYUO7BtmfhlN/QTRo= +github.com/pkoukk/tiktoken-go v0.1.8/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg= +github.com/pkoukk/tiktoken-go-loader v0.0.2 h1:LUKws63GV3pVHwH1srkBplBv+7URgmOmhSkRxsIvsK4= +github.com/pkoukk/tiktoken-go-loader v0.0.2/go.mod h1:4mIkYyZooFlnenDlormIo6cd5wrlUKNr97wp9nGgEKo= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -337,6 +342,8 @@ github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSS github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= @@ -345,8 +352,6 @@ github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= -github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= -github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= github.com/tam7t/hpkp v0.0.0-20160821193359-2b70b4024ed5 h1:YqAladjX7xpA6BM04leXMWAEjS0mTZ5kUU9KRBriQJc= @@ -398,8 +403,6 @@ go.opentelemetry.io/otel/metric v1.37.0 h1:mvwbQS5m0tbmqML4NqK+e3aDiO02vsf/Wgbsd go.opentelemetry.io/otel/metric v1.37.0/go.mod h1:04wGrZurHYKOc+RKeye86GwKiTb9FKm1WHtO+4EVr2E= go.opentelemetry.io/otel/sdk v1.37.0 h1:ItB0QUqnjesGRvNcmAcU0LyvkVyGJ2xftD29bWdDvKI= go.opentelemetry.io/otel/sdk v1.37.0/go.mod h1:VredYzxUvuo2q3WRcDnKDjbdvmO0sCzOvVAiY+yUkAg= -go.opentelemetry.io/otel/sdk/metric v1.37.0 h1:90lI228XrB9jCMuSdA0673aubgRobVZFhbjxHHspCPc= -go.opentelemetry.io/otel/sdk/metric v1.37.0/go.mod h1:cNen4ZWfiD37l5NhS+Keb5RXVWZWpRE+9WyVCpbo5ps= go.opentelemetry.io/otel/trace v1.37.0 h1:HLdcFNbRQBE2imdSEgm/kwqmQj1Or1l/7bW6mxVK7z4= go.opentelemetry.io/otel/trace v1.37.0/go.mod h1:TlgrlQ+PtQO5XFerSPUYG0JSgGyryXewPGyayAWSBS0= go.opentelemetry.io/proto/otlp v1.3.1 h1:TrMUixzpM0yuc/znrFTP9MMRh8trP93mkCiDVeXrui0= @@ -438,11 +441,11 @@ golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg= golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM= @@ -455,8 +458,6 @@ golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGm golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc= golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= -gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E= google.golang.org/genproto v0.0.0-20231106174013-bbf56f31fb17 h1:wpZ8pe2x1Q3f2KyT5f8oP/fa9rHAKgFPr/HZdNuS+PQ= google.golang.org/genproto/googleapis/api v0.0.0-20250929231259-57b25ae835d4 h1:8XJ4pajGwOlasW+L13MnEGA8W4115jJySQtVfS2/IBU= google.golang.org/genproto/googleapis/api v0.0.0-20250929231259-57b25ae835d4/go.mod h1:NnuHhy+bxcg30o7FnVAZbXsPHUDQ9qKWAQKCD7VxFtk= diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index cc81ce54..de876098 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -516,7 +516,7 @@ func (c *UserMessageQueueConfig) GetEffectiveMode() string { type GatewayOpenAIWSConfig struct { // ModeRouterV2Enabled: 新版 WS mode 路由开关(默认 false;关闭时保持 legacy 行为) ModeRouterV2Enabled bool `mapstructure:"mode_router_v2_enabled"` - // IngressModeDefault: ingress 默认模式(off/shared/dedicated) + // IngressModeDefault: ingress 默认模式(off/ctx_pool/passthrough) IngressModeDefault string `mapstructure:"ingress_mode_default"` // Enabled: 全局总开关(默认 true) Enabled bool `mapstructure:"enabled"` @@ -1227,7 +1227,7 @@ func setDefaults() { // Ops (vNext) viper.SetDefault("ops.enabled", true) - viper.SetDefault("ops.use_preaggregated_tables", false) + viper.SetDefault("ops.use_preaggregated_tables", true) viper.SetDefault("ops.cleanup.enabled", true) viper.SetDefault("ops.cleanup.schedule", "0 2 * * *") // Retention days: vNext defaults to 30 days across ops datasets. @@ -1335,7 +1335,7 @@ func setDefaults() { // OpenAI Responses WebSocket(默认开启;可通过 force_http 紧急回滚) viper.SetDefault("gateway.openai_ws.enabled", true) viper.SetDefault("gateway.openai_ws.mode_router_v2_enabled", false) - viper.SetDefault("gateway.openai_ws.ingress_mode_default", "shared") + viper.SetDefault("gateway.openai_ws.ingress_mode_default", "ctx_pool") viper.SetDefault("gateway.openai_ws.oauth_enabled", true) viper.SetDefault("gateway.openai_ws.apikey_enabled", true) viper.SetDefault("gateway.openai_ws.force_http", false) @@ -1402,7 +1402,7 @@ func setDefaults() { viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 30) // 并发槽位过期时间(支持超长请求) viper.SetDefault("gateway.stream_data_interval_timeout", 180) viper.SetDefault("gateway.stream_keepalive_interval", 10) - viper.SetDefault("gateway.max_line_size", 40*1024*1024) + viper.SetDefault("gateway.max_line_size", 500*1024*1024) viper.SetDefault("gateway.scheduling.sticky_session_max_waiting", 3) viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 120*time.Second) viper.SetDefault("gateway.scheduling.fallback_wait_timeout", 30*time.Second) @@ -2043,9 +2043,11 @@ func (c *Config) Validate() error { } if mode := strings.ToLower(strings.TrimSpace(c.Gateway.OpenAIWS.IngressModeDefault)); mode != "" { switch mode { - case "off", "shared", "dedicated": + case "off", "ctx_pool", "passthrough": + case "shared", "dedicated": + slog.Warn("gateway.openai_ws.ingress_mode_default is deprecated, treating as ctx_pool; please update to off|ctx_pool|passthrough", "value", mode) default: - return fmt.Errorf("gateway.openai_ws.ingress_mode_default must be one of off|shared|dedicated") + return fmt.Errorf("gateway.openai_ws.ingress_mode_default must be one of off|ctx_pool|passthrough") } } if mode := strings.ToLower(strings.TrimSpace(c.Gateway.OpenAIWS.StoreDisabledConnMode)); mode != "" { diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go index e3b592e2..79fcc6d0 100644 --- a/backend/internal/config/config_test.go +++ b/backend/internal/config/config_test.go @@ -153,8 +153,8 @@ func TestLoadDefaultOpenAIWSConfig(t *testing.T) { if cfg.Gateway.OpenAIWS.ModeRouterV2Enabled { t.Fatalf("Gateway.OpenAIWS.ModeRouterV2Enabled = true, want false") } - if cfg.Gateway.OpenAIWS.IngressModeDefault != "shared" { - t.Fatalf("Gateway.OpenAIWS.IngressModeDefault = %q, want %q", cfg.Gateway.OpenAIWS.IngressModeDefault, "shared") + if cfg.Gateway.OpenAIWS.IngressModeDefault != "ctx_pool" { + t.Fatalf("Gateway.OpenAIWS.IngressModeDefault = %q, want %q", cfg.Gateway.OpenAIWS.IngressModeDefault, "ctx_pool") } } @@ -1373,7 +1373,7 @@ func TestValidateConfig_OpenAIWSRules(t *testing.T) { wantErr: "gateway.openai_ws.store_disabled_conn_mode", }, { - name: "ingress_mode_default 必须为 off|shared|dedicated", + name: "ingress_mode_default 必须为 off|ctx_pool|passthrough", mutate: func(c *Config) { c.Gateway.OpenAIWS.IngressModeDefault = "invalid" }, wantErr: "gateway.openai_ws.ingress_mode_default", }, diff --git a/backend/internal/domain/announcement.go b/backend/internal/domain/announcement.go index 7dc9a9cc..0e68fb0f 100644 --- a/backend/internal/domain/announcement.go +++ b/backend/internal/domain/announcement.go @@ -13,6 +13,11 @@ const ( AnnouncementStatusArchived = "archived" ) +const ( + AnnouncementNotifyModeSilent = "silent" + AnnouncementNotifyModePopup = "popup" +) + const ( AnnouncementConditionTypeSubscription = "subscription" AnnouncementConditionTypeBalance = "balance" @@ -195,17 +200,18 @@ func (c AnnouncementCondition) validate() error { } type Announcement struct { - ID int64 - Title string - Content string - Status string - Targeting AnnouncementTargeting - StartsAt *time.Time - EndsAt *time.Time - CreatedBy *int64 - UpdatedBy *int64 - CreatedAt time.Time - UpdatedAt time.Time + ID int64 + Title string + Content string + Status string + NotifyMode string + Targeting AnnouncementTargeting + StartsAt *time.Time + EndsAt *time.Time + CreatedBy *int64 + UpdatedBy *int64 + CreatedAt time.Time + UpdatedAt time.Time } func (a *Announcement) IsActiveAt(now time.Time) bool { diff --git a/backend/internal/domain/constants.go b/backend/internal/domain/constants.go index d7bb50fc..8a6621a1 100644 --- a/backend/internal/domain/constants.go +++ b/backend/internal/domain/constants.go @@ -84,10 +84,12 @@ var DefaultAntigravityModelMapping = map[string]string{ "claude-haiku-4-5": "claude-sonnet-4-5", "claude-haiku-4-5-20251001": "claude-sonnet-4-5", // Gemini 2.5 白名单 - "gemini-2.5-flash": "gemini-2.5-flash", - "gemini-2.5-flash-lite": "gemini-2.5-flash-lite", - "gemini-2.5-flash-thinking": "gemini-2.5-flash-thinking", - "gemini-2.5-pro": "gemini-2.5-pro", + "gemini-2.5-flash": "gemini-2.5-flash", + "gemini-2.5-flash-image": "gemini-2.5-flash-image", + "gemini-2.5-flash-image-preview": "gemini-2.5-flash-image", + "gemini-2.5-flash-lite": "gemini-2.5-flash-lite", + "gemini-2.5-flash-thinking": "gemini-2.5-flash-thinking", + "gemini-2.5-pro": "gemini-2.5-pro", // Gemini 3 白名单 "gemini-3-flash": "gemini-3-flash", "gemini-3-pro-high": "gemini-3-pro-high", diff --git a/backend/internal/domain/constants_test.go b/backend/internal/domain/constants_test.go index 29605ac6..de66137f 100644 --- a/backend/internal/domain/constants_test.go +++ b/backend/internal/domain/constants_test.go @@ -6,6 +6,8 @@ func TestDefaultAntigravityModelMapping_ImageCompatibilityAliases(t *testing.T) t.Parallel() cases := map[string]string{ + "gemini-2.5-flash-image": "gemini-2.5-flash-image", + "gemini-2.5-flash-image-preview": "gemini-2.5-flash-image", "gemini-3.1-flash-image": "gemini-3.1-flash-image", "gemini-3.1-flash-image-preview": "gemini-3.1-flash-image", "gemini-3-pro-image": "gemini-3.1-flash-image", diff --git a/backend/internal/handler/admin/account_data.go b/backend/internal/handler/admin/account_data.go index 4ce17219..fbac73d3 100644 --- a/backend/internal/handler/admin/account_data.go +++ b/backend/internal/handler/admin/account_data.go @@ -8,6 +8,9 @@ import ( "strings" "time" + "log/slog" + + "github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" @@ -292,6 +295,8 @@ func (h *AccountHandler) importData(ctx context.Context, req DataImportRequest) } } + enrichCredentialsFromIDToken(&item) + accountInput := &service.CreateAccountInput{ Name: item.Name, Notes: item.Notes, @@ -535,6 +540,57 @@ func defaultProxyName(name string) string { return name } +// enrichCredentialsFromIDToken performs best-effort extraction of user info fields +// (email, plan_type, chatgpt_account_id, etc.) from id_token in credentials. +// Only applies to OpenAI/Sora OAuth accounts. Skips expired token errors silently. +// Existing credential values are never overwritten — only missing fields are filled. +func enrichCredentialsFromIDToken(item *DataAccount) { + if item.Credentials == nil { + return + } + // Only enrich OpenAI/Sora OAuth accounts + platform := strings.ToLower(strings.TrimSpace(item.Platform)) + if platform != service.PlatformOpenAI && platform != service.PlatformSora { + return + } + if strings.ToLower(strings.TrimSpace(item.Type)) != service.AccountTypeOAuth { + return + } + + idToken, _ := item.Credentials["id_token"].(string) + if strings.TrimSpace(idToken) == "" { + return + } + + // DecodeIDToken skips expiry validation — safe for imported data + claims, err := openai.DecodeIDToken(idToken) + if err != nil { + slog.Debug("import_enrich_id_token_decode_failed", "account", item.Name, "error", err) + return + } + + userInfo := claims.GetUserInfo() + if userInfo == nil { + return + } + + // Fill missing fields only (never overwrite existing values) + setIfMissing := func(key, value string) { + if value == "" { + return + } + if existing, _ := item.Credentials[key].(string); existing == "" { + item.Credentials[key] = value + } + } + + setIfMissing("email", userInfo.Email) + setIfMissing("plan_type", userInfo.PlanType) + setIfMissing("chatgpt_account_id", userInfo.ChatGPTAccountID) + setIfMissing("chatgpt_user_id", userInfo.ChatGPTUserID) + setIfMissing("organization_id", userInfo.OrganizationID) +} + func normalizeProxyStatus(status string) string { normalized := strings.TrimSpace(strings.ToLower(status)) switch normalized { diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go index 98ead284..1f0d0b52 100644 --- a/backend/internal/handler/admin/account_handler.go +++ b/backend/internal/handler/admin/account_handler.go @@ -8,6 +8,7 @@ import ( "encoding/json" "errors" "fmt" + "log" "net/http" "strconv" "strings" @@ -18,6 +19,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" "github.com/Wei-Shaw/sub2api/internal/pkg/claude" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/pkg/response" @@ -102,6 +104,7 @@ type CreateAccountRequest struct { Concurrency int `json:"concurrency"` Priority int `json:"priority"` RateMultiplier *float64 `json:"rate_multiplier"` + LoadFactor *int `json:"load_factor"` GroupIDs []int64 `json:"group_ids"` ExpiresAt *int64 `json:"expires_at"` AutoPauseOnExpired *bool `json:"auto_pause_on_expired"` @@ -120,7 +123,8 @@ type UpdateAccountRequest struct { Concurrency *int `json:"concurrency"` Priority *int `json:"priority"` RateMultiplier *float64 `json:"rate_multiplier"` - Status string `json:"status" binding:"omitempty,oneof=active inactive"` + LoadFactor *int `json:"load_factor"` + Status string `json:"status" binding:"omitempty,oneof=active inactive error"` GroupIDs *[]int64 `json:"group_ids"` ExpiresAt *int64 `json:"expires_at"` AutoPauseOnExpired *bool `json:"auto_pause_on_expired"` @@ -135,6 +139,7 @@ type BulkUpdateAccountsRequest struct { Concurrency *int `json:"concurrency"` Priority *int `json:"priority"` RateMultiplier *float64 `json:"rate_multiplier"` + LoadFactor *int `json:"load_factor"` Status string `json:"status" binding:"omitempty,oneof=active inactive error"` Schedulable *bool `json:"schedulable"` GroupIDs *[]int64 `json:"group_ids"` @@ -217,6 +222,7 @@ func (h *AccountHandler) List(c *gin.Context) { if len(search) > 100 { search = search[:100] } + lite := parseBoolQueryWithDefault(c.Query("lite"), false) var groupID int64 if groupIDStr := c.Query("group"); groupIDStr != "" { @@ -235,10 +241,16 @@ func (h *AccountHandler) List(c *gin.Context) { accountIDs[i] = acc.ID } - concurrencyCounts, err := h.concurrencyService.GetAccountConcurrencyBatch(c.Request.Context(), accountIDs) - if err != nil { - // Log error but don't fail the request, just use 0 for all - concurrencyCounts = make(map[int64]int) + concurrencyCounts := make(map[int64]int) + var windowCosts map[int64]float64 + var activeSessions map[int64]int + var rpmCounts map[int64]int + + // 始终获取并发数(Redis ZCARD,极低开销) + if h.concurrencyService != nil { + if cc, ccErr := h.concurrencyService.GetAccountConcurrencyBatch(c.Request.Context(), accountIDs); ccErr == nil && cc != nil { + concurrencyCounts = cc + } } // 识别需要查询窗口费用、会话数和 RPM 的账号(Anthropic OAuth/SetupToken 且启用了相应功能) @@ -262,12 +274,7 @@ func (h *AccountHandler) List(c *gin.Context) { } } - // 并行获取窗口费用、活跃会话数和 RPM 计数 - var windowCosts map[int64]float64 - var activeSessions map[int64]int - var rpmCounts map[int64]int - - // 获取 RPM 计数(批量查询) + // 始终获取 RPM 计数(Redis GET,极低开销) if len(rpmAccountIDs) > 0 && h.rpmCache != nil { rpmCounts, _ = h.rpmCache.GetRPMBatch(c.Request.Context(), rpmAccountIDs) if rpmCounts == nil { @@ -275,7 +282,7 @@ func (h *AccountHandler) List(c *gin.Context) { } } - // 获取活跃会话数(批量查询,传入各账号的 idleTimeout 配置) + // 始终获取活跃会话数(Redis ZCARD,低开销) if len(sessionLimitAccountIDs) > 0 && h.sessionLimitCache != nil { activeSessions, _ = h.sessionLimitCache.GetActiveSessionCountBatch(c.Request.Context(), sessionLimitAccountIDs, sessionIdleTimeouts) if activeSessions == nil { @@ -283,7 +290,7 @@ func (h *AccountHandler) List(c *gin.Context) { } } - // 获取窗口费用(并行查询) + // 始终获取窗口费用(PostgreSQL 聚合查询) if len(windowCostAccountIDs) > 0 { windowCosts = make(map[int64]float64) var mu sync.Mutex @@ -344,7 +351,7 @@ func (h *AccountHandler) List(c *gin.Context) { result[i] = item } - etag := buildAccountsListETag(result, total, page, pageSize, platform, accountType, status, search) + etag := buildAccountsListETag(result, total, page, pageSize, platform, accountType, status, search, lite) if etag != "" { c.Header("ETag", etag) c.Header("Vary", "If-None-Match") @@ -362,6 +369,7 @@ func buildAccountsListETag( total int64, page, pageSize int, platform, accountType, status, search string, + lite bool, ) string { payload := struct { Total int64 `json:"total"` @@ -371,6 +379,7 @@ func buildAccountsListETag( AccountType string `json:"type"` Status string `json:"status"` Search string `json:"search"` + Lite bool `json:"lite"` Items []AccountWithConcurrency `json:"items"` }{ Total: total, @@ -380,6 +389,7 @@ func buildAccountsListETag( AccountType: accountType, Status: status, Search: search, + Lite: lite, Items: items, } raw, err := json.Marshal(payload) @@ -501,6 +511,7 @@ func (h *AccountHandler) Create(c *gin.Context) { Concurrency: req.Concurrency, Priority: req.Priority, RateMultiplier: req.RateMultiplier, + LoadFactor: req.LoadFactor, GroupIDs: req.GroupIDs, ExpiresAt: req.ExpiresAt, AutoPauseOnExpired: req.AutoPauseOnExpired, @@ -570,6 +581,7 @@ func (h *AccountHandler) Update(c *gin.Context) { Concurrency: req.Concurrency, // 指针类型,nil 表示未提供 Priority: req.Priority, // 指针类型,nil 表示未提供 RateMultiplier: req.RateMultiplier, + LoadFactor: req.LoadFactor, Status: req.Status, GroupIDs: req.GroupIDs, ExpiresAt: req.ExpiresAt, @@ -616,6 +628,7 @@ func (h *AccountHandler) Delete(c *gin.Context) { // TestAccountRequest represents the request body for testing an account type TestAccountRequest struct { ModelID string `json:"model_id"` + Prompt string `json:"prompt"` } type SyncFromCRSRequest struct { @@ -646,10 +659,46 @@ func (h *AccountHandler) Test(c *gin.Context) { _ = c.ShouldBindJSON(&req) // Use AccountTestService to test the account with SSE streaming - if err := h.accountTestService.TestAccountConnection(c, accountID, req.ModelID); err != nil { + if err := h.accountTestService.TestAccountConnection(c, accountID, req.ModelID, req.Prompt); err != nil { // Error already sent via SSE, just log return } + + if h.rateLimitService != nil { + if _, err := h.rateLimitService.RecoverAccountAfterSuccessfulTest(c.Request.Context(), accountID); err != nil { + _ = c.Error(err) + } + } +} + +// RecoverState handles unified recovery of recoverable account runtime state. +// POST /api/v1/admin/accounts/:id/recover-state +func (h *AccountHandler) RecoverState(c *gin.Context) { + accountID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "Invalid account ID") + return + } + + if h.rateLimitService == nil { + response.Error(c, http.StatusServiceUnavailable, "Rate limit service unavailable") + return + } + + if _, err := h.rateLimitService.RecoverAccountState(c.Request.Context(), accountID, service.AccountRecoveryOptions{ + InvalidateToken: true, + }); err != nil { + response.ErrorFrom(c, err) + return + } + + account, err := h.adminService.GetAccount(c.Request.Context(), accountID) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account)) } // SyncFromCRS handles syncing accounts from claude-relay-service (CRS) @@ -705,52 +754,31 @@ func (h *AccountHandler) PreviewFromCRS(c *gin.Context) { response.Success(c, result) } -// Refresh handles refreshing account credentials -// POST /api/v1/admin/accounts/:id/refresh -func (h *AccountHandler) Refresh(c *gin.Context) { - accountID, err := strconv.ParseInt(c.Param("id"), 10, 64) - if err != nil { - response.BadRequest(c, "Invalid account ID") - return - } - - // Get account - account, err := h.adminService.GetAccount(c.Request.Context(), accountID) - if err != nil { - response.NotFound(c, "Account not found") - return - } - - // Only refresh OAuth-based accounts (oauth and setup-token) +// refreshSingleAccount refreshes credentials for a single OAuth account. +// Returns (updatedAccount, warning, error) where warning is used for Antigravity ProjectIDMissing scenario. +func (h *AccountHandler) refreshSingleAccount(ctx context.Context, account *service.Account) (*service.Account, string, error) { if !account.IsOAuth() { - response.BadRequest(c, "Cannot refresh non-OAuth account credentials") - return + return nil, "", infraerrors.BadRequest("NOT_OAUTH", "cannot refresh non-OAuth account") } var newCredentials map[string]any if account.IsOpenAI() { - // Use OpenAI OAuth service to refresh token - tokenInfo, err := h.openaiOAuthService.RefreshAccountToken(c.Request.Context(), account) + tokenInfo, err := h.openaiOAuthService.RefreshAccountToken(ctx, account) if err != nil { - response.ErrorFrom(c, err) - return + return nil, "", err } - // Build new credentials from token info newCredentials = h.openaiOAuthService.BuildAccountCredentials(tokenInfo) - - // Preserve non-token settings from existing credentials for k, v := range account.Credentials { if _, exists := newCredentials[k]; !exists { newCredentials[k] = v } } } else if account.Platform == service.PlatformGemini { - tokenInfo, err := h.geminiOAuthService.RefreshAccountToken(c.Request.Context(), account) + tokenInfo, err := h.geminiOAuthService.RefreshAccountToken(ctx, account) if err != nil { - response.InternalError(c, "Failed to refresh credentials: "+err.Error()) - return + return nil, "", fmt.Errorf("failed to refresh credentials: %w", err) } newCredentials = h.geminiOAuthService.BuildAccountCredentials(tokenInfo) @@ -760,10 +788,9 @@ func (h *AccountHandler) Refresh(c *gin.Context) { } } } else if account.Platform == service.PlatformAntigravity { - tokenInfo, err := h.antigravityOAuthService.RefreshAccountToken(c.Request.Context(), account) + tokenInfo, err := h.antigravityOAuthService.RefreshAccountToken(ctx, account) if err != nil { - response.ErrorFrom(c, err) - return + return nil, "", err } newCredentials = h.antigravityOAuthService.BuildAccountCredentials(tokenInfo) @@ -782,37 +809,27 @@ func (h *AccountHandler) Refresh(c *gin.Context) { } // 如果 project_id 获取失败,更新凭证但不标记为 error - // LoadCodeAssist 失败可能是临时网络问题,给它机会在下次自动刷新时重试 if tokenInfo.ProjectIDMissing { - // 先更新凭证(token 本身刷新成功了) - _, updateErr := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{ + updatedAccount, updateErr := h.adminService.UpdateAccount(ctx, account.ID, &service.UpdateAccountInput{ Credentials: newCredentials, }) if updateErr != nil { - response.InternalError(c, "Failed to update credentials: "+updateErr.Error()) - return + return nil, "", fmt.Errorf("failed to update credentials: %w", updateErr) } - // 不标记为 error,只返回警告信息 - response.Success(c, gin.H{ - "message": "Token refreshed successfully, but project_id could not be retrieved (will retry automatically)", - "warning": "missing_project_id_temporary", - }) - return + return updatedAccount, "missing_project_id_temporary", nil } // 成功获取到 project_id,如果之前是 missing_project_id 错误则清除 if account.Status == service.StatusError && strings.Contains(account.ErrorMessage, "missing_project_id:") { - if _, clearErr := h.adminService.ClearAccountError(c.Request.Context(), accountID); clearErr != nil { - response.InternalError(c, "Failed to clear account error: "+clearErr.Error()) - return + if _, clearErr := h.adminService.ClearAccountError(ctx, account.ID); clearErr != nil { + return nil, "", fmt.Errorf("failed to clear account error: %w", clearErr) } } } else { // Use Anthropic/Claude OAuth service to refresh token - tokenInfo, err := h.oauthService.RefreshAccountToken(c.Request.Context(), account) + tokenInfo, err := h.oauthService.RefreshAccountToken(ctx, account) if err != nil { - response.ErrorFrom(c, err) - return + return nil, "", err } // Copy existing credentials to preserve non-token settings (e.g., intercept_warmup_requests) @@ -834,20 +851,51 @@ func (h *AccountHandler) Refresh(c *gin.Context) { } } - updatedAccount, err := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{ + updatedAccount, err := h.adminService.UpdateAccount(ctx, account.ID, &service.UpdateAccountInput{ Credentials: newCredentials, }) + if err != nil { + return nil, "", err + } + + // 刷新成功后,清除 token 缓存,确保下次请求使用新 token + if h.tokenCacheInvalidator != nil { + if invalidateErr := h.tokenCacheInvalidator.InvalidateToken(ctx, updatedAccount); invalidateErr != nil { + log.Printf("[WARN] Failed to invalidate token cache for account %d: %v", updatedAccount.ID, invalidateErr) + } + } + + return updatedAccount, "", nil +} + +// Refresh handles refreshing account credentials +// POST /api/v1/admin/accounts/:id/refresh +func (h *AccountHandler) Refresh(c *gin.Context) { + accountID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "Invalid account ID") + return + } + + // Get account + account, err := h.adminService.GetAccount(c.Request.Context(), accountID) + if err != nil { + response.NotFound(c, "Account not found") + return + } + + updatedAccount, warning, err := h.refreshSingleAccount(c.Request.Context(), account) if err != nil { response.ErrorFrom(c, err) return } - // 刷新成功后,清除 token 缓存,确保下次请求使用新 token - if h.tokenCacheInvalidator != nil { - if invalidateErr := h.tokenCacheInvalidator.InvalidateToken(c.Request.Context(), updatedAccount); invalidateErr != nil { - // 缓存失效失败只记录日志,不影响主流程 - _ = c.Error(invalidateErr) - } + if warning == "missing_project_id_temporary" { + response.Success(c, gin.H{ + "message": "Token refreshed successfully, but project_id could not be retrieved (will retry automatically)", + "warning": "missing_project_id_temporary", + }) + return } response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), updatedAccount)) @@ -903,14 +951,175 @@ func (h *AccountHandler) ClearError(c *gin.Context) { // 这解决了管理员重置账号状态后,旧的失效 token 仍在缓存中导致立即再次 401 的问题 if h.tokenCacheInvalidator != nil && account.IsOAuth() { if invalidateErr := h.tokenCacheInvalidator.InvalidateToken(c.Request.Context(), account); invalidateErr != nil { - // 缓存失效失败只记录日志,不影响主流程 - _ = c.Error(invalidateErr) + log.Printf("[WARN] Failed to invalidate token cache for account %d: %v", accountID, invalidateErr) } } response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account)) } +// BatchClearError handles batch clearing account errors +// POST /api/v1/admin/accounts/batch-clear-error +func (h *AccountHandler) BatchClearError(c *gin.Context) { + var req struct { + AccountIDs []int64 `json:"account_ids"` + } + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + if len(req.AccountIDs) == 0 { + response.BadRequest(c, "account_ids is required") + return + } + + ctx := c.Request.Context() + + const maxConcurrency = 10 + g, gctx := errgroup.WithContext(ctx) + g.SetLimit(maxConcurrency) + + var mu sync.Mutex + var successCount, failedCount int + var errors []gin.H + + // 注意:所有 goroutine 必须 return nil,避免 errgroup cancel 其他并发任务 + for _, id := range req.AccountIDs { + accountID := id // 闭包捕获 + g.Go(func() error { + account, err := h.adminService.ClearAccountError(gctx, accountID) + if err != nil { + mu.Lock() + failedCount++ + errors = append(errors, gin.H{ + "account_id": accountID, + "error": err.Error(), + }) + mu.Unlock() + return nil + } + + // 清除错误后,同时清除 token 缓存 + if h.tokenCacheInvalidator != nil && account.IsOAuth() { + if invalidateErr := h.tokenCacheInvalidator.InvalidateToken(gctx, account); invalidateErr != nil { + log.Printf("[WARN] Failed to invalidate token cache for account %d: %v", accountID, invalidateErr) + } + } + + mu.Lock() + successCount++ + mu.Unlock() + return nil + }) + } + + if err := g.Wait(); err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, gin.H{ + "total": len(req.AccountIDs), + "success": successCount, + "failed": failedCount, + "errors": errors, + }) +} + +// BatchRefresh handles batch refreshing account credentials +// POST /api/v1/admin/accounts/batch-refresh +func (h *AccountHandler) BatchRefresh(c *gin.Context) { + var req struct { + AccountIDs []int64 `json:"account_ids"` + } + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + if len(req.AccountIDs) == 0 { + response.BadRequest(c, "account_ids is required") + return + } + + ctx := c.Request.Context() + + accounts, err := h.adminService.GetAccountsByIDs(ctx, req.AccountIDs) + if err != nil { + response.ErrorFrom(c, err) + return + } + + // 建立已获取账号的 ID 集合,检测缺失的 ID + foundIDs := make(map[int64]bool, len(accounts)) + for _, acc := range accounts { + if acc != nil { + foundIDs[acc.ID] = true + } + } + + const maxConcurrency = 10 + g, gctx := errgroup.WithContext(ctx) + g.SetLimit(maxConcurrency) + + var mu sync.Mutex + var successCount, failedCount int + var errors []gin.H + var warnings []gin.H + + // 将不存在的账号 ID 标记为失败 + for _, id := range req.AccountIDs { + if !foundIDs[id] { + failedCount++ + errors = append(errors, gin.H{ + "account_id": id, + "error": "account not found", + }) + } + } + + // 注意:所有 goroutine 必须 return nil,避免 errgroup cancel 其他并发任务 + for _, account := range accounts { + acc := account // 闭包捕获 + if acc == nil { + continue + } + g.Go(func() error { + _, warning, err := h.refreshSingleAccount(gctx, acc) + mu.Lock() + if err != nil { + failedCount++ + errors = append(errors, gin.H{ + "account_id": acc.ID, + "error": err.Error(), + }) + } else { + successCount++ + if warning != "" { + warnings = append(warnings, gin.H{ + "account_id": acc.ID, + "warning": warning, + }) + } + } + mu.Unlock() + return nil + }) + } + + if err := g.Wait(); err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, gin.H{ + "total": len(req.AccountIDs), + "success": successCount, + "failed": failedCount, + "errors": errors, + "warnings": warnings, + }) +} + // BatchCreate handles batch creating accounts // POST /api/v1/admin/accounts/batch func (h *AccountHandler) BatchCreate(c *gin.Context) { @@ -1096,6 +1305,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) { req.Concurrency != nil || req.Priority != nil || req.RateMultiplier != nil || + req.LoadFactor != nil || req.Status != "" || req.Schedulable != nil || req.GroupIDs != nil || @@ -1114,6 +1324,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) { Concurrency: req.Concurrency, Priority: req.Priority, RateMultiplier: req.RateMultiplier, + LoadFactor: req.LoadFactor, Status: req.Status, Schedulable: req.Schedulable, GroupIDs: req.GroupIDs, @@ -1127,6 +1338,12 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) { c.JSON(409, gin.H{ "error": "mixed_channel_warning", "message": mixedErr.Error(), + "details": gin.H{ + "group_id": mixedErr.GroupID, + "group_name": mixedErr.GroupName, + "current_platform": mixedErr.CurrentPlatform, + "other_platform": mixedErr.OtherPlatform, + }, }) return } @@ -1323,6 +1540,29 @@ func (h *AccountHandler) ClearRateLimit(c *gin.Context) { response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account)) } +// ResetQuota handles resetting account quota usage +// POST /api/v1/admin/accounts/:id/reset-quota +func (h *AccountHandler) ResetQuota(c *gin.Context) { + accountID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "Invalid account ID") + return + } + + if err := h.adminService.ResetAccountQuota(c.Request.Context(), accountID); err != nil { + response.InternalError(c, "Failed to reset account quota: "+err.Error()) + return + } + + account, err := h.adminService.GetAccount(c.Request.Context(), accountID) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account)) +} + // GetTempUnschedulable handles getting temporary unschedulable status // GET /api/v1/admin/accounts/:id/temp-unschedulable func (h *AccountHandler) GetTempUnschedulable(c *gin.Context) { @@ -1398,18 +1638,41 @@ func (h *AccountHandler) GetBatchTodayStats(c *gin.Context) { return } - if len(req.AccountIDs) == 0 { + accountIDs := normalizeInt64IDList(req.AccountIDs) + if len(accountIDs) == 0 { response.Success(c, gin.H{"stats": map[string]any{}}) return } - stats, err := h.accountUsageService.GetTodayStatsBatch(c.Request.Context(), req.AccountIDs) + cacheKey := buildAccountTodayStatsBatchCacheKey(accountIDs) + if cached, ok := accountTodayStatsBatchCache.Get(cacheKey); ok { + if cached.ETag != "" { + c.Header("ETag", cached.ETag) + c.Header("Vary", "If-None-Match") + if ifNoneMatchMatched(c.GetHeader("If-None-Match"), cached.ETag) { + c.Status(http.StatusNotModified) + return + } + } + c.Header("X-Snapshot-Cache", "hit") + response.Success(c, cached.Payload) + return + } + + stats, err := h.accountUsageService.GetTodayStatsBatch(c.Request.Context(), accountIDs) if err != nil { response.ErrorFrom(c, err) return } - response.Success(c, gin.H{"stats": stats}) + payload := gin.H{"stats": stats} + cached := accountTodayStatsBatchCache.Set(cacheKey, payload) + if cached.ETag != "" { + c.Header("ETag", cached.ETag) + c.Header("Vary", "If-None-Match") + } + c.Header("X-Snapshot-Cache", "miss") + response.Success(c, payload) } // SetSchedulableRequest represents the request body for setting schedulable status diff --git a/backend/internal/handler/admin/account_handler_mixed_channel_test.go b/backend/internal/handler/admin/account_handler_mixed_channel_test.go index 24ec5bcf..5b81db2a 100644 --- a/backend/internal/handler/admin/account_handler_mixed_channel_test.go +++ b/backend/internal/handler/admin/account_handler_mixed_channel_test.go @@ -111,7 +111,7 @@ func TestAccountHandlerCreateMixedChannelConflictSimplifiedResponse(t *testing.T var resp map[string]any require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) require.Equal(t, "mixed_channel_warning", resp["error"]) - require.Contains(t, resp["message"], "mixed_channel_warning") + require.Contains(t, resp["message"], "claude-max") _, hasDetails := resp["details"] _, hasRequireConfirmation := resp["require_confirmation"] require.False(t, hasDetails) @@ -140,7 +140,7 @@ func TestAccountHandlerUpdateMixedChannelConflictSimplifiedResponse(t *testing.T var resp map[string]any require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) require.Equal(t, "mixed_channel_warning", resp["error"]) - require.Contains(t, resp["message"], "mixed_channel_warning") + require.Contains(t, resp["message"], "claude-max") _, hasDetails := resp["details"] _, hasRequireConfirmation := resp["require_confirmation"] require.False(t, hasDetails) diff --git a/backend/internal/handler/admin/account_today_stats_cache.go b/backend/internal/handler/admin/account_today_stats_cache.go new file mode 100644 index 00000000..61922f70 --- /dev/null +++ b/backend/internal/handler/admin/account_today_stats_cache.go @@ -0,0 +1,25 @@ +package admin + +import ( + "strconv" + "strings" + "time" +) + +var accountTodayStatsBatchCache = newSnapshotCache(30 * time.Second) + +func buildAccountTodayStatsBatchCacheKey(accountIDs []int64) string { + if len(accountIDs) == 0 { + return "accounts_today_stats_empty" + } + var b strings.Builder + b.Grow(len(accountIDs) * 6) + _, _ = b.WriteString("accounts_today_stats:") + for i, id := range accountIDs { + if i > 0 { + _ = b.WriteByte(',') + } + _, _ = b.WriteString(strconv.FormatInt(id, 10)) + } + return b.String() +} diff --git a/backend/internal/handler/admin/admin_service_stub_test.go b/backend/internal/handler/admin/admin_service_stub_test.go index f3b99ddb..b77a2b7f 100644 --- a/backend/internal/handler/admin/admin_service_stub_test.go +++ b/backend/internal/handler/admin/admin_service_stub_test.go @@ -175,6 +175,10 @@ func (s *stubAdminService) GetGroupAPIKeys(ctx context.Context, groupID int64, p return s.apiKeys, int64(len(s.apiKeys)), nil } +func (s *stubAdminService) GetGroupRateMultipliers(_ context.Context, _ int64) ([]service.UserGroupRateEntry, error) { + return nil, nil +} + func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]service.Account, int64, error) { return s.accounts, int64(len(s.accounts)), nil } @@ -425,5 +429,9 @@ func (s *stubAdminService) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID i return nil, service.ErrAPIKeyNotFound } +func (s *stubAdminService) ResetAccountQuota(ctx context.Context, id int64) error { + return nil +} + // Ensure stub implements interface. var _ service.AdminService = (*stubAdminService)(nil) diff --git a/backend/internal/handler/admin/announcement_handler.go b/backend/internal/handler/admin/announcement_handler.go index 0b5d0fbc..d1312bc0 100644 --- a/backend/internal/handler/admin/announcement_handler.go +++ b/backend/internal/handler/admin/announcement_handler.go @@ -27,21 +27,23 @@ func NewAnnouncementHandler(announcementService *service.AnnouncementService) *A } type CreateAnnouncementRequest struct { - Title string `json:"title" binding:"required"` - Content string `json:"content" binding:"required"` - Status string `json:"status" binding:"omitempty,oneof=draft active archived"` - Targeting service.AnnouncementTargeting `json:"targeting"` - StartsAt *int64 `json:"starts_at"` // Unix seconds, 0/empty = immediate - EndsAt *int64 `json:"ends_at"` // Unix seconds, 0/empty = never + Title string `json:"title" binding:"required"` + Content string `json:"content" binding:"required"` + Status string `json:"status" binding:"omitempty,oneof=draft active archived"` + NotifyMode string `json:"notify_mode" binding:"omitempty,oneof=silent popup"` + Targeting service.AnnouncementTargeting `json:"targeting"` + StartsAt *int64 `json:"starts_at"` // Unix seconds, 0/empty = immediate + EndsAt *int64 `json:"ends_at"` // Unix seconds, 0/empty = never } type UpdateAnnouncementRequest struct { - Title *string `json:"title"` - Content *string `json:"content"` - Status *string `json:"status" binding:"omitempty,oneof=draft active archived"` - Targeting *service.AnnouncementTargeting `json:"targeting"` - StartsAt *int64 `json:"starts_at"` // Unix seconds, 0 = clear - EndsAt *int64 `json:"ends_at"` // Unix seconds, 0 = clear + Title *string `json:"title"` + Content *string `json:"content"` + Status *string `json:"status" binding:"omitempty,oneof=draft active archived"` + NotifyMode *string `json:"notify_mode" binding:"omitempty,oneof=silent popup"` + Targeting *service.AnnouncementTargeting `json:"targeting"` + StartsAt *int64 `json:"starts_at"` // Unix seconds, 0 = clear + EndsAt *int64 `json:"ends_at"` // Unix seconds, 0 = clear } // List handles listing announcements with filters @@ -110,11 +112,12 @@ func (h *AnnouncementHandler) Create(c *gin.Context) { } input := &service.CreateAnnouncementInput{ - Title: req.Title, - Content: req.Content, - Status: req.Status, - Targeting: req.Targeting, - ActorID: &subject.UserID, + Title: req.Title, + Content: req.Content, + Status: req.Status, + NotifyMode: req.NotifyMode, + Targeting: req.Targeting, + ActorID: &subject.UserID, } if req.StartsAt != nil && *req.StartsAt > 0 { @@ -157,11 +160,12 @@ func (h *AnnouncementHandler) Update(c *gin.Context) { } input := &service.UpdateAnnouncementInput{ - Title: req.Title, - Content: req.Content, - Status: req.Status, - Targeting: req.Targeting, - ActorID: &subject.UserID, + Title: req.Title, + Content: req.Content, + Status: req.Status, + NotifyMode: req.NotifyMode, + Targeting: req.Targeting, + ActorID: &subject.UserID, } if req.StartsAt != nil { diff --git a/backend/internal/handler/admin/dashboard_handler.go b/backend/internal/handler/admin/dashboard_handler.go index 1d48c653..aa82b24f 100644 --- a/backend/internal/handler/admin/dashboard_handler.go +++ b/backend/internal/handler/admin/dashboard_handler.go @@ -1,6 +1,7 @@ package admin import ( + "encoding/json" "errors" "strconv" "strings" @@ -248,11 +249,12 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) { } } - trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType) + trend, hit, err := h.getUsageTrendCached(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType) if err != nil { response.Error(c, 500, "Failed to get usage trend") return } + c.Header("X-Snapshot-Cache", cacheStatusValue(hit)) response.Success(c, gin.H{ "trend": trend, @@ -320,11 +322,12 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) { } } - stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType) + stats, hit, err := h.getModelStatsCached(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType) if err != nil { response.Error(c, 500, "Failed to get model statistics") return } + c.Header("X-Snapshot-Cache", cacheStatusValue(hit)) response.Success(c, gin.H{ "models": stats, @@ -390,11 +393,12 @@ func (h *DashboardHandler) GetGroupStats(c *gin.Context) { } } - stats, err := h.dashboardService.GetGroupStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType) + stats, hit, err := h.getGroupStatsCached(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType) if err != nil { response.Error(c, 500, "Failed to get group statistics") return } + c.Header("X-Snapshot-Cache", cacheStatusValue(hit)) response.Success(c, gin.H{ "groups": stats, @@ -415,11 +419,12 @@ func (h *DashboardHandler) GetAPIKeyUsageTrend(c *gin.Context) { limit = 5 } - trend, err := h.dashboardService.GetAPIKeyUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit) + trend, hit, err := h.getAPIKeyUsageTrendCached(c.Request.Context(), startTime, endTime, granularity, limit) if err != nil { response.Error(c, 500, "Failed to get API key usage trend") return } + c.Header("X-Snapshot-Cache", cacheStatusValue(hit)) response.Success(c, gin.H{ "trend": trend, @@ -441,11 +446,12 @@ func (h *DashboardHandler) GetUserUsageTrend(c *gin.Context) { limit = 12 } - trend, err := h.dashboardService.GetUserUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit) + trend, hit, err := h.getUserUsageTrendCached(c.Request.Context(), startTime, endTime, granularity, limit) if err != nil { response.Error(c, 500, "Failed to get user usage trend") return } + c.Header("X-Snapshot-Cache", cacheStatusValue(hit)) response.Success(c, gin.H{ "trend": trend, @@ -460,6 +466,9 @@ type BatchUsersUsageRequest struct { UserIDs []int64 `json:"user_ids" binding:"required"` } +var dashboardBatchUsersUsageCache = newSnapshotCache(30 * time.Second) +var dashboardBatchAPIKeysUsageCache = newSnapshotCache(30 * time.Second) + // GetBatchUsersUsage handles getting usage stats for multiple users // POST /api/v1/admin/dashboard/users-usage func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) { @@ -469,18 +478,34 @@ func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) { return } - if len(req.UserIDs) == 0 { + userIDs := normalizeInt64IDList(req.UserIDs) + if len(userIDs) == 0 { response.Success(c, gin.H{"stats": map[string]any{}}) return } - stats, err := h.dashboardService.GetBatchUserUsageStats(c.Request.Context(), req.UserIDs, time.Time{}, time.Time{}) + keyRaw, _ := json.Marshal(struct { + UserIDs []int64 `json:"user_ids"` + }{ + UserIDs: userIDs, + }) + cacheKey := string(keyRaw) + if cached, ok := dashboardBatchUsersUsageCache.Get(cacheKey); ok { + c.Header("X-Snapshot-Cache", "hit") + response.Success(c, cached.Payload) + return + } + + stats, err := h.dashboardService.GetBatchUserUsageStats(c.Request.Context(), userIDs, time.Time{}, time.Time{}) if err != nil { response.Error(c, 500, "Failed to get user usage stats") return } - response.Success(c, gin.H{"stats": stats}) + payload := gin.H{"stats": stats} + dashboardBatchUsersUsageCache.Set(cacheKey, payload) + c.Header("X-Snapshot-Cache", "miss") + response.Success(c, payload) } // BatchAPIKeysUsageRequest represents the request body for batch api key usage stats @@ -497,16 +522,32 @@ func (h *DashboardHandler) GetBatchAPIKeysUsage(c *gin.Context) { return } - if len(req.APIKeyIDs) == 0 { + apiKeyIDs := normalizeInt64IDList(req.APIKeyIDs) + if len(apiKeyIDs) == 0 { response.Success(c, gin.H{"stats": map[string]any{}}) return } - stats, err := h.dashboardService.GetBatchAPIKeyUsageStats(c.Request.Context(), req.APIKeyIDs, time.Time{}, time.Time{}) + keyRaw, _ := json.Marshal(struct { + APIKeyIDs []int64 `json:"api_key_ids"` + }{ + APIKeyIDs: apiKeyIDs, + }) + cacheKey := string(keyRaw) + if cached, ok := dashboardBatchAPIKeysUsageCache.Get(cacheKey); ok { + c.Header("X-Snapshot-Cache", "hit") + response.Success(c, cached.Payload) + return + } + + stats, err := h.dashboardService.GetBatchAPIKeyUsageStats(c.Request.Context(), apiKeyIDs, time.Time{}, time.Time{}) if err != nil { response.Error(c, 500, "Failed to get API key usage stats") return } - response.Success(c, gin.H{"stats": stats}) + payload := gin.H{"stats": stats} + dashboardBatchAPIKeysUsageCache.Set(cacheKey, payload) + c.Header("X-Snapshot-Cache", "miss") + response.Success(c, payload) } diff --git a/backend/internal/handler/admin/dashboard_handler_cache_test.go b/backend/internal/handler/admin/dashboard_handler_cache_test.go new file mode 100644 index 00000000..ec888849 --- /dev/null +++ b/backend/internal/handler/admin/dashboard_handler_cache_test.go @@ -0,0 +1,118 @@ +package admin + +import ( + "context" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type dashboardUsageRepoCacheProbe struct { + service.UsageLogRepository + trendCalls atomic.Int32 + usersTrendCalls atomic.Int32 +} + +func (r *dashboardUsageRepoCacheProbe) GetUsageTrendWithFilters( + ctx context.Context, + startTime, endTime time.Time, + granularity string, + userID, apiKeyID, accountID, groupID int64, + model string, + requestType *int16, + stream *bool, + billingType *int8, +) ([]usagestats.TrendDataPoint, error) { + r.trendCalls.Add(1) + return []usagestats.TrendDataPoint{{ + Date: "2026-03-11", + Requests: 1, + TotalTokens: 2, + Cost: 3, + ActualCost: 4, + }}, nil +} + +func (r *dashboardUsageRepoCacheProbe) GetUserUsageTrend( + ctx context.Context, + startTime, endTime time.Time, + granularity string, + limit int, +) ([]usagestats.UserUsageTrendPoint, error) { + r.usersTrendCalls.Add(1) + return []usagestats.UserUsageTrendPoint{{ + Date: "2026-03-11", + UserID: 1, + Email: "cache@test.dev", + Requests: 2, + Tokens: 20, + Cost: 2, + ActualCost: 1, + }}, nil +} + +func resetDashboardReadCachesForTest() { + dashboardTrendCache = newSnapshotCache(30 * time.Second) + dashboardUsersTrendCache = newSnapshotCache(30 * time.Second) + dashboardAPIKeysTrendCache = newSnapshotCache(30 * time.Second) + dashboardModelStatsCache = newSnapshotCache(30 * time.Second) + dashboardGroupStatsCache = newSnapshotCache(30 * time.Second) + dashboardSnapshotV2Cache = newSnapshotCache(30 * time.Second) +} + +func TestDashboardHandler_GetUsageTrend_UsesCache(t *testing.T) { + t.Cleanup(resetDashboardReadCachesForTest) + resetDashboardReadCachesForTest() + + gin.SetMode(gin.TestMode) + repo := &dashboardUsageRepoCacheProbe{} + dashboardSvc := service.NewDashboardService(repo, nil, nil, nil) + handler := NewDashboardHandler(dashboardSvc, nil) + router := gin.New() + router.GET("/admin/dashboard/trend", handler.GetUsageTrend) + + req1 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/trend?start_date=2026-03-01&end_date=2026-03-07&granularity=day", nil) + rec1 := httptest.NewRecorder() + router.ServeHTTP(rec1, req1) + require.Equal(t, http.StatusOK, rec1.Code) + require.Equal(t, "miss", rec1.Header().Get("X-Snapshot-Cache")) + + req2 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/trend?start_date=2026-03-01&end_date=2026-03-07&granularity=day", nil) + rec2 := httptest.NewRecorder() + router.ServeHTTP(rec2, req2) + require.Equal(t, http.StatusOK, rec2.Code) + require.Equal(t, "hit", rec2.Header().Get("X-Snapshot-Cache")) + require.Equal(t, int32(1), repo.trendCalls.Load()) +} + +func TestDashboardHandler_GetUserUsageTrend_UsesCache(t *testing.T) { + t.Cleanup(resetDashboardReadCachesForTest) + resetDashboardReadCachesForTest() + + gin.SetMode(gin.TestMode) + repo := &dashboardUsageRepoCacheProbe{} + dashboardSvc := service.NewDashboardService(repo, nil, nil, nil) + handler := NewDashboardHandler(dashboardSvc, nil) + router := gin.New() + router.GET("/admin/dashboard/users-trend", handler.GetUserUsageTrend) + + req1 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/users-trend?start_date=2026-03-01&end_date=2026-03-07&granularity=day&limit=8", nil) + rec1 := httptest.NewRecorder() + router.ServeHTTP(rec1, req1) + require.Equal(t, http.StatusOK, rec1.Code) + require.Equal(t, "miss", rec1.Header().Get("X-Snapshot-Cache")) + + req2 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/users-trend?start_date=2026-03-01&end_date=2026-03-07&granularity=day&limit=8", nil) + rec2 := httptest.NewRecorder() + router.ServeHTTP(rec2, req2) + require.Equal(t, http.StatusOK, rec2.Code) + require.Equal(t, "hit", rec2.Header().Get("X-Snapshot-Cache")) + require.Equal(t, int32(1), repo.usersTrendCalls.Load()) +} diff --git a/backend/internal/handler/admin/dashboard_query_cache.go b/backend/internal/handler/admin/dashboard_query_cache.go new file mode 100644 index 00000000..47af5117 --- /dev/null +++ b/backend/internal/handler/admin/dashboard_query_cache.go @@ -0,0 +1,200 @@ +package admin + +import ( + "context" + "encoding/json" + "fmt" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" +) + +var ( + dashboardTrendCache = newSnapshotCache(30 * time.Second) + dashboardModelStatsCache = newSnapshotCache(30 * time.Second) + dashboardGroupStatsCache = newSnapshotCache(30 * time.Second) + dashboardUsersTrendCache = newSnapshotCache(30 * time.Second) + dashboardAPIKeysTrendCache = newSnapshotCache(30 * time.Second) +) + +type dashboardTrendCacheKey struct { + StartTime string `json:"start_time"` + EndTime string `json:"end_time"` + Granularity string `json:"granularity"` + UserID int64 `json:"user_id"` + APIKeyID int64 `json:"api_key_id"` + AccountID int64 `json:"account_id"` + GroupID int64 `json:"group_id"` + Model string `json:"model"` + RequestType *int16 `json:"request_type"` + Stream *bool `json:"stream"` + BillingType *int8 `json:"billing_type"` +} + +type dashboardModelGroupCacheKey struct { + StartTime string `json:"start_time"` + EndTime string `json:"end_time"` + UserID int64 `json:"user_id"` + APIKeyID int64 `json:"api_key_id"` + AccountID int64 `json:"account_id"` + GroupID int64 `json:"group_id"` + RequestType *int16 `json:"request_type"` + Stream *bool `json:"stream"` + BillingType *int8 `json:"billing_type"` +} + +type dashboardEntityTrendCacheKey struct { + StartTime string `json:"start_time"` + EndTime string `json:"end_time"` + Granularity string `json:"granularity"` + Limit int `json:"limit"` +} + +func cacheStatusValue(hit bool) string { + if hit { + return "hit" + } + return "miss" +} + +func mustMarshalDashboardCacheKey(value any) string { + raw, err := json.Marshal(value) + if err != nil { + return "" + } + return string(raw) +} + +func snapshotPayloadAs[T any](payload any) (T, error) { + typed, ok := payload.(T) + if !ok { + var zero T + return zero, fmt.Errorf("unexpected cache payload type %T", payload) + } + return typed, nil +} + +func (h *DashboardHandler) getUsageTrendCached( + ctx context.Context, + startTime, endTime time.Time, + granularity string, + userID, apiKeyID, accountID, groupID int64, + model string, + requestType *int16, + stream *bool, + billingType *int8, +) ([]usagestats.TrendDataPoint, bool, error) { + key := mustMarshalDashboardCacheKey(dashboardTrendCacheKey{ + StartTime: startTime.UTC().Format(time.RFC3339), + EndTime: endTime.UTC().Format(time.RFC3339), + Granularity: granularity, + UserID: userID, + APIKeyID: apiKeyID, + AccountID: accountID, + GroupID: groupID, + Model: model, + RequestType: requestType, + Stream: stream, + BillingType: billingType, + }) + entry, hit, err := dashboardTrendCache.GetOrLoad(key, func() (any, error) { + return h.dashboardService.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType) + }) + if err != nil { + return nil, hit, err + } + trend, err := snapshotPayloadAs[[]usagestats.TrendDataPoint](entry.Payload) + return trend, hit, err +} + +func (h *DashboardHandler) getModelStatsCached( + ctx context.Context, + startTime, endTime time.Time, + userID, apiKeyID, accountID, groupID int64, + requestType *int16, + stream *bool, + billingType *int8, +) ([]usagestats.ModelStat, bool, error) { + key := mustMarshalDashboardCacheKey(dashboardModelGroupCacheKey{ + StartTime: startTime.UTC().Format(time.RFC3339), + EndTime: endTime.UTC().Format(time.RFC3339), + UserID: userID, + APIKeyID: apiKeyID, + AccountID: accountID, + GroupID: groupID, + RequestType: requestType, + Stream: stream, + BillingType: billingType, + }) + entry, hit, err := dashboardModelStatsCache.GetOrLoad(key, func() (any, error) { + return h.dashboardService.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType) + }) + if err != nil { + return nil, hit, err + } + stats, err := snapshotPayloadAs[[]usagestats.ModelStat](entry.Payload) + return stats, hit, err +} + +func (h *DashboardHandler) getGroupStatsCached( + ctx context.Context, + startTime, endTime time.Time, + userID, apiKeyID, accountID, groupID int64, + requestType *int16, + stream *bool, + billingType *int8, +) ([]usagestats.GroupStat, bool, error) { + key := mustMarshalDashboardCacheKey(dashboardModelGroupCacheKey{ + StartTime: startTime.UTC().Format(time.RFC3339), + EndTime: endTime.UTC().Format(time.RFC3339), + UserID: userID, + APIKeyID: apiKeyID, + AccountID: accountID, + GroupID: groupID, + RequestType: requestType, + Stream: stream, + BillingType: billingType, + }) + entry, hit, err := dashboardGroupStatsCache.GetOrLoad(key, func() (any, error) { + return h.dashboardService.GetGroupStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType) + }) + if err != nil { + return nil, hit, err + } + stats, err := snapshotPayloadAs[[]usagestats.GroupStat](entry.Payload) + return stats, hit, err +} + +func (h *DashboardHandler) getAPIKeyUsageTrendCached(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, bool, error) { + key := mustMarshalDashboardCacheKey(dashboardEntityTrendCacheKey{ + StartTime: startTime.UTC().Format(time.RFC3339), + EndTime: endTime.UTC().Format(time.RFC3339), + Granularity: granularity, + Limit: limit, + }) + entry, hit, err := dashboardAPIKeysTrendCache.GetOrLoad(key, func() (any, error) { + return h.dashboardService.GetAPIKeyUsageTrend(ctx, startTime, endTime, granularity, limit) + }) + if err != nil { + return nil, hit, err + } + trend, err := snapshotPayloadAs[[]usagestats.APIKeyUsageTrendPoint](entry.Payload) + return trend, hit, err +} + +func (h *DashboardHandler) getUserUsageTrendCached(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, bool, error) { + key := mustMarshalDashboardCacheKey(dashboardEntityTrendCacheKey{ + StartTime: startTime.UTC().Format(time.RFC3339), + EndTime: endTime.UTC().Format(time.RFC3339), + Granularity: granularity, + Limit: limit, + }) + entry, hit, err := dashboardUsersTrendCache.GetOrLoad(key, func() (any, error) { + return h.dashboardService.GetUserUsageTrend(ctx, startTime, endTime, granularity, limit) + }) + if err != nil { + return nil, hit, err + } + trend, err := snapshotPayloadAs[[]usagestats.UserUsageTrendPoint](entry.Payload) + return trend, hit, err +} diff --git a/backend/internal/handler/admin/dashboard_snapshot_v2_handler.go b/backend/internal/handler/admin/dashboard_snapshot_v2_handler.go new file mode 100644 index 00000000..16e10339 --- /dev/null +++ b/backend/internal/handler/admin/dashboard_snapshot_v2_handler.go @@ -0,0 +1,302 @@ +package admin + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "strconv" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" +) + +var dashboardSnapshotV2Cache = newSnapshotCache(30 * time.Second) + +type dashboardSnapshotV2Stats struct { + usagestats.DashboardStats + Uptime int64 `json:"uptime"` +} + +type dashboardSnapshotV2Response struct { + GeneratedAt string `json:"generated_at"` + + StartDate string `json:"start_date"` + EndDate string `json:"end_date"` + Granularity string `json:"granularity"` + + Stats *dashboardSnapshotV2Stats `json:"stats,omitempty"` + Trend []usagestats.TrendDataPoint `json:"trend,omitempty"` + Models []usagestats.ModelStat `json:"models,omitempty"` + Groups []usagestats.GroupStat `json:"groups,omitempty"` + UsersTrend []usagestats.UserUsageTrendPoint `json:"users_trend,omitempty"` +} + +type dashboardSnapshotV2Filters struct { + UserID int64 + APIKeyID int64 + AccountID int64 + GroupID int64 + Model string + RequestType *int16 + Stream *bool + BillingType *int8 +} + +type dashboardSnapshotV2CacheKey struct { + StartTime string `json:"start_time"` + EndTime string `json:"end_time"` + Granularity string `json:"granularity"` + UserID int64 `json:"user_id"` + APIKeyID int64 `json:"api_key_id"` + AccountID int64 `json:"account_id"` + GroupID int64 `json:"group_id"` + Model string `json:"model"` + RequestType *int16 `json:"request_type"` + Stream *bool `json:"stream"` + BillingType *int8 `json:"billing_type"` + IncludeStats bool `json:"include_stats"` + IncludeTrend bool `json:"include_trend"` + IncludeModels bool `json:"include_models"` + IncludeGroups bool `json:"include_groups"` + IncludeUsersTrend bool `json:"include_users_trend"` + UsersTrendLimit int `json:"users_trend_limit"` +} + +func (h *DashboardHandler) GetSnapshotV2(c *gin.Context) { + startTime, endTime := parseTimeRange(c) + granularity := strings.TrimSpace(c.DefaultQuery("granularity", "day")) + if granularity != "hour" { + granularity = "day" + } + + includeStats := parseBoolQueryWithDefault(c.Query("include_stats"), true) + includeTrend := parseBoolQueryWithDefault(c.Query("include_trend"), true) + includeModels := parseBoolQueryWithDefault(c.Query("include_model_stats"), true) + includeGroups := parseBoolQueryWithDefault(c.Query("include_group_stats"), false) + includeUsersTrend := parseBoolQueryWithDefault(c.Query("include_users_trend"), false) + usersTrendLimit := 12 + if raw := strings.TrimSpace(c.Query("users_trend_limit")); raw != "" { + if parsed, err := strconv.Atoi(raw); err == nil && parsed > 0 && parsed <= 50 { + usersTrendLimit = parsed + } + } + + filters, err := parseDashboardSnapshotV2Filters(c) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + + keyRaw, _ := json.Marshal(dashboardSnapshotV2CacheKey{ + StartTime: startTime.UTC().Format(time.RFC3339), + EndTime: endTime.UTC().Format(time.RFC3339), + Granularity: granularity, + UserID: filters.UserID, + APIKeyID: filters.APIKeyID, + AccountID: filters.AccountID, + GroupID: filters.GroupID, + Model: filters.Model, + RequestType: filters.RequestType, + Stream: filters.Stream, + BillingType: filters.BillingType, + IncludeStats: includeStats, + IncludeTrend: includeTrend, + IncludeModels: includeModels, + IncludeGroups: includeGroups, + IncludeUsersTrend: includeUsersTrend, + UsersTrendLimit: usersTrendLimit, + }) + cacheKey := string(keyRaw) + + cached, hit, err := dashboardSnapshotV2Cache.GetOrLoad(cacheKey, func() (any, error) { + return h.buildSnapshotV2Response( + c.Request.Context(), + startTime, + endTime, + granularity, + filters, + includeStats, + includeTrend, + includeModels, + includeGroups, + includeUsersTrend, + usersTrendLimit, + ) + }) + if err != nil { + response.Error(c, 500, err.Error()) + return + } + if cached.ETag != "" { + c.Header("ETag", cached.ETag) + c.Header("Vary", "If-None-Match") + if ifNoneMatchMatched(c.GetHeader("If-None-Match"), cached.ETag) { + c.Status(http.StatusNotModified) + return + } + } + c.Header("X-Snapshot-Cache", cacheStatusValue(hit)) + response.Success(c, cached.Payload) +} + +func (h *DashboardHandler) buildSnapshotV2Response( + ctx context.Context, + startTime, endTime time.Time, + granularity string, + filters *dashboardSnapshotV2Filters, + includeStats, includeTrend, includeModels, includeGroups, includeUsersTrend bool, + usersTrendLimit int, +) (*dashboardSnapshotV2Response, error) { + resp := &dashboardSnapshotV2Response{ + GeneratedAt: time.Now().UTC().Format(time.RFC3339), + StartDate: startTime.Format("2006-01-02"), + EndDate: endTime.Add(-24 * time.Hour).Format("2006-01-02"), + Granularity: granularity, + } + + if includeStats { + stats, err := h.dashboardService.GetDashboardStats(ctx) + if err != nil { + return nil, errors.New("failed to get dashboard statistics") + } + resp.Stats = &dashboardSnapshotV2Stats{ + DashboardStats: *stats, + Uptime: int64(time.Since(h.startTime).Seconds()), + } + } + + if includeTrend { + trend, _, err := h.getUsageTrendCached( + ctx, + startTime, + endTime, + granularity, + filters.UserID, + filters.APIKeyID, + filters.AccountID, + filters.GroupID, + filters.Model, + filters.RequestType, + filters.Stream, + filters.BillingType, + ) + if err != nil { + return nil, errors.New("failed to get usage trend") + } + resp.Trend = trend + } + + if includeModels { + models, _, err := h.getModelStatsCached( + ctx, + startTime, + endTime, + filters.UserID, + filters.APIKeyID, + filters.AccountID, + filters.GroupID, + filters.RequestType, + filters.Stream, + filters.BillingType, + ) + if err != nil { + return nil, errors.New("failed to get model statistics") + } + resp.Models = models + } + + if includeGroups { + groups, _, err := h.getGroupStatsCached( + ctx, + startTime, + endTime, + filters.UserID, + filters.APIKeyID, + filters.AccountID, + filters.GroupID, + filters.RequestType, + filters.Stream, + filters.BillingType, + ) + if err != nil { + return nil, errors.New("failed to get group statistics") + } + resp.Groups = groups + } + + if includeUsersTrend { + usersTrend, _, err := h.getUserUsageTrendCached(ctx, startTime, endTime, granularity, usersTrendLimit) + if err != nil { + return nil, errors.New("failed to get user usage trend") + } + resp.UsersTrend = usersTrend + } + + return resp, nil +} + +func parseDashboardSnapshotV2Filters(c *gin.Context) (*dashboardSnapshotV2Filters, error) { + filters := &dashboardSnapshotV2Filters{ + Model: strings.TrimSpace(c.Query("model")), + } + + if userIDStr := strings.TrimSpace(c.Query("user_id")); userIDStr != "" { + id, err := strconv.ParseInt(userIDStr, 10, 64) + if err != nil { + return nil, err + } + filters.UserID = id + } + if apiKeyIDStr := strings.TrimSpace(c.Query("api_key_id")); apiKeyIDStr != "" { + id, err := strconv.ParseInt(apiKeyIDStr, 10, 64) + if err != nil { + return nil, err + } + filters.APIKeyID = id + } + if accountIDStr := strings.TrimSpace(c.Query("account_id")); accountIDStr != "" { + id, err := strconv.ParseInt(accountIDStr, 10, 64) + if err != nil { + return nil, err + } + filters.AccountID = id + } + if groupIDStr := strings.TrimSpace(c.Query("group_id")); groupIDStr != "" { + id, err := strconv.ParseInt(groupIDStr, 10, 64) + if err != nil { + return nil, err + } + filters.GroupID = id + } + + if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" { + parsed, err := service.ParseUsageRequestType(requestTypeStr) + if err != nil { + return nil, err + } + value := int16(parsed) + filters.RequestType = &value + } else if streamStr := strings.TrimSpace(c.Query("stream")); streamStr != "" { + streamVal, err := strconv.ParseBool(streamStr) + if err != nil { + return nil, err + } + filters.Stream = &streamVal + } + + if billingTypeStr := strings.TrimSpace(c.Query("billing_type")); billingTypeStr != "" { + v, err := strconv.ParseInt(billingTypeStr, 10, 8) + if err != nil { + return nil, err + } + bt := int8(v) + filters.BillingType = &bt + } + + return filters, nil +} diff --git a/backend/internal/handler/admin/group_handler.go b/backend/internal/handler/admin/group_handler.go index 1edf4dcc..34c94f2a 100644 --- a/backend/internal/handler/admin/group_handler.go +++ b/backend/internal/handler/admin/group_handler.go @@ -46,13 +46,17 @@ type CreateGroupRequest struct { FallbackGroupID *int64 `json:"fallback_group_id"` FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"` // 模型路由配置(仅 anthropic 平台使用) - ModelRouting map[string][]int64 `json:"model_routing"` - ModelRoutingEnabled bool `json:"model_routing_enabled"` - MCPXMLInject *bool `json:"mcp_xml_inject"` + ModelRouting map[string][]int64 `json:"model_routing"` + ModelRoutingEnabled bool `json:"model_routing_enabled"` + MCPXMLInject *bool `json:"mcp_xml_inject"` + SimulateClaudeMaxEnabled *bool `json:"simulate_claude_max_enabled"` // 支持的模型系列(仅 antigravity 平台使用) SupportedModelScopes []string `json:"supported_model_scopes"` // Sora 存储配额 SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"` + // OpenAI Messages 调度配置(仅 openai 平台使用) + AllowMessagesDispatch bool `json:"allow_messages_dispatch"` + DefaultMappedModel string `json:"default_mapped_model"` // 从指定分组复制账号(创建后自动绑定) CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"` } @@ -81,13 +85,17 @@ type UpdateGroupRequest struct { FallbackGroupID *int64 `json:"fallback_group_id"` FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"` // 模型路由配置(仅 anthropic 平台使用) - ModelRouting map[string][]int64 `json:"model_routing"` - ModelRoutingEnabled *bool `json:"model_routing_enabled"` - MCPXMLInject *bool `json:"mcp_xml_inject"` + ModelRouting map[string][]int64 `json:"model_routing"` + ModelRoutingEnabled *bool `json:"model_routing_enabled"` + MCPXMLInject *bool `json:"mcp_xml_inject"` + SimulateClaudeMaxEnabled *bool `json:"simulate_claude_max_enabled"` // 支持的模型系列(仅 antigravity 平台使用) SupportedModelScopes *[]string `json:"supported_model_scopes"` // Sora 存储配额 SoraStorageQuotaBytes *int64 `json:"sora_storage_quota_bytes"` + // OpenAI Messages 调度配置(仅 openai 平台使用) + AllowMessagesDispatch *bool `json:"allow_messages_dispatch"` + DefaultMappedModel *string `json:"default_mapped_model"` // 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号) CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"` } @@ -201,8 +209,11 @@ func (h *GroupHandler) Create(c *gin.Context) { ModelRouting: req.ModelRouting, ModelRoutingEnabled: req.ModelRoutingEnabled, MCPXMLInject: req.MCPXMLInject, + SimulateClaudeMaxEnabled: req.SimulateClaudeMaxEnabled, SupportedModelScopes: req.SupportedModelScopes, SoraStorageQuotaBytes: req.SoraStorageQuotaBytes, + AllowMessagesDispatch: req.AllowMessagesDispatch, + DefaultMappedModel: req.DefaultMappedModel, CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs, }) if err != nil { @@ -252,8 +263,11 @@ func (h *GroupHandler) Update(c *gin.Context) { ModelRouting: req.ModelRouting, ModelRoutingEnabled: req.ModelRoutingEnabled, MCPXMLInject: req.MCPXMLInject, + SimulateClaudeMaxEnabled: req.SimulateClaudeMaxEnabled, SupportedModelScopes: req.SupportedModelScopes, SoraStorageQuotaBytes: req.SoraStorageQuotaBytes, + AllowMessagesDispatch: req.AllowMessagesDispatch, + DefaultMappedModel: req.DefaultMappedModel, CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs, }) if err != nil { @@ -325,6 +339,27 @@ func (h *GroupHandler) GetGroupAPIKeys(c *gin.Context) { response.Paginated(c, outKeys, total, page, pageSize) } +// GetGroupRateMultipliers handles getting rate multipliers for users in a group +// GET /api/v1/admin/groups/:id/rate-multipliers +func (h *GroupHandler) GetGroupRateMultipliers(c *gin.Context) { + groupID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "Invalid group ID") + return + } + + entries, err := h.adminService.GetGroupRateMultipliers(c.Request.Context(), groupID) + if err != nil { + response.ErrorFrom(c, err) + return + } + + if entries == nil { + entries = []service.UserGroupRateEntry{} + } + response.Success(c, entries) +} + // UpdateSortOrderRequest represents the request to update group sort orders type UpdateSortOrderRequest struct { Updates []struct { diff --git a/backend/internal/handler/admin/id_list_utils.go b/backend/internal/handler/admin/id_list_utils.go new file mode 100644 index 00000000..2aeefe38 --- /dev/null +++ b/backend/internal/handler/admin/id_list_utils.go @@ -0,0 +1,25 @@ +package admin + +import "sort" + +func normalizeInt64IDList(ids []int64) []int64 { + if len(ids) == 0 { + return nil + } + + out := make([]int64, 0, len(ids)) + seen := make(map[int64]struct{}, len(ids)) + for _, id := range ids { + if id <= 0 { + continue + } + if _, ok := seen[id]; ok { + continue + } + seen[id] = struct{}{} + out = append(out, id) + } + + sort.Slice(out, func(i, j int) bool { return out[i] < out[j] }) + return out +} diff --git a/backend/internal/handler/admin/id_list_utils_test.go b/backend/internal/handler/admin/id_list_utils_test.go new file mode 100644 index 00000000..aa65d5c0 --- /dev/null +++ b/backend/internal/handler/admin/id_list_utils_test.go @@ -0,0 +1,57 @@ +//go:build unit + +package admin + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestNormalizeInt64IDList(t *testing.T) { + tests := []struct { + name string + in []int64 + want []int64 + }{ + {"nil input", nil, nil}, + {"empty input", []int64{}, nil}, + {"single element", []int64{5}, []int64{5}}, + {"already sorted unique", []int64{1, 2, 3}, []int64{1, 2, 3}}, + {"duplicates removed", []int64{3, 1, 3, 2, 1}, []int64{1, 2, 3}}, + {"zero filtered", []int64{0, 1, 2}, []int64{1, 2}}, + {"negative filtered", []int64{-5, -1, 3}, []int64{3}}, + {"all invalid", []int64{0, -1, -2}, []int64{}}, + {"sorted output", []int64{9, 3, 7, 1}, []int64{1, 3, 7, 9}}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := normalizeInt64IDList(tc.in) + if tc.want == nil { + require.Nil(t, got) + } else { + require.Equal(t, tc.want, got) + } + }) + } +} + +func TestBuildAccountTodayStatsBatchCacheKey(t *testing.T) { + tests := []struct { + name string + ids []int64 + want string + }{ + {"empty", nil, "accounts_today_stats_empty"}, + {"single", []int64{42}, "accounts_today_stats:42"}, + {"multiple", []int64{1, 2, 3}, "accounts_today_stats:1,2,3"}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := buildAccountTodayStatsBatchCacheKey(tc.ids) + require.Equal(t, tc.want, got) + }) + } +} diff --git a/backend/internal/handler/admin/ops_alerts_handler.go b/backend/internal/handler/admin/ops_alerts_handler.go index c9da19c7..edc8c7f7 100644 --- a/backend/internal/handler/admin/ops_alerts_handler.go +++ b/backend/internal/handler/admin/ops_alerts_handler.go @@ -23,6 +23,13 @@ var validOpsAlertMetricTypes = []string{ "cpu_usage_percent", "memory_usage_percent", "concurrency_queue_depth", + "group_available_accounts", + "group_available_ratio", + "group_rate_limit_ratio", + "account_rate_limited_count", + "account_error_count", + "account_error_ratio", + "overload_account_count", } var validOpsAlertMetricTypeSet = func() map[string]struct{} { @@ -82,7 +89,10 @@ func isPercentOrRateMetric(metricType string) bool { "error_rate", "upstream_error_rate", "cpu_usage_percent", - "memory_usage_percent": + "memory_usage_percent", + "group_available_ratio", + "group_rate_limit_ratio", + "account_error_ratio": return true default: return false diff --git a/backend/internal/handler/admin/ops_snapshot_v2_handler.go b/backend/internal/handler/admin/ops_snapshot_v2_handler.go new file mode 100644 index 00000000..5cac00fe --- /dev/null +++ b/backend/internal/handler/admin/ops_snapshot_v2_handler.go @@ -0,0 +1,145 @@ +package admin + +import ( + "encoding/json" + "net/http" + "strconv" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "golang.org/x/sync/errgroup" +) + +var opsDashboardSnapshotV2Cache = newSnapshotCache(30 * time.Second) + +type opsDashboardSnapshotV2Response struct { + GeneratedAt string `json:"generated_at"` + + Overview *service.OpsDashboardOverview `json:"overview"` + ThroughputTrend *service.OpsThroughputTrendResponse `json:"throughput_trend"` + ErrorTrend *service.OpsErrorTrendResponse `json:"error_trend"` +} + +type opsDashboardSnapshotV2CacheKey struct { + StartTime string `json:"start_time"` + EndTime string `json:"end_time"` + Platform string `json:"platform"` + GroupID *int64 `json:"group_id"` + QueryMode service.OpsQueryMode `json:"mode"` + BucketSecond int `json:"bucket_second"` +} + +// GetDashboardSnapshotV2 returns ops dashboard core snapshot in one request. +// GET /api/v1/admin/ops/dashboard/snapshot-v2 +func (h *OpsHandler) GetDashboardSnapshotV2(c *gin.Context) { + if h.opsService == nil { + response.Error(c, http.StatusServiceUnavailable, "Ops service not available") + return + } + if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil { + response.ErrorFrom(c, err) + return + } + + startTime, endTime, err := parseOpsTimeRange(c, "1h") + if err != nil { + response.BadRequest(c, err.Error()) + return + } + + filter := &service.OpsDashboardFilter{ + StartTime: startTime, + EndTime: endTime, + Platform: strings.TrimSpace(c.Query("platform")), + QueryMode: parseOpsQueryMode(c), + } + if v := strings.TrimSpace(c.Query("group_id")); v != "" { + id, err := strconv.ParseInt(v, 10, 64) + if err != nil || id <= 0 { + response.BadRequest(c, "Invalid group_id") + return + } + filter.GroupID = &id + } + bucketSeconds := pickThroughputBucketSeconds(endTime.Sub(startTime)) + + keyRaw, _ := json.Marshal(opsDashboardSnapshotV2CacheKey{ + StartTime: startTime.UTC().Format(time.RFC3339), + EndTime: endTime.UTC().Format(time.RFC3339), + Platform: filter.Platform, + GroupID: filter.GroupID, + QueryMode: filter.QueryMode, + BucketSecond: bucketSeconds, + }) + cacheKey := string(keyRaw) + + if cached, ok := opsDashboardSnapshotV2Cache.Get(cacheKey); ok { + if cached.ETag != "" { + c.Header("ETag", cached.ETag) + c.Header("Vary", "If-None-Match") + if ifNoneMatchMatched(c.GetHeader("If-None-Match"), cached.ETag) { + c.Status(http.StatusNotModified) + return + } + } + c.Header("X-Snapshot-Cache", "hit") + response.Success(c, cached.Payload) + return + } + + var ( + overview *service.OpsDashboardOverview + trend *service.OpsThroughputTrendResponse + errTrend *service.OpsErrorTrendResponse + ) + g, gctx := errgroup.WithContext(c.Request.Context()) + g.Go(func() error { + f := *filter + result, err := h.opsService.GetDashboardOverview(gctx, &f) + if err != nil { + return err + } + overview = result + return nil + }) + g.Go(func() error { + f := *filter + result, err := h.opsService.GetThroughputTrend(gctx, &f, bucketSeconds) + if err != nil { + return err + } + trend = result + return nil + }) + g.Go(func() error { + f := *filter + result, err := h.opsService.GetErrorTrend(gctx, &f, bucketSeconds) + if err != nil { + return err + } + errTrend = result + return nil + }) + if err := g.Wait(); err != nil { + response.ErrorFrom(c, err) + return + } + + resp := &opsDashboardSnapshotV2Response{ + GeneratedAt: time.Now().UTC().Format(time.RFC3339), + Overview: overview, + ThroughputTrend: trend, + ErrorTrend: errTrend, + } + + cached := opsDashboardSnapshotV2Cache.Set(cacheKey, resp) + if cached.ETag != "" { + c.Header("ETag", cached.ETag) + c.Header("Vary", "If-None-Match") + } + c.Header("X-Snapshot-Cache", "miss") + response.Success(c, resp) +} diff --git a/backend/internal/handler/admin/scheduled_test_handler.go b/backend/internal/handler/admin/scheduled_test_handler.go new file mode 100644 index 00000000..d9f39737 --- /dev/null +++ b/backend/internal/handler/admin/scheduled_test_handler.go @@ -0,0 +1,163 @@ +package admin + +import ( + "net/http" + "strconv" + + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" +) + +// ScheduledTestHandler handles admin scheduled-test-plan management. +type ScheduledTestHandler struct { + scheduledTestSvc *service.ScheduledTestService +} + +// NewScheduledTestHandler creates a new ScheduledTestHandler. +func NewScheduledTestHandler(scheduledTestSvc *service.ScheduledTestService) *ScheduledTestHandler { + return &ScheduledTestHandler{scheduledTestSvc: scheduledTestSvc} +} + +type createScheduledTestPlanRequest struct { + AccountID int64 `json:"account_id" binding:"required"` + ModelID string `json:"model_id"` + CronExpression string `json:"cron_expression" binding:"required"` + Enabled *bool `json:"enabled"` + MaxResults int `json:"max_results"` + AutoRecover *bool `json:"auto_recover"` +} + +type updateScheduledTestPlanRequest struct { + ModelID string `json:"model_id"` + CronExpression string `json:"cron_expression"` + Enabled *bool `json:"enabled"` + MaxResults int `json:"max_results"` + AutoRecover *bool `json:"auto_recover"` +} + +// ListByAccount GET /admin/accounts/:id/scheduled-test-plans +func (h *ScheduledTestHandler) ListByAccount(c *gin.Context) { + accountID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "invalid account id") + return + } + + plans, err := h.scheduledTestSvc.ListPlansByAccount(c.Request.Context(), accountID) + if err != nil { + response.InternalError(c, err.Error()) + return + } + c.JSON(http.StatusOK, plans) +} + +// Create POST /admin/scheduled-test-plans +func (h *ScheduledTestHandler) Create(c *gin.Context) { + var req createScheduledTestPlanRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, err.Error()) + return + } + + plan := &service.ScheduledTestPlan{ + AccountID: req.AccountID, + ModelID: req.ModelID, + CronExpression: req.CronExpression, + Enabled: true, + MaxResults: req.MaxResults, + } + if req.Enabled != nil { + plan.Enabled = *req.Enabled + } + if req.AutoRecover != nil { + plan.AutoRecover = *req.AutoRecover + } + + created, err := h.scheduledTestSvc.CreatePlan(c.Request.Context(), plan) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + c.JSON(http.StatusOK, created) +} + +// Update PUT /admin/scheduled-test-plans/:id +func (h *ScheduledTestHandler) Update(c *gin.Context) { + planID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "invalid plan id") + return + } + + existing, err := h.scheduledTestSvc.GetPlan(c.Request.Context(), planID) + if err != nil { + response.NotFound(c, "plan not found") + return + } + + var req updateScheduledTestPlanRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, err.Error()) + return + } + + if req.ModelID != "" { + existing.ModelID = req.ModelID + } + if req.CronExpression != "" { + existing.CronExpression = req.CronExpression + } + if req.Enabled != nil { + existing.Enabled = *req.Enabled + } + if req.MaxResults > 0 { + existing.MaxResults = req.MaxResults + } + if req.AutoRecover != nil { + existing.AutoRecover = *req.AutoRecover + } + + updated, err := h.scheduledTestSvc.UpdatePlan(c.Request.Context(), existing) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + c.JSON(http.StatusOK, updated) +} + +// Delete DELETE /admin/scheduled-test-plans/:id +func (h *ScheduledTestHandler) Delete(c *gin.Context) { + planID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "invalid plan id") + return + } + + if err := h.scheduledTestSvc.DeletePlan(c.Request.Context(), planID); err != nil { + response.InternalError(c, err.Error()) + return + } + c.JSON(http.StatusOK, gin.H{"message": "deleted"}) +} + +// ListResults GET /admin/scheduled-test-plans/:id/results +func (h *ScheduledTestHandler) ListResults(c *gin.Context) { + planID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "invalid plan id") + return + } + + limit := 50 + if l, err := strconv.Atoi(c.Query("limit")); err == nil && l > 0 { + limit = l + } + + results, err := h.scheduledTestSvc.ListResults(c.Request.Context(), planID, limit) + if err != nil { + response.InternalError(c, err.Error()) + return + } + c.JSON(http.StatusOK, results) +} diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index 43339412..8330868d 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -77,6 +77,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { response.Success(c, dto.SystemSettings{ RegistrationEnabled: settings.RegistrationEnabled, EmailVerifyEnabled: settings.EmailVerifyEnabled, + RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist, PromoCodeEnabled: settings.PromoCodeEnabled, PasswordResetEnabled: settings.PasswordResetEnabled, InvitationCodeEnabled: settings.InvitationCodeEnabled, @@ -130,12 +131,13 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { // UpdateSettingsRequest 更新设置请求 type UpdateSettingsRequest struct { // 注册设置 - RegistrationEnabled bool `json:"registration_enabled"` - EmailVerifyEnabled bool `json:"email_verify_enabled"` - PromoCodeEnabled bool `json:"promo_code_enabled"` - PasswordResetEnabled bool `json:"password_reset_enabled"` - InvitationCodeEnabled bool `json:"invitation_code_enabled"` - TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证 + RegistrationEnabled bool `json:"registration_enabled"` + EmailVerifyEnabled bool `json:"email_verify_enabled"` + RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"` + PromoCodeEnabled bool `json:"promo_code_enabled"` + PasswordResetEnabled bool `json:"password_reset_enabled"` + InvitationCodeEnabled bool `json:"invitation_code_enabled"` + TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证 // 邮件服务设置 SMTPHost string `json:"smtp_host"` @@ -426,50 +428,51 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { } settings := &service.SystemSettings{ - RegistrationEnabled: req.RegistrationEnabled, - EmailVerifyEnabled: req.EmailVerifyEnabled, - PromoCodeEnabled: req.PromoCodeEnabled, - PasswordResetEnabled: req.PasswordResetEnabled, - InvitationCodeEnabled: req.InvitationCodeEnabled, - TotpEnabled: req.TotpEnabled, - SMTPHost: req.SMTPHost, - SMTPPort: req.SMTPPort, - SMTPUsername: req.SMTPUsername, - SMTPPassword: req.SMTPPassword, - SMTPFrom: req.SMTPFrom, - SMTPFromName: req.SMTPFromName, - SMTPUseTLS: req.SMTPUseTLS, - TurnstileEnabled: req.TurnstileEnabled, - TurnstileSiteKey: req.TurnstileSiteKey, - TurnstileSecretKey: req.TurnstileSecretKey, - LinuxDoConnectEnabled: req.LinuxDoConnectEnabled, - LinuxDoConnectClientID: req.LinuxDoConnectClientID, - LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret, - LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL, - SiteName: req.SiteName, - SiteLogo: req.SiteLogo, - SiteSubtitle: req.SiteSubtitle, - APIBaseURL: req.APIBaseURL, - ContactInfo: req.ContactInfo, - DocURL: req.DocURL, - HomeContent: req.HomeContent, - HideCcsImportButton: req.HideCcsImportButton, - PurchaseSubscriptionEnabled: purchaseEnabled, - PurchaseSubscriptionURL: purchaseURL, - SoraClientEnabled: req.SoraClientEnabled, - CustomMenuItems: customMenuJSON, - DefaultConcurrency: req.DefaultConcurrency, - DefaultBalance: req.DefaultBalance, - DefaultSubscriptions: defaultSubscriptions, - EnableModelFallback: req.EnableModelFallback, - FallbackModelAnthropic: req.FallbackModelAnthropic, - FallbackModelOpenAI: req.FallbackModelOpenAI, - FallbackModelGemini: req.FallbackModelGemini, - FallbackModelAntigravity: req.FallbackModelAntigravity, - EnableIdentityPatch: req.EnableIdentityPatch, - IdentityPatchPrompt: req.IdentityPatchPrompt, - MinClaudeCodeVersion: req.MinClaudeCodeVersion, - AllowUngroupedKeyScheduling: req.AllowUngroupedKeyScheduling, + RegistrationEnabled: req.RegistrationEnabled, + EmailVerifyEnabled: req.EmailVerifyEnabled, + RegistrationEmailSuffixWhitelist: req.RegistrationEmailSuffixWhitelist, + PromoCodeEnabled: req.PromoCodeEnabled, + PasswordResetEnabled: req.PasswordResetEnabled, + InvitationCodeEnabled: req.InvitationCodeEnabled, + TotpEnabled: req.TotpEnabled, + SMTPHost: req.SMTPHost, + SMTPPort: req.SMTPPort, + SMTPUsername: req.SMTPUsername, + SMTPPassword: req.SMTPPassword, + SMTPFrom: req.SMTPFrom, + SMTPFromName: req.SMTPFromName, + SMTPUseTLS: req.SMTPUseTLS, + TurnstileEnabled: req.TurnstileEnabled, + TurnstileSiteKey: req.TurnstileSiteKey, + TurnstileSecretKey: req.TurnstileSecretKey, + LinuxDoConnectEnabled: req.LinuxDoConnectEnabled, + LinuxDoConnectClientID: req.LinuxDoConnectClientID, + LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret, + LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL, + SiteName: req.SiteName, + SiteLogo: req.SiteLogo, + SiteSubtitle: req.SiteSubtitle, + APIBaseURL: req.APIBaseURL, + ContactInfo: req.ContactInfo, + DocURL: req.DocURL, + HomeContent: req.HomeContent, + HideCcsImportButton: req.HideCcsImportButton, + PurchaseSubscriptionEnabled: purchaseEnabled, + PurchaseSubscriptionURL: purchaseURL, + SoraClientEnabled: req.SoraClientEnabled, + CustomMenuItems: customMenuJSON, + DefaultConcurrency: req.DefaultConcurrency, + DefaultBalance: req.DefaultBalance, + DefaultSubscriptions: defaultSubscriptions, + EnableModelFallback: req.EnableModelFallback, + FallbackModelAnthropic: req.FallbackModelAnthropic, + FallbackModelOpenAI: req.FallbackModelOpenAI, + FallbackModelGemini: req.FallbackModelGemini, + FallbackModelAntigravity: req.FallbackModelAntigravity, + EnableIdentityPatch: req.EnableIdentityPatch, + IdentityPatchPrompt: req.IdentityPatchPrompt, + MinClaudeCodeVersion: req.MinClaudeCodeVersion, + AllowUngroupedKeyScheduling: req.AllowUngroupedKeyScheduling, OpsMonitoringEnabled: func() bool { if req.OpsMonitoringEnabled != nil { return *req.OpsMonitoringEnabled @@ -520,6 +523,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { response.Success(c, dto.SystemSettings{ RegistrationEnabled: updatedSettings.RegistrationEnabled, EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled, + RegistrationEmailSuffixWhitelist: updatedSettings.RegistrationEmailSuffixWhitelist, PromoCodeEnabled: updatedSettings.PromoCodeEnabled, PasswordResetEnabled: updatedSettings.PasswordResetEnabled, InvitationCodeEnabled: updatedSettings.InvitationCodeEnabled, @@ -598,6 +602,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, if before.EmailVerifyEnabled != after.EmailVerifyEnabled { changed = append(changed, "email_verify_enabled") } + if !equalStringSlice(before.RegistrationEmailSuffixWhitelist, after.RegistrationEmailSuffixWhitelist) { + changed = append(changed, "registration_email_suffix_whitelist") + } if before.PasswordResetEnabled != after.PasswordResetEnabled { changed = append(changed, "password_reset_enabled") } @@ -747,6 +754,18 @@ func normalizeDefaultSubscriptions(input []dto.DefaultSubscriptionSetting) []dto return normalized } +func equalStringSlice(a, b []string) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} + func equalDefaultSubscriptions(a, b []service.DefaultSubscriptionSetting) bool { if len(a) != len(b) { return false @@ -800,7 +819,7 @@ func (h *SettingHandler) TestSMTPConnection(c *gin.Context) { err := h.emailService.TestSMTPConnectionWithConfig(config) if err != nil { - response.ErrorFrom(c, err) + response.BadRequest(c, "SMTP connection test failed: "+err.Error()) return } @@ -886,7 +905,7 @@ func (h *SettingHandler) SendTestEmail(c *gin.Context) { ` if err := h.emailService.SendEmailWithConfig(config, req.Email, subject, body); err != nil { - response.ErrorFrom(c, err) + response.BadRequest(c, "Failed to send test email: "+err.Error()) return } @@ -1329,6 +1348,118 @@ func (h *SettingHandler) TestSoraS3Connection(c *gin.Context) { response.Success(c, gin.H{"message": "S3 连接成功"}) } +// GetRectifierSettings 获取请求整流器配置 +// GET /api/v1/admin/settings/rectifier +func (h *SettingHandler) GetRectifierSettings(c *gin.Context) { + settings, err := h.settingService.GetRectifierSettings(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, dto.RectifierSettings{ + Enabled: settings.Enabled, + ThinkingSignatureEnabled: settings.ThinkingSignatureEnabled, + ThinkingBudgetEnabled: settings.ThinkingBudgetEnabled, + }) +} + +// UpdateRectifierSettingsRequest 更新整流器配置请求 +type UpdateRectifierSettingsRequest struct { + Enabled bool `json:"enabled"` + ThinkingSignatureEnabled bool `json:"thinking_signature_enabled"` + ThinkingBudgetEnabled bool `json:"thinking_budget_enabled"` +} + +// UpdateRectifierSettings 更新请求整流器配置 +// PUT /api/v1/admin/settings/rectifier +func (h *SettingHandler) UpdateRectifierSettings(c *gin.Context) { + var req UpdateRectifierSettingsRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + settings := &service.RectifierSettings{ + Enabled: req.Enabled, + ThinkingSignatureEnabled: req.ThinkingSignatureEnabled, + ThinkingBudgetEnabled: req.ThinkingBudgetEnabled, + } + + if err := h.settingService.SetRectifierSettings(c.Request.Context(), settings); err != nil { + response.BadRequest(c, err.Error()) + return + } + + // 重新获取设置返回 + updatedSettings, err := h.settingService.GetRectifierSettings(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, dto.RectifierSettings{ + Enabled: updatedSettings.Enabled, + ThinkingSignatureEnabled: updatedSettings.ThinkingSignatureEnabled, + ThinkingBudgetEnabled: updatedSettings.ThinkingBudgetEnabled, + }) +} + +// GetBetaPolicySettings 获取 Beta 策略配置 +// GET /api/v1/admin/settings/beta-policy +func (h *SettingHandler) GetBetaPolicySettings(c *gin.Context) { + settings, err := h.settingService.GetBetaPolicySettings(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + + rules := make([]dto.BetaPolicyRule, len(settings.Rules)) + for i, r := range settings.Rules { + rules[i] = dto.BetaPolicyRule(r) + } + response.Success(c, dto.BetaPolicySettings{Rules: rules}) +} + +// UpdateBetaPolicySettingsRequest 更新 Beta 策略配置请求 +type UpdateBetaPolicySettingsRequest struct { + Rules []dto.BetaPolicyRule `json:"rules"` +} + +// UpdateBetaPolicySettings 更新 Beta 策略配置 +// PUT /api/v1/admin/settings/beta-policy +func (h *SettingHandler) UpdateBetaPolicySettings(c *gin.Context) { + var req UpdateBetaPolicySettingsRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + rules := make([]service.BetaPolicyRule, len(req.Rules)) + for i, r := range req.Rules { + rules[i] = service.BetaPolicyRule(r) + } + + settings := &service.BetaPolicySettings{Rules: rules} + if err := h.settingService.SetBetaPolicySettings(c.Request.Context(), settings); err != nil { + response.BadRequest(c, err.Error()) + return + } + + // Re-fetch to return updated settings + updated, err := h.settingService.GetBetaPolicySettings(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + + outRules := make([]dto.BetaPolicyRule, len(updated.Rules)) + for i, r := range updated.Rules { + outRules[i] = dto.BetaPolicyRule(r) + } + response.Success(c, dto.BetaPolicySettings{Rules: outRules}) +} + // UpdateStreamTimeoutSettingsRequest 更新流超时配置请求 type UpdateStreamTimeoutSettingsRequest struct { Enabled bool `json:"enabled"` diff --git a/backend/internal/handler/admin/snapshot_cache.go b/backend/internal/handler/admin/snapshot_cache.go new file mode 100644 index 00000000..d6973ff9 --- /dev/null +++ b/backend/internal/handler/admin/snapshot_cache.go @@ -0,0 +1,138 @@ +package admin + +import ( + "crypto/sha256" + "encoding/hex" + "encoding/json" + "strings" + "sync" + "time" + + "golang.org/x/sync/singleflight" +) + +type snapshotCacheEntry struct { + ETag string + Payload any + ExpiresAt time.Time +} + +type snapshotCache struct { + mu sync.RWMutex + ttl time.Duration + items map[string]snapshotCacheEntry + sf singleflight.Group +} + +type snapshotCacheLoadResult struct { + Entry snapshotCacheEntry + Hit bool +} + +func newSnapshotCache(ttl time.Duration) *snapshotCache { + if ttl <= 0 { + ttl = 30 * time.Second + } + return &snapshotCache{ + ttl: ttl, + items: make(map[string]snapshotCacheEntry), + } +} + +func (c *snapshotCache) Get(key string) (snapshotCacheEntry, bool) { + if c == nil || key == "" { + return snapshotCacheEntry{}, false + } + now := time.Now() + + c.mu.RLock() + entry, ok := c.items[key] + c.mu.RUnlock() + if !ok { + return snapshotCacheEntry{}, false + } + if now.After(entry.ExpiresAt) { + c.mu.Lock() + delete(c.items, key) + c.mu.Unlock() + return snapshotCacheEntry{}, false + } + return entry, true +} + +func (c *snapshotCache) Set(key string, payload any) snapshotCacheEntry { + if c == nil { + return snapshotCacheEntry{} + } + entry := snapshotCacheEntry{ + ETag: buildETagFromAny(payload), + Payload: payload, + ExpiresAt: time.Now().Add(c.ttl), + } + if key == "" { + return entry + } + c.mu.Lock() + c.items[key] = entry + c.mu.Unlock() + return entry +} + +func (c *snapshotCache) GetOrLoad(key string, load func() (any, error)) (snapshotCacheEntry, bool, error) { + if load == nil { + return snapshotCacheEntry{}, false, nil + } + if entry, ok := c.Get(key); ok { + return entry, true, nil + } + if c == nil || key == "" { + payload, err := load() + if err != nil { + return snapshotCacheEntry{}, false, err + } + return c.Set(key, payload), false, nil + } + + value, err, _ := c.sf.Do(key, func() (any, error) { + if entry, ok := c.Get(key); ok { + return snapshotCacheLoadResult{Entry: entry, Hit: true}, nil + } + payload, err := load() + if err != nil { + return nil, err + } + return snapshotCacheLoadResult{Entry: c.Set(key, payload), Hit: false}, nil + }) + if err != nil { + return snapshotCacheEntry{}, false, err + } + result, ok := value.(snapshotCacheLoadResult) + if !ok { + return snapshotCacheEntry{}, false, nil + } + return result.Entry, result.Hit, nil +} + +func buildETagFromAny(payload any) string { + raw, err := json.Marshal(payload) + if err != nil { + return "" + } + sum := sha256.Sum256(raw) + return "\"" + hex.EncodeToString(sum[:]) + "\"" +} + +func parseBoolQueryWithDefault(raw string, def bool) bool { + value := strings.TrimSpace(strings.ToLower(raw)) + if value == "" { + return def + } + switch value { + case "1", "true", "yes", "on": + return true + case "0", "false", "no", "off": + return false + default: + return def + } +} diff --git a/backend/internal/handler/admin/snapshot_cache_test.go b/backend/internal/handler/admin/snapshot_cache_test.go new file mode 100644 index 00000000..ee3f72ca --- /dev/null +++ b/backend/internal/handler/admin/snapshot_cache_test.go @@ -0,0 +1,185 @@ +//go:build unit + +package admin + +import ( + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestSnapshotCache_SetAndGet(t *testing.T) { + c := newSnapshotCache(5 * time.Second) + + entry := c.Set("key1", map[string]string{"hello": "world"}) + require.NotEmpty(t, entry.ETag) + require.NotNil(t, entry.Payload) + + got, ok := c.Get("key1") + require.True(t, ok) + require.Equal(t, entry.ETag, got.ETag) +} + +func TestSnapshotCache_Expiration(t *testing.T) { + c := newSnapshotCache(1 * time.Millisecond) + + c.Set("key1", "value") + time.Sleep(5 * time.Millisecond) + + _, ok := c.Get("key1") + require.False(t, ok, "expired entry should not be returned") +} + +func TestSnapshotCache_GetEmptyKey(t *testing.T) { + c := newSnapshotCache(5 * time.Second) + _, ok := c.Get("") + require.False(t, ok) +} + +func TestSnapshotCache_GetMiss(t *testing.T) { + c := newSnapshotCache(5 * time.Second) + _, ok := c.Get("nonexistent") + require.False(t, ok) +} + +func TestSnapshotCache_NilReceiver(t *testing.T) { + var c *snapshotCache + _, ok := c.Get("key") + require.False(t, ok) + + entry := c.Set("key", "value") + require.Empty(t, entry.ETag) +} + +func TestSnapshotCache_SetEmptyKey(t *testing.T) { + c := newSnapshotCache(5 * time.Second) + + // Set with empty key should return entry but not store it + entry := c.Set("", "value") + require.NotEmpty(t, entry.ETag) + + _, ok := c.Get("") + require.False(t, ok) +} + +func TestSnapshotCache_DefaultTTL(t *testing.T) { + c := newSnapshotCache(0) + require.Equal(t, 30*time.Second, c.ttl) + + c2 := newSnapshotCache(-1 * time.Second) + require.Equal(t, 30*time.Second, c2.ttl) +} + +func TestSnapshotCache_ETagDeterministic(t *testing.T) { + c := newSnapshotCache(5 * time.Second) + payload := map[string]int{"a": 1, "b": 2} + + entry1 := c.Set("k1", payload) + entry2 := c.Set("k2", payload) + require.Equal(t, entry1.ETag, entry2.ETag, "same payload should produce same ETag") +} + +func TestSnapshotCache_ETagFormat(t *testing.T) { + c := newSnapshotCache(5 * time.Second) + entry := c.Set("k", "test") + // ETag should be quoted hex string: "abcdef..." + require.True(t, len(entry.ETag) > 2) + require.Equal(t, byte('"'), entry.ETag[0]) + require.Equal(t, byte('"'), entry.ETag[len(entry.ETag)-1]) +} + +func TestBuildETagFromAny_UnmarshalablePayload(t *testing.T) { + // channels are not JSON-serializable + etag := buildETagFromAny(make(chan int)) + require.Empty(t, etag) +} + +func TestSnapshotCache_GetOrLoad_MissThenHit(t *testing.T) { + c := newSnapshotCache(5 * time.Second) + var loads atomic.Int32 + + entry, hit, err := c.GetOrLoad("key1", func() (any, error) { + loads.Add(1) + return map[string]string{"hello": "world"}, nil + }) + require.NoError(t, err) + require.False(t, hit) + require.NotEmpty(t, entry.ETag) + require.Equal(t, int32(1), loads.Load()) + + entry2, hit, err := c.GetOrLoad("key1", func() (any, error) { + loads.Add(1) + return map[string]string{"unexpected": "value"}, nil + }) + require.NoError(t, err) + require.True(t, hit) + require.Equal(t, entry.ETag, entry2.ETag) + require.Equal(t, int32(1), loads.Load()) +} + +func TestSnapshotCache_GetOrLoad_ConcurrentSingleflight(t *testing.T) { + c := newSnapshotCache(5 * time.Second) + var loads atomic.Int32 + start := make(chan struct{}) + const callers = 8 + errCh := make(chan error, callers) + + var wg sync.WaitGroup + wg.Add(callers) + for range callers { + go func() { + defer wg.Done() + <-start + _, _, err := c.GetOrLoad("shared", func() (any, error) { + loads.Add(1) + time.Sleep(20 * time.Millisecond) + return "value", nil + }) + errCh <- err + }() + } + close(start) + wg.Wait() + close(errCh) + + for err := range errCh { + require.NoError(t, err) + } + + require.Equal(t, int32(1), loads.Load()) +} + +func TestParseBoolQueryWithDefault(t *testing.T) { + tests := []struct { + name string + raw string + def bool + want bool + }{ + {"empty returns default true", "", true, true}, + {"empty returns default false", "", false, false}, + {"1", "1", false, true}, + {"true", "true", false, true}, + {"TRUE", "TRUE", false, true}, + {"yes", "yes", false, true}, + {"on", "on", false, true}, + {"0", "0", true, false}, + {"false", "false", true, false}, + {"FALSE", "FALSE", true, false}, + {"no", "no", true, false}, + {"off", "off", true, false}, + {"whitespace trimmed", " true ", false, true}, + {"unknown returns default true", "maybe", true, true}, + {"unknown returns default false", "maybe", false, false}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := parseBoolQueryWithDefault(tc.raw, tc.def) + require.Equal(t, tc.want, got) + }) + } +} diff --git a/backend/internal/handler/admin/subscription_handler.go b/backend/internal/handler/admin/subscription_handler.go index e5b6db13..d6073551 100644 --- a/backend/internal/handler/admin/subscription_handler.go +++ b/backend/internal/handler/admin/subscription_handler.go @@ -216,6 +216,37 @@ func (h *SubscriptionHandler) Extend(c *gin.Context) { }) } +// ResetSubscriptionQuotaRequest represents the reset quota request +type ResetSubscriptionQuotaRequest struct { + Daily bool `json:"daily"` + Weekly bool `json:"weekly"` +} + +// ResetQuota resets daily and/or weekly usage for a subscription. +// POST /api/v1/admin/subscriptions/:id/reset-quota +func (h *SubscriptionHandler) ResetQuota(c *gin.Context) { + subscriptionID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "Invalid subscription ID") + return + } + var req ResetSubscriptionQuotaRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + if !req.Daily && !req.Weekly { + response.BadRequest(c, "At least one of 'daily' or 'weekly' must be true") + return + } + sub, err := h.subscriptionService.AdminResetQuota(c.Request.Context(), subscriptionID, req.Daily, req.Weekly) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, dto.UserSubscriptionFromServiceAdmin(sub)) +} + // Revoke handles revoking a subscription // DELETE /api/v1/admin/subscriptions/:id func (h *SubscriptionHandler) Revoke(c *gin.Context) { diff --git a/backend/internal/handler/admin/usage_handler.go b/backend/internal/handler/admin/usage_handler.go index d0bba773..05fd00f1 100644 --- a/backend/internal/handler/admin/usage_handler.go +++ b/backend/internal/handler/admin/usage_handler.go @@ -61,6 +61,15 @@ type CreateUsageCleanupTaskRequest struct { // GET /api/v1/admin/usage func (h *UsageHandler) List(c *gin.Context) { page, pageSize := response.ParsePagination(c) + exactTotal := false + if exactTotalRaw := strings.TrimSpace(c.Query("exact_total")); exactTotalRaw != "" { + parsed, err := strconv.ParseBool(exactTotalRaw) + if err != nil { + response.BadRequest(c, "Invalid exact_total value, use true or false") + return + } + exactTotal = parsed + } // Parse filters var userID, apiKeyID, accountID, groupID int64 @@ -167,6 +176,7 @@ func (h *UsageHandler) List(c *gin.Context) { BillingType: billingType, StartTime: startTime, EndTime: endTime, + ExactTotal: exactTotal, } records, result, err := h.usageService.ListWithFilters(c.Request.Context(), params, filters) diff --git a/backend/internal/handler/admin/usage_handler_request_type_test.go b/backend/internal/handler/admin/usage_handler_request_type_test.go index 21add574..3f158316 100644 --- a/backend/internal/handler/admin/usage_handler_request_type_test.go +++ b/backend/internal/handler/admin/usage_handler_request_type_test.go @@ -80,6 +80,29 @@ func TestAdminUsageListInvalidStream(t *testing.T) { require.Equal(t, http.StatusBadRequest, rec.Code) } +func TestAdminUsageListExactTotalTrue(t *testing.T) { + repo := &adminUsageRepoCapture{} + router := newAdminUsageRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/admin/usage?exact_total=true", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.True(t, repo.listFilters.ExactTotal) +} + +func TestAdminUsageListInvalidExactTotal(t *testing.T) { + repo := &adminUsageRepoCapture{} + router := newAdminUsageRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/admin/usage?exact_total=oops", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) +} + func TestAdminUsageStatsRequestTypePriority(t *testing.T) { repo := &adminUsageRepoCapture{} router := newAdminUsageRequestTypeTestRouter(repo) diff --git a/backend/internal/handler/admin/user_attribute_handler.go b/backend/internal/handler/admin/user_attribute_handler.go index 2f326279..3f84076e 100644 --- a/backend/internal/handler/admin/user_attribute_handler.go +++ b/backend/internal/handler/admin/user_attribute_handler.go @@ -1,7 +1,9 @@ package admin import ( + "encoding/json" "strconv" + "time" "github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/service" @@ -67,6 +69,8 @@ type BatchUserAttributesResponse struct { Attributes map[int64]map[int64]string `json:"attributes"` } +var userAttributesBatchCache = newSnapshotCache(30 * time.Second) + // AttributeDefinitionResponse represents attribute definition response type AttributeDefinitionResponse struct { ID int64 `json:"id"` @@ -327,16 +331,32 @@ func (h *UserAttributeHandler) GetBatchUserAttributes(c *gin.Context) { return } - if len(req.UserIDs) == 0 { + userIDs := normalizeInt64IDList(req.UserIDs) + if len(userIDs) == 0 { response.Success(c, BatchUserAttributesResponse{Attributes: map[int64]map[int64]string{}}) return } - attrs, err := h.attrService.GetBatchUserAttributes(c.Request.Context(), req.UserIDs) + keyRaw, _ := json.Marshal(struct { + UserIDs []int64 `json:"user_ids"` + }{ + UserIDs: userIDs, + }) + cacheKey := string(keyRaw) + if cached, ok := userAttributesBatchCache.Get(cacheKey); ok { + c.Header("X-Snapshot-Cache", "hit") + response.Success(c, cached.Payload) + return + } + + attrs, err := h.attrService.GetBatchUserAttributes(c.Request.Context(), userIDs) if err != nil { response.ErrorFrom(c, err) return } - response.Success(c, BatchUserAttributesResponse{Attributes: attrs}) + payload := BatchUserAttributesResponse{Attributes: attrs} + userAttributesBatchCache.Set(cacheKey, payload) + c.Header("X-Snapshot-Cache", "miss") + response.Success(c, payload) } diff --git a/backend/internal/handler/admin/user_handler.go b/backend/internal/handler/admin/user_handler.go index f85c060e..5a55ab14 100644 --- a/backend/internal/handler/admin/user_handler.go +++ b/backend/internal/handler/admin/user_handler.go @@ -91,6 +91,10 @@ func (h *UserHandler) List(c *gin.Context) { Search: search, Attributes: parseAttributeFilters(c), } + if raw, ok := c.GetQuery("include_subscriptions"); ok { + includeSubscriptions := parseBoolQueryWithDefault(raw, true) + filters.IncludeSubscriptions = &includeSubscriptions + } users, total, err := h.adminService.ListUsers(c.Request.Context(), page, pageSize, filters) if err != nil { diff --git a/backend/internal/handler/auth_linuxdo_oauth.go b/backend/internal/handler/auth_linuxdo_oauth.go index 0ccf47e4..0c7c2da7 100644 --- a/backend/internal/handler/auth_linuxdo_oauth.go +++ b/backend/internal/handler/auth_linuxdo_oauth.go @@ -211,8 +211,22 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) { email = linuxDoSyntheticEmail(subject) } - tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username) + // 传入空邀请码;如果需要邀请码,服务层返回 ErrOAuthInvitationRequired + tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "") if err != nil { + if errors.Is(err, service.ErrOAuthInvitationRequired) { + pendingToken, tokenErr := h.authService.CreatePendingOAuthToken(email, username) + if tokenErr != nil { + redirectOAuthError(c, frontendCallback, "login_failed", "service_error", "") + return + } + fragment := url.Values{} + fragment.Set("error", "invitation_required") + fragment.Set("pending_oauth_token", pendingToken) + fragment.Set("redirect", redirectTo) + redirectWithFragment(c, frontendCallback, fragment) + return + } // 避免把内部细节泄露给客户端;给前端保留结构化原因与提示信息即可。 redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err)) return @@ -227,6 +241,41 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) { redirectWithFragment(c, frontendCallback, fragment) } +type completeLinuxDoOAuthRequest struct { + PendingOAuthToken string `json:"pending_oauth_token" binding:"required"` + InvitationCode string `json:"invitation_code" binding:"required"` +} + +// CompleteLinuxDoOAuthRegistration completes a pending OAuth registration by validating +// the invitation code and creating the user account. +// POST /api/v1/auth/oauth/linuxdo/complete-registration +func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) { + var req completeLinuxDoOAuthRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "INVALID_REQUEST", "message": err.Error()}) + return + } + + email, username, err := h.authService.VerifyPendingOAuthToken(req.PendingOAuthToken) + if err != nil { + c.JSON(http.StatusUnauthorized, gin.H{"error": "INVALID_TOKEN", "message": "invalid or expired registration token"}) + return + } + + tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode) + if err != nil { + response.ErrorFrom(c, err) + return + } + + c.JSON(http.StatusOK, gin.H{ + "access_token": tokenPair.AccessToken, + "refresh_token": tokenPair.RefreshToken, + "expires_in": tokenPair.ExpiresIn, + "token_type": "Bearer", + }) +} + func (h *AuthHandler) getLinuxDoOAuthConfig(ctx context.Context) (config.LinuxDoConnectConfig, error) { if h != nil && h.settingSvc != nil { return h.settingSvc.GetLinuxDoConnectOAuthConfig(ctx) diff --git a/backend/internal/handler/dto/announcement.go b/backend/internal/handler/dto/announcement.go index bc0db1b2..16650b8e 100644 --- a/backend/internal/handler/dto/announcement.go +++ b/backend/internal/handler/dto/announcement.go @@ -7,10 +7,11 @@ import ( ) type Announcement struct { - ID int64 `json:"id"` - Title string `json:"title"` - Content string `json:"content"` - Status string `json:"status"` + ID int64 `json:"id"` + Title string `json:"title"` + Content string `json:"content"` + Status string `json:"status"` + NotifyMode string `json:"notify_mode"` Targeting service.AnnouncementTargeting `json:"targeting"` @@ -25,9 +26,10 @@ type Announcement struct { } type UserAnnouncement struct { - ID int64 `json:"id"` - Title string `json:"title"` - Content string `json:"content"` + ID int64 `json:"id"` + Title string `json:"title"` + Content string `json:"content"` + NotifyMode string `json:"notify_mode"` StartsAt *time.Time `json:"starts_at,omitempty"` EndsAt *time.Time `json:"ends_at,omitempty"` @@ -43,17 +45,18 @@ func AnnouncementFromService(a *service.Announcement) *Announcement { return nil } return &Announcement{ - ID: a.ID, - Title: a.Title, - Content: a.Content, - Status: a.Status, - Targeting: a.Targeting, - StartsAt: a.StartsAt, - EndsAt: a.EndsAt, - CreatedBy: a.CreatedBy, - UpdatedBy: a.UpdatedBy, - CreatedAt: a.CreatedAt, - UpdatedAt: a.UpdatedAt, + ID: a.ID, + Title: a.Title, + Content: a.Content, + Status: a.Status, + NotifyMode: a.NotifyMode, + Targeting: a.Targeting, + StartsAt: a.StartsAt, + EndsAt: a.EndsAt, + CreatedBy: a.CreatedBy, + UpdatedBy: a.UpdatedBy, + CreatedAt: a.CreatedAt, + UpdatedAt: a.UpdatedAt, } } @@ -62,13 +65,14 @@ func UserAnnouncementFromService(a *service.UserAnnouncement) *UserAnnouncement return nil } return &UserAnnouncement{ - ID: a.Announcement.ID, - Title: a.Announcement.Title, - Content: a.Announcement.Content, - StartsAt: a.Announcement.StartsAt, - EndsAt: a.Announcement.EndsAt, - ReadAt: a.ReadAt, - CreatedAt: a.Announcement.CreatedAt, - UpdatedAt: a.Announcement.UpdatedAt, + ID: a.Announcement.ID, + Title: a.Announcement.Title, + Content: a.Announcement.Content, + NotifyMode: a.Announcement.NotifyMode, + StartsAt: a.Announcement.StartsAt, + EndsAt: a.Announcement.EndsAt, + ReadAt: a.ReadAt, + CreatedAt: a.Announcement.CreatedAt, + UpdatedAt: a.Announcement.UpdatedAt, } } diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go index fe2a1d77..402bf095 100644 --- a/backend/internal/handler/dto/mappers.go +++ b/backend/internal/handler/dto/mappers.go @@ -71,7 +71,7 @@ func APIKeyFromService(k *service.APIKey) *APIKey { if k == nil { return nil } - return &APIKey{ + out := &APIKey{ ID: k.ID, UserID: k.UserID, Key: k.Key, @@ -89,15 +89,28 @@ func APIKeyFromService(k *service.APIKey) *APIKey { RateLimit5h: k.RateLimit5h, RateLimit1d: k.RateLimit1d, RateLimit7d: k.RateLimit7d, - Usage5h: k.Usage5h, - Usage1d: k.Usage1d, - Usage7d: k.Usage7d, + Usage5h: k.EffectiveUsage5h(), + Usage1d: k.EffectiveUsage1d(), + Usage7d: k.EffectiveUsage7d(), Window5hStart: k.Window5hStart, Window1dStart: k.Window1dStart, Window7dStart: k.Window7dStart, User: UserFromServiceShallow(k.User), Group: GroupFromServiceShallow(k.Group), } + if k.Window5hStart != nil && !service.IsWindowExpired(k.Window5hStart, service.RateLimitWindow5h) { + t := k.Window5hStart.Add(service.RateLimitWindow5h) + out.Reset5hAt = &t + } + if k.Window1dStart != nil && !service.IsWindowExpired(k.Window1dStart, service.RateLimitWindow1d) { + t := k.Window1dStart.Add(service.RateLimitWindow1d) + out.Reset1dAt = &t + } + if k.Window7dStart != nil && !service.IsWindowExpired(k.Window7dStart, service.RateLimitWindow7d) { + t := k.Window7dStart.Add(service.RateLimitWindow7d) + out.Reset7dAt = &t + } + return out } func GroupFromServiceShallow(g *service.Group) *Group { @@ -122,13 +135,15 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup { return nil } out := &AdminGroup{ - Group: groupFromServiceBase(g), - ModelRouting: g.ModelRouting, - ModelRoutingEnabled: g.ModelRoutingEnabled, - MCPXMLInject: g.MCPXMLInject, - SupportedModelScopes: g.SupportedModelScopes, - AccountCount: g.AccountCount, - SortOrder: g.SortOrder, + Group: groupFromServiceBase(g), + ModelRouting: g.ModelRouting, + ModelRoutingEnabled: g.ModelRoutingEnabled, + MCPXMLInject: g.MCPXMLInject, + DefaultMappedModel: g.DefaultMappedModel, + SimulateClaudeMaxEnabled: g.SimulateClaudeMaxEnabled, + SupportedModelScopes: g.SupportedModelScopes, + AccountCount: g.AccountCount, + SortOrder: g.SortOrder, } if len(g.AccountGroups) > 0 { out.AccountGroups = make([]AccountGroup, 0, len(g.AccountGroups)) @@ -164,6 +179,7 @@ func groupFromServiceBase(g *service.Group) Group { FallbackGroupID: g.FallbackGroupID, FallbackGroupIDOnInvalidRequest: g.FallbackGroupIDOnInvalidRequest, SoraStorageQuotaBytes: g.SoraStorageQuotaBytes, + AllowMessagesDispatch: g.AllowMessagesDispatch, CreatedAt: g.CreatedAt, UpdatedAt: g.UpdatedAt, } @@ -183,6 +199,7 @@ func AccountFromServiceShallow(a *service.Account) *Account { Extra: a.Extra, ProxyID: a.ProxyID, Concurrency: a.Concurrency, + LoadFactor: a.LoadFactor, Priority: a.Priority, RateMultiplier: a.BillingRateMultiplier(), Status: a.Status, @@ -248,6 +265,25 @@ func AccountFromServiceShallow(a *service.Account) *Account { } } + // 提取 API Key 账号配额限制(仅 apikey 类型有效) + if a.Type == service.AccountTypeAPIKey { + if limit := a.GetQuotaLimit(); limit > 0 { + out.QuotaLimit = &limit + used := a.GetQuotaUsed() + out.QuotaUsed = &used + } + if limit := a.GetQuotaDailyLimit(); limit > 0 { + out.QuotaDailyLimit = &limit + used := a.GetQuotaDailyUsed() + out.QuotaDailyUsed = &used + } + if limit := a.GetQuotaWeeklyLimit(); limit > 0 { + out.QuotaWeeklyLimit = &limit + used := a.GetQuotaWeeklyUsed() + out.QuotaWeeklyUsed = &used + } + } + return out } @@ -461,6 +497,7 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog { AccountID: l.AccountID, RequestID: l.RequestID, Model: l.Model, + ServiceTier: l.ServiceTier, ReasoningEffort: l.ReasoningEffort, GroupID: l.GroupID, SubscriptionID: l.SubscriptionID, diff --git a/backend/internal/handler/dto/mappers_usage_test.go b/backend/internal/handler/dto/mappers_usage_test.go index d716bdc4..ea408ecb 100644 --- a/backend/internal/handler/dto/mappers_usage_test.go +++ b/backend/internal/handler/dto/mappers_usage_test.go @@ -71,3 +71,29 @@ func TestRequestTypeStringPtrNil(t *testing.T) { t.Parallel() require.Nil(t, requestTypeStringPtr(nil)) } + +func TestUsageLogFromService_IncludesServiceTierForUserAndAdmin(t *testing.T) { + t.Parallel() + + serviceTier := "priority" + log := &service.UsageLog{ + RequestID: "req_3", + Model: "gpt-5.4", + ServiceTier: &serviceTier, + AccountRateMultiplier: f64Ptr(1.5), + } + + userDTO := UsageLogFromService(log) + adminDTO := UsageLogFromServiceAdmin(log) + + require.NotNil(t, userDTO.ServiceTier) + require.Equal(t, serviceTier, *userDTO.ServiceTier) + require.NotNil(t, adminDTO.ServiceTier) + require.Equal(t, serviceTier, *adminDTO.ServiceTier) + require.NotNil(t, adminDTO.AccountRateMultiplier) + require.InDelta(t, 1.5, *adminDTO.AccountRateMultiplier, 1e-12) +} + +func f64Ptr(value float64) *float64 { + return &value +} diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go index 1e20c9a2..8a1bba5d 100644 --- a/backend/internal/handler/dto/settings.go +++ b/backend/internal/handler/dto/settings.go @@ -17,13 +17,14 @@ type CustomMenuItem struct { // SystemSettings represents the admin settings API response payload. type SystemSettings struct { - RegistrationEnabled bool `json:"registration_enabled"` - EmailVerifyEnabled bool `json:"email_verify_enabled"` - PromoCodeEnabled bool `json:"promo_code_enabled"` - PasswordResetEnabled bool `json:"password_reset_enabled"` - InvitationCodeEnabled bool `json:"invitation_code_enabled"` - TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证 - TotpEncryptionKeyConfigured bool `json:"totp_encryption_key_configured"` // TOTP 加密密钥是否已配置 + RegistrationEnabled bool `json:"registration_enabled"` + EmailVerifyEnabled bool `json:"email_verify_enabled"` + RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"` + PromoCodeEnabled bool `json:"promo_code_enabled"` + PasswordResetEnabled bool `json:"password_reset_enabled"` + InvitationCodeEnabled bool `json:"invitation_code_enabled"` + TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证 + TotpEncryptionKeyConfigured bool `json:"totp_encryption_key_configured"` // TOTP 加密密钥是否已配置 SMTPHost string `json:"smtp_host"` SMTPPort int `json:"smtp_port"` @@ -88,28 +89,29 @@ type DefaultSubscriptionSetting struct { } type PublicSettings struct { - RegistrationEnabled bool `json:"registration_enabled"` - EmailVerifyEnabled bool `json:"email_verify_enabled"` - PromoCodeEnabled bool `json:"promo_code_enabled"` - PasswordResetEnabled bool `json:"password_reset_enabled"` - InvitationCodeEnabled bool `json:"invitation_code_enabled"` - TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证 - TurnstileEnabled bool `json:"turnstile_enabled"` - TurnstileSiteKey string `json:"turnstile_site_key"` - SiteName string `json:"site_name"` - SiteLogo string `json:"site_logo"` - SiteSubtitle string `json:"site_subtitle"` - APIBaseURL string `json:"api_base_url"` - ContactInfo string `json:"contact_info"` - DocURL string `json:"doc_url"` - HomeContent string `json:"home_content"` - HideCcsImportButton bool `json:"hide_ccs_import_button"` - PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"` - PurchaseSubscriptionURL string `json:"purchase_subscription_url"` - CustomMenuItems []CustomMenuItem `json:"custom_menu_items"` - LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` - SoraClientEnabled bool `json:"sora_client_enabled"` - Version string `json:"version"` + RegistrationEnabled bool `json:"registration_enabled"` + EmailVerifyEnabled bool `json:"email_verify_enabled"` + RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"` + PromoCodeEnabled bool `json:"promo_code_enabled"` + PasswordResetEnabled bool `json:"password_reset_enabled"` + InvitationCodeEnabled bool `json:"invitation_code_enabled"` + TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证 + TurnstileEnabled bool `json:"turnstile_enabled"` + TurnstileSiteKey string `json:"turnstile_site_key"` + SiteName string `json:"site_name"` + SiteLogo string `json:"site_logo"` + SiteSubtitle string `json:"site_subtitle"` + APIBaseURL string `json:"api_base_url"` + ContactInfo string `json:"contact_info"` + DocURL string `json:"doc_url"` + HomeContent string `json:"home_content"` + HideCcsImportButton bool `json:"hide_ccs_import_button"` + PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"` + PurchaseSubscriptionURL string `json:"purchase_subscription_url"` + CustomMenuItems []CustomMenuItem `json:"custom_menu_items"` + LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` + SoraClientEnabled bool `json:"sora_client_enabled"` + Version string `json:"version"` } // SoraS3Settings Sora S3 存储配置 DTO(响应用,不含敏感字段) @@ -159,6 +161,26 @@ type StreamTimeoutSettings struct { ThresholdWindowMinutes int `json:"threshold_window_minutes"` } +// RectifierSettings 请求整流器配置 DTO +type RectifierSettings struct { + Enabled bool `json:"enabled"` + ThinkingSignatureEnabled bool `json:"thinking_signature_enabled"` + ThinkingBudgetEnabled bool `json:"thinking_budget_enabled"` +} + +// BetaPolicyRule Beta 策略规则 DTO +type BetaPolicyRule struct { + BetaToken string `json:"beta_token"` + Action string `json:"action"` + Scope string `json:"scope"` + ErrorMessage string `json:"error_message,omitempty"` +} + +// BetaPolicySettings Beta 策略配置 DTO +type BetaPolicySettings struct { + Rules []BetaPolicyRule `json:"rules"` +} + // ParseCustomMenuItems parses a JSON string into a slice of CustomMenuItem. // Returns empty slice on empty/invalid input. func ParseCustomMenuItems(raw string) []CustomMenuItem { diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go index 920615f7..7f1788a1 100644 --- a/backend/internal/handler/dto/types.go +++ b/backend/internal/handler/dto/types.go @@ -57,6 +57,9 @@ type APIKey struct { Window5hStart *time.Time `json:"window_5h_start"` Window1dStart *time.Time `json:"window_1d_start"` Window7dStart *time.Time `json:"window_7d_start"` + Reset5hAt *time.Time `json:"reset_5h_at,omitempty"` + Reset1dAt *time.Time `json:"reset_1d_at,omitempty"` + Reset7dAt *time.Time `json:"reset_7d_at,omitempty"` User *User `json:"user,omitempty"` Group *Group `json:"group,omitempty"` @@ -96,6 +99,9 @@ type Group struct { // Sora 存储配额 SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"` + // OpenAI Messages 调度开关(用户侧需要此字段判断是否展示 Claude Code 教程) + AllowMessagesDispatch bool `json:"allow_messages_dispatch"` + CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` } @@ -111,6 +117,11 @@ type AdminGroup struct { // MCP XML 协议注入(仅 antigravity 平台使用) MCPXMLInject bool `json:"mcp_xml_inject"` + // Claude usage 模拟开关(仅管理员可见) + SimulateClaudeMaxEnabled bool `json:"simulate_claude_max_enabled"` + + // OpenAI Messages 调度配置(仅 openai 平台使用) + DefaultMappedModel string `json:"default_mapped_model"` // 支持的模型系列(仅 antigravity 平台使用) SupportedModelScopes []string `json:"supported_model_scopes"` @@ -131,6 +142,7 @@ type Account struct { Extra map[string]any `json:"extra"` ProxyID *int64 `json:"proxy_id"` Concurrency int `json:"concurrency"` + LoadFactor *int `json:"load_factor,omitempty"` Priority int `json:"priority"` RateMultiplier float64 `json:"rate_multiplier"` Status string `json:"status"` @@ -185,6 +197,14 @@ type Account struct { CacheTTLOverrideEnabled *bool `json:"cache_ttl_override_enabled,omitempty"` CacheTTLOverrideTarget *string `json:"cache_ttl_override_target,omitempty"` + // API Key 账号配额限制 + QuotaLimit *float64 `json:"quota_limit,omitempty"` + QuotaUsed *float64 `json:"quota_used,omitempty"` + QuotaDailyLimit *float64 `json:"quota_daily_limit,omitempty"` + QuotaDailyUsed *float64 `json:"quota_daily_used,omitempty"` + QuotaWeeklyLimit *float64 `json:"quota_weekly_limit,omitempty"` + QuotaWeeklyUsed *float64 `json:"quota_weekly_used,omitempty"` + Proxy *Proxy `json:"proxy,omitempty"` AccountGroups []AccountGroup `json:"account_groups,omitempty"` @@ -304,6 +324,8 @@ type UsageLog struct { AccountID int64 `json:"account_id"` RequestID string `json:"request_id"` Model string `json:"model"` + // ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex". + ServiceTier *string `json:"service_tier,omitempty"` // ReasoningEffort is the request's reasoning effort level (OpenAI Responses API). // nil means not provided / not applicable. ReasoningEffort *string `json:"reasoning_effort,omitempty"` diff --git a/backend/internal/handler/failover_loop.go b/backend/internal/handler/failover_loop.go index b2583301..6d8ddc72 100644 --- a/backend/internal/handler/failover_loop.go +++ b/backend/internal/handler/failover_loop.go @@ -30,7 +30,7 @@ const ( const ( // maxSameAccountRetries 同账号重试次数上限(针对 RetryableOnSameAccount 错误) - maxSameAccountRetries = 2 + maxSameAccountRetries = 3 // sameAccountRetryDelay 同账号重试间隔 sameAccountRetryDelay = 500 * time.Millisecond // singleAccountBackoffDelay 单账号分组 503 退避重试固定延时。 diff --git a/backend/internal/handler/failover_loop_test.go b/backend/internal/handler/failover_loop_test.go index 5a41b2dd..2c65ebc2 100644 --- a/backend/internal/handler/failover_loop_test.go +++ b/backend/internal/handler/failover_loop_test.go @@ -291,35 +291,31 @@ func TestHandleFailoverError_SameAccountRetry(t *testing.T) { require.Less(t, elapsed, 2*time.Second) }) - t.Run("第二次重试仍返回FailoverContinue", func(t *testing.T) { + t.Run("达到最大重试次数前均返回FailoverContinue", func(t *testing.T) { mock := &mockTempUnscheduler{} fs := NewFailoverState(3, false) err := newTestFailoverErr(400, true, false) - // 第一次 - action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) - require.Equal(t, FailoverContinue, action) - require.Equal(t, 1, fs.SameAccountRetryCount[100]) + for i := 1; i <= maxSameAccountRetries; i++ { + action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + require.Equal(t, FailoverContinue, action) + require.Equal(t, i, fs.SameAccountRetryCount[100]) + } - // 第二次 - action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) - require.Equal(t, FailoverContinue, action) - require.Equal(t, 2, fs.SameAccountRetryCount[100]) - - require.Empty(t, mock.calls, "两次重试期间均不应调用 TempUnschedule") + require.Empty(t, mock.calls, "达到最大重试次数前均不应调用 TempUnschedule") }) - t.Run("第三次重试耗尽_触发TempUnschedule并切换", func(t *testing.T) { + t.Run("超过最大重试次数后触发TempUnschedule并切换", func(t *testing.T) { mock := &mockTempUnscheduler{} fs := NewFailoverState(3, false) err := newTestFailoverErr(400, true, false) - // 第一次、第二次重试 - fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) - fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) - require.Equal(t, 2, fs.SameAccountRetryCount[100]) + for i := 0; i < maxSameAccountRetries; i++ { + fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + } + require.Equal(t, maxSameAccountRetries, fs.SameAccountRetryCount[100]) - // 第三次:重试已达到 maxSameAccountRetries(2),应切换账号 + // 第 maxSameAccountRetries+1 次:重试耗尽,应切换账号 action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) require.Equal(t, FailoverContinue, action) require.Equal(t, 1, fs.SwitchCount) @@ -354,13 +350,14 @@ func TestHandleFailoverError_SameAccountRetry(t *testing.T) { err := newTestFailoverErr(400, true, false) // 耗尽账号 100 的重试 - fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) - fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) - // 第三次: 重试耗尽 → 切换 + for i := 0; i < maxSameAccountRetries; i++ { + fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + } + // 第 maxSameAccountRetries+1 次: 重试耗尽 → 切换 action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) require.Equal(t, FailoverContinue, action) - // 再次遇到账号 100,计数仍为 2,条件不满足 → 直接切换 + // 再次遇到账号 100,计数仍为 maxSameAccountRetries,条件不满足 → 直接切换 action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) require.Equal(t, FailoverContinue, action) require.Len(t, mock.calls, 2, "第二次耗尽也应调用 TempUnschedule") @@ -386,9 +383,10 @@ func TestHandleFailoverError_TempUnschedule(t *testing.T) { fs := NewFailoverState(3, false) err := newTestFailoverErr(502, true, false) - // 耗尽重试 - fs.HandleFailoverError(context.Background(), mock, 42, "openai", err) - fs.HandleFailoverError(context.Background(), mock, 42, "openai", err) + for i := 0; i < maxSameAccountRetries; i++ { + fs.HandleFailoverError(context.Background(), mock, 42, "openai", err) + } + // 再次触发时才会执行 TempUnschedule + 切换 fs.HandleFailoverError(context.Background(), mock, 42, "openai", err) require.Len(t, mock.calls, 1) @@ -521,17 +519,16 @@ func TestHandleFailoverError_IntegrationScenario(t *testing.T) { mock := &mockTempUnscheduler{} fs := NewFailoverState(3, true) // hasBoundSession=true - // 1. 账号 100 遇到可重试错误,同账号重试 2 次 + // 1. 账号 100 遇到可重试错误,同账号重试 maxSameAccountRetries 次 retryErr := newTestFailoverErr(400, true, false) - action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr) - require.Equal(t, FailoverContinue, action) + for i := 0; i < maxSameAccountRetries; i++ { + action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr) + require.Equal(t, FailoverContinue, action) + } require.True(t, fs.ForceCacheBilling, "hasBoundSession=true 应设置 ForceCacheBilling") - action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr) - require.Equal(t, FailoverContinue, action) - - // 2. 账号 100 重试耗尽 → TempUnschedule + 切换 - action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr) + // 2. 账号 100 超过重试上限 → TempUnschedule + 切换 + action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr) require.Equal(t, FailoverContinue, action) require.Equal(t, 1, fs.SwitchCount) require.Len(t, mock.calls, 1) diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 1c0ef8e6..743624a2 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -439,6 +439,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { h.submitUsageRecordTask(func(ctx context.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ Result: result, + ParsedRequest: parsedReq, APIKey: apiKey, User: apiKey.User, Account: account, @@ -630,6 +631,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { // ===== 用户消息串行队列 END ===== // 转发请求 - 根据账号平台分流 + c.Set("parsed_request", parsedReq) var result *service.ForwardResult requestCtx := c.Request.Context() if fs.SwitchCount > 0 { @@ -652,6 +654,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) { accountReleaseFunc() } if err != nil { + // Beta policy block: return 400 immediately, no failover + var betaBlockedErr *service.BetaBlockedError + if errors.As(err, &betaBlockedErr) { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", betaBlockedErr.Message) + return + } + var promptTooLongErr *service.PromptTooLongError if errors.As(err, &promptTooLongErr) { reqLog.Warn("gateway.prompt_too_long_from_antigravity", @@ -734,6 +743,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { h.submitUsageRecordTask(func(ctx context.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ Result: result, + ParsedRequest: parsedReq, APIKey: currentAPIKey, User: currentAPIKey.User, Account: account, @@ -971,34 +981,46 @@ func (h *GatewayHandler) usageQuotaLimited(c *gin.Context, ctx context.Context, if err == nil && rateLimitData != nil { var rateLimits []gin.H if apiKey.RateLimit5h > 0 { - used := rateLimitData.Usage5h - rateLimits = append(rateLimits, gin.H{ + used := rateLimitData.EffectiveUsage5h() + entry := gin.H{ "window": "5h", "limit": apiKey.RateLimit5h, "used": used, "remaining": max(0, apiKey.RateLimit5h-used), "window_start": rateLimitData.Window5hStart, - }) + } + if rateLimitData.Window5hStart != nil && !service.IsWindowExpired(rateLimitData.Window5hStart, service.RateLimitWindow5h) { + entry["reset_at"] = rateLimitData.Window5hStart.Add(service.RateLimitWindow5h) + } + rateLimits = append(rateLimits, entry) } if apiKey.RateLimit1d > 0 { - used := rateLimitData.Usage1d - rateLimits = append(rateLimits, gin.H{ + used := rateLimitData.EffectiveUsage1d() + entry := gin.H{ "window": "1d", "limit": apiKey.RateLimit1d, "used": used, "remaining": max(0, apiKey.RateLimit1d-used), "window_start": rateLimitData.Window1dStart, - }) + } + if rateLimitData.Window1dStart != nil && !service.IsWindowExpired(rateLimitData.Window1dStart, service.RateLimitWindow1d) { + entry["reset_at"] = rateLimitData.Window1dStart.Add(service.RateLimitWindow1d) + } + rateLimits = append(rateLimits, entry) } if apiKey.RateLimit7d > 0 { - used := rateLimitData.Usage7d - rateLimits = append(rateLimits, gin.H{ + used := rateLimitData.EffectiveUsage7d() + entry := gin.H{ "window": "7d", "limit": apiKey.RateLimit7d, "used": used, "remaining": max(0, apiKey.RateLimit7d-used), "window_start": rateLimitData.Window7dStart, - }) + } + if rateLimitData.Window7dStart != nil && !service.IsWindowExpired(rateLimitData.Window7dStart, service.RateLimitWindow7d) { + entry["reset_at"] = rateLimitData.Window7dStart.Add(service.RateLimitWindow7d) + } + rateLimits = append(rateLimits, entry) } if len(rateLimits) > 0 { resp["rate_limits"] = rateLimits diff --git a/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go index c07c568d..0c94d50b 100644 --- a/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go +++ b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go @@ -127,6 +127,7 @@ func (f *fakeConcurrencyCache) GetAccountConcurrencyBatch(_ context.Context, acc return result, nil } func (f *fakeConcurrencyCache) CleanupExpiredAccountSlots(context.Context, int64) error { return nil } +func (f *fakeConcurrencyCache) CleanupStaleProcessSlots(context.Context, string) error { return nil } func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*service.Account) (*GatewayHandler, func()) { t.Helper() @@ -155,6 +156,7 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi nil, // sessionLimitCache nil, // rpmCache nil, // digestStore + nil, // settingService ) // RunModeSimple:跳过计费检查,避免引入 repo/cache 依赖。 diff --git a/backend/internal/handler/gateway_helper_fastpath_test.go b/backend/internal/handler/gateway_helper_fastpath_test.go index 31d489f0..c7c0fb6c 100644 --- a/backend/internal/handler/gateway_helper_fastpath_test.go +++ b/backend/internal/handler/gateway_helper_fastpath_test.go @@ -89,6 +89,10 @@ func (m *concurrencyCacheMock) CleanupExpiredAccountSlots(ctx context.Context, a return nil } +func (m *concurrencyCacheMock) CleanupStaleProcessSlots(ctx context.Context, activeRequestPrefix string) error { + return nil +} + func TestConcurrencyHelper_TryAcquireUserSlot(t *testing.T) { cache := &concurrencyCacheMock{ acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) { diff --git a/backend/internal/handler/gateway_helper_hotpath_test.go b/backend/internal/handler/gateway_helper_hotpath_test.go index f8f7eaca..9e904107 100644 --- a/backend/internal/handler/gateway_helper_hotpath_test.go +++ b/backend/internal/handler/gateway_helper_hotpath_test.go @@ -120,6 +120,10 @@ func (s *helperConcurrencyCacheStub) CleanupExpiredAccountSlots(ctx context.Cont return nil } +func (s *helperConcurrencyCacheStub) CleanupStaleProcessSlots(ctx context.Context, activeRequestPrefix string) error { + return nil +} + func newHelperTestContext(method, path string) (*gin.Context, *httptest.ResponseRecorder) { gin.SetMode(gin.TestMode) rec := httptest.NewRecorder() diff --git a/backend/internal/handler/handler.go b/backend/internal/handler/handler.go index 1e1247fc..3f1d73ca 100644 --- a/backend/internal/handler/handler.go +++ b/backend/internal/handler/handler.go @@ -27,6 +27,7 @@ type AdminHandlers struct { UserAttribute *admin.UserAttributeHandler ErrorPassthrough *admin.ErrorPassthroughHandler APIKey *admin.AdminAPIKeyHandler + ScheduledTest *admin.ScheduledTestHandler } // Handlers contains all HTTP handlers diff --git a/backend/internal/handler/openai_chat_completions.go b/backend/internal/handler/openai_chat_completions.go new file mode 100644 index 00000000..6900e7cd --- /dev/null +++ b/backend/internal/handler/openai_chat_completions.go @@ -0,0 +1,290 @@ +package handler + +import ( + "context" + "errors" + "net/http" + "time" + + pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil" + "github.com/Wei-Shaw/sub2api/internal/pkg/ip" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" + "go.uber.org/zap" +) + +// ChatCompletions handles OpenAI Chat Completions API requests. +// POST /v1/chat/completions +func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { + streamStarted := false + defer h.recoverResponsesPanic(c, &streamStarted) + + requestStart := time.Now() + + apiKey, ok := middleware2.GetAPIKeyFromContext(c) + if !ok { + h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key") + return + } + + subject, ok := middleware2.GetAuthSubjectFromContext(c) + if !ok { + h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found") + return + } + reqLog := requestLogger( + c, + "handler.openai_gateway.chat_completions", + zap.Int64("user_id", subject.UserID), + zap.Int64("api_key_id", apiKey.ID), + zap.Any("group_id", apiKey.GroupID), + ) + + if !h.ensureResponsesDependencies(c, reqLog) { + return + } + + body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request) + if err != nil { + if maxErr, ok := extractMaxBytesError(err); ok { + h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit)) + return + } + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body") + return + } + if len(body) == 0 { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty") + return + } + + if !gjson.ValidBytes(body) { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") + return + } + + modelResult := gjson.GetBytes(body, "model") + if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required") + return + } + reqModel := modelResult.String() + reqStream := gjson.GetBytes(body, "stream").Bool() + + reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream)) + + setOpsRequestContext(c, reqModel, reqStream, body) + + if h.errorPassthroughService != nil { + service.BindErrorPassthroughService(c, h.errorPassthroughService) + } + + subscription, _ := middleware2.GetSubscriptionFromContext(c) + + service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds()) + routingStart := time.Now() + + userReleaseFunc, acquired := h.acquireResponsesUserSlot(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted, reqLog) + if !acquired { + return + } + if userReleaseFunc != nil { + defer userReleaseFunc() + } + + if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { + reqLog.Info("openai_chat_completions.billing_eligibility_check_failed", zap.Error(err)) + status, code, message := billingErrorDetails(err) + h.handleStreamingAwareError(c, status, code, message, streamStarted) + return + } + + sessionHash := h.gatewayService.GenerateSessionHash(c, body) + promptCacheKey := h.gatewayService.ExtractSessionID(c, body) + + maxAccountSwitches := h.maxAccountSwitches + switchCount := 0 + failedAccountIDs := make(map[int64]struct{}) + sameAccountRetryCount := make(map[int64]int) + var lastFailoverErr *service.UpstreamFailoverError + + for { + c.Set("openai_chat_completions_fallback_model", "") + reqLog.Debug("openai_chat_completions.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs))) + selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler( + c.Request.Context(), + apiKey.GroupID, + "", + sessionHash, + reqModel, + failedAccountIDs, + service.OpenAIUpstreamTransportAny, + ) + if err != nil { + reqLog.Warn("openai_chat_completions.account_select_failed", + zap.Error(err), + zap.Int("excluded_account_count", len(failedAccountIDs)), + ) + if len(failedAccountIDs) == 0 { + defaultModel := "" + if apiKey.Group != nil { + defaultModel = apiKey.Group.DefaultMappedModel + } + if defaultModel != "" && defaultModel != reqModel { + reqLog.Info("openai_chat_completions.fallback_to_default_model", + zap.String("default_mapped_model", defaultModel), + ) + selection, scheduleDecision, err = h.gatewayService.SelectAccountWithScheduler( + c.Request.Context(), + apiKey.GroupID, + "", + sessionHash, + defaultModel, + failedAccountIDs, + service.OpenAIUpstreamTransportAny, + ) + if err == nil && selection != nil { + c.Set("openai_chat_completions_fallback_model", defaultModel) + } + } + if err != nil { + h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted) + return + } + } else { + if lastFailoverErr != nil { + h.handleFailoverExhausted(c, lastFailoverErr, streamStarted) + } else { + h.handleStreamingAwareError(c, http.StatusBadGateway, "api_error", "Upstream request failed", streamStarted) + } + return + } + } + if selection == nil || selection.Account == nil { + h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted) + return + } + account := selection.Account + sessionHash = ensureOpenAIPoolModeSessionHash(sessionHash, account) + reqLog.Debug("openai_chat_completions.account_selected", zap.Int64("account_id", account.ID), zap.String("account_name", account.Name)) + _ = scheduleDecision + setOpsSelectedAccount(c, account.ID, account.Platform) + + accountReleaseFunc, acquired := h.acquireResponsesAccountSlot(c, apiKey.GroupID, sessionHash, selection, reqStream, &streamStarted, reqLog) + if !acquired { + return + } + + service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds()) + forwardStart := time.Now() + + defaultMappedModel := "" + if apiKey.Group != nil { + defaultMappedModel = apiKey.Group.DefaultMappedModel + } + if fallbackModel := c.GetString("openai_chat_completions_fallback_model"); fallbackModel != "" { + defaultMappedModel = fallbackModel + } + result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel) + + forwardDurationMs := time.Since(forwardStart).Milliseconds() + if accountReleaseFunc != nil { + accountReleaseFunc() + } + upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey) + responseLatencyMs := forwardDurationMs + if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs { + responseLatencyMs = forwardDurationMs - upstreamLatencyMs + } + service.SetOpsLatencyMs(c, service.OpsResponseLatencyMsKey, responseLatencyMs) + if err == nil && result != nil && result.FirstTokenMs != nil { + service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs)) + } + if err != nil { + var failoverErr *service.UpstreamFailoverError + if errors.As(err, &failoverErr) { + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) + // Pool mode: retry on the same account + if failoverErr.RetryableOnSameAccount { + retryLimit := account.GetPoolModeRetryCount() + if sameAccountRetryCount[account.ID] < retryLimit { + sameAccountRetryCount[account.ID]++ + reqLog.Warn("openai_chat_completions.pool_mode_same_account_retry", + zap.Int64("account_id", account.ID), + zap.Int("upstream_status", failoverErr.StatusCode), + zap.Int("retry_limit", retryLimit), + zap.Int("retry_count", sameAccountRetryCount[account.ID]), + ) + select { + case <-c.Request.Context().Done(): + return + case <-time.After(sameAccountRetryDelay): + } + continue + } + } + h.gatewayService.RecordOpenAIAccountSwitch() + failedAccountIDs[account.ID] = struct{}{} + lastFailoverErr = failoverErr + if switchCount >= maxAccountSwitches { + h.handleFailoverExhausted(c, failoverErr, streamStarted) + return + } + switchCount++ + reqLog.Warn("openai_chat_completions.upstream_failover_switching", + zap.Int64("account_id", account.ID), + zap.Int("upstream_status", failoverErr.StatusCode), + zap.Int("switch_count", switchCount), + zap.Int("max_switches", maxAccountSwitches), + ) + continue + } + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) + wroteFallback := h.ensureForwardErrorResponse(c, streamStarted) + reqLog.Warn("openai_chat_completions.forward_failed", + zap.Int64("account_id", account.ID), + zap.Bool("fallback_error_response_written", wroteFallback), + zap.Error(err), + ) + return + } + if result != nil { + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs) + } else { + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil) + } + + userAgent := c.GetHeader("User-Agent") + clientIP := ip.GetClientIP(c) + + h.submitUsageRecordTask(func(ctx context.Context) { + if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{ + Result: result, + APIKey: apiKey, + User: apiKey.User, + Account: account, + Subscription: subscription, + UserAgent: userAgent, + IPAddress: clientIP, + APIKeyService: h.apiKeyService, + }); err != nil { + logger.L().With( + zap.String("component", "handler.openai_gateway.chat_completions"), + zap.Int64("user_id", subject.UserID), + zap.Int64("api_key_id", apiKey.ID), + zap.Any("group_id", apiKey.GroupID), + zap.String("model", reqModel), + zap.Int64("account_id", account.ID), + ).Error("openai_chat_completions.record_usage_failed", zap.Error(err)) + } + }) + reqLog.Debug("openai_chat_completions.request_completed", + zap.Int64("account_id", account.ID), + zap.Int("switch_count", switchCount), + ) + return + } +} diff --git a/backend/internal/handler/openai_gateway_compact_log_test.go b/backend/internal/handler/openai_gateway_compact_log_test.go new file mode 100644 index 00000000..062f318b --- /dev/null +++ b/backend/internal/handler/openai_gateway_compact_log_test.go @@ -0,0 +1,192 @@ +package handler + +import ( + "fmt" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +var handlerStructuredLogCaptureMu sync.Mutex + +type handlerInMemoryLogSink struct { + mu sync.Mutex + events []*logger.LogEvent +} + +func (s *handlerInMemoryLogSink) WriteLogEvent(event *logger.LogEvent) { + if event == nil { + return + } + cloned := *event + if event.Fields != nil { + cloned.Fields = make(map[string]any, len(event.Fields)) + for k, v := range event.Fields { + cloned.Fields[k] = v + } + } + s.mu.Lock() + s.events = append(s.events, &cloned) + s.mu.Unlock() +} + +func (s *handlerInMemoryLogSink) ContainsMessageAtLevel(substr, level string) bool { + s.mu.Lock() + defer s.mu.Unlock() + wantLevel := strings.ToLower(strings.TrimSpace(level)) + for _, ev := range s.events { + if ev == nil { + continue + } + if strings.Contains(ev.Message, substr) && strings.ToLower(strings.TrimSpace(ev.Level)) == wantLevel { + return true + } + } + return false +} + +func (s *handlerInMemoryLogSink) ContainsFieldValue(field, substr string) bool { + s.mu.Lock() + defer s.mu.Unlock() + for _, ev := range s.events { + if ev == nil || ev.Fields == nil { + continue + } + if v, ok := ev.Fields[field]; ok && strings.Contains(fmt.Sprint(v), substr) { + return true + } + } + return false +} + +func captureHandlerStructuredLog(t *testing.T) (*handlerInMemoryLogSink, func()) { + t.Helper() + handlerStructuredLogCaptureMu.Lock() + + err := logger.Init(logger.InitOptions{ + Level: "debug", + Format: "json", + ServiceName: "sub2api", + Environment: "test", + Output: logger.OutputOptions{ + ToStdout: true, + ToFile: false, + }, + Sampling: logger.SamplingOptions{Enabled: false}, + }) + require.NoError(t, err) + + sink := &handlerInMemoryLogSink{} + logger.SetSink(sink) + return sink, func() { + logger.SetSink(nil) + handlerStructuredLogCaptureMu.Unlock() + } +} + +func TestIsOpenAIRemoteCompactPath(t *testing.T) { + require.False(t, isOpenAIRemoteCompactPath(nil)) + + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses/compact", nil) + require.True(t, isOpenAIRemoteCompactPath(c)) + + c.Request = httptest.NewRequest(http.MethodPost, "/responses/compact/", nil) + require.True(t, isOpenAIRemoteCompactPath(c)) + + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + require.False(t, isOpenAIRemoteCompactPath(c)) +} + +func TestLogOpenAIRemoteCompactOutcome_Succeeded(t *testing.T) { + gin.SetMode(gin.TestMode) + logSink, restore := captureHandlerStructuredLog(t) + defer restore() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses/compact", nil) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.104.0") + c.Set(opsModelKey, "gpt-5.3-codex") + c.Set(opsAccountIDKey, int64(123)) + c.Header("x-request-id", "rid-compact-ok") + c.Status(http.StatusOK) + + h := &OpenAIGatewayHandler{} + h.logOpenAIRemoteCompactOutcome(c, time.Now().Add(-8*time.Millisecond)) + + require.True(t, logSink.ContainsMessageAtLevel("codex.remote_compact.succeeded", "info")) + require.True(t, logSink.ContainsFieldValue("compact_outcome", "succeeded")) + require.True(t, logSink.ContainsFieldValue("status_code", "200")) + require.True(t, logSink.ContainsFieldValue("path", "/v1/responses/compact")) + require.True(t, logSink.ContainsFieldValue("request_model", "gpt-5.3-codex")) + require.True(t, logSink.ContainsFieldValue("account_id", "123")) + require.True(t, logSink.ContainsFieldValue("upstream_request_id", "rid-compact-ok")) +} + +func TestLogOpenAIRemoteCompactOutcome_Failed(t *testing.T) { + gin.SetMode(gin.TestMode) + logSink, restore := captureHandlerStructuredLog(t) + defer restore() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/responses/compact", nil) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.104.0") + c.Status(http.StatusBadGateway) + + h := &OpenAIGatewayHandler{} + h.logOpenAIRemoteCompactOutcome(c, time.Now()) + + require.True(t, logSink.ContainsMessageAtLevel("codex.remote_compact.failed", "warn")) + require.True(t, logSink.ContainsFieldValue("compact_outcome", "failed")) + require.True(t, logSink.ContainsFieldValue("status_code", "502")) + require.True(t, logSink.ContainsFieldValue("path", "/responses/compact")) +} + +func TestLogOpenAIRemoteCompactOutcome_NonCompactSkips(t *testing.T) { + gin.SetMode(gin.TestMode) + logSink, restore := captureHandlerStructuredLog(t) + defer restore() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + c.Status(http.StatusOK) + + h := &OpenAIGatewayHandler{} + h.logOpenAIRemoteCompactOutcome(c, time.Now()) + + require.False(t, logSink.ContainsMessageAtLevel("codex.remote_compact.succeeded", "info")) + require.False(t, logSink.ContainsMessageAtLevel("codex.remote_compact.failed", "warn")) +} + +func TestOpenAIResponses_CompactUnauthorizedLogsFailed(t *testing.T) { + gin.SetMode(gin.TestMode) + logSink, restore := captureHandlerStructuredLog(t) + defer restore() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses/compact", strings.NewReader(`{"model":"gpt-5.3-codex"}`)) + c.Request.Header.Set("Content-Type", "application/json") + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.104.0") + + h := &OpenAIGatewayHandler{} + h.Responses(c) + + require.Equal(t, http.StatusUnauthorized, rec.Code) + require.True(t, logSink.ContainsMessageAtLevel("codex.remote_compact.failed", "warn")) + require.True(t, logSink.ContainsFieldValue("status_code", "401")) + require.True(t, logSink.ContainsFieldValue("path", "/v1/responses/compact")) +} diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index 4bbd17ba..8567b52b 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -20,6 +20,7 @@ import ( coderws "github.com/coder/websocket" "github.com/gin-gonic/gin" + "github.com/google/uuid" "github.com/tidwall/gjson" "go.uber.org/zap" ) @@ -33,6 +34,7 @@ type OpenAIGatewayHandler struct { errorPassthroughService *service.ErrorPassthroughService concurrencyHelper *ConcurrencyHelper maxAccountSwitches int + cfg *config.Config } // NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler @@ -61,6 +63,7 @@ func NewOpenAIGatewayHandler( errorPassthroughService: errorPassthroughService, concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval), maxAccountSwitches: maxAccountSwitches, + cfg: cfg, } } @@ -70,6 +73,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { // 局部兜底:确保该 handler 内部任何 panic 都不会击穿到进程级。 streamStarted := false defer h.recoverResponsesPanic(c, &streamStarted) + compactStartedAt := time.Now() + defer h.logOpenAIRemoteCompactOutcome(c, compactStartedAt) setOpenAIClientTransportHTTP(c) requestStart := time.Now() @@ -114,6 +119,20 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { } setOpsRequestContext(c, "", false, body) + sessionHashBody := body + if service.IsOpenAIResponsesCompactPathForTest(c) { + if compactSeed := strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String()); compactSeed != "" { + c.Set(service.OpenAICompactSessionSeedKeyForTest(), compactSeed) + } + normalizedCompactBody, normalizedCompact, compactErr := service.NormalizeOpenAICompactRequestBodyForTest(body) + if compactErr != nil { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to normalize compact request body") + return + } + if normalizedCompact { + body = normalizedCompactBody + } + } // 校验请求体 JSON 合法性 if !gjson.ValidBytes(body) { @@ -189,11 +208,12 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { } // Generate session hash (header first; fallback to prompt_cache_key) - sessionHash := h.gatewayService.GenerateSessionHash(c, body) + sessionHash := h.gatewayService.GenerateSessionHash(c, sessionHashBody) maxAccountSwitches := h.maxAccountSwitches switchCount := 0 failedAccountIDs := make(map[int64]struct{}) + sameAccountRetryCount := make(map[int64]int) var lastFailoverErr *service.UpstreamFailoverError for { @@ -241,6 +261,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { zap.Float64("load_skew", scheduleDecision.LoadSkew), ) account := selection.Account + sessionHash = ensureOpenAIPoolModeSessionHash(sessionHash, account) reqLog.Debug("openai.account_selected", zap.Int64("account_id", account.ID), zap.String("account_name", account.Name)) setOpsSelectedAccount(c, account.ID, account.Platform) @@ -270,6 +291,25 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { var failoverErr *service.UpstreamFailoverError if errors.As(err, &failoverErr) { h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) + // 池模式:同账号重试 + if failoverErr.RetryableOnSameAccount { + retryLimit := account.GetPoolModeRetryCount() + if sameAccountRetryCount[account.ID] < retryLimit { + sameAccountRetryCount[account.ID]++ + reqLog.Warn("openai.pool_mode_same_account_retry", + zap.Int64("account_id", account.ID), + zap.Int("upstream_status", failoverErr.StatusCode), + zap.Int("retry_limit", retryLimit), + zap.Int("retry_count", sameAccountRetryCount[account.ID]), + ) + select { + case <-c.Request.Context().Done(): + return + case <-time.After(sameAccountRetryDelay): + } + continue + } + } h.gatewayService.RecordOpenAIAccountSwitch() failedAccountIDs[account.ID] = struct{}{} lastFailoverErr = failoverErr @@ -301,6 +341,9 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { return } if result != nil { + if account.Type == service.AccountTypeOAuth { + h.gatewayService.UpdateCodexUsageSnapshotFromHeaders(c.Request.Context(), account.ID, result.ResponseHeaders) + } h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs) } else { h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil) @@ -340,6 +383,432 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { } } +func isOpenAIRemoteCompactPath(c *gin.Context) bool { + if c == nil || c.Request == nil || c.Request.URL == nil { + return false + } + normalizedPath := strings.TrimRight(strings.TrimSpace(c.Request.URL.Path), "/") + return strings.HasSuffix(normalizedPath, "/responses/compact") +} + +func (h *OpenAIGatewayHandler) logOpenAIRemoteCompactOutcome(c *gin.Context, startedAt time.Time) { + if !isOpenAIRemoteCompactPath(c) { + return + } + + var ( + ctx = context.Background() + path string + status int + ) + if c != nil { + if c.Request != nil { + ctx = c.Request.Context() + if c.Request.URL != nil { + path = strings.TrimSpace(c.Request.URL.Path) + } + } + if c.Writer != nil { + status = c.Writer.Status() + } + } + + outcome := "failed" + if status >= 200 && status < 300 { + outcome = "succeeded" + } + latencyMs := time.Since(startedAt).Milliseconds() + if latencyMs < 0 { + latencyMs = 0 + } + + fields := []zap.Field{ + zap.String("component", "handler.openai_gateway.responses"), + zap.Bool("remote_compact", true), + zap.String("compact_outcome", outcome), + zap.Int("status_code", status), + zap.Int64("latency_ms", latencyMs), + zap.String("path", path), + zap.Bool("force_codex_cli", h != nil && h.cfg != nil && h.cfg.Gateway.ForceCodexCLI), + } + + if c != nil { + if userAgent := strings.TrimSpace(c.GetHeader("User-Agent")); userAgent != "" { + fields = append(fields, zap.String("request_user_agent", userAgent)) + } + if v, ok := c.Get(opsModelKey); ok { + if model, ok := v.(string); ok && strings.TrimSpace(model) != "" { + fields = append(fields, zap.String("request_model", strings.TrimSpace(model))) + } + } + if v, ok := c.Get(opsAccountIDKey); ok { + if accountID, ok := v.(int64); ok && accountID > 0 { + fields = append(fields, zap.Int64("account_id", accountID)) + } + } + if c.Writer != nil { + if upstreamRequestID := strings.TrimSpace(c.Writer.Header().Get("x-request-id")); upstreamRequestID != "" { + fields = append(fields, zap.String("upstream_request_id", upstreamRequestID)) + } else if upstreamRequestID := strings.TrimSpace(c.Writer.Header().Get("X-Request-Id")); upstreamRequestID != "" { + fields = append(fields, zap.String("upstream_request_id", upstreamRequestID)) + } + } + } + + log := logger.FromContext(ctx).With(fields...) + if outcome == "succeeded" { + log.Info("codex.remote_compact.succeeded") + return + } + log.Warn("codex.remote_compact.failed") +} + +// Messages handles Anthropic Messages API requests routed to OpenAI platform. +// POST /v1/messages (when group platform is OpenAI) +func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { + streamStarted := false + defer h.recoverAnthropicMessagesPanic(c, &streamStarted) + + requestStart := time.Now() + + apiKey, ok := middleware2.GetAPIKeyFromContext(c) + if !ok { + h.anthropicErrorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key") + return + } + + subject, ok := middleware2.GetAuthSubjectFromContext(c) + if !ok { + h.anthropicErrorResponse(c, http.StatusInternalServerError, "api_error", "User context not found") + return + } + reqLog := requestLogger( + c, + "handler.openai_gateway.messages", + zap.Int64("user_id", subject.UserID), + zap.Int64("api_key_id", apiKey.ID), + zap.Any("group_id", apiKey.GroupID), + ) + + // 检查分组是否允许 /v1/messages 调度 + if apiKey.Group != nil && !apiKey.Group.AllowMessagesDispatch { + h.anthropicErrorResponse(c, http.StatusForbidden, "permission_error", + "This group does not allow /v1/messages dispatch") + return + } + + if !h.ensureResponsesDependencies(c, reqLog) { + return + } + + body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request) + if err != nil { + if maxErr, ok := extractMaxBytesError(err); ok { + h.anthropicErrorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit)) + return + } + h.anthropicErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body") + return + } + if len(body) == 0 { + h.anthropicErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty") + return + } + + if !gjson.ValidBytes(body) { + h.anthropicErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") + return + } + + modelResult := gjson.GetBytes(body, "model") + if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" { + h.anthropicErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required") + return + } + reqModel := modelResult.String() + reqStream := gjson.GetBytes(body, "stream").Bool() + + reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream)) + + setOpsRequestContext(c, reqModel, reqStream, body) + + // 绑定错误透传服务,允许 service 层在非 failover 错误场景复用规则。 + if h.errorPassthroughService != nil { + service.BindErrorPassthroughService(c, h.errorPassthroughService) + } + + subscription, _ := middleware2.GetSubscriptionFromContext(c) + + service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds()) + routingStart := time.Now() + + userReleaseFunc, acquired := h.acquireResponsesUserSlot(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted, reqLog) + if !acquired { + return + } + if userReleaseFunc != nil { + defer userReleaseFunc() + } + + if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { + reqLog.Info("openai_messages.billing_eligibility_check_failed", zap.Error(err)) + status, code, message := billingErrorDetails(err) + h.anthropicStreamingAwareError(c, status, code, message, streamStarted) + return + } + + sessionHash := h.gatewayService.GenerateSessionHash(c, body) + promptCacheKey := h.gatewayService.ExtractSessionID(c, body) + + // Anthropic 格式的请求在 metadata.user_id 中携带 session 标识, + // 而非 OpenAI 的 session_id/conversation_id headers。 + // 从中派生 sessionHash(sticky session)和 promptCacheKey(upstream cache)。 + if sessionHash == "" || promptCacheKey == "" { + if userID := strings.TrimSpace(gjson.GetBytes(body, "metadata.user_id").String()); userID != "" { + seed := reqModel + "-" + userID + if promptCacheKey == "" { + promptCacheKey = service.GenerateSessionUUID(seed) + } + if sessionHash == "" { + sessionHash = service.DeriveSessionHashFromSeed(seed) + } + } + } + + maxAccountSwitches := h.maxAccountSwitches + switchCount := 0 + failedAccountIDs := make(map[int64]struct{}) + sameAccountRetryCount := make(map[int64]int) + var lastFailoverErr *service.UpstreamFailoverError + + for { + // 清除上一次迭代的降级模型标记,避免残留影响本次迭代 + c.Set("openai_messages_fallback_model", "") + reqLog.Debug("openai_messages.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs))) + selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler( + c.Request.Context(), + apiKey.GroupID, + "", // no previous_response_id + sessionHash, + reqModel, + failedAccountIDs, + service.OpenAIUpstreamTransportAny, + ) + if err != nil { + reqLog.Warn("openai_messages.account_select_failed", + zap.Error(err), + zap.Int("excluded_account_count", len(failedAccountIDs)), + ) + // 首次调度失败 + 有默认映射模型 → 用默认模型重试 + if len(failedAccountIDs) == 0 { + defaultModel := "" + if apiKey.Group != nil { + defaultModel = apiKey.Group.DefaultMappedModel + } + if defaultModel != "" && defaultModel != reqModel { + reqLog.Info("openai_messages.fallback_to_default_model", + zap.String("default_mapped_model", defaultModel), + ) + selection, scheduleDecision, err = h.gatewayService.SelectAccountWithScheduler( + c.Request.Context(), + apiKey.GroupID, + "", + sessionHash, + defaultModel, + failedAccountIDs, + service.OpenAIUpstreamTransportAny, + ) + if err == nil && selection != nil { + c.Set("openai_messages_fallback_model", defaultModel) + } + } + if err != nil { + h.anthropicStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted) + return + } + } else { + if lastFailoverErr != nil { + h.handleAnthropicFailoverExhausted(c, lastFailoverErr, streamStarted) + } else { + h.anthropicStreamingAwareError(c, http.StatusBadGateway, "api_error", "Upstream request failed", streamStarted) + } + return + } + } + if selection == nil || selection.Account == nil { + h.anthropicStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted) + return + } + account := selection.Account + sessionHash = ensureOpenAIPoolModeSessionHash(sessionHash, account) + reqLog.Debug("openai_messages.account_selected", zap.Int64("account_id", account.ID), zap.String("account_name", account.Name)) + _ = scheduleDecision + setOpsSelectedAccount(c, account.ID, account.Platform) + + accountReleaseFunc, acquired := h.acquireResponsesAccountSlot(c, apiKey.GroupID, sessionHash, selection, reqStream, &streamStarted, reqLog) + if !acquired { + return + } + + service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds()) + forwardStart := time.Now() + + defaultMappedModel := "" + if apiKey.Group != nil { + defaultMappedModel = apiKey.Group.DefaultMappedModel + } + // 如果使用了降级模型调度,强制使用降级模型 + if fallbackModel := c.GetString("openai_messages_fallback_model"); fallbackModel != "" { + defaultMappedModel = fallbackModel + } + result, err := h.gatewayService.ForwardAsAnthropic(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel) + + forwardDurationMs := time.Since(forwardStart).Milliseconds() + if accountReleaseFunc != nil { + accountReleaseFunc() + } + upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey) + responseLatencyMs := forwardDurationMs + if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs { + responseLatencyMs = forwardDurationMs - upstreamLatencyMs + } + service.SetOpsLatencyMs(c, service.OpsResponseLatencyMsKey, responseLatencyMs) + if err == nil && result != nil && result.FirstTokenMs != nil { + service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs)) + } + if err != nil { + var failoverErr *service.UpstreamFailoverError + if errors.As(err, &failoverErr) { + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) + // 池模式:同账号重试 + if failoverErr.RetryableOnSameAccount { + retryLimit := account.GetPoolModeRetryCount() + if sameAccountRetryCount[account.ID] < retryLimit { + sameAccountRetryCount[account.ID]++ + reqLog.Warn("openai_messages.pool_mode_same_account_retry", + zap.Int64("account_id", account.ID), + zap.Int("upstream_status", failoverErr.StatusCode), + zap.Int("retry_limit", retryLimit), + zap.Int("retry_count", sameAccountRetryCount[account.ID]), + ) + select { + case <-c.Request.Context().Done(): + return + case <-time.After(sameAccountRetryDelay): + } + continue + } + } + h.gatewayService.RecordOpenAIAccountSwitch() + failedAccountIDs[account.ID] = struct{}{} + lastFailoverErr = failoverErr + if switchCount >= maxAccountSwitches { + h.handleAnthropicFailoverExhausted(c, failoverErr, streamStarted) + return + } + switchCount++ + reqLog.Warn("openai_messages.upstream_failover_switching", + zap.Int64("account_id", account.ID), + zap.Int("upstream_status", failoverErr.StatusCode), + zap.Int("switch_count", switchCount), + zap.Int("max_switches", maxAccountSwitches), + ) + continue + } + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) + wroteFallback := h.ensureAnthropicErrorResponse(c, streamStarted) + reqLog.Warn("openai_messages.forward_failed", + zap.Int64("account_id", account.ID), + zap.Bool("fallback_error_response_written", wroteFallback), + zap.Error(err), + ) + return + } + if result != nil { + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs) + } else { + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil) + } + + userAgent := c.GetHeader("User-Agent") + clientIP := ip.GetClientIP(c) + + h.submitUsageRecordTask(func(ctx context.Context) { + if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{ + Result: result, + APIKey: apiKey, + User: apiKey.User, + Account: account, + Subscription: subscription, + UserAgent: userAgent, + IPAddress: clientIP, + APIKeyService: h.apiKeyService, + }); err != nil { + logger.L().With( + zap.String("component", "handler.openai_gateway.messages"), + zap.Int64("user_id", subject.UserID), + zap.Int64("api_key_id", apiKey.ID), + zap.Any("group_id", apiKey.GroupID), + zap.String("model", reqModel), + zap.Int64("account_id", account.ID), + ).Error("openai_messages.record_usage_failed", zap.Error(err)) + } + }) + reqLog.Debug("openai_messages.request_completed", + zap.Int64("account_id", account.ID), + zap.Int("switch_count", switchCount), + ) + return + } +} + +// anthropicErrorResponse writes an error in Anthropic Messages API format. +func (h *OpenAIGatewayHandler) anthropicErrorResponse(c *gin.Context, status int, errType, message string) { + c.JSON(status, gin.H{ + "type": "error", + "error": gin.H{ + "type": errType, + "message": message, + }, + }) +} + +// anthropicStreamingAwareError handles errors that may occur during streaming, +// using Anthropic SSE error format. +func (h *OpenAIGatewayHandler) anthropicStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) { + if streamStarted { + flusher, ok := c.Writer.(http.Flusher) + if ok { + errPayload, _ := json.Marshal(gin.H{ + "type": "error", + "error": gin.H{ + "type": errType, + "message": message, + }, + }) + fmt.Fprintf(c.Writer, "event: error\ndata: %s\n\n", errPayload) //nolint:errcheck + flusher.Flush() + } + return + } + h.anthropicErrorResponse(c, status, errType, message) +} + +// handleAnthropicFailoverExhausted maps upstream failover errors to Anthropic format. +func (h *OpenAIGatewayHandler) handleAnthropicFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, streamStarted bool) { + status, errType, errMsg := h.mapUpstreamError(failoverErr.StatusCode) + h.anthropicStreamingAwareError(c, status, errType, errMsg, streamStarted) +} + +// ensureAnthropicErrorResponse writes a fallback Anthropic error if no response was written. +func (h *OpenAIGatewayHandler) ensureAnthropicErrorResponse(c *gin.Context, streamStarted bool) bool { + if c == nil || c.Writer == nil || c.Writer.Written() { + return false + } + h.anthropicStreamingAwareError(c, http.StatusBadGateway, "api_error", "Upstream request failed", streamStarted) + return true +} + func (h *OpenAIGatewayHandler) validateFunctionCallOutputRequest(c *gin.Context, body []byte, reqLog *zap.Logger) bool { if !gjson.GetBytes(body, `input.#(type=="function_call_output")`).Exists() { return true @@ -756,6 +1225,9 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { if turnErr != nil || result == nil { return } + if account.Type == service.AccountTypeOAuth { + h.gatewayService.UpdateCodexUsageSnapshotFromHeaders(ctx, account.ID, result.ResponseHeaders) + } h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs) h.submitUsageRecordTask(func(taskCtx context.Context) { if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{ @@ -817,6 +1289,26 @@ func (h *OpenAIGatewayHandler) recoverResponsesPanic(c *gin.Context, streamStart ) } +// recoverAnthropicMessagesPanic recovers from panics in the Anthropic Messages +// handler and returns an Anthropic-formatted error response. +func (h *OpenAIGatewayHandler) recoverAnthropicMessagesPanic(c *gin.Context, streamStarted *bool) { + recovered := recover() + if recovered == nil { + return + } + + started := streamStarted != nil && *streamStarted + requestLogger(c, "handler.openai_gateway.messages").Error( + "openai.messages_panic_recovered", + zap.Bool("stream_started", started), + zap.Any("panic", recovered), + zap.ByteString("stack", debug.Stack()), + ) + if !started { + h.anthropicErrorResponse(c, http.StatusInternalServerError, "api_error", "Internal server error") + } +} + func (h *OpenAIGatewayHandler) ensureResponsesDependencies(c *gin.Context, reqLog *zap.Logger) bool { missing := h.missingResponsesDependencies() if len(missing) == 0 { @@ -1022,6 +1514,14 @@ func setOpenAIClientTransportWS(c *gin.Context) { service.SetOpenAIClientTransport(c, service.OpenAIClientTransportWS) } +func ensureOpenAIPoolModeSessionHash(sessionHash string, account *service.Account) string { + if sessionHash != "" || account == nil || !account.IsPoolMode() { + return sessionHash + } + // 为当前请求生成一次性粘性会话键,确保同账号重试不会重新负载均衡到其他账号。 + return "openai-pool-retry-" + uuid.NewString() +} + func openAIWSIngressFallbackSessionSeed(userID, apiKeyID int64, groupID *int64) string { gid := int64(0) if groupID != nil { diff --git a/backend/internal/handler/ops_error_logger.go b/backend/internal/handler/ops_error_logger.go index 2f53d655..cb2fad5d 100644 --- a/backend/internal/handler/ops_error_logger.go +++ b/backend/internal/handler/ops_error_logger.go @@ -31,6 +31,7 @@ const ( const ( opsErrorLogTimeout = 5 * time.Second opsErrorLogDrainTimeout = 10 * time.Second + opsErrorLogBatchWindow = 200 * time.Millisecond opsErrorLogMinWorkerCount = 4 opsErrorLogMaxWorkerCount = 32 @@ -38,6 +39,7 @@ const ( opsErrorLogQueueSizePerWorker = 128 opsErrorLogMinQueueSize = 256 opsErrorLogMaxQueueSize = 8192 + opsErrorLogBatchSize = 32 ) type opsErrorLogJob struct { @@ -82,27 +84,82 @@ func startOpsErrorLogWorkers() { for i := 0; i < workerCount; i++ { go func() { defer opsErrorLogWorkersWg.Done() - for job := range opsErrorLogQueue { - opsErrorLogQueueLen.Add(-1) - if job.ops == nil || job.entry == nil { - continue + for { + job, ok := <-opsErrorLogQueue + if !ok { + return } - func() { - defer func() { - if r := recover(); r != nil { - log.Printf("[OpsErrorLogger] worker panic: %v\n%s", r, debug.Stack()) + opsErrorLogQueueLen.Add(-1) + batch := make([]opsErrorLogJob, 0, opsErrorLogBatchSize) + batch = append(batch, job) + + timer := time.NewTimer(opsErrorLogBatchWindow) + batchLoop: + for len(batch) < opsErrorLogBatchSize { + select { + case nextJob, ok := <-opsErrorLogQueue: + if !ok { + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + flushOpsErrorLogBatch(batch) + return } - }() - ctx, cancel := context.WithTimeout(context.Background(), opsErrorLogTimeout) - _ = job.ops.RecordError(ctx, job.entry, nil) - cancel() - opsErrorLogProcessed.Add(1) - }() + opsErrorLogQueueLen.Add(-1) + batch = append(batch, nextJob) + case <-timer.C: + break batchLoop + } + } + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + flushOpsErrorLogBatch(batch) } }() } } +func flushOpsErrorLogBatch(batch []opsErrorLogJob) { + if len(batch) == 0 { + return + } + defer func() { + if r := recover(); r != nil { + log.Printf("[OpsErrorLogger] worker panic: %v\n%s", r, debug.Stack()) + } + }() + + grouped := make(map[*service.OpsService][]*service.OpsInsertErrorLogInput, len(batch)) + var processed int64 + for _, job := range batch { + if job.ops == nil || job.entry == nil { + continue + } + grouped[job.ops] = append(grouped[job.ops], job.entry) + processed++ + } + if processed == 0 { + return + } + + for opsSvc, entries := range grouped { + if opsSvc == nil || len(entries) == 0 { + continue + } + ctx, cancel := context.WithTimeout(context.Background(), opsErrorLogTimeout) + _ = opsSvc.RecordErrorBatch(ctx, entries) + cancel() + } + opsErrorLogProcessed.Add(processed) +} + func enqueueOpsErrorLog(ops *service.OpsService, entry *service.OpsInsertErrorLogInput) { if ops == nil || entry == nil { return diff --git a/backend/internal/handler/setting_handler.go b/backend/internal/handler/setting_handler.go index a48eaf31..1188d55e 100644 --- a/backend/internal/handler/setting_handler.go +++ b/backend/internal/handler/setting_handler.go @@ -32,27 +32,28 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) { } response.Success(c, dto.PublicSettings{ - RegistrationEnabled: settings.RegistrationEnabled, - EmailVerifyEnabled: settings.EmailVerifyEnabled, - PromoCodeEnabled: settings.PromoCodeEnabled, - PasswordResetEnabled: settings.PasswordResetEnabled, - InvitationCodeEnabled: settings.InvitationCodeEnabled, - TotpEnabled: settings.TotpEnabled, - TurnstileEnabled: settings.TurnstileEnabled, - TurnstileSiteKey: settings.TurnstileSiteKey, - SiteName: settings.SiteName, - SiteLogo: settings.SiteLogo, - SiteSubtitle: settings.SiteSubtitle, - APIBaseURL: settings.APIBaseURL, - ContactInfo: settings.ContactInfo, - DocURL: settings.DocURL, - HomeContent: settings.HomeContent, - HideCcsImportButton: settings.HideCcsImportButton, - PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled, - PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL, - CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems), - LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, - SoraClientEnabled: settings.SoraClientEnabled, - Version: h.version, + RegistrationEnabled: settings.RegistrationEnabled, + EmailVerifyEnabled: settings.EmailVerifyEnabled, + RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist, + PromoCodeEnabled: settings.PromoCodeEnabled, + PasswordResetEnabled: settings.PasswordResetEnabled, + InvitationCodeEnabled: settings.InvitationCodeEnabled, + TotpEnabled: settings.TotpEnabled, + TurnstileEnabled: settings.TurnstileEnabled, + TurnstileSiteKey: settings.TurnstileSiteKey, + SiteName: settings.SiteName, + SiteLogo: settings.SiteLogo, + SiteSubtitle: settings.SiteSubtitle, + APIBaseURL: settings.APIBaseURL, + ContactInfo: settings.ContactInfo, + DocURL: settings.DocURL, + HomeContent: settings.HomeContent, + HideCcsImportButton: settings.HideCcsImportButton, + PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled, + PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL, + CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems), + LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, + SoraClientEnabled: settings.SoraClientEnabled, + Version: h.version, }) } diff --git a/backend/internal/handler/sora_client_handler_test.go b/backend/internal/handler/sora_client_handler_test.go index d2d9790d..30a761bd 100644 --- a/backend/internal/handler/sora_client_handler_test.go +++ b/backend/internal/handler/sora_client_handler_test.go @@ -2132,6 +2132,14 @@ func (r *stubAccountRepoForHandler) BulkUpdate(context.Context, []int64, service return 0, nil } +func (r *stubAccountRepoForHandler) IncrementQuotaUsed(context.Context, int64, float64) error { + return nil +} + +func (r *stubAccountRepoForHandler) ResetQuotaUsed(context.Context, int64) error { + return nil +} + // ==================== Stub: SoraClient (用于 SoraGatewayService) ==================== var _ service.SoraClient = (*stubSoraClientForHandler)(nil) @@ -2199,7 +2207,7 @@ func (s *stubSoraClientForHandler) GetVideoTask(_ context.Context, _ *service.Ac func newMinimalGatewayService(accountRepo service.AccountRepository) *service.GatewayService { return service.NewGatewayService( accountRepo, nil, nil, nil, nil, nil, nil, nil, - nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, + nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, ) } diff --git a/backend/internal/handler/sora_gateway_handler_test.go b/backend/internal/handler/sora_gateway_handler_test.go index b76ab67d..688c5d12 100644 --- a/backend/internal/handler/sora_gateway_handler_test.go +++ b/backend/internal/handler/sora_gateway_handler_test.go @@ -216,6 +216,14 @@ func (r *stubAccountRepo) BulkUpdate(ctx context.Context, ids []int64, updates s return 0, nil } +func (r *stubAccountRepo) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error { + return nil +} + +func (r *stubAccountRepo) ResetQuotaUsed(ctx context.Context, id int64) error { + return nil +} + func (r *stubAccountRepo) listSchedulable() []service.Account { var result []service.Account for _, acc := range r.accounts { @@ -437,6 +445,7 @@ func TestSoraGatewayHandler_ChatCompletions(t *testing.T) { testutil.StubSessionLimitCache{}, nil, // rpmCache nil, // digestStore + nil, // settingService ) soraClient := &stubSoraClient{imageURLs: []string{"https://example.com/a.png"}} diff --git a/backend/internal/handler/wire.go b/backend/internal/handler/wire.go index 76f5a979..d1e12e03 100644 --- a/backend/internal/handler/wire.go +++ b/backend/internal/handler/wire.go @@ -30,6 +30,7 @@ func ProvideAdminHandlers( userAttributeHandler *admin.UserAttributeHandler, errorPassthroughHandler *admin.ErrorPassthroughHandler, apiKeyHandler *admin.AdminAPIKeyHandler, + scheduledTestHandler *admin.ScheduledTestHandler, ) *AdminHandlers { return &AdminHandlers{ Dashboard: dashboardHandler, @@ -53,6 +54,7 @@ func ProvideAdminHandlers( UserAttribute: userAttributeHandler, ErrorPassthrough: errorPassthroughHandler, APIKey: apiKeyHandler, + ScheduledTest: scheduledTestHandler, } } @@ -141,6 +143,7 @@ var ProviderSet = wire.NewSet( admin.NewUserAttributeHandler, admin.NewErrorPassthroughHandler, admin.NewAdminAPIKeyHandler, + admin.NewScheduledTestHandler, // AdminHandlers and Handlers constructors ProvideAdminHandlers, diff --git a/backend/internal/pkg/antigravity/claude_types.go b/backend/internal/pkg/antigravity/claude_types.go index 7cc68060..8ea87f18 100644 --- a/backend/internal/pkg/antigravity/claude_types.go +++ b/backend/internal/pkg/antigravity/claude_types.go @@ -159,6 +159,8 @@ var claudeModels = []modelDef{ // Antigravity 支持的 Gemini 模型 var geminiModels = []modelDef{ {ID: "gemini-2.5-flash", DisplayName: "Gemini 2.5 Flash", CreatedAt: "2025-01-01T00:00:00Z"}, + {ID: "gemini-2.5-flash-image", DisplayName: "Gemini 2.5 Flash Image", CreatedAt: "2025-01-01T00:00:00Z"}, + {ID: "gemini-2.5-flash-image-preview", DisplayName: "Gemini 2.5 Flash Image Preview", CreatedAt: "2025-01-01T00:00:00Z"}, {ID: "gemini-2.5-flash-lite", DisplayName: "Gemini 2.5 Flash Lite", CreatedAt: "2025-01-01T00:00:00Z"}, {ID: "gemini-2.5-flash-thinking", DisplayName: "Gemini 2.5 Flash Thinking", CreatedAt: "2025-01-01T00:00:00Z"}, {ID: "gemini-3-flash", DisplayName: "Gemini 3 Flash", CreatedAt: "2025-06-01T00:00:00Z"}, diff --git a/backend/internal/pkg/antigravity/claude_types_test.go b/backend/internal/pkg/antigravity/claude_types_test.go index f7cb0a24..9fc09b1b 100644 --- a/backend/internal/pkg/antigravity/claude_types_test.go +++ b/backend/internal/pkg/antigravity/claude_types_test.go @@ -13,6 +13,8 @@ func TestDefaultModels_ContainsNewAndLegacyImageModels(t *testing.T) { requiredIDs := []string{ "claude-opus-4-6-thinking", + "gemini-2.5-flash-image", + "gemini-2.5-flash-image-preview", "gemini-3.1-flash-image", "gemini-3.1-flash-image-preview", "gemini-3-pro-image", // legacy compatibility diff --git a/backend/internal/pkg/antigravity/oauth.go b/backend/internal/pkg/antigravity/oauth.go index afffe9b1..5bda31ac 100644 --- a/backend/internal/pkg/antigravity/oauth.go +++ b/backend/internal/pkg/antigravity/oauth.go @@ -49,8 +49,8 @@ const ( antigravityDailyBaseURL = "https://daily-cloudcode-pa.sandbox.googleapis.com" ) -// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.19.6 -var defaultUserAgentVersion = "1.19.6" +// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.20.4 +var defaultUserAgentVersion = "1.20.4" // defaultClientSecret 可通过环境变量 ANTIGRAVITY_OAUTH_CLIENT_SECRET 配置 var defaultClientSecret = "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" diff --git a/backend/internal/pkg/antigravity/oauth_test.go b/backend/internal/pkg/antigravity/oauth_test.go index 743e2a33..f4630b09 100644 --- a/backend/internal/pkg/antigravity/oauth_test.go +++ b/backend/internal/pkg/antigravity/oauth_test.go @@ -690,7 +690,7 @@ func TestConstants_值正确(t *testing.T) { if RedirectURI != "http://localhost:8085/callback" { t.Errorf("RedirectURI 不匹配: got %s", RedirectURI) } - if GetUserAgent() != "antigravity/1.19.6 windows/amd64" { + if GetUserAgent() != "antigravity/1.20.4 windows/amd64" { t.Errorf("UserAgent 不匹配: got %s", GetUserAgent()) } if SessionTTL != 30*time.Minute { diff --git a/backend/internal/pkg/antigravity/stream_transformer.go b/backend/internal/pkg/antigravity/stream_transformer.go index 677435ad..ee600c8b 100644 --- a/backend/internal/pkg/antigravity/stream_transformer.go +++ b/backend/internal/pkg/antigravity/stream_transformer.go @@ -18,6 +18,9 @@ const ( BlockTypeFunction ) +// UsageMapHook is a callback that can modify usage data before it's emitted in SSE events. +type UsageMapHook func(usageMap map[string]any) + // StreamingProcessor 流式响应处理器 type StreamingProcessor struct { blockType BlockType @@ -30,6 +33,7 @@ type StreamingProcessor struct { originalModel string webSearchQueries []string groundingChunks []GeminiGroundingChunk + usageMapHook UsageMapHook // 累计 usage inputTokens int @@ -45,6 +49,25 @@ func NewStreamingProcessor(originalModel string) *StreamingProcessor { } } +// SetUsageMapHook sets an optional hook that modifies usage maps before they are emitted. +func (p *StreamingProcessor) SetUsageMapHook(fn UsageMapHook) { + p.usageMapHook = fn +} + +func usageToMap(u ClaudeUsage) map[string]any { + m := map[string]any{ + "input_tokens": u.InputTokens, + "output_tokens": u.OutputTokens, + } + if u.CacheCreationInputTokens > 0 { + m["cache_creation_input_tokens"] = u.CacheCreationInputTokens + } + if u.CacheReadInputTokens > 0 { + m["cache_read_input_tokens"] = u.CacheReadInputTokens + } + return m +} + // ProcessLine 处理 SSE 行,返回 Claude SSE 事件 func (p *StreamingProcessor) ProcessLine(line string) []byte { line = strings.TrimSpace(line) @@ -119,23 +142,33 @@ func (p *StreamingProcessor) ProcessLine(line string) []byte { return result.Bytes() } -// Finish 结束处理,返回最终事件和用量 +// Finish 结束处理,返回最终事件和用量。 +// 若整个流未收到任何可解析的上游数据(messageStartSent == false), +// 则不补发任何结束事件,防止客户端收到没有 message_start 的残缺流。 func (p *StreamingProcessor) Finish() ([]byte, *ClaudeUsage) { - var result bytes.Buffer - - if !p.messageStopSent { - _, _ = result.Write(p.emitFinish("")) - } - usage := &ClaudeUsage{ InputTokens: p.inputTokens, OutputTokens: p.outputTokens, CacheReadInputTokens: p.cacheReadTokens, } + if !p.messageStartSent { + return nil, usage + } + + var result bytes.Buffer + if !p.messageStopSent { + _, _ = result.Write(p.emitFinish("")) + } + return result.Bytes(), usage } +// MessageStartSent 报告流中是否已发出过 message_start 事件(即是否收到过有效的上游数据) +func (p *StreamingProcessor) MessageStartSent() bool { + return p.messageStartSent +} + // emitMessageStart 发送 message_start 事件 func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte { if p.messageStartSent { @@ -158,6 +191,13 @@ func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte responseID = "msg_" + generateRandomID() } + var usageValue any = usage + if p.usageMapHook != nil { + usageMap := usageToMap(usage) + p.usageMapHook(usageMap) + usageValue = usageMap + } + message := map[string]any{ "id": responseID, "type": "message", @@ -166,7 +206,7 @@ func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte "model": p.originalModel, "stop_reason": nil, "stop_sequence": nil, - "usage": usage, + "usage": usageValue, } event := map[string]any{ @@ -477,13 +517,20 @@ func (p *StreamingProcessor) emitFinish(finishReason string) []byte { CacheReadInputTokens: p.cacheReadTokens, } + var usageValue any = usage + if p.usageMapHook != nil { + usageMap := usageToMap(usage) + p.usageMapHook(usageMap) + usageValue = usageMap + } + deltaEvent := map[string]any{ "type": "message_delta", "delta": map[string]any{ "stop_reason": stopReason, "stop_sequence": nil, }, - "usage": usage, + "usage": usageValue, } _, _ = result.Write(p.formatSSE("message_delta", deltaEvent)) diff --git a/backend/internal/pkg/apicompat/anthropic_responses_test.go b/backend/internal/pkg/apicompat/anthropic_responses_test.go new file mode 100644 index 00000000..1c1d39bb --- /dev/null +++ b/backend/internal/pkg/apicompat/anthropic_responses_test.go @@ -0,0 +1,1009 @@ +package apicompat + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// AnthropicToResponses tests +// --------------------------------------------------------------------------- + +func TestAnthropicToResponses_BasicText(t *testing.T) { + req := &AnthropicRequest{ + Model: "gpt-5.2", + MaxTokens: 1024, + Stream: true, + Messages: []AnthropicMessage{ + {Role: "user", Content: json.RawMessage(`"Hello"`)}, + }, + } + + resp, err := AnthropicToResponses(req) + require.NoError(t, err) + assert.Equal(t, "gpt-5.2", resp.Model) + assert.True(t, resp.Stream) + assert.Equal(t, 1024, *resp.MaxOutputTokens) + assert.False(t, *resp.Store) + + var items []ResponsesInputItem + require.NoError(t, json.Unmarshal(resp.Input, &items)) + require.Len(t, items, 1) + assert.Equal(t, "user", items[0].Role) +} + +func TestAnthropicToResponses_SystemPrompt(t *testing.T) { + t.Run("string", func(t *testing.T) { + req := &AnthropicRequest{ + Model: "gpt-5.2", + MaxTokens: 100, + System: json.RawMessage(`"You are helpful."`), + Messages: []AnthropicMessage{{Role: "user", Content: json.RawMessage(`"Hi"`)}}, + } + resp, err := AnthropicToResponses(req) + require.NoError(t, err) + + var items []ResponsesInputItem + require.NoError(t, json.Unmarshal(resp.Input, &items)) + require.Len(t, items, 2) + assert.Equal(t, "system", items[0].Role) + }) + + t.Run("array", func(t *testing.T) { + req := &AnthropicRequest{ + Model: "gpt-5.2", + MaxTokens: 100, + System: json.RawMessage(`[{"type":"text","text":"Part 1"},{"type":"text","text":"Part 2"}]`), + Messages: []AnthropicMessage{{Role: "user", Content: json.RawMessage(`"Hi"`)}}, + } + resp, err := AnthropicToResponses(req) + require.NoError(t, err) + + var items []ResponsesInputItem + require.NoError(t, json.Unmarshal(resp.Input, &items)) + require.Len(t, items, 2) + assert.Equal(t, "system", items[0].Role) + // System text should be joined with double newline. + var text string + require.NoError(t, json.Unmarshal(items[0].Content, &text)) + assert.Equal(t, "Part 1\n\nPart 2", text) + }) +} + +func TestAnthropicToResponses_ToolUse(t *testing.T) { + req := &AnthropicRequest{ + Model: "gpt-5.2", + MaxTokens: 1024, + Messages: []AnthropicMessage{ + {Role: "user", Content: json.RawMessage(`"What is the weather?"`)}, + {Role: "assistant", Content: json.RawMessage(`[{"type":"text","text":"Let me check."},{"type":"tool_use","id":"call_1","name":"get_weather","input":{"city":"NYC"}}]`)}, + {Role: "user", Content: json.RawMessage(`[{"type":"tool_result","tool_use_id":"call_1","content":"Sunny, 72°F"}]`)}, + }, + Tools: []AnthropicTool{ + {Name: "get_weather", Description: "Get weather", InputSchema: json.RawMessage(`{"type":"object","properties":{"city":{"type":"string"}}}`)}, + }, + } + + resp, err := AnthropicToResponses(req) + require.NoError(t, err) + + // Check tools + require.Len(t, resp.Tools, 1) + assert.Equal(t, "function", resp.Tools[0].Type) + assert.Equal(t, "get_weather", resp.Tools[0].Name) + + // Check input items + var items []ResponsesInputItem + require.NoError(t, json.Unmarshal(resp.Input, &items)) + // user + assistant + function_call + function_call_output = 4 + require.Len(t, items, 4) + + assert.Equal(t, "user", items[0].Role) + assert.Equal(t, "assistant", items[1].Role) + assert.Equal(t, "function_call", items[2].Type) + assert.Equal(t, "fc_call_1", items[2].CallID) + assert.Equal(t, "function_call_output", items[3].Type) + assert.Equal(t, "fc_call_1", items[3].CallID) + assert.Equal(t, "Sunny, 72°F", items[3].Output) +} + +func TestAnthropicToResponses_ThinkingIgnored(t *testing.T) { + req := &AnthropicRequest{ + Model: "gpt-5.2", + MaxTokens: 1024, + Messages: []AnthropicMessage{ + {Role: "user", Content: json.RawMessage(`"Hello"`)}, + {Role: "assistant", Content: json.RawMessage(`[{"type":"thinking","thinking":"deep thought"},{"type":"text","text":"Hi!"}]`)}, + {Role: "user", Content: json.RawMessage(`"More"`)}, + }, + } + + resp, err := AnthropicToResponses(req) + require.NoError(t, err) + + var items []ResponsesInputItem + require.NoError(t, json.Unmarshal(resp.Input, &items)) + // user + assistant(text only, thinking ignored) + user = 3 + require.Len(t, items, 3) + assert.Equal(t, "assistant", items[1].Role) + // Assistant content should only have text, not thinking. + var parts []ResponsesContentPart + require.NoError(t, json.Unmarshal(items[1].Content, &parts)) + require.Len(t, parts, 1) + assert.Equal(t, "output_text", parts[0].Type) + assert.Equal(t, "Hi!", parts[0].Text) +} + +func TestAnthropicToResponses_MaxTokensFloor(t *testing.T) { + req := &AnthropicRequest{ + Model: "gpt-5.2", + MaxTokens: 10, // below minMaxOutputTokens (128) + Messages: []AnthropicMessage{{Role: "user", Content: json.RawMessage(`"Hi"`)}}, + } + + resp, err := AnthropicToResponses(req) + require.NoError(t, err) + assert.Equal(t, 128, *resp.MaxOutputTokens) +} + +// --------------------------------------------------------------------------- +// ResponsesToAnthropic (non-streaming) tests +// --------------------------------------------------------------------------- + +func TestResponsesToAnthropic_TextOnly(t *testing.T) { + resp := &ResponsesResponse{ + ID: "resp_123", + Model: "gpt-5.2", + Status: "completed", + Output: []ResponsesOutput{ + { + Type: "message", + Content: []ResponsesContentPart{ + {Type: "output_text", Text: "Hello there!"}, + }, + }, + }, + Usage: &ResponsesUsage{InputTokens: 10, OutputTokens: 5, TotalTokens: 15}, + } + + anth := ResponsesToAnthropic(resp, "claude-opus-4-6") + assert.Equal(t, "resp_123", anth.ID) + assert.Equal(t, "claude-opus-4-6", anth.Model) + assert.Equal(t, "end_turn", anth.StopReason) + require.Len(t, anth.Content, 1) + assert.Equal(t, "text", anth.Content[0].Type) + assert.Equal(t, "Hello there!", anth.Content[0].Text) + assert.Equal(t, 10, anth.Usage.InputTokens) + assert.Equal(t, 5, anth.Usage.OutputTokens) +} + +func TestResponsesToAnthropic_ToolUse(t *testing.T) { + resp := &ResponsesResponse{ + ID: "resp_456", + Model: "gpt-5.2", + Status: "completed", + Output: []ResponsesOutput{ + { + Type: "message", + Content: []ResponsesContentPart{ + {Type: "output_text", Text: "Let me check."}, + }, + }, + { + Type: "function_call", + CallID: "call_1", + Name: "get_weather", + Arguments: `{"city":"NYC"}`, + }, + }, + } + + anth := ResponsesToAnthropic(resp, "claude-opus-4-6") + assert.Equal(t, "tool_use", anth.StopReason) + require.Len(t, anth.Content, 2) + assert.Equal(t, "text", anth.Content[0].Type) + assert.Equal(t, "tool_use", anth.Content[1].Type) + assert.Equal(t, "call_1", anth.Content[1].ID) + assert.Equal(t, "get_weather", anth.Content[1].Name) +} + +func TestResponsesToAnthropic_Reasoning(t *testing.T) { + resp := &ResponsesResponse{ + ID: "resp_789", + Model: "gpt-5.2", + Status: "completed", + Output: []ResponsesOutput{ + { + Type: "reasoning", + Summary: []ResponsesSummary{ + {Type: "summary_text", Text: "Thinking about the answer..."}, + }, + }, + { + Type: "message", + Content: []ResponsesContentPart{ + {Type: "output_text", Text: "42"}, + }, + }, + }, + } + + anth := ResponsesToAnthropic(resp, "claude-opus-4-6") + require.Len(t, anth.Content, 2) + assert.Equal(t, "thinking", anth.Content[0].Type) + assert.Equal(t, "Thinking about the answer...", anth.Content[0].Thinking) + assert.Equal(t, "text", anth.Content[1].Type) + assert.Equal(t, "42", anth.Content[1].Text) +} + +func TestResponsesToAnthropic_Incomplete(t *testing.T) { + resp := &ResponsesResponse{ + ID: "resp_inc", + Model: "gpt-5.2", + Status: "incomplete", + IncompleteDetails: &ResponsesIncompleteDetails{ + Reason: "max_output_tokens", + }, + Output: []ResponsesOutput{ + { + Type: "message", + Content: []ResponsesContentPart{{Type: "output_text", Text: "Partial..."}}, + }, + }, + } + + anth := ResponsesToAnthropic(resp, "claude-opus-4-6") + assert.Equal(t, "max_tokens", anth.StopReason) +} + +func TestResponsesToAnthropic_EmptyOutput(t *testing.T) { + resp := &ResponsesResponse{ + ID: "resp_empty", + Model: "gpt-5.2", + Status: "completed", + Output: []ResponsesOutput{}, + } + + anth := ResponsesToAnthropic(resp, "claude-opus-4-6") + require.Len(t, anth.Content, 1) + assert.Equal(t, "text", anth.Content[0].Type) + assert.Equal(t, "", anth.Content[0].Text) +} + +// --------------------------------------------------------------------------- +// Streaming: ResponsesEventToAnthropicEvents tests +// --------------------------------------------------------------------------- + +func TestStreamingTextOnly(t *testing.T) { + state := NewResponsesEventToAnthropicState() + + // 1. response.created + events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.created", + Response: &ResponsesResponse{ + ID: "resp_1", + Model: "gpt-5.2", + }, + }, state) + require.Len(t, events, 1) + assert.Equal(t, "message_start", events[0].Type) + + // 2. output_item.added (message) + events = ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.output_item.added", + OutputIndex: 0, + Item: &ResponsesOutput{Type: "message"}, + }, state) + assert.Len(t, events, 0) // message item doesn't emit events + + // 3. text delta + events = ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.output_text.delta", + Delta: "Hello", + }, state) + require.Len(t, events, 2) // content_block_start + content_block_delta + assert.Equal(t, "content_block_start", events[0].Type) + assert.Equal(t, "text", events[0].ContentBlock.Type) + assert.Equal(t, "content_block_delta", events[1].Type) + assert.Equal(t, "text_delta", events[1].Delta.Type) + assert.Equal(t, "Hello", events[1].Delta.Text) + + // 4. more text + events = ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.output_text.delta", + Delta: " world", + }, state) + require.Len(t, events, 1) // only delta, no new block start + assert.Equal(t, "content_block_delta", events[0].Type) + + // 5. text done + events = ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.output_text.done", + }, state) + require.Len(t, events, 1) + assert.Equal(t, "content_block_stop", events[0].Type) + + // 6. completed + events = ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.completed", + Response: &ResponsesResponse{ + Status: "completed", + Usage: &ResponsesUsage{InputTokens: 10, OutputTokens: 5}, + }, + }, state) + require.Len(t, events, 2) // message_delta + message_stop + assert.Equal(t, "message_delta", events[0].Type) + assert.Equal(t, "end_turn", events[0].Delta.StopReason) + assert.Equal(t, 10, events[0].Usage.InputTokens) + assert.Equal(t, 5, events[0].Usage.OutputTokens) + assert.Equal(t, "message_stop", events[1].Type) +} + +func TestStreamingToolCall(t *testing.T) { + state := NewResponsesEventToAnthropicState() + + // 1. response.created + ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.created", + Response: &ResponsesResponse{ID: "resp_2", Model: "gpt-5.2"}, + }, state) + + // 2. function_call added + events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.output_item.added", + OutputIndex: 0, + Item: &ResponsesOutput{Type: "function_call", CallID: "call_1", Name: "get_weather"}, + }, state) + require.Len(t, events, 1) + assert.Equal(t, "content_block_start", events[0].Type) + assert.Equal(t, "tool_use", events[0].ContentBlock.Type) + assert.Equal(t, "call_1", events[0].ContentBlock.ID) + + // 3. arguments delta + events = ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.function_call_arguments.delta", + OutputIndex: 0, + Delta: `{"city":`, + }, state) + require.Len(t, events, 1) + assert.Equal(t, "content_block_delta", events[0].Type) + assert.Equal(t, "input_json_delta", events[0].Delta.Type) + assert.Equal(t, `{"city":`, events[0].Delta.PartialJSON) + + // 4. arguments done + events = ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.function_call_arguments.done", + }, state) + require.Len(t, events, 1) + assert.Equal(t, "content_block_stop", events[0].Type) + + // 5. completed with tool_calls + events = ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.completed", + Response: &ResponsesResponse{ + Status: "completed", + Usage: &ResponsesUsage{InputTokens: 20, OutputTokens: 10}, + }, + }, state) + require.Len(t, events, 2) + assert.Equal(t, "tool_use", events[0].Delta.StopReason) +} + +func TestStreamingReasoning(t *testing.T) { + state := NewResponsesEventToAnthropicState() + + ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.created", + Response: &ResponsesResponse{ID: "resp_3", Model: "gpt-5.2"}, + }, state) + + // reasoning item added + events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.output_item.added", + OutputIndex: 0, + Item: &ResponsesOutput{Type: "reasoning"}, + }, state) + require.Len(t, events, 1) + assert.Equal(t, "content_block_start", events[0].Type) + assert.Equal(t, "thinking", events[0].ContentBlock.Type) + + // reasoning text delta + events = ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.reasoning_summary_text.delta", + OutputIndex: 0, + Delta: "Let me think...", + }, state) + require.Len(t, events, 1) + assert.Equal(t, "content_block_delta", events[0].Type) + assert.Equal(t, "thinking_delta", events[0].Delta.Type) + assert.Equal(t, "Let me think...", events[0].Delta.Thinking) + + // reasoning done + events = ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.reasoning_summary_text.done", + }, state) + require.Len(t, events, 1) + assert.Equal(t, "content_block_stop", events[0].Type) +} + +func TestStreamingIncomplete(t *testing.T) { + state := NewResponsesEventToAnthropicState() + + ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.created", + Response: &ResponsesResponse{ID: "resp_4", Model: "gpt-5.2"}, + }, state) + + ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.output_text.delta", + Delta: "Partial output...", + }, state) + + events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.incomplete", + Response: &ResponsesResponse{ + Status: "incomplete", + IncompleteDetails: &ResponsesIncompleteDetails{Reason: "max_output_tokens"}, + Usage: &ResponsesUsage{InputTokens: 100, OutputTokens: 4096}, + }, + }, state) + + // Should close the text block + message_delta + message_stop + require.Len(t, events, 3) + assert.Equal(t, "content_block_stop", events[0].Type) + assert.Equal(t, "message_delta", events[1].Type) + assert.Equal(t, "max_tokens", events[1].Delta.StopReason) + assert.Equal(t, "message_stop", events[2].Type) +} + +func TestFinalizeStream_NeverStarted(t *testing.T) { + state := NewResponsesEventToAnthropicState() + events := FinalizeResponsesAnthropicStream(state) + assert.Nil(t, events) +} + +func TestFinalizeStream_AlreadyCompleted(t *testing.T) { + state := NewResponsesEventToAnthropicState() + state.MessageStartSent = true + state.MessageStopSent = true + events := FinalizeResponsesAnthropicStream(state) + assert.Nil(t, events) +} + +func TestFinalizeStream_AbnormalTermination(t *testing.T) { + state := NewResponsesEventToAnthropicState() + + // Simulate a stream that started but never completed + ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.created", + Response: &ResponsesResponse{ID: "resp_5", Model: "gpt-5.2"}, + }, state) + + ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.output_text.delta", + Delta: "Interrupted...", + }, state) + + // Stream ends without response.completed + events := FinalizeResponsesAnthropicStream(state) + require.Len(t, events, 3) // content_block_stop + message_delta + message_stop + assert.Equal(t, "content_block_stop", events[0].Type) + assert.Equal(t, "message_delta", events[1].Type) + assert.Equal(t, "end_turn", events[1].Delta.StopReason) + assert.Equal(t, "message_stop", events[2].Type) +} + +func TestStreamingEmptyResponse(t *testing.T) { + state := NewResponsesEventToAnthropicState() + + ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.created", + Response: &ResponsesResponse{ID: "resp_6", Model: "gpt-5.2"}, + }, state) + + events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.completed", + Response: &ResponsesResponse{ + Status: "completed", + Usage: &ResponsesUsage{InputTokens: 5, OutputTokens: 0}, + }, + }, state) + + require.Len(t, events, 2) // message_delta + message_stop + assert.Equal(t, "message_delta", events[0].Type) + assert.Equal(t, "end_turn", events[0].Delta.StopReason) +} + +func TestResponsesAnthropicEventToSSE(t *testing.T) { + evt := AnthropicStreamEvent{ + Type: "message_start", + Message: &AnthropicResponse{ + ID: "resp_1", + Type: "message", + Role: "assistant", + }, + } + sse, err := ResponsesAnthropicEventToSSE(evt) + require.NoError(t, err) + assert.Contains(t, sse, "event: message_start\n") + assert.Contains(t, sse, "data: ") + assert.Contains(t, sse, `"resp_1"`) +} + +// --------------------------------------------------------------------------- +// response.failed tests +// --------------------------------------------------------------------------- + +func TestStreamingFailed(t *testing.T) { + state := NewResponsesEventToAnthropicState() + + // 1. response.created + ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.created", + Response: &ResponsesResponse{ID: "resp_fail_1", Model: "gpt-5.2"}, + }, state) + + // 2. Some text output before failure + ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.output_text.delta", + Delta: "Partial output before failure", + }, state) + + // 3. response.failed + events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.failed", + Response: &ResponsesResponse{ + Status: "failed", + Error: &ResponsesError{Code: "server_error", Message: "Internal error"}, + Usage: &ResponsesUsage{InputTokens: 50, OutputTokens: 10}, + }, + }, state) + + // Should close text block + message_delta + message_stop + require.Len(t, events, 3) + assert.Equal(t, "content_block_stop", events[0].Type) + assert.Equal(t, "message_delta", events[1].Type) + assert.Equal(t, "end_turn", events[1].Delta.StopReason) + assert.Equal(t, 50, events[1].Usage.InputTokens) + assert.Equal(t, 10, events[1].Usage.OutputTokens) + assert.Equal(t, "message_stop", events[2].Type) +} + +func TestStreamingFailedNoOutput(t *testing.T) { + state := NewResponsesEventToAnthropicState() + + // 1. response.created + ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.created", + Response: &ResponsesResponse{ID: "resp_fail_2", Model: "gpt-5.2"}, + }, state) + + // 2. response.failed with no prior output + events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.failed", + Response: &ResponsesResponse{ + Status: "failed", + Error: &ResponsesError{Code: "rate_limit_error", Message: "Too many requests"}, + Usage: &ResponsesUsage{InputTokens: 20, OutputTokens: 0}, + }, + }, state) + + // Should emit message_delta + message_stop (no block to close) + require.Len(t, events, 2) + assert.Equal(t, "message_delta", events[0].Type) + assert.Equal(t, "end_turn", events[0].Delta.StopReason) + assert.Equal(t, "message_stop", events[1].Type) +} + +func TestResponsesToAnthropic_Failed(t *testing.T) { + resp := &ResponsesResponse{ + ID: "resp_fail_3", + Model: "gpt-5.2", + Status: "failed", + Error: &ResponsesError{Code: "server_error", Message: "Something went wrong"}, + Output: []ResponsesOutput{}, + Usage: &ResponsesUsage{InputTokens: 30, OutputTokens: 0}, + } + + anth := ResponsesToAnthropic(resp, "claude-opus-4-6") + // Failed status defaults to "end_turn" stop reason + assert.Equal(t, "end_turn", anth.StopReason) + // Should have at least an empty text block + require.Len(t, anth.Content, 1) + assert.Equal(t, "text", anth.Content[0].Type) +} + +// --------------------------------------------------------------------------- +// thinking → reasoning conversion tests +// --------------------------------------------------------------------------- + +func TestAnthropicToResponses_ThinkingEnabled(t *testing.T) { + req := &AnthropicRequest{ + Model: "gpt-5.2", + MaxTokens: 1024, + Messages: []AnthropicMessage{{Role: "user", Content: json.RawMessage(`"Hello"`)}}, + Thinking: &AnthropicThinking{Type: "enabled", BudgetTokens: 10000}, + } + + resp, err := AnthropicToResponses(req) + require.NoError(t, err) + require.NotNil(t, resp.Reasoning) + // thinking.type is ignored for effort; default xhigh applies. + assert.Equal(t, "xhigh", resp.Reasoning.Effort) + assert.Equal(t, "auto", resp.Reasoning.Summary) + assert.Contains(t, resp.Include, "reasoning.encrypted_content") + assert.NotContains(t, resp.Include, "reasoning.summary") +} + +func TestAnthropicToResponses_ThinkingAdaptive(t *testing.T) { + req := &AnthropicRequest{ + Model: "gpt-5.2", + MaxTokens: 1024, + Messages: []AnthropicMessage{{Role: "user", Content: json.RawMessage(`"Hello"`)}}, + Thinking: &AnthropicThinking{Type: "adaptive", BudgetTokens: 5000}, + } + + resp, err := AnthropicToResponses(req) + require.NoError(t, err) + require.NotNil(t, resp.Reasoning) + // thinking.type is ignored for effort; default xhigh applies. + assert.Equal(t, "xhigh", resp.Reasoning.Effort) + assert.Equal(t, "auto", resp.Reasoning.Summary) + assert.NotContains(t, resp.Include, "reasoning.summary") +} + +func TestAnthropicToResponses_ThinkingDisabled(t *testing.T) { + req := &AnthropicRequest{ + Model: "gpt-5.2", + MaxTokens: 1024, + Messages: []AnthropicMessage{{Role: "user", Content: json.RawMessage(`"Hello"`)}}, + Thinking: &AnthropicThinking{Type: "disabled"}, + } + + resp, err := AnthropicToResponses(req) + require.NoError(t, err) + // Default effort applies (high → xhigh) even when thinking is disabled. + require.NotNil(t, resp.Reasoning) + assert.Equal(t, "xhigh", resp.Reasoning.Effort) +} + +func TestAnthropicToResponses_NoThinking(t *testing.T) { + req := &AnthropicRequest{ + Model: "gpt-5.2", + MaxTokens: 1024, + Messages: []AnthropicMessage{{Role: "user", Content: json.RawMessage(`"Hello"`)}}, + } + + resp, err := AnthropicToResponses(req) + require.NoError(t, err) + // Default effort applies (high → xhigh) when no thinking/output_config is set. + require.NotNil(t, resp.Reasoning) + assert.Equal(t, "xhigh", resp.Reasoning.Effort) +} + +// --------------------------------------------------------------------------- +// output_config.effort override tests +// --------------------------------------------------------------------------- + +func TestAnthropicToResponses_OutputConfigOverridesDefault(t *testing.T) { + // Default is xhigh, but output_config.effort="low" overrides. low→low after mapping. + req := &AnthropicRequest{ + Model: "gpt-5.2", + MaxTokens: 1024, + Messages: []AnthropicMessage{{Role: "user", Content: json.RawMessage(`"Hello"`)}}, + Thinking: &AnthropicThinking{Type: "enabled", BudgetTokens: 10000}, + OutputConfig: &AnthropicOutputConfig{Effort: "low"}, + } + + resp, err := AnthropicToResponses(req) + require.NoError(t, err) + require.NotNil(t, resp.Reasoning) + assert.Equal(t, "low", resp.Reasoning.Effort) + assert.Equal(t, "auto", resp.Reasoning.Summary) +} + +func TestAnthropicToResponses_OutputConfigWithoutThinking(t *testing.T) { + // No thinking field, but output_config.effort="medium" → creates reasoning. + // medium→high after mapping. + req := &AnthropicRequest{ + Model: "gpt-5.2", + MaxTokens: 1024, + Messages: []AnthropicMessage{{Role: "user", Content: json.RawMessage(`"Hello"`)}}, + OutputConfig: &AnthropicOutputConfig{Effort: "medium"}, + } + + resp, err := AnthropicToResponses(req) + require.NoError(t, err) + require.NotNil(t, resp.Reasoning) + assert.Equal(t, "high", resp.Reasoning.Effort) + assert.Equal(t, "auto", resp.Reasoning.Summary) +} + +func TestAnthropicToResponses_OutputConfigHigh(t *testing.T) { + // output_config.effort="high" → mapped to "xhigh". + req := &AnthropicRequest{ + Model: "gpt-5.2", + MaxTokens: 1024, + Messages: []AnthropicMessage{{Role: "user", Content: json.RawMessage(`"Hello"`)}}, + OutputConfig: &AnthropicOutputConfig{Effort: "high"}, + } + + resp, err := AnthropicToResponses(req) + require.NoError(t, err) + require.NotNil(t, resp.Reasoning) + assert.Equal(t, "xhigh", resp.Reasoning.Effort) + assert.Equal(t, "auto", resp.Reasoning.Summary) +} + +func TestAnthropicToResponses_NoOutputConfig(t *testing.T) { + // No output_config → default xhigh regardless of thinking.type. + req := &AnthropicRequest{ + Model: "gpt-5.2", + MaxTokens: 1024, + Messages: []AnthropicMessage{{Role: "user", Content: json.RawMessage(`"Hello"`)}}, + Thinking: &AnthropicThinking{Type: "enabled", BudgetTokens: 10000}, + } + + resp, err := AnthropicToResponses(req) + require.NoError(t, err) + require.NotNil(t, resp.Reasoning) + assert.Equal(t, "xhigh", resp.Reasoning.Effort) +} + +func TestAnthropicToResponses_OutputConfigWithoutEffort(t *testing.T) { + // output_config present but effort empty (e.g. only format set) → default xhigh. + req := &AnthropicRequest{ + Model: "gpt-5.2", + MaxTokens: 1024, + Messages: []AnthropicMessage{{Role: "user", Content: json.RawMessage(`"Hello"`)}}, + OutputConfig: &AnthropicOutputConfig{}, + } + + resp, err := AnthropicToResponses(req) + require.NoError(t, err) + require.NotNil(t, resp.Reasoning) + assert.Equal(t, "xhigh", resp.Reasoning.Effort) +} + +// --------------------------------------------------------------------------- +// tool_choice conversion tests +// --------------------------------------------------------------------------- + +func TestAnthropicToResponses_ToolChoiceAuto(t *testing.T) { + req := &AnthropicRequest{ + Model: "gpt-5.2", + MaxTokens: 1024, + Messages: []AnthropicMessage{{Role: "user", Content: json.RawMessage(`"Hello"`)}}, + ToolChoice: json.RawMessage(`{"type":"auto"}`), + } + + resp, err := AnthropicToResponses(req) + require.NoError(t, err) + + var tc string + require.NoError(t, json.Unmarshal(resp.ToolChoice, &tc)) + assert.Equal(t, "auto", tc) +} + +func TestAnthropicToResponses_ToolChoiceAny(t *testing.T) { + req := &AnthropicRequest{ + Model: "gpt-5.2", + MaxTokens: 1024, + Messages: []AnthropicMessage{{Role: "user", Content: json.RawMessage(`"Hello"`)}}, + ToolChoice: json.RawMessage(`{"type":"any"}`), + } + + resp, err := AnthropicToResponses(req) + require.NoError(t, err) + + var tc string + require.NoError(t, json.Unmarshal(resp.ToolChoice, &tc)) + assert.Equal(t, "required", tc) +} + +func TestAnthropicToResponses_ToolChoiceSpecific(t *testing.T) { + req := &AnthropicRequest{ + Model: "gpt-5.2", + MaxTokens: 1024, + Messages: []AnthropicMessage{{Role: "user", Content: json.RawMessage(`"Hello"`)}}, + ToolChoice: json.RawMessage(`{"type":"tool","name":"get_weather"}`), + } + + resp, err := AnthropicToResponses(req) + require.NoError(t, err) + + var tc map[string]any + require.NoError(t, json.Unmarshal(resp.ToolChoice, &tc)) + assert.Equal(t, "function", tc["type"]) + fn, ok := tc["function"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "get_weather", fn["name"]) +} + +// --------------------------------------------------------------------------- +// Image content block conversion tests +// --------------------------------------------------------------------------- + +func TestAnthropicToResponses_UserImageBlock(t *testing.T) { + req := &AnthropicRequest{ + Model: "gpt-5.2", + MaxTokens: 1024, + Messages: []AnthropicMessage{ + {Role: "user", Content: json.RawMessage(`[ + {"type":"text","text":"What is in this image?"}, + {"type":"image","source":{"type":"base64","media_type":"image/png","data":"iVBOR"}} + ]`)}, + }, + } + + resp, err := AnthropicToResponses(req) + require.NoError(t, err) + + var items []ResponsesInputItem + require.NoError(t, json.Unmarshal(resp.Input, &items)) + require.Len(t, items, 1) + assert.Equal(t, "user", items[0].Role) + + var parts []ResponsesContentPart + require.NoError(t, json.Unmarshal(items[0].Content, &parts)) + require.Len(t, parts, 2) + assert.Equal(t, "input_text", parts[0].Type) + assert.Equal(t, "What is in this image?", parts[0].Text) + assert.Equal(t, "input_image", parts[1].Type) + assert.Equal(t, "data:image/png;base64,iVBOR", parts[1].ImageURL) +} + +func TestAnthropicToResponses_ImageOnlyUserMessage(t *testing.T) { + req := &AnthropicRequest{ + Model: "gpt-5.2", + MaxTokens: 1024, + Messages: []AnthropicMessage{ + {Role: "user", Content: json.RawMessage(`[ + {"type":"image","source":{"type":"base64","media_type":"image/jpeg","data":"/9j/4AAQ"}} + ]`)}, + }, + } + + resp, err := AnthropicToResponses(req) + require.NoError(t, err) + + var items []ResponsesInputItem + require.NoError(t, json.Unmarshal(resp.Input, &items)) + require.Len(t, items, 1) + + var parts []ResponsesContentPart + require.NoError(t, json.Unmarshal(items[0].Content, &parts)) + require.Len(t, parts, 1) + assert.Equal(t, "input_image", parts[0].Type) + assert.Equal(t, "data:image/jpeg;base64,/9j/4AAQ", parts[0].ImageURL) +} + +func TestAnthropicToResponses_ToolResultWithImage(t *testing.T) { + req := &AnthropicRequest{ + Model: "gpt-5.2", + MaxTokens: 1024, + Messages: []AnthropicMessage{ + {Role: "user", Content: json.RawMessage(`"Read the screenshot"`)}, + {Role: "assistant", Content: json.RawMessage(`[{"type":"tool_use","id":"toolu_1","name":"Read","input":{"file_path":"/tmp/screen.png"}}]`)}, + {Role: "user", Content: json.RawMessage(`[ + {"type":"tool_result","tool_use_id":"toolu_1","content":[ + {"type":"image","source":{"type":"base64","media_type":"image/png","data":"iVBOR"}} + ]} + ]`)}, + }, + } + + resp, err := AnthropicToResponses(req) + require.NoError(t, err) + + var items []ResponsesInputItem + require.NoError(t, json.Unmarshal(resp.Input, &items)) + // user + function_call + function_call_output + user(image) = 4 + require.Len(t, items, 4) + + // function_call_output should have text-only output (no image). + assert.Equal(t, "function_call_output", items[2].Type) + assert.Equal(t, "fc_toolu_1", items[2].CallID) + assert.Equal(t, "(empty)", items[2].Output) + + // Image should be in a separate user message. + assert.Equal(t, "user", items[3].Role) + var parts []ResponsesContentPart + require.NoError(t, json.Unmarshal(items[3].Content, &parts)) + require.Len(t, parts, 1) + assert.Equal(t, "input_image", parts[0].Type) + assert.Equal(t, "data:image/png;base64,iVBOR", parts[0].ImageURL) +} + +func TestAnthropicToResponses_ToolResultMixed(t *testing.T) { + req := &AnthropicRequest{ + Model: "gpt-5.2", + MaxTokens: 1024, + Messages: []AnthropicMessage{ + {Role: "user", Content: json.RawMessage(`"Describe the file"`)}, + {Role: "assistant", Content: json.RawMessage(`[{"type":"tool_use","id":"toolu_2","name":"Read","input":{"file_path":"/tmp/photo.png"}}]`)}, + {Role: "user", Content: json.RawMessage(`[ + {"type":"tool_result","tool_use_id":"toolu_2","content":[ + {"type":"text","text":"File metadata: 800x600 PNG"}, + {"type":"image","source":{"type":"base64","media_type":"image/png","data":"AAAA"}} + ]} + ]`)}, + }, + } + + resp, err := AnthropicToResponses(req) + require.NoError(t, err) + + var items []ResponsesInputItem + require.NoError(t, json.Unmarshal(resp.Input, &items)) + // user + function_call + function_call_output + user(image) = 4 + require.Len(t, items, 4) + + // function_call_output should have text-only output. + assert.Equal(t, "function_call_output", items[2].Type) + assert.Equal(t, "File metadata: 800x600 PNG", items[2].Output) + + // Image should be in a separate user message. + assert.Equal(t, "user", items[3].Role) + var parts []ResponsesContentPart + require.NoError(t, json.Unmarshal(items[3].Content, &parts)) + require.Len(t, parts, 1) + assert.Equal(t, "input_image", parts[0].Type) + assert.Equal(t, "data:image/png;base64,AAAA", parts[0].ImageURL) +} + +func TestAnthropicToResponses_TextOnlyToolResultBackwardCompat(t *testing.T) { + req := &AnthropicRequest{ + Model: "gpt-5.2", + MaxTokens: 1024, + Messages: []AnthropicMessage{ + {Role: "user", Content: json.RawMessage(`"Check weather"`)}, + {Role: "assistant", Content: json.RawMessage(`[{"type":"tool_use","id":"call_1","name":"get_weather","input":{"city":"NYC"}}]`)}, + {Role: "user", Content: json.RawMessage(`[ + {"type":"tool_result","tool_use_id":"call_1","content":[ + {"type":"text","text":"Sunny, 72°F"} + ]} + ]`)}, + }, + } + + resp, err := AnthropicToResponses(req) + require.NoError(t, err) + + var items []ResponsesInputItem + require.NoError(t, json.Unmarshal(resp.Input, &items)) + // user + function_call + function_call_output = 3 + require.Len(t, items, 3) + + // Text-only tool_result should produce a plain string. + assert.Equal(t, "Sunny, 72°F", items[2].Output) +} + +func TestAnthropicToResponses_ImageEmptyMediaType(t *testing.T) { + req := &AnthropicRequest{ + Model: "gpt-5.2", + MaxTokens: 1024, + Messages: []AnthropicMessage{ + {Role: "user", Content: json.RawMessage(`[ + {"type":"image","source":{"type":"base64","media_type":"","data":"iVBOR"}} + ]`)}, + }, + } + + resp, err := AnthropicToResponses(req) + require.NoError(t, err) + + var items []ResponsesInputItem + require.NoError(t, json.Unmarshal(resp.Input, &items)) + require.Len(t, items, 1) + + var parts []ResponsesContentPart + require.NoError(t, json.Unmarshal(items[0].Content, &parts)) + require.Len(t, parts, 1) + assert.Equal(t, "input_image", parts[0].Type) + // Should default to image/png when media_type is empty. + assert.Equal(t, "data:image/png;base64,iVBOR", parts[0].ImageURL) +} diff --git a/backend/internal/pkg/apicompat/anthropic_to_responses.go b/backend/internal/pkg/apicompat/anthropic_to_responses.go new file mode 100644 index 00000000..592bec39 --- /dev/null +++ b/backend/internal/pkg/apicompat/anthropic_to_responses.go @@ -0,0 +1,417 @@ +package apicompat + +import ( + "encoding/json" + "fmt" + "strings" +) + +// AnthropicToResponses converts an Anthropic Messages request directly into +// a Responses API request. This preserves fields that would be lost in a +// Chat Completions intermediary round-trip (e.g. thinking, cache_control, +// structured system prompts). +func AnthropicToResponses(req *AnthropicRequest) (*ResponsesRequest, error) { + input, err := convertAnthropicToResponsesInput(req.System, req.Messages) + if err != nil { + return nil, err + } + + inputJSON, err := json.Marshal(input) + if err != nil { + return nil, err + } + + out := &ResponsesRequest{ + Model: req.Model, + Input: inputJSON, + Temperature: req.Temperature, + TopP: req.TopP, + Stream: req.Stream, + Include: []string{"reasoning.encrypted_content"}, + } + + storeFalse := false + out.Store = &storeFalse + + if req.MaxTokens > 0 { + v := req.MaxTokens + if v < minMaxOutputTokens { + v = minMaxOutputTokens + } + out.MaxOutputTokens = &v + } + + if len(req.Tools) > 0 { + out.Tools = convertAnthropicToolsToResponses(req.Tools) + } + + // Determine reasoning effort: only output_config.effort controls the + // level; thinking.type is ignored. Default is xhigh when unset. + // Anthropic levels map to OpenAI: low→low, medium→high, high→xhigh. + effort := "high" // default → maps to xhigh + if req.OutputConfig != nil && req.OutputConfig.Effort != "" { + effort = req.OutputConfig.Effort + } + out.Reasoning = &ResponsesReasoning{ + Effort: mapAnthropicEffortToResponses(effort), + Summary: "auto", + } + + // Convert tool_choice + if len(req.ToolChoice) > 0 { + tc, err := convertAnthropicToolChoiceToResponses(req.ToolChoice) + if err != nil { + return nil, fmt.Errorf("convert tool_choice: %w", err) + } + out.ToolChoice = tc + } + + return out, nil +} + +// convertAnthropicToolChoiceToResponses maps Anthropic tool_choice to Responses format. +// +// {"type":"auto"} → "auto" +// {"type":"any"} → "required" +// {"type":"none"} → "none" +// {"type":"tool","name":"X"} → {"type":"function","function":{"name":"X"}} +func convertAnthropicToolChoiceToResponses(raw json.RawMessage) (json.RawMessage, error) { + var tc struct { + Type string `json:"type"` + Name string `json:"name"` + } + if err := json.Unmarshal(raw, &tc); err != nil { + return nil, err + } + + switch tc.Type { + case "auto": + return json.Marshal("auto") + case "any": + return json.Marshal("required") + case "none": + return json.Marshal("none") + case "tool": + return json.Marshal(map[string]any{ + "type": "function", + "function": map[string]string{"name": tc.Name}, + }) + default: + // Pass through unknown types as-is + return raw, nil + } +} + +// convertAnthropicToResponsesInput builds the Responses API input items array +// from the Anthropic system field and message list. +func convertAnthropicToResponsesInput(system json.RawMessage, msgs []AnthropicMessage) ([]ResponsesInputItem, error) { + var out []ResponsesInputItem + + // System prompt → system role input item. + if len(system) > 0 { + sysText, err := parseAnthropicSystemPrompt(system) + if err != nil { + return nil, err + } + if sysText != "" { + content, _ := json.Marshal(sysText) + out = append(out, ResponsesInputItem{ + Role: "system", + Content: content, + }) + } + } + + for _, m := range msgs { + items, err := anthropicMsgToResponsesItems(m) + if err != nil { + return nil, err + } + out = append(out, items...) + } + return out, nil +} + +// parseAnthropicSystemPrompt handles the Anthropic system field which can be +// a plain string or an array of text blocks. +func parseAnthropicSystemPrompt(raw json.RawMessage) (string, error) { + var s string + if err := json.Unmarshal(raw, &s); err == nil { + return s, nil + } + var blocks []AnthropicContentBlock + if err := json.Unmarshal(raw, &blocks); err != nil { + return "", err + } + var parts []string + for _, b := range blocks { + if b.Type == "text" && b.Text != "" { + parts = append(parts, b.Text) + } + } + return strings.Join(parts, "\n\n"), nil +} + +// anthropicMsgToResponsesItems converts a single Anthropic message into one +// or more Responses API input items. +func anthropicMsgToResponsesItems(m AnthropicMessage) ([]ResponsesInputItem, error) { + switch m.Role { + case "user": + return anthropicUserToResponses(m.Content) + case "assistant": + return anthropicAssistantToResponses(m.Content) + default: + return anthropicUserToResponses(m.Content) + } +} + +// anthropicUserToResponses handles an Anthropic user message. Content can be a +// plain string or an array of blocks. tool_result blocks are extracted into +// function_call_output items. Image blocks are converted to input_image parts. +func anthropicUserToResponses(raw json.RawMessage) ([]ResponsesInputItem, error) { + // Try plain string. + var s string + if err := json.Unmarshal(raw, &s); err == nil { + content, _ := json.Marshal(s) + return []ResponsesInputItem{{Role: "user", Content: content}}, nil + } + + var blocks []AnthropicContentBlock + if err := json.Unmarshal(raw, &blocks); err != nil { + return nil, err + } + + var out []ResponsesInputItem + var toolResultImageParts []ResponsesContentPart + + // Extract tool_result blocks → function_call_output items. + // Images inside tool_results are extracted separately because the + // Responses API function_call_output.output only accepts strings. + for _, b := range blocks { + if b.Type != "tool_result" { + continue + } + outputText, imageParts := convertToolResultOutput(b) + out = append(out, ResponsesInputItem{ + Type: "function_call_output", + CallID: toResponsesCallID(b.ToolUseID), + Output: outputText, + }) + toolResultImageParts = append(toolResultImageParts, imageParts...) + } + + // Remaining text + image blocks → user message with content parts. + // Also include images extracted from tool_results so the model can see them. + var parts []ResponsesContentPart + for _, b := range blocks { + switch b.Type { + case "text": + if b.Text != "" { + parts = append(parts, ResponsesContentPart{Type: "input_text", Text: b.Text}) + } + case "image": + if uri := anthropicImageToDataURI(b.Source); uri != "" { + parts = append(parts, ResponsesContentPart{Type: "input_image", ImageURL: uri}) + } + } + } + parts = append(parts, toolResultImageParts...) + + if len(parts) > 0 { + content, err := json.Marshal(parts) + if err != nil { + return nil, err + } + out = append(out, ResponsesInputItem{Role: "user", Content: content}) + } + + return out, nil +} + +// anthropicAssistantToResponses handles an Anthropic assistant message. +// Text content → assistant message with output_text parts. +// tool_use blocks → function_call items. +// thinking blocks → ignored (OpenAI doesn't accept them as input). +func anthropicAssistantToResponses(raw json.RawMessage) ([]ResponsesInputItem, error) { + // Try plain string. + var s string + if err := json.Unmarshal(raw, &s); err == nil { + parts := []ResponsesContentPart{{Type: "output_text", Text: s}} + partsJSON, err := json.Marshal(parts) + if err != nil { + return nil, err + } + return []ResponsesInputItem{{Role: "assistant", Content: partsJSON}}, nil + } + + var blocks []AnthropicContentBlock + if err := json.Unmarshal(raw, &blocks); err != nil { + return nil, err + } + + var items []ResponsesInputItem + + // Text content → assistant message with output_text content parts. + text := extractAnthropicTextFromBlocks(blocks) + if text != "" { + parts := []ResponsesContentPart{{Type: "output_text", Text: text}} + partsJSON, err := json.Marshal(parts) + if err != nil { + return nil, err + } + items = append(items, ResponsesInputItem{Role: "assistant", Content: partsJSON}) + } + + // tool_use → function_call items. + for _, b := range blocks { + if b.Type != "tool_use" { + continue + } + args := "{}" + if len(b.Input) > 0 { + args = string(b.Input) + } + fcID := toResponsesCallID(b.ID) + items = append(items, ResponsesInputItem{ + Type: "function_call", + CallID: fcID, + Name: b.Name, + Arguments: args, + ID: fcID, + }) + } + + return items, nil +} + +// toResponsesCallID converts an Anthropic tool ID (toolu_xxx / call_xxx) to a +// Responses API function_call ID that starts with "fc_". +func toResponsesCallID(id string) string { + if strings.HasPrefix(id, "fc_") { + return id + } + return "fc_" + id +} + +// fromResponsesCallID reverses toResponsesCallID, stripping the "fc_" prefix +// that was added during request conversion. +func fromResponsesCallID(id string) string { + if after, ok := strings.CutPrefix(id, "fc_"); ok { + // Only strip if the remainder doesn't look like it was already "fc_" prefixed. + // E.g. "fc_toolu_xxx" → "toolu_xxx", "fc_call_xxx" → "call_xxx" + if strings.HasPrefix(after, "toolu_") || strings.HasPrefix(after, "call_") { + return after + } + } + return id +} + +// anthropicImageToDataURI converts an AnthropicImageSource to a data URI string. +// Returns "" if the source is nil or has no data. +func anthropicImageToDataURI(src *AnthropicImageSource) string { + if src == nil || src.Data == "" { + return "" + } + mediaType := src.MediaType + if mediaType == "" { + mediaType = "image/png" + } + return "data:" + mediaType + ";base64," + src.Data +} + +// convertToolResultOutput extracts text and image content from a tool_result +// block. Returns the text as a string for the function_call_output Output +// field, plus any image parts that must be sent in a separate user message +// (the Responses API output field only accepts strings). +func convertToolResultOutput(b AnthropicContentBlock) (string, []ResponsesContentPart) { + if len(b.Content) == 0 { + return "(empty)", nil + } + + // Try plain string content. + var s string + if err := json.Unmarshal(b.Content, &s); err == nil { + if s == "" { + s = "(empty)" + } + return s, nil + } + + // Array of content blocks — may contain text and/or images. + var inner []AnthropicContentBlock + if err := json.Unmarshal(b.Content, &inner); err != nil { + return "(empty)", nil + } + + // Separate text (for function_call_output) from images (for user message). + var textParts []string + var imageParts []ResponsesContentPart + for _, ib := range inner { + switch ib.Type { + case "text": + if ib.Text != "" { + textParts = append(textParts, ib.Text) + } + case "image": + if uri := anthropicImageToDataURI(ib.Source); uri != "" { + imageParts = append(imageParts, ResponsesContentPart{Type: "input_image", ImageURL: uri}) + } + } + } + + text := strings.Join(textParts, "\n\n") + if text == "" { + text = "(empty)" + } + return text, imageParts +} + +// extractAnthropicTextFromBlocks joins all text blocks, ignoring thinking/ +// tool_use/tool_result blocks. +func extractAnthropicTextFromBlocks(blocks []AnthropicContentBlock) string { + var parts []string + for _, b := range blocks { + if b.Type == "text" && b.Text != "" { + parts = append(parts, b.Text) + } + } + return strings.Join(parts, "\n\n") +} + +// mapAnthropicEffortToResponses converts Anthropic reasoning effort levels to +// OpenAI Responses API effort levels. +// +// low → low +// medium → high +// high → xhigh +func mapAnthropicEffortToResponses(effort string) string { + switch effort { + case "medium": + return "high" + case "high": + return "xhigh" + default: + return effort // "low" and any unknown values pass through unchanged + } +} + +// convertAnthropicToolsToResponses maps Anthropic tool definitions to +// Responses API tools. Server-side tools like web_search are mapped to their +// OpenAI equivalents; regular tools become function tools. +func convertAnthropicToolsToResponses(tools []AnthropicTool) []ResponsesTool { + var out []ResponsesTool + for _, t := range tools { + // Anthropic server tools like "web_search_20250305" → OpenAI {"type":"web_search"} + if strings.HasPrefix(t.Type, "web_search") { + out = append(out, ResponsesTool{Type: "web_search"}) + continue + } + out = append(out, ResponsesTool{ + Type: "function", + Name: t.Name, + Description: t.Description, + Parameters: t.InputSchema, + }) + } + return out +} diff --git a/backend/internal/pkg/apicompat/chatcompletions_responses_test.go b/backend/internal/pkg/apicompat/chatcompletions_responses_test.go new file mode 100644 index 00000000..71b7a6f5 --- /dev/null +++ b/backend/internal/pkg/apicompat/chatcompletions_responses_test.go @@ -0,0 +1,733 @@ +package apicompat + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// ChatCompletionsToResponses tests +// --------------------------------------------------------------------------- + +func TestChatCompletionsToResponses_BasicText(t *testing.T) { + req := &ChatCompletionsRequest{ + Model: "gpt-4o", + Messages: []ChatMessage{ + {Role: "user", Content: json.RawMessage(`"Hello"`)}, + }, + } + + resp, err := ChatCompletionsToResponses(req) + require.NoError(t, err) + assert.Equal(t, "gpt-4o", resp.Model) + assert.True(t, resp.Stream) // always forced true + assert.False(t, *resp.Store) + + var items []ResponsesInputItem + require.NoError(t, json.Unmarshal(resp.Input, &items)) + require.Len(t, items, 1) + assert.Equal(t, "user", items[0].Role) +} + +func TestChatCompletionsToResponses_SystemMessage(t *testing.T) { + req := &ChatCompletionsRequest{ + Model: "gpt-4o", + Messages: []ChatMessage{ + {Role: "system", Content: json.RawMessage(`"You are helpful."`)}, + {Role: "user", Content: json.RawMessage(`"Hi"`)}, + }, + } + + resp, err := ChatCompletionsToResponses(req) + require.NoError(t, err) + + var items []ResponsesInputItem + require.NoError(t, json.Unmarshal(resp.Input, &items)) + require.Len(t, items, 2) + assert.Equal(t, "system", items[0].Role) + assert.Equal(t, "user", items[1].Role) +} + +func TestChatCompletionsToResponses_ToolCalls(t *testing.T) { + req := &ChatCompletionsRequest{ + Model: "gpt-4o", + Messages: []ChatMessage{ + {Role: "user", Content: json.RawMessage(`"Call the function"`)}, + { + Role: "assistant", + ToolCalls: []ChatToolCall{ + { + ID: "call_1", + Type: "function", + Function: ChatFunctionCall{ + Name: "ping", + Arguments: `{"host":"example.com"}`, + }, + }, + }, + }, + { + Role: "tool", + ToolCallID: "call_1", + Content: json.RawMessage(`"pong"`), + }, + }, + Tools: []ChatTool{ + { + Type: "function", + Function: &ChatFunction{ + Name: "ping", + Description: "Ping a host", + Parameters: json.RawMessage(`{"type":"object"}`), + }, + }, + }, + } + + resp, err := ChatCompletionsToResponses(req) + require.NoError(t, err) + + var items []ResponsesInputItem + require.NoError(t, json.Unmarshal(resp.Input, &items)) + // user + function_call + function_call_output = 3 + // (assistant message with empty content + tool_calls → only function_call items emitted) + require.Len(t, items, 3) + + // Check function_call item + assert.Equal(t, "function_call", items[1].Type) + assert.Equal(t, "call_1", items[1].CallID) + assert.Equal(t, "ping", items[1].Name) + + // Check function_call_output item + assert.Equal(t, "function_call_output", items[2].Type) + assert.Equal(t, "call_1", items[2].CallID) + assert.Equal(t, "pong", items[2].Output) + + // Check tools + require.Len(t, resp.Tools, 1) + assert.Equal(t, "function", resp.Tools[0].Type) + assert.Equal(t, "ping", resp.Tools[0].Name) +} + +func TestChatCompletionsToResponses_MaxTokens(t *testing.T) { + t.Run("max_tokens", func(t *testing.T) { + maxTokens := 100 + req := &ChatCompletionsRequest{ + Model: "gpt-4o", + MaxTokens: &maxTokens, + Messages: []ChatMessage{{Role: "user", Content: json.RawMessage(`"Hi"`)}}, + } + resp, err := ChatCompletionsToResponses(req) + require.NoError(t, err) + require.NotNil(t, resp.MaxOutputTokens) + // Below minMaxOutputTokens (128), should be clamped + assert.Equal(t, minMaxOutputTokens, *resp.MaxOutputTokens) + }) + + t.Run("max_completion_tokens_preferred", func(t *testing.T) { + maxTokens := 100 + maxCompletion := 500 + req := &ChatCompletionsRequest{ + Model: "gpt-4o", + MaxTokens: &maxTokens, + MaxCompletionTokens: &maxCompletion, + Messages: []ChatMessage{{Role: "user", Content: json.RawMessage(`"Hi"`)}}, + } + resp, err := ChatCompletionsToResponses(req) + require.NoError(t, err) + require.NotNil(t, resp.MaxOutputTokens) + assert.Equal(t, 500, *resp.MaxOutputTokens) + }) +} + +func TestChatCompletionsToResponses_ReasoningEffort(t *testing.T) { + req := &ChatCompletionsRequest{ + Model: "gpt-4o", + ReasoningEffort: "high", + Messages: []ChatMessage{{Role: "user", Content: json.RawMessage(`"Hi"`)}}, + } + resp, err := ChatCompletionsToResponses(req) + require.NoError(t, err) + require.NotNil(t, resp.Reasoning) + assert.Equal(t, "high", resp.Reasoning.Effort) + assert.Equal(t, "auto", resp.Reasoning.Summary) +} + +func TestChatCompletionsToResponses_ImageURL(t *testing.T) { + content := `[{"type":"text","text":"Describe this"},{"type":"image_url","image_url":{"url":"data:image/png;base64,abc123"}}]` + req := &ChatCompletionsRequest{ + Model: "gpt-4o", + Messages: []ChatMessage{ + {Role: "user", Content: json.RawMessage(content)}, + }, + } + resp, err := ChatCompletionsToResponses(req) + require.NoError(t, err) + + var items []ResponsesInputItem + require.NoError(t, json.Unmarshal(resp.Input, &items)) + require.Len(t, items, 1) + + var parts []ResponsesContentPart + require.NoError(t, json.Unmarshal(items[0].Content, &parts)) + require.Len(t, parts, 2) + assert.Equal(t, "input_text", parts[0].Type) + assert.Equal(t, "Describe this", parts[0].Text) + assert.Equal(t, "input_image", parts[1].Type) + assert.Equal(t, "data:image/png;base64,abc123", parts[1].ImageURL) +} + +func TestChatCompletionsToResponses_LegacyFunctions(t *testing.T) { + req := &ChatCompletionsRequest{ + Model: "gpt-4o", + Messages: []ChatMessage{ + {Role: "user", Content: json.RawMessage(`"Hi"`)}, + }, + Functions: []ChatFunction{ + { + Name: "get_weather", + Description: "Get weather", + Parameters: json.RawMessage(`{"type":"object"}`), + }, + }, + FunctionCall: json.RawMessage(`{"name":"get_weather"}`), + } + + resp, err := ChatCompletionsToResponses(req) + require.NoError(t, err) + require.Len(t, resp.Tools, 1) + assert.Equal(t, "function", resp.Tools[0].Type) + assert.Equal(t, "get_weather", resp.Tools[0].Name) + + // tool_choice should be converted + require.NotNil(t, resp.ToolChoice) + var tc map[string]any + require.NoError(t, json.Unmarshal(resp.ToolChoice, &tc)) + assert.Equal(t, "function", tc["type"]) +} + +func TestChatCompletionsToResponses_ServiceTier(t *testing.T) { + req := &ChatCompletionsRequest{ + Model: "gpt-4o", + ServiceTier: "flex", + Messages: []ChatMessage{{Role: "user", Content: json.RawMessage(`"Hi"`)}}, + } + resp, err := ChatCompletionsToResponses(req) + require.NoError(t, err) + assert.Equal(t, "flex", resp.ServiceTier) +} + +func TestChatCompletionsToResponses_AssistantWithTextAndToolCalls(t *testing.T) { + req := &ChatCompletionsRequest{ + Model: "gpt-4o", + Messages: []ChatMessage{ + {Role: "user", Content: json.RawMessage(`"Do something"`)}, + { + Role: "assistant", + Content: json.RawMessage(`"Let me call a function."`), + ToolCalls: []ChatToolCall{ + { + ID: "call_abc", + Type: "function", + Function: ChatFunctionCall{ + Name: "do_thing", + Arguments: `{}`, + }, + }, + }, + }, + }, + } + + resp, err := ChatCompletionsToResponses(req) + require.NoError(t, err) + + var items []ResponsesInputItem + require.NoError(t, json.Unmarshal(resp.Input, &items)) + // user + assistant message (with text) + function_call + require.Len(t, items, 3) + assert.Equal(t, "user", items[0].Role) + assert.Equal(t, "assistant", items[1].Role) + assert.Equal(t, "function_call", items[2].Type) +} + +// --------------------------------------------------------------------------- +// ResponsesToChatCompletions tests +// --------------------------------------------------------------------------- + +func TestResponsesToChatCompletions_BasicText(t *testing.T) { + resp := &ResponsesResponse{ + ID: "resp_123", + Status: "completed", + Output: []ResponsesOutput{ + { + Type: "message", + Content: []ResponsesContentPart{ + {Type: "output_text", Text: "Hello, world!"}, + }, + }, + }, + Usage: &ResponsesUsage{ + InputTokens: 10, + OutputTokens: 5, + TotalTokens: 15, + }, + } + + chat := ResponsesToChatCompletions(resp, "gpt-4o") + assert.Equal(t, "chat.completion", chat.Object) + assert.Equal(t, "gpt-4o", chat.Model) + require.Len(t, chat.Choices, 1) + assert.Equal(t, "stop", chat.Choices[0].FinishReason) + + var content string + require.NoError(t, json.Unmarshal(chat.Choices[0].Message.Content, &content)) + assert.Equal(t, "Hello, world!", content) + + require.NotNil(t, chat.Usage) + assert.Equal(t, 10, chat.Usage.PromptTokens) + assert.Equal(t, 5, chat.Usage.CompletionTokens) + assert.Equal(t, 15, chat.Usage.TotalTokens) +} + +func TestResponsesToChatCompletions_ToolCalls(t *testing.T) { + resp := &ResponsesResponse{ + ID: "resp_456", + Status: "completed", + Output: []ResponsesOutput{ + { + Type: "function_call", + CallID: "call_xyz", + Name: "get_weather", + Arguments: `{"city":"NYC"}`, + }, + }, + } + + chat := ResponsesToChatCompletions(resp, "gpt-4o") + require.Len(t, chat.Choices, 1) + assert.Equal(t, "tool_calls", chat.Choices[0].FinishReason) + + msg := chat.Choices[0].Message + require.Len(t, msg.ToolCalls, 1) + assert.Equal(t, "call_xyz", msg.ToolCalls[0].ID) + assert.Equal(t, "function", msg.ToolCalls[0].Type) + assert.Equal(t, "get_weather", msg.ToolCalls[0].Function.Name) + assert.Equal(t, `{"city":"NYC"}`, msg.ToolCalls[0].Function.Arguments) +} + +func TestResponsesToChatCompletions_Reasoning(t *testing.T) { + resp := &ResponsesResponse{ + ID: "resp_789", + Status: "completed", + Output: []ResponsesOutput{ + { + Type: "reasoning", + Summary: []ResponsesSummary{ + {Type: "summary_text", Text: "I thought about it."}, + }, + }, + { + Type: "message", + Content: []ResponsesContentPart{ + {Type: "output_text", Text: "The answer is 42."}, + }, + }, + }, + } + + chat := ResponsesToChatCompletions(resp, "gpt-4o") + require.Len(t, chat.Choices, 1) + + var content string + require.NoError(t, json.Unmarshal(chat.Choices[0].Message.Content, &content)) + // Reasoning summary is prepended to text + assert.Equal(t, "I thought about it.The answer is 42.", content) +} + +func TestResponsesToChatCompletions_Incomplete(t *testing.T) { + resp := &ResponsesResponse{ + ID: "resp_inc", + Status: "incomplete", + IncompleteDetails: &ResponsesIncompleteDetails{Reason: "max_output_tokens"}, + Output: []ResponsesOutput{ + { + Type: "message", + Content: []ResponsesContentPart{ + {Type: "output_text", Text: "partial..."}, + }, + }, + }, + } + + chat := ResponsesToChatCompletions(resp, "gpt-4o") + require.Len(t, chat.Choices, 1) + assert.Equal(t, "length", chat.Choices[0].FinishReason) +} + +func TestResponsesToChatCompletions_CachedTokens(t *testing.T) { + resp := &ResponsesResponse{ + ID: "resp_cache", + Status: "completed", + Output: []ResponsesOutput{ + { + Type: "message", + Content: []ResponsesContentPart{{Type: "output_text", Text: "cached"}}, + }, + }, + Usage: &ResponsesUsage{ + InputTokens: 100, + OutputTokens: 10, + TotalTokens: 110, + InputTokensDetails: &ResponsesInputTokensDetails{ + CachedTokens: 80, + }, + }, + } + + chat := ResponsesToChatCompletions(resp, "gpt-4o") + require.NotNil(t, chat.Usage) + require.NotNil(t, chat.Usage.PromptTokensDetails) + assert.Equal(t, 80, chat.Usage.PromptTokensDetails.CachedTokens) +} + +func TestResponsesToChatCompletions_WebSearch(t *testing.T) { + resp := &ResponsesResponse{ + ID: "resp_ws", + Status: "completed", + Output: []ResponsesOutput{ + { + Type: "web_search_call", + Action: &WebSearchAction{Type: "search", Query: "test"}, + }, + { + Type: "message", + Content: []ResponsesContentPart{{Type: "output_text", Text: "search results"}}, + }, + }, + } + + chat := ResponsesToChatCompletions(resp, "gpt-4o") + require.Len(t, chat.Choices, 1) + assert.Equal(t, "stop", chat.Choices[0].FinishReason) + + var content string + require.NoError(t, json.Unmarshal(chat.Choices[0].Message.Content, &content)) + assert.Equal(t, "search results", content) +} + +// --------------------------------------------------------------------------- +// Streaming: ResponsesEventToChatChunks tests +// --------------------------------------------------------------------------- + +func TestResponsesEventToChatChunks_TextDelta(t *testing.T) { + state := NewResponsesEventToChatState() + state.Model = "gpt-4o" + + // response.created → role chunk + chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{ + Type: "response.created", + Response: &ResponsesResponse{ + ID: "resp_stream", + }, + }, state) + require.Len(t, chunks, 1) + assert.Equal(t, "assistant", chunks[0].Choices[0].Delta.Role) + assert.True(t, state.SentRole) + + // response.output_text.delta → content chunk + chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{ + Type: "response.output_text.delta", + Delta: "Hello", + }, state) + require.Len(t, chunks, 1) + require.NotNil(t, chunks[0].Choices[0].Delta.Content) + assert.Equal(t, "Hello", *chunks[0].Choices[0].Delta.Content) +} + +func TestResponsesEventToChatChunks_ToolCallDelta(t *testing.T) { + state := NewResponsesEventToChatState() + state.Model = "gpt-4o" + state.SentRole = true + + // response.output_item.added (function_call) — output_index=1 (e.g. after a message item at 0) + chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{ + Type: "response.output_item.added", + OutputIndex: 1, + Item: &ResponsesOutput{ + Type: "function_call", + CallID: "call_1", + Name: "get_weather", + }, + }, state) + require.Len(t, chunks, 1) + require.Len(t, chunks[0].Choices[0].Delta.ToolCalls, 1) + tc := chunks[0].Choices[0].Delta.ToolCalls[0] + assert.Equal(t, "call_1", tc.ID) + assert.Equal(t, "get_weather", tc.Function.Name) + require.NotNil(t, tc.Index) + assert.Equal(t, 0, *tc.Index) + + // response.function_call_arguments.delta — uses output_index (NOT call_id) to find tool + chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{ + Type: "response.function_call_arguments.delta", + OutputIndex: 1, // matches the output_index from output_item.added above + Delta: `{"city":`, + }, state) + require.Len(t, chunks, 1) + tc = chunks[0].Choices[0].Delta.ToolCalls[0] + require.NotNil(t, tc.Index) + assert.Equal(t, 0, *tc.Index, "argument delta must use same index as the tool call") + assert.Equal(t, `{"city":`, tc.Function.Arguments) + + // Add a second function call at output_index=2 + chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{ + Type: "response.output_item.added", + OutputIndex: 2, + Item: &ResponsesOutput{ + Type: "function_call", + CallID: "call_2", + Name: "get_time", + }, + }, state) + require.Len(t, chunks, 1) + tc = chunks[0].Choices[0].Delta.ToolCalls[0] + require.NotNil(t, tc.Index) + assert.Equal(t, 1, *tc.Index, "second tool call should get index 1") + + // Argument delta for second tool call + chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{ + Type: "response.function_call_arguments.delta", + OutputIndex: 2, + Delta: `{"tz":"UTC"}`, + }, state) + require.Len(t, chunks, 1) + tc = chunks[0].Choices[0].Delta.ToolCalls[0] + require.NotNil(t, tc.Index) + assert.Equal(t, 1, *tc.Index, "second tool arg delta must use index 1") + + // Argument delta for first tool call (interleaved) + chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{ + Type: "response.function_call_arguments.delta", + OutputIndex: 1, + Delta: `"Tokyo"}`, + }, state) + require.Len(t, chunks, 1) + tc = chunks[0].Choices[0].Delta.ToolCalls[0] + require.NotNil(t, tc.Index) + assert.Equal(t, 0, *tc.Index, "first tool arg delta must still use index 0") +} + +func TestResponsesEventToChatChunks_Completed(t *testing.T) { + state := NewResponsesEventToChatState() + state.Model = "gpt-4o" + state.IncludeUsage = true + + chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{ + Type: "response.completed", + Response: &ResponsesResponse{ + Status: "completed", + Usage: &ResponsesUsage{ + InputTokens: 50, + OutputTokens: 20, + TotalTokens: 70, + InputTokensDetails: &ResponsesInputTokensDetails{ + CachedTokens: 30, + }, + }, + }, + }, state) + // finish chunk + usage chunk + require.Len(t, chunks, 2) + + // First chunk: finish_reason + require.NotNil(t, chunks[0].Choices[0].FinishReason) + assert.Equal(t, "stop", *chunks[0].Choices[0].FinishReason) + + // Second chunk: usage + require.NotNil(t, chunks[1].Usage) + assert.Equal(t, 50, chunks[1].Usage.PromptTokens) + assert.Equal(t, 20, chunks[1].Usage.CompletionTokens) + assert.Equal(t, 70, chunks[1].Usage.TotalTokens) + require.NotNil(t, chunks[1].Usage.PromptTokensDetails) + assert.Equal(t, 30, chunks[1].Usage.PromptTokensDetails.CachedTokens) +} + +func TestResponsesEventToChatChunks_CompletedWithToolCalls(t *testing.T) { + state := NewResponsesEventToChatState() + state.Model = "gpt-4o" + state.SawToolCall = true + + chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{ + Type: "response.completed", + Response: &ResponsesResponse{ + Status: "completed", + }, + }, state) + require.Len(t, chunks, 1) + require.NotNil(t, chunks[0].Choices[0].FinishReason) + assert.Equal(t, "tool_calls", *chunks[0].Choices[0].FinishReason) +} + +func TestResponsesEventToChatChunks_ReasoningDelta(t *testing.T) { + state := NewResponsesEventToChatState() + state.Model = "gpt-4o" + state.SentRole = true + + chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{ + Type: "response.reasoning_summary_text.delta", + Delta: "Thinking...", + }, state) + require.Len(t, chunks, 1) + require.NotNil(t, chunks[0].Choices[0].Delta.Content) + assert.Equal(t, "Thinking...", *chunks[0].Choices[0].Delta.Content) +} + +func TestFinalizeResponsesChatStream(t *testing.T) { + state := NewResponsesEventToChatState() + state.Model = "gpt-4o" + state.IncludeUsage = true + state.Usage = &ChatUsage{ + PromptTokens: 100, + CompletionTokens: 50, + TotalTokens: 150, + } + + chunks := FinalizeResponsesChatStream(state) + require.Len(t, chunks, 2) + + // Finish chunk + require.NotNil(t, chunks[0].Choices[0].FinishReason) + assert.Equal(t, "stop", *chunks[0].Choices[0].FinishReason) + + // Usage chunk + require.NotNil(t, chunks[1].Usage) + assert.Equal(t, 100, chunks[1].Usage.PromptTokens) + + // Idempotent: second call returns nil + assert.Nil(t, FinalizeResponsesChatStream(state)) +} + +func TestFinalizeResponsesChatStream_AfterCompleted(t *testing.T) { + // If response.completed already emitted the finish chunk, FinalizeResponsesChatStream + // must be a no-op (prevents double finish_reason being sent to the client). + state := NewResponsesEventToChatState() + state.Model = "gpt-4o" + state.IncludeUsage = true + + // Simulate response.completed + chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{ + Type: "response.completed", + Response: &ResponsesResponse{ + Status: "completed", + Usage: &ResponsesUsage{ + InputTokens: 10, + OutputTokens: 5, + TotalTokens: 15, + }, + }, + }, state) + require.NotEmpty(t, chunks) // finish + usage chunks + + // Now FinalizeResponsesChatStream should return nil — already finalized. + assert.Nil(t, FinalizeResponsesChatStream(state)) +} + +func TestChatChunkToSSE(t *testing.T) { + chunk := ChatCompletionsChunk{ + ID: "chatcmpl-test", + Object: "chat.completion.chunk", + Created: 1700000000, + Model: "gpt-4o", + Choices: []ChatChunkChoice{ + { + Index: 0, + Delta: ChatDelta{Role: "assistant"}, + FinishReason: nil, + }, + }, + } + + sse, err := ChatChunkToSSE(chunk) + require.NoError(t, err) + assert.Contains(t, sse, "data: ") + assert.Contains(t, sse, "chatcmpl-test") + assert.Contains(t, sse, "assistant") + assert.True(t, len(sse) > 10) +} + +// --------------------------------------------------------------------------- +// Stream round-trip test +// --------------------------------------------------------------------------- + +func TestChatCompletionsStreamRoundTrip(t *testing.T) { + // Simulate: client sends chat completions request, upstream returns Responses SSE events. + // Verify that the streaming state machine produces correct chat completions chunks. + + state := NewResponsesEventToChatState() + state.Model = "gpt-4o" + state.IncludeUsage = true + + var allChunks []ChatCompletionsChunk + + // 1. response.created + chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{ + Type: "response.created", + Response: &ResponsesResponse{ID: "resp_rt"}, + }, state) + allChunks = append(allChunks, chunks...) + + // 2. text deltas + for _, text := range []string{"Hello", ", ", "world", "!"} { + chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{ + Type: "response.output_text.delta", + Delta: text, + }, state) + allChunks = append(allChunks, chunks...) + } + + // 3. response.completed + chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{ + Type: "response.completed", + Response: &ResponsesResponse{ + Status: "completed", + Usage: &ResponsesUsage{ + InputTokens: 10, + OutputTokens: 4, + TotalTokens: 14, + }, + }, + }, state) + allChunks = append(allChunks, chunks...) + + // Verify: role chunk + 4 text chunks + finish chunk + usage chunk = 7 + require.Len(t, allChunks, 7) + + // First chunk has role + assert.Equal(t, "assistant", allChunks[0].Choices[0].Delta.Role) + + // Text chunks + var fullText string + for i := 1; i <= 4; i++ { + require.NotNil(t, allChunks[i].Choices[0].Delta.Content) + fullText += *allChunks[i].Choices[0].Delta.Content + } + assert.Equal(t, "Hello, world!", fullText) + + // Finish chunk + require.NotNil(t, allChunks[5].Choices[0].FinishReason) + assert.Equal(t, "stop", *allChunks[5].Choices[0].FinishReason) + + // Usage chunk + require.NotNil(t, allChunks[6].Usage) + assert.Equal(t, 10, allChunks[6].Usage.PromptTokens) + assert.Equal(t, 4, allChunks[6].Usage.CompletionTokens) + + // All chunks share the same ID + for _, c := range allChunks { + assert.Equal(t, "resp_rt", c.ID) + } +} diff --git a/backend/internal/pkg/apicompat/chatcompletions_to_responses.go b/backend/internal/pkg/apicompat/chatcompletions_to_responses.go new file mode 100644 index 00000000..37285b09 --- /dev/null +++ b/backend/internal/pkg/apicompat/chatcompletions_to_responses.go @@ -0,0 +1,312 @@ +package apicompat + +import ( + "encoding/json" + "fmt" +) + +// ChatCompletionsToResponses converts a Chat Completions request into a +// Responses API request. The upstream always streams, so Stream is forced to +// true. store is always false and reasoning.encrypted_content is always +// included so that the response translator has full context. +func ChatCompletionsToResponses(req *ChatCompletionsRequest) (*ResponsesRequest, error) { + input, err := convertChatMessagesToResponsesInput(req.Messages) + if err != nil { + return nil, err + } + + inputJSON, err := json.Marshal(input) + if err != nil { + return nil, err + } + + out := &ResponsesRequest{ + Model: req.Model, + Input: inputJSON, + Temperature: req.Temperature, + TopP: req.TopP, + Stream: true, // upstream always streams + Include: []string{"reasoning.encrypted_content"}, + ServiceTier: req.ServiceTier, + } + + storeFalse := false + out.Store = &storeFalse + + // max_tokens / max_completion_tokens → max_output_tokens, prefer max_completion_tokens + maxTokens := 0 + if req.MaxTokens != nil { + maxTokens = *req.MaxTokens + } + if req.MaxCompletionTokens != nil { + maxTokens = *req.MaxCompletionTokens + } + if maxTokens > 0 { + v := maxTokens + if v < minMaxOutputTokens { + v = minMaxOutputTokens + } + out.MaxOutputTokens = &v + } + + // reasoning_effort → reasoning.effort + reasoning.summary="auto" + if req.ReasoningEffort != "" { + out.Reasoning = &ResponsesReasoning{ + Effort: req.ReasoningEffort, + Summary: "auto", + } + } + + // tools[] and legacy functions[] → ResponsesTool[] + if len(req.Tools) > 0 || len(req.Functions) > 0 { + out.Tools = convertChatToolsToResponses(req.Tools, req.Functions) + } + + // tool_choice: already compatible format — pass through directly. + // Legacy function_call needs mapping. + if len(req.ToolChoice) > 0 { + out.ToolChoice = req.ToolChoice + } else if len(req.FunctionCall) > 0 { + tc, err := convertChatFunctionCallToToolChoice(req.FunctionCall) + if err != nil { + return nil, fmt.Errorf("convert function_call: %w", err) + } + out.ToolChoice = tc + } + + return out, nil +} + +// convertChatMessagesToResponsesInput converts the Chat Completions messages +// array into a Responses API input items array. +func convertChatMessagesToResponsesInput(msgs []ChatMessage) ([]ResponsesInputItem, error) { + var out []ResponsesInputItem + for _, m := range msgs { + items, err := chatMessageToResponsesItems(m) + if err != nil { + return nil, err + } + out = append(out, items...) + } + return out, nil +} + +// chatMessageToResponsesItems converts a single ChatMessage into one or more +// ResponsesInputItem values. +func chatMessageToResponsesItems(m ChatMessage) ([]ResponsesInputItem, error) { + switch m.Role { + case "system": + return chatSystemToResponses(m) + case "user": + return chatUserToResponses(m) + case "assistant": + return chatAssistantToResponses(m) + case "tool": + return chatToolToResponses(m) + case "function": + return chatFunctionToResponses(m) + default: + return chatUserToResponses(m) + } +} + +// chatSystemToResponses converts a system message. +func chatSystemToResponses(m ChatMessage) ([]ResponsesInputItem, error) { + text, err := parseChatContent(m.Content) + if err != nil { + return nil, err + } + content, err := json.Marshal(text) + if err != nil { + return nil, err + } + return []ResponsesInputItem{{Role: "system", Content: content}}, nil +} + +// chatUserToResponses converts a user message, handling both plain strings and +// multi-modal content arrays. +func chatUserToResponses(m ChatMessage) ([]ResponsesInputItem, error) { + // Try plain string first. + var s string + if err := json.Unmarshal(m.Content, &s); err == nil { + content, _ := json.Marshal(s) + return []ResponsesInputItem{{Role: "user", Content: content}}, nil + } + + var parts []ChatContentPart + if err := json.Unmarshal(m.Content, &parts); err != nil { + return nil, fmt.Errorf("parse user content: %w", err) + } + + var responseParts []ResponsesContentPart + for _, p := range parts { + switch p.Type { + case "text": + if p.Text != "" { + responseParts = append(responseParts, ResponsesContentPart{ + Type: "input_text", + Text: p.Text, + }) + } + case "image_url": + if p.ImageURL != nil && p.ImageURL.URL != "" { + responseParts = append(responseParts, ResponsesContentPart{ + Type: "input_image", + ImageURL: p.ImageURL.URL, + }) + } + } + } + + content, err := json.Marshal(responseParts) + if err != nil { + return nil, err + } + return []ResponsesInputItem{{Role: "user", Content: content}}, nil +} + +// chatAssistantToResponses converts an assistant message. If there is both +// text content and tool_calls, the text is emitted as an assistant message +// first, then each tool_call becomes a function_call item. If the content is +// empty/nil and there are tool_calls, only function_call items are emitted. +func chatAssistantToResponses(m ChatMessage) ([]ResponsesInputItem, error) { + var items []ResponsesInputItem + + // Emit assistant message with output_text if content is non-empty. + if len(m.Content) > 0 { + var s string + if err := json.Unmarshal(m.Content, &s); err == nil && s != "" { + parts := []ResponsesContentPart{{Type: "output_text", Text: s}} + partsJSON, err := json.Marshal(parts) + if err != nil { + return nil, err + } + items = append(items, ResponsesInputItem{Role: "assistant", Content: partsJSON}) + } + } + + // Emit one function_call item per tool_call. + for _, tc := range m.ToolCalls { + args := tc.Function.Arguments + if args == "" { + args = "{}" + } + items = append(items, ResponsesInputItem{ + Type: "function_call", + CallID: tc.ID, + Name: tc.Function.Name, + Arguments: args, + ID: tc.ID, + }) + } + + return items, nil +} + +// chatToolToResponses converts a tool result message (role=tool) into a +// function_call_output item. +func chatToolToResponses(m ChatMessage) ([]ResponsesInputItem, error) { + output, err := parseChatContent(m.Content) + if err != nil { + return nil, err + } + if output == "" { + output = "(empty)" + } + return []ResponsesInputItem{{ + Type: "function_call_output", + CallID: m.ToolCallID, + Output: output, + }}, nil +} + +// chatFunctionToResponses converts a legacy function result message +// (role=function) into a function_call_output item. The Name field is used as +// call_id since legacy function calls do not carry a separate call_id. +func chatFunctionToResponses(m ChatMessage) ([]ResponsesInputItem, error) { + output, err := parseChatContent(m.Content) + if err != nil { + return nil, err + } + if output == "" { + output = "(empty)" + } + return []ResponsesInputItem{{ + Type: "function_call_output", + CallID: m.Name, + Output: output, + }}, nil +} + +// parseChatContent returns the string value of a ChatMessage Content field. +// Content must be a JSON string. Returns "" if content is null or empty. +func parseChatContent(raw json.RawMessage) (string, error) { + if len(raw) == 0 { + return "", nil + } + var s string + if err := json.Unmarshal(raw, &s); err != nil { + return "", fmt.Errorf("parse content as string: %w", err) + } + return s, nil +} + +// convertChatToolsToResponses maps Chat Completions tool definitions and legacy +// function definitions to Responses API tool definitions. +func convertChatToolsToResponses(tools []ChatTool, functions []ChatFunction) []ResponsesTool { + var out []ResponsesTool + + for _, t := range tools { + if t.Type != "function" || t.Function == nil { + continue + } + rt := ResponsesTool{ + Type: "function", + Name: t.Function.Name, + Description: t.Function.Description, + Parameters: t.Function.Parameters, + Strict: t.Function.Strict, + } + out = append(out, rt) + } + + // Legacy functions[] are treated as function-type tools. + for _, f := range functions { + rt := ResponsesTool{ + Type: "function", + Name: f.Name, + Description: f.Description, + Parameters: f.Parameters, + Strict: f.Strict, + } + out = append(out, rt) + } + + return out +} + +// convertChatFunctionCallToToolChoice maps the legacy function_call field to a +// Responses API tool_choice value. +// +// "auto" → "auto" +// "none" → "none" +// {"name":"X"} → {"type":"function","function":{"name":"X"}} +func convertChatFunctionCallToToolChoice(raw json.RawMessage) (json.RawMessage, error) { + // Try string first ("auto", "none", etc.) — pass through as-is. + var s string + if err := json.Unmarshal(raw, &s); err == nil { + return json.Marshal(s) + } + + // Object form: {"name":"X"} + var obj struct { + Name string `json:"name"` + } + if err := json.Unmarshal(raw, &obj); err != nil { + return nil, err + } + return json.Marshal(map[string]any{ + "type": "function", + "function": map[string]string{"name": obj.Name}, + }) +} diff --git a/backend/internal/pkg/apicompat/responses_to_anthropic.go b/backend/internal/pkg/apicompat/responses_to_anthropic.go new file mode 100644 index 00000000..5409a0f4 --- /dev/null +++ b/backend/internal/pkg/apicompat/responses_to_anthropic.go @@ -0,0 +1,516 @@ +package apicompat + +import ( + "encoding/json" + "fmt" + "time" +) + +// --------------------------------------------------------------------------- +// Non-streaming: ResponsesResponse → AnthropicResponse +// --------------------------------------------------------------------------- + +// ResponsesToAnthropic converts a Responses API response directly into an +// Anthropic Messages response. Reasoning output items are mapped to thinking +// blocks; function_call items become tool_use blocks. +func ResponsesToAnthropic(resp *ResponsesResponse, model string) *AnthropicResponse { + out := &AnthropicResponse{ + ID: resp.ID, + Type: "message", + Role: "assistant", + Model: model, + } + + var blocks []AnthropicContentBlock + + for _, item := range resp.Output { + switch item.Type { + case "reasoning": + summaryText := "" + for _, s := range item.Summary { + if s.Type == "summary_text" && s.Text != "" { + summaryText += s.Text + } + } + if summaryText != "" { + blocks = append(blocks, AnthropicContentBlock{ + Type: "thinking", + Thinking: summaryText, + }) + } + case "message": + for _, part := range item.Content { + if part.Type == "output_text" && part.Text != "" { + blocks = append(blocks, AnthropicContentBlock{ + Type: "text", + Text: part.Text, + }) + } + } + case "function_call": + blocks = append(blocks, AnthropicContentBlock{ + Type: "tool_use", + ID: fromResponsesCallID(item.CallID), + Name: item.Name, + Input: json.RawMessage(item.Arguments), + }) + case "web_search_call": + toolUseID := "srvtoolu_" + item.ID + query := "" + if item.Action != nil { + query = item.Action.Query + } + inputJSON, _ := json.Marshal(map[string]string{"query": query}) + blocks = append(blocks, AnthropicContentBlock{ + Type: "server_tool_use", + ID: toolUseID, + Name: "web_search", + Input: inputJSON, + }) + emptyResults, _ := json.Marshal([]struct{}{}) + blocks = append(blocks, AnthropicContentBlock{ + Type: "web_search_tool_result", + ToolUseID: toolUseID, + Content: emptyResults, + }) + } + } + + if len(blocks) == 0 { + blocks = append(blocks, AnthropicContentBlock{Type: "text", Text: ""}) + } + out.Content = blocks + + out.StopReason = responsesStatusToAnthropicStopReason(resp.Status, resp.IncompleteDetails, blocks) + + if resp.Usage != nil { + out.Usage = AnthropicUsage{ + InputTokens: resp.Usage.InputTokens, + OutputTokens: resp.Usage.OutputTokens, + } + if resp.Usage.InputTokensDetails != nil { + out.Usage.CacheReadInputTokens = resp.Usage.InputTokensDetails.CachedTokens + } + } + + return out +} + +func responsesStatusToAnthropicStopReason(status string, details *ResponsesIncompleteDetails, blocks []AnthropicContentBlock) string { + switch status { + case "incomplete": + if details != nil && details.Reason == "max_output_tokens" { + return "max_tokens" + } + return "end_turn" + case "completed": + if len(blocks) > 0 && blocks[len(blocks)-1].Type == "tool_use" { + return "tool_use" + } + return "end_turn" + default: + return "end_turn" + } +} + +// --------------------------------------------------------------------------- +// Streaming: ResponsesStreamEvent → []AnthropicStreamEvent (stateful converter) +// --------------------------------------------------------------------------- + +// ResponsesEventToAnthropicState tracks state for converting a sequence of +// Responses SSE events directly into Anthropic SSE events. +type ResponsesEventToAnthropicState struct { + MessageStartSent bool + MessageStopSent bool + + ContentBlockIndex int + ContentBlockOpen bool + CurrentBlockType string // "text" | "thinking" | "tool_use" + + // OutputIndexToBlockIdx maps Responses output_index → Anthropic content block index. + OutputIndexToBlockIdx map[int]int + + InputTokens int + OutputTokens int + CacheReadInputTokens int + + ResponseID string + Model string + Created int64 +} + +// NewResponsesEventToAnthropicState returns an initialised stream state. +func NewResponsesEventToAnthropicState() *ResponsesEventToAnthropicState { + return &ResponsesEventToAnthropicState{ + OutputIndexToBlockIdx: make(map[int]int), + Created: time.Now().Unix(), + } +} + +// ResponsesEventToAnthropicEvents converts a single Responses SSE event into +// zero or more Anthropic SSE events, updating state as it goes. +func ResponsesEventToAnthropicEvents( + evt *ResponsesStreamEvent, + state *ResponsesEventToAnthropicState, +) []AnthropicStreamEvent { + switch evt.Type { + case "response.created": + return resToAnthHandleCreated(evt, state) + case "response.output_item.added": + return resToAnthHandleOutputItemAdded(evt, state) + case "response.output_text.delta": + return resToAnthHandleTextDelta(evt, state) + case "response.output_text.done": + return resToAnthHandleBlockDone(state) + case "response.function_call_arguments.delta": + return resToAnthHandleFuncArgsDelta(evt, state) + case "response.function_call_arguments.done": + return resToAnthHandleBlockDone(state) + case "response.output_item.done": + return resToAnthHandleOutputItemDone(evt, state) + case "response.reasoning_summary_text.delta": + return resToAnthHandleReasoningDelta(evt, state) + case "response.reasoning_summary_text.done": + return resToAnthHandleBlockDone(state) + case "response.completed", "response.incomplete", "response.failed": + return resToAnthHandleCompleted(evt, state) + default: + return nil + } +} + +// FinalizeResponsesAnthropicStream emits synthetic termination events if the +// stream ended without a proper completion event. +func FinalizeResponsesAnthropicStream(state *ResponsesEventToAnthropicState) []AnthropicStreamEvent { + if !state.MessageStartSent || state.MessageStopSent { + return nil + } + + var events []AnthropicStreamEvent + events = append(events, closeCurrentBlock(state)...) + + events = append(events, + AnthropicStreamEvent{ + Type: "message_delta", + Delta: &AnthropicDelta{ + StopReason: "end_turn", + }, + Usage: &AnthropicUsage{ + InputTokens: state.InputTokens, + OutputTokens: state.OutputTokens, + CacheReadInputTokens: state.CacheReadInputTokens, + }, + }, + AnthropicStreamEvent{Type: "message_stop"}, + ) + state.MessageStopSent = true + return events +} + +// ResponsesAnthropicEventToSSE formats an AnthropicStreamEvent as an SSE line pair. +func ResponsesAnthropicEventToSSE(evt AnthropicStreamEvent) (string, error) { + data, err := json.Marshal(evt) + if err != nil { + return "", err + } + return fmt.Sprintf("event: %s\ndata: %s\n\n", evt.Type, data), nil +} + +// --- internal handlers --- + +func resToAnthHandleCreated(evt *ResponsesStreamEvent, state *ResponsesEventToAnthropicState) []AnthropicStreamEvent { + if evt.Response != nil { + state.ResponseID = evt.Response.ID + // Only use upstream model if no override was set (e.g. originalModel) + if state.Model == "" { + state.Model = evt.Response.Model + } + } + + if state.MessageStartSent { + return nil + } + state.MessageStartSent = true + + return []AnthropicStreamEvent{{ + Type: "message_start", + Message: &AnthropicResponse{ + ID: state.ResponseID, + Type: "message", + Role: "assistant", + Content: []AnthropicContentBlock{}, + Model: state.Model, + Usage: AnthropicUsage{ + InputTokens: 0, + OutputTokens: 0, + }, + }, + }} +} + +func resToAnthHandleOutputItemAdded(evt *ResponsesStreamEvent, state *ResponsesEventToAnthropicState) []AnthropicStreamEvent { + if evt.Item == nil { + return nil + } + + switch evt.Item.Type { + case "function_call": + var events []AnthropicStreamEvent + events = append(events, closeCurrentBlock(state)...) + + idx := state.ContentBlockIndex + state.OutputIndexToBlockIdx[evt.OutputIndex] = idx + state.ContentBlockOpen = true + state.CurrentBlockType = "tool_use" + + events = append(events, AnthropicStreamEvent{ + Type: "content_block_start", + Index: &idx, + ContentBlock: &AnthropicContentBlock{ + Type: "tool_use", + ID: fromResponsesCallID(evt.Item.CallID), + Name: evt.Item.Name, + Input: json.RawMessage("{}"), + }, + }) + return events + + case "reasoning": + var events []AnthropicStreamEvent + events = append(events, closeCurrentBlock(state)...) + + idx := state.ContentBlockIndex + state.OutputIndexToBlockIdx[evt.OutputIndex] = idx + state.ContentBlockOpen = true + state.CurrentBlockType = "thinking" + + events = append(events, AnthropicStreamEvent{ + Type: "content_block_start", + Index: &idx, + ContentBlock: &AnthropicContentBlock{ + Type: "thinking", + Thinking: "", + }, + }) + return events + + case "message": + return nil + } + + return nil +} + +func resToAnthHandleTextDelta(evt *ResponsesStreamEvent, state *ResponsesEventToAnthropicState) []AnthropicStreamEvent { + if evt.Delta == "" { + return nil + } + + var events []AnthropicStreamEvent + + if !state.ContentBlockOpen || state.CurrentBlockType != "text" { + events = append(events, closeCurrentBlock(state)...) + + idx := state.ContentBlockIndex + state.ContentBlockOpen = true + state.CurrentBlockType = "text" + + events = append(events, AnthropicStreamEvent{ + Type: "content_block_start", + Index: &idx, + ContentBlock: &AnthropicContentBlock{ + Type: "text", + Text: "", + }, + }) + } + + idx := state.ContentBlockIndex + events = append(events, AnthropicStreamEvent{ + Type: "content_block_delta", + Index: &idx, + Delta: &AnthropicDelta{ + Type: "text_delta", + Text: evt.Delta, + }, + }) + return events +} + +func resToAnthHandleFuncArgsDelta(evt *ResponsesStreamEvent, state *ResponsesEventToAnthropicState) []AnthropicStreamEvent { + if evt.Delta == "" { + return nil + } + + blockIdx, ok := state.OutputIndexToBlockIdx[evt.OutputIndex] + if !ok { + return nil + } + + return []AnthropicStreamEvent{{ + Type: "content_block_delta", + Index: &blockIdx, + Delta: &AnthropicDelta{ + Type: "input_json_delta", + PartialJSON: evt.Delta, + }, + }} +} + +func resToAnthHandleReasoningDelta(evt *ResponsesStreamEvent, state *ResponsesEventToAnthropicState) []AnthropicStreamEvent { + if evt.Delta == "" { + return nil + } + + blockIdx, ok := state.OutputIndexToBlockIdx[evt.OutputIndex] + if !ok { + return nil + } + + return []AnthropicStreamEvent{{ + Type: "content_block_delta", + Index: &blockIdx, + Delta: &AnthropicDelta{ + Type: "thinking_delta", + Thinking: evt.Delta, + }, + }} +} + +func resToAnthHandleBlockDone(state *ResponsesEventToAnthropicState) []AnthropicStreamEvent { + if !state.ContentBlockOpen { + return nil + } + return closeCurrentBlock(state) +} + +func resToAnthHandleOutputItemDone(evt *ResponsesStreamEvent, state *ResponsesEventToAnthropicState) []AnthropicStreamEvent { + if evt.Item == nil { + return nil + } + + // Handle web_search_call → synthesize server_tool_use + web_search_tool_result blocks. + if evt.Item.Type == "web_search_call" && evt.Item.Status == "completed" { + return resToAnthHandleWebSearchDone(evt, state) + } + + if state.ContentBlockOpen { + return closeCurrentBlock(state) + } + return nil +} + +// resToAnthHandleWebSearchDone converts an OpenAI web_search_call output item +// into Anthropic server_tool_use + web_search_tool_result content block pairs. +// This allows Claude Code to count the searches performed. +func resToAnthHandleWebSearchDone(evt *ResponsesStreamEvent, state *ResponsesEventToAnthropicState) []AnthropicStreamEvent { + var events []AnthropicStreamEvent + events = append(events, closeCurrentBlock(state)...) + + toolUseID := "srvtoolu_" + evt.Item.ID + query := "" + if evt.Item.Action != nil { + query = evt.Item.Action.Query + } + inputJSON, _ := json.Marshal(map[string]string{"query": query}) + + // Emit server_tool_use block (start + stop). + idx1 := state.ContentBlockIndex + events = append(events, AnthropicStreamEvent{ + Type: "content_block_start", + Index: &idx1, + ContentBlock: &AnthropicContentBlock{ + Type: "server_tool_use", + ID: toolUseID, + Name: "web_search", + Input: inputJSON, + }, + }) + events = append(events, AnthropicStreamEvent{ + Type: "content_block_stop", + Index: &idx1, + }) + state.ContentBlockIndex++ + + // Emit web_search_tool_result block (start + stop). + // Content is empty because OpenAI does not expose individual search results; + // the model consumes them internally and produces text output. + emptyResults, _ := json.Marshal([]struct{}{}) + idx2 := state.ContentBlockIndex + events = append(events, AnthropicStreamEvent{ + Type: "content_block_start", + Index: &idx2, + ContentBlock: &AnthropicContentBlock{ + Type: "web_search_tool_result", + ToolUseID: toolUseID, + Content: emptyResults, + }, + }) + events = append(events, AnthropicStreamEvent{ + Type: "content_block_stop", + Index: &idx2, + }) + state.ContentBlockIndex++ + + return events +} + +func resToAnthHandleCompleted(evt *ResponsesStreamEvent, state *ResponsesEventToAnthropicState) []AnthropicStreamEvent { + if state.MessageStopSent { + return nil + } + + var events []AnthropicStreamEvent + events = append(events, closeCurrentBlock(state)...) + + stopReason := "end_turn" + if evt.Response != nil { + if evt.Response.Usage != nil { + state.InputTokens = evt.Response.Usage.InputTokens + state.OutputTokens = evt.Response.Usage.OutputTokens + if evt.Response.Usage.InputTokensDetails != nil { + state.CacheReadInputTokens = evt.Response.Usage.InputTokensDetails.CachedTokens + } + } + switch evt.Response.Status { + case "incomplete": + if evt.Response.IncompleteDetails != nil && evt.Response.IncompleteDetails.Reason == "max_output_tokens" { + stopReason = "max_tokens" + } + case "completed": + if state.ContentBlockIndex > 0 && state.CurrentBlockType == "tool_use" { + stopReason = "tool_use" + } + } + } + + events = append(events, + AnthropicStreamEvent{ + Type: "message_delta", + Delta: &AnthropicDelta{ + StopReason: stopReason, + }, + Usage: &AnthropicUsage{ + InputTokens: state.InputTokens, + OutputTokens: state.OutputTokens, + CacheReadInputTokens: state.CacheReadInputTokens, + }, + }, + AnthropicStreamEvent{Type: "message_stop"}, + ) + state.MessageStopSent = true + return events +} + +func closeCurrentBlock(state *ResponsesEventToAnthropicState) []AnthropicStreamEvent { + if !state.ContentBlockOpen { + return nil + } + idx := state.ContentBlockIndex + state.ContentBlockOpen = false + state.ContentBlockIndex++ + return []AnthropicStreamEvent{{ + Type: "content_block_stop", + Index: &idx, + }} +} diff --git a/backend/internal/pkg/apicompat/responses_to_chatcompletions.go b/backend/internal/pkg/apicompat/responses_to_chatcompletions.go new file mode 100644 index 00000000..8f83bce4 --- /dev/null +++ b/backend/internal/pkg/apicompat/responses_to_chatcompletions.go @@ -0,0 +1,368 @@ +package apicompat + +import ( + "crypto/rand" + "encoding/hex" + "encoding/json" + "fmt" + "time" +) + +// --------------------------------------------------------------------------- +// Non-streaming: ResponsesResponse → ChatCompletionsResponse +// --------------------------------------------------------------------------- + +// ResponsesToChatCompletions converts a Responses API response into a Chat +// Completions response. Text output items are concatenated into +// choices[0].message.content; function_call items become tool_calls. +func ResponsesToChatCompletions(resp *ResponsesResponse, model string) *ChatCompletionsResponse { + id := resp.ID + if id == "" { + id = generateChatCmplID() + } + + out := &ChatCompletionsResponse{ + ID: id, + Object: "chat.completion", + Created: time.Now().Unix(), + Model: model, + } + + var contentText string + var toolCalls []ChatToolCall + + for _, item := range resp.Output { + switch item.Type { + case "message": + for _, part := range item.Content { + if part.Type == "output_text" && part.Text != "" { + contentText += part.Text + } + } + case "function_call": + toolCalls = append(toolCalls, ChatToolCall{ + ID: item.CallID, + Type: "function", + Function: ChatFunctionCall{ + Name: item.Name, + Arguments: item.Arguments, + }, + }) + case "reasoning": + for _, s := range item.Summary { + if s.Type == "summary_text" && s.Text != "" { + contentText += s.Text + } + } + case "web_search_call": + // silently consumed — results already incorporated into text output + } + } + + msg := ChatMessage{Role: "assistant"} + if len(toolCalls) > 0 { + msg.ToolCalls = toolCalls + } + if contentText != "" { + raw, _ := json.Marshal(contentText) + msg.Content = raw + } + + finishReason := responsesStatusToChatFinishReason(resp.Status, resp.IncompleteDetails, toolCalls) + + out.Choices = []ChatChoice{{ + Index: 0, + Message: msg, + FinishReason: finishReason, + }} + + if resp.Usage != nil { + usage := &ChatUsage{ + PromptTokens: resp.Usage.InputTokens, + CompletionTokens: resp.Usage.OutputTokens, + TotalTokens: resp.Usage.InputTokens + resp.Usage.OutputTokens, + } + if resp.Usage.InputTokensDetails != nil && resp.Usage.InputTokensDetails.CachedTokens > 0 { + usage.PromptTokensDetails = &ChatTokenDetails{ + CachedTokens: resp.Usage.InputTokensDetails.CachedTokens, + } + } + out.Usage = usage + } + + return out +} + +func responsesStatusToChatFinishReason(status string, details *ResponsesIncompleteDetails, toolCalls []ChatToolCall) string { + switch status { + case "incomplete": + if details != nil && details.Reason == "max_output_tokens" { + return "length" + } + return "stop" + case "completed": + if len(toolCalls) > 0 { + return "tool_calls" + } + return "stop" + default: + return "stop" + } +} + +// --------------------------------------------------------------------------- +// Streaming: ResponsesStreamEvent → []ChatCompletionsChunk (stateful converter) +// --------------------------------------------------------------------------- + +// ResponsesEventToChatState tracks state for converting a sequence of Responses +// SSE events into Chat Completions SSE chunks. +type ResponsesEventToChatState struct { + ID string + Model string + Created int64 + SentRole bool + SawToolCall bool + SawText bool + Finalized bool // true after finish chunk has been emitted + NextToolCallIndex int // next sequential tool_call index to assign + OutputIndexToToolIndex map[int]int // Responses output_index → Chat tool_calls index + IncludeUsage bool + Usage *ChatUsage +} + +// NewResponsesEventToChatState returns an initialised stream state. +func NewResponsesEventToChatState() *ResponsesEventToChatState { + return &ResponsesEventToChatState{ + ID: generateChatCmplID(), + Created: time.Now().Unix(), + OutputIndexToToolIndex: make(map[int]int), + } +} + +// ResponsesEventToChatChunks converts a single Responses SSE event into zero +// or more Chat Completions chunks, updating state as it goes. +func ResponsesEventToChatChunks(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk { + switch evt.Type { + case "response.created": + return resToChatHandleCreated(evt, state) + case "response.output_text.delta": + return resToChatHandleTextDelta(evt, state) + case "response.output_item.added": + return resToChatHandleOutputItemAdded(evt, state) + case "response.function_call_arguments.delta": + return resToChatHandleFuncArgsDelta(evt, state) + case "response.reasoning_summary_text.delta": + return resToChatHandleReasoningDelta(evt, state) + case "response.completed", "response.incomplete", "response.failed": + return resToChatHandleCompleted(evt, state) + default: + return nil + } +} + +// FinalizeResponsesChatStream emits a final chunk with finish_reason if the +// stream ended without a proper completion event (e.g. upstream disconnect). +// It is idempotent: if a completion event already emitted the finish chunk, +// this returns nil. +func FinalizeResponsesChatStream(state *ResponsesEventToChatState) []ChatCompletionsChunk { + if state.Finalized { + return nil + } + state.Finalized = true + + finishReason := "stop" + if state.SawToolCall { + finishReason = "tool_calls" + } + + chunks := []ChatCompletionsChunk{makeChatFinishChunk(state, finishReason)} + + if state.IncludeUsage && state.Usage != nil { + chunks = append(chunks, ChatCompletionsChunk{ + ID: state.ID, + Object: "chat.completion.chunk", + Created: state.Created, + Model: state.Model, + Choices: []ChatChunkChoice{}, + Usage: state.Usage, + }) + } + + return chunks +} + +// ChatChunkToSSE formats a ChatCompletionsChunk as an SSE data line. +func ChatChunkToSSE(chunk ChatCompletionsChunk) (string, error) { + data, err := json.Marshal(chunk) + if err != nil { + return "", err + } + return fmt.Sprintf("data: %s\n\n", data), nil +} + +// --- internal handlers --- + +func resToChatHandleCreated(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk { + if evt.Response != nil { + if evt.Response.ID != "" { + state.ID = evt.Response.ID + } + if state.Model == "" && evt.Response.Model != "" { + state.Model = evt.Response.Model + } + } + // Emit the role chunk. + if state.SentRole { + return nil + } + state.SentRole = true + + role := "assistant" + return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{Role: role})} +} + +func resToChatHandleTextDelta(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk { + if evt.Delta == "" { + return nil + } + state.SawText = true + content := evt.Delta + return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{Content: &content})} +} + +func resToChatHandleOutputItemAdded(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk { + if evt.Item == nil || evt.Item.Type != "function_call" { + return nil + } + + state.SawToolCall = true + idx := state.NextToolCallIndex + state.OutputIndexToToolIndex[evt.OutputIndex] = idx + state.NextToolCallIndex++ + + return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{ + ToolCalls: []ChatToolCall{{ + Index: &idx, + ID: evt.Item.CallID, + Type: "function", + Function: ChatFunctionCall{ + Name: evt.Item.Name, + }, + }}, + })} +} + +func resToChatHandleFuncArgsDelta(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk { + if evt.Delta == "" { + return nil + } + + idx, ok := state.OutputIndexToToolIndex[evt.OutputIndex] + if !ok { + return nil + } + + return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{ + ToolCalls: []ChatToolCall{{ + Index: &idx, + Function: ChatFunctionCall{ + Arguments: evt.Delta, + }, + }}, + })} +} + +func resToChatHandleReasoningDelta(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk { + if evt.Delta == "" { + return nil + } + content := evt.Delta + return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{Content: &content})} +} + +func resToChatHandleCompleted(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk { + state.Finalized = true + finishReason := "stop" + + if evt.Response != nil { + if evt.Response.Usage != nil { + u := evt.Response.Usage + usage := &ChatUsage{ + PromptTokens: u.InputTokens, + CompletionTokens: u.OutputTokens, + TotalTokens: u.InputTokens + u.OutputTokens, + } + if u.InputTokensDetails != nil && u.InputTokensDetails.CachedTokens > 0 { + usage.PromptTokensDetails = &ChatTokenDetails{ + CachedTokens: u.InputTokensDetails.CachedTokens, + } + } + state.Usage = usage + } + + switch evt.Response.Status { + case "incomplete": + if evt.Response.IncompleteDetails != nil && evt.Response.IncompleteDetails.Reason == "max_output_tokens" { + finishReason = "length" + } + case "completed": + if state.SawToolCall { + finishReason = "tool_calls" + } + } + } else if state.SawToolCall { + finishReason = "tool_calls" + } + + var chunks []ChatCompletionsChunk + chunks = append(chunks, makeChatFinishChunk(state, finishReason)) + + if state.IncludeUsage && state.Usage != nil { + chunks = append(chunks, ChatCompletionsChunk{ + ID: state.ID, + Object: "chat.completion.chunk", + Created: state.Created, + Model: state.Model, + Choices: []ChatChunkChoice{}, + Usage: state.Usage, + }) + } + + return chunks +} + +func makeChatDeltaChunk(state *ResponsesEventToChatState, delta ChatDelta) ChatCompletionsChunk { + return ChatCompletionsChunk{ + ID: state.ID, + Object: "chat.completion.chunk", + Created: state.Created, + Model: state.Model, + Choices: []ChatChunkChoice{{ + Index: 0, + Delta: delta, + FinishReason: nil, + }}, + } +} + +func makeChatFinishChunk(state *ResponsesEventToChatState, finishReason string) ChatCompletionsChunk { + empty := "" + return ChatCompletionsChunk{ + ID: state.ID, + Object: "chat.completion.chunk", + Created: state.Created, + Model: state.Model, + Choices: []ChatChunkChoice{{ + Index: 0, + Delta: ChatDelta{Content: &empty}, + FinishReason: &finishReason, + }}, + } +} + +// generateChatCmplID returns a "chatcmpl-" prefixed random hex ID. +func generateChatCmplID() string { + b := make([]byte, 12) + _, _ = rand.Read(b) + return "chatcmpl-" + hex.EncodeToString(b) +} diff --git a/backend/internal/pkg/apicompat/types.go b/backend/internal/pkg/apicompat/types.go new file mode 100644 index 00000000..eb77d89f --- /dev/null +++ b/backend/internal/pkg/apicompat/types.go @@ -0,0 +1,480 @@ +// Package apicompat provides type definitions and conversion utilities for +// translating between Anthropic Messages and OpenAI Responses API formats. +// It enables multi-protocol support so that clients using different API +// formats can be served through a unified gateway. +package apicompat + +import "encoding/json" + +// --------------------------------------------------------------------------- +// Anthropic Messages API types +// --------------------------------------------------------------------------- + +// AnthropicRequest is the request body for POST /v1/messages. +type AnthropicRequest struct { + Model string `json:"model"` + MaxTokens int `json:"max_tokens"` + System json.RawMessage `json:"system,omitempty"` // string or []AnthropicContentBlock + Messages []AnthropicMessage `json:"messages"` + Tools []AnthropicTool `json:"tools,omitempty"` + Stream bool `json:"stream,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + StopSeqs []string `json:"stop_sequences,omitempty"` + Thinking *AnthropicThinking `json:"thinking,omitempty"` + ToolChoice json.RawMessage `json:"tool_choice,omitempty"` + OutputConfig *AnthropicOutputConfig `json:"output_config,omitempty"` +} + +// AnthropicOutputConfig controls output generation parameters. +type AnthropicOutputConfig struct { + Effort string `json:"effort,omitempty"` // "low" | "medium" | "high" +} + +// AnthropicThinking configures extended thinking in the Anthropic API. +type AnthropicThinking struct { + Type string `json:"type"` // "enabled" | "adaptive" | "disabled" + BudgetTokens int `json:"budget_tokens,omitempty"` // max thinking tokens +} + +// AnthropicMessage is a single message in the Anthropic conversation. +type AnthropicMessage struct { + Role string `json:"role"` // "user" | "assistant" + Content json.RawMessage `json:"content"` +} + +// AnthropicContentBlock is one block inside a message's content array. +type AnthropicContentBlock struct { + Type string `json:"type"` + + // type=text + Text string `json:"text,omitempty"` + + // type=thinking + Thinking string `json:"thinking,omitempty"` + + // type=image + Source *AnthropicImageSource `json:"source,omitempty"` + + // type=tool_use + ID string `json:"id,omitempty"` + Name string `json:"name,omitempty"` + Input json.RawMessage `json:"input,omitempty"` + + // type=tool_result + ToolUseID string `json:"tool_use_id,omitempty"` + Content json.RawMessage `json:"content,omitempty"` // string or []AnthropicContentBlock + IsError bool `json:"is_error,omitempty"` +} + +// AnthropicImageSource describes the source data for an image content block. +type AnthropicImageSource struct { + Type string `json:"type"` // "base64" + MediaType string `json:"media_type"` + Data string `json:"data"` +} + +// AnthropicTool describes a tool available to the model. +type AnthropicTool struct { + Type string `json:"type,omitempty"` // e.g. "web_search_20250305" for server tools + Name string `json:"name"` + Description string `json:"description,omitempty"` + InputSchema json.RawMessage `json:"input_schema"` // JSON Schema object +} + +// AnthropicResponse is the non-streaming response from POST /v1/messages. +type AnthropicResponse struct { + ID string `json:"id"` + Type string `json:"type"` // "message" + Role string `json:"role"` // "assistant" + Content []AnthropicContentBlock `json:"content"` + Model string `json:"model"` + StopReason string `json:"stop_reason"` + StopSequence *string `json:"stop_sequence,omitempty"` + Usage AnthropicUsage `json:"usage"` +} + +// AnthropicUsage holds token counts in Anthropic format. +type AnthropicUsage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + CacheCreationInputTokens int `json:"cache_creation_input_tokens"` + CacheReadInputTokens int `json:"cache_read_input_tokens"` +} + +// --------------------------------------------------------------------------- +// Anthropic SSE event types +// --------------------------------------------------------------------------- + +// AnthropicStreamEvent is a single SSE event in the Anthropic streaming protocol. +type AnthropicStreamEvent struct { + Type string `json:"type"` + + // message_start + Message *AnthropicResponse `json:"message,omitempty"` + + // content_block_start + Index *int `json:"index,omitempty"` + ContentBlock *AnthropicContentBlock `json:"content_block,omitempty"` + + // content_block_delta + Delta *AnthropicDelta `json:"delta,omitempty"` + + // message_delta + Usage *AnthropicUsage `json:"usage,omitempty"` +} + +// AnthropicDelta carries incremental content in streaming events. +type AnthropicDelta struct { + Type string `json:"type,omitempty"` // "text_delta" | "input_json_delta" | "thinking_delta" | "signature_delta" + + // text_delta + Text string `json:"text,omitempty"` + + // input_json_delta + PartialJSON string `json:"partial_json,omitempty"` + + // thinking_delta + Thinking string `json:"thinking,omitempty"` + + // signature_delta + Signature string `json:"signature,omitempty"` + + // message_delta fields + StopReason string `json:"stop_reason,omitempty"` + StopSequence *string `json:"stop_sequence,omitempty"` +} + +// --------------------------------------------------------------------------- +// OpenAI Responses API types +// --------------------------------------------------------------------------- + +// ResponsesRequest is the request body for POST /v1/responses. +type ResponsesRequest struct { + Model string `json:"model"` + Input json.RawMessage `json:"input"` // string or []ResponsesInputItem + MaxOutputTokens *int `json:"max_output_tokens,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + Stream bool `json:"stream,omitempty"` + Tools []ResponsesTool `json:"tools,omitempty"` + Include []string `json:"include,omitempty"` + Store *bool `json:"store,omitempty"` + Reasoning *ResponsesReasoning `json:"reasoning,omitempty"` + ToolChoice json.RawMessage `json:"tool_choice,omitempty"` + ServiceTier string `json:"service_tier,omitempty"` +} + +// ResponsesReasoning configures reasoning effort in the Responses API. +type ResponsesReasoning struct { + Effort string `json:"effort"` // "low" | "medium" | "high" + Summary string `json:"summary,omitempty"` // "auto" | "concise" | "detailed" +} + +// ResponsesInputItem is one item in the Responses API input array. +// The Type field determines which other fields are populated. +type ResponsesInputItem struct { + // Common + Type string `json:"type,omitempty"` // "" for role-based messages + + // Role-based messages (system/user/assistant) + Role string `json:"role,omitempty"` + Content json.RawMessage `json:"content,omitempty"` // string or []ResponsesContentPart + + // type=function_call + CallID string `json:"call_id,omitempty"` + Name string `json:"name,omitempty"` + Arguments string `json:"arguments,omitempty"` + ID string `json:"id,omitempty"` + + // type=function_call_output + Output string `json:"output,omitempty"` +} + +// ResponsesContentPart is a typed content part in a Responses message. +type ResponsesContentPart struct { + Type string `json:"type"` // "input_text" | "output_text" | "input_image" + Text string `json:"text,omitempty"` + ImageURL string `json:"image_url,omitempty"` // data URI for input_image +} + +// ResponsesTool describes a tool in the Responses API. +type ResponsesTool struct { + Type string `json:"type"` // "function" | "web_search" | "local_shell" etc. + Name string `json:"name,omitempty"` + Description string `json:"description,omitempty"` + Parameters json.RawMessage `json:"parameters,omitempty"` + Strict *bool `json:"strict,omitempty"` +} + +// ResponsesResponse is the non-streaming response from POST /v1/responses. +type ResponsesResponse struct { + ID string `json:"id"` + Object string `json:"object"` // "response" + Model string `json:"model"` + Status string `json:"status"` // "completed" | "incomplete" | "failed" + Output []ResponsesOutput `json:"output"` + Usage *ResponsesUsage `json:"usage,omitempty"` + + // incomplete_details is present when status="incomplete" + IncompleteDetails *ResponsesIncompleteDetails `json:"incomplete_details,omitempty"` + + // Error is present when status="failed" + Error *ResponsesError `json:"error,omitempty"` +} + +// ResponsesError describes an error in a failed response. +type ResponsesError struct { + Code string `json:"code"` + Message string `json:"message"` +} + +// ResponsesIncompleteDetails explains why a response is incomplete. +type ResponsesIncompleteDetails struct { + Reason string `json:"reason"` // "max_output_tokens" | "content_filter" +} + +// ResponsesOutput is one output item in a Responses API response. +type ResponsesOutput struct { + Type string `json:"type"` // "message" | "reasoning" | "function_call" | "web_search_call" + + // type=message + ID string `json:"id,omitempty"` + Role string `json:"role,omitempty"` + Content []ResponsesContentPart `json:"content,omitempty"` + Status string `json:"status,omitempty"` + + // type=reasoning + EncryptedContent string `json:"encrypted_content,omitempty"` + Summary []ResponsesSummary `json:"summary,omitempty"` + + // type=function_call + CallID string `json:"call_id,omitempty"` + Name string `json:"name,omitempty"` + Arguments string `json:"arguments,omitempty"` + + // type=web_search_call + Action *WebSearchAction `json:"action,omitempty"` +} + +// WebSearchAction describes the search action in a web_search_call output item. +type WebSearchAction struct { + Type string `json:"type,omitempty"` // "search" + Query string `json:"query,omitempty"` // primary search query +} + +// ResponsesSummary is a summary text block inside a reasoning output. +type ResponsesSummary struct { + Type string `json:"type"` // "summary_text" + Text string `json:"text"` +} + +// ResponsesUsage holds token counts in Responses API format. +type ResponsesUsage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + TotalTokens int `json:"total_tokens"` + + // Optional detailed breakdown + InputTokensDetails *ResponsesInputTokensDetails `json:"input_tokens_details,omitempty"` + OutputTokensDetails *ResponsesOutputTokensDetails `json:"output_tokens_details,omitempty"` +} + +// ResponsesInputTokensDetails breaks down input token usage. +type ResponsesInputTokensDetails struct { + CachedTokens int `json:"cached_tokens,omitempty"` +} + +// ResponsesOutputTokensDetails breaks down output token usage. +type ResponsesOutputTokensDetails struct { + ReasoningTokens int `json:"reasoning_tokens,omitempty"` +} + +// --------------------------------------------------------------------------- +// Responses SSE event types +// --------------------------------------------------------------------------- + +// ResponsesStreamEvent is a single SSE event in the Responses streaming protocol. +// The Type field corresponds to the "type" in the JSON payload. +type ResponsesStreamEvent struct { + Type string `json:"type"` + + // response.created / response.completed / response.failed / response.incomplete + Response *ResponsesResponse `json:"response,omitempty"` + + // response.output_item.added / response.output_item.done + Item *ResponsesOutput `json:"item,omitempty"` + + // response.output_text.delta / response.output_text.done + OutputIndex int `json:"output_index,omitempty"` + ContentIndex int `json:"content_index,omitempty"` + Delta string `json:"delta,omitempty"` + Text string `json:"text,omitempty"` + ItemID string `json:"item_id,omitempty"` + + // response.function_call_arguments.delta / done + CallID string `json:"call_id,omitempty"` + Name string `json:"name,omitempty"` + Arguments string `json:"arguments,omitempty"` + + // response.reasoning_summary_text.delta / done + // Reuses Text/Delta fields above, SummaryIndex identifies which summary part + SummaryIndex int `json:"summary_index,omitempty"` + + // error event fields + Code string `json:"code,omitempty"` + Param string `json:"param,omitempty"` + + // Sequence number for ordering events + SequenceNumber int `json:"sequence_number,omitempty"` +} + +// --------------------------------------------------------------------------- +// OpenAI Chat Completions API types +// --------------------------------------------------------------------------- + +// ChatCompletionsRequest is the request body for POST /v1/chat/completions. +type ChatCompletionsRequest struct { + Model string `json:"model"` + Messages []ChatMessage `json:"messages"` + MaxTokens *int `json:"max_tokens,omitempty"` + MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + Stream bool `json:"stream,omitempty"` + StreamOptions *ChatStreamOptions `json:"stream_options,omitempty"` + Tools []ChatTool `json:"tools,omitempty"` + ToolChoice json.RawMessage `json:"tool_choice,omitempty"` + ReasoningEffort string `json:"reasoning_effort,omitempty"` // "low" | "medium" | "high" + ServiceTier string `json:"service_tier,omitempty"` + Stop json.RawMessage `json:"stop,omitempty"` // string or []string + + // Legacy function calling (deprecated but still supported) + Functions []ChatFunction `json:"functions,omitempty"` + FunctionCall json.RawMessage `json:"function_call,omitempty"` +} + +// ChatStreamOptions configures streaming behavior. +type ChatStreamOptions struct { + IncludeUsage bool `json:"include_usage,omitempty"` +} + +// ChatMessage is a single message in the Chat Completions conversation. +type ChatMessage struct { + Role string `json:"role"` // "system" | "user" | "assistant" | "tool" | "function" + Content json.RawMessage `json:"content,omitempty"` + Name string `json:"name,omitempty"` + ToolCalls []ChatToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` + + // Legacy function calling + FunctionCall *ChatFunctionCall `json:"function_call,omitempty"` +} + +// ChatContentPart is a typed content part in a multi-modal message. +type ChatContentPart struct { + Type string `json:"type"` // "text" | "image_url" + Text string `json:"text,omitempty"` + ImageURL *ChatImageURL `json:"image_url,omitempty"` +} + +// ChatImageURL contains the URL for an image content part. +type ChatImageURL struct { + URL string `json:"url"` + Detail string `json:"detail,omitempty"` // "auto" | "low" | "high" +} + +// ChatTool describes a tool available to the model. +type ChatTool struct { + Type string `json:"type"` // "function" + Function *ChatFunction `json:"function,omitempty"` +} + +// ChatFunction describes a function tool definition. +type ChatFunction struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + Parameters json.RawMessage `json:"parameters,omitempty"` + Strict *bool `json:"strict,omitempty"` +} + +// ChatToolCall represents a tool call made by the assistant. +// Index is only populated in streaming chunks (omitted in non-streaming responses). +type ChatToolCall struct { + Index *int `json:"index,omitempty"` + ID string `json:"id,omitempty"` + Type string `json:"type,omitempty"` // "function" + Function ChatFunctionCall `json:"function"` +} + +// ChatFunctionCall contains the function name and arguments. +type ChatFunctionCall struct { + Name string `json:"name"` + Arguments string `json:"arguments"` +} + +// ChatCompletionsResponse is the non-streaming response from POST /v1/chat/completions. +type ChatCompletionsResponse struct { + ID string `json:"id"` + Object string `json:"object"` // "chat.completion" + Created int64 `json:"created"` + Model string `json:"model"` + Choices []ChatChoice `json:"choices"` + Usage *ChatUsage `json:"usage,omitempty"` + SystemFingerprint string `json:"system_fingerprint,omitempty"` + ServiceTier string `json:"service_tier,omitempty"` +} + +// ChatChoice is a single completion choice. +type ChatChoice struct { + Index int `json:"index"` + Message ChatMessage `json:"message"` + FinishReason string `json:"finish_reason"` // "stop" | "length" | "tool_calls" | "content_filter" +} + +// ChatUsage holds token counts in Chat Completions format. +type ChatUsage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + PromptTokensDetails *ChatTokenDetails `json:"prompt_tokens_details,omitempty"` +} + +// ChatTokenDetails provides a breakdown of token usage. +type ChatTokenDetails struct { + CachedTokens int `json:"cached_tokens,omitempty"` +} + +// ChatCompletionsChunk is a single streaming chunk from POST /v1/chat/completions. +type ChatCompletionsChunk struct { + ID string `json:"id"` + Object string `json:"object"` // "chat.completion.chunk" + Created int64 `json:"created"` + Model string `json:"model"` + Choices []ChatChunkChoice `json:"choices"` + Usage *ChatUsage `json:"usage,omitempty"` + SystemFingerprint string `json:"system_fingerprint,omitempty"` + ServiceTier string `json:"service_tier,omitempty"` +} + +// ChatChunkChoice is a single choice in a streaming chunk. +type ChatChunkChoice struct { + Index int `json:"index"` + Delta ChatDelta `json:"delta"` + FinishReason *string `json:"finish_reason"` // pointer: null when not final +} + +// ChatDelta carries incremental content in a streaming chunk. +type ChatDelta struct { + Role string `json:"role,omitempty"` + Content *string `json:"content,omitempty"` // pointer: omit when not present, null vs "" matters + ToolCalls []ChatToolCall `json:"tool_calls,omitempty"` +} + +// --------------------------------------------------------------------------- +// Shared constants +// --------------------------------------------------------------------------- + +// minMaxOutputTokens is the floor for max_output_tokens in a Responses request. +// Very small values may cause upstream API errors, so we enforce a minimum. +const minMaxOutputTokens = 128 diff --git a/backend/internal/pkg/claude/constants.go b/backend/internal/pkg/claude/constants.go index 22405382..dfca252f 100644 --- a/backend/internal/pkg/claude/constants.go +++ b/backend/internal/pkg/claude/constants.go @@ -16,7 +16,7 @@ const ( // DroppedBetas 是转发时需要从 anthropic-beta header 中移除的 beta token 列表。 // 这些 token 是客户端特有的,不应透传给上游 API。 -var DroppedBetas = []string{BetaContext1M, BetaFastMode} +var DroppedBetas = []string{} // DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming diff --git a/backend/internal/pkg/gemini/models.go b/backend/internal/pkg/gemini/models.go index c300b17d..882d2ebd 100644 --- a/backend/internal/pkg/gemini/models.go +++ b/backend/internal/pkg/gemini/models.go @@ -18,10 +18,12 @@ func DefaultModels() []Model { return []Model{ {Name: "models/gemini-2.0-flash", SupportedGenerationMethods: methods}, {Name: "models/gemini-2.5-flash", SupportedGenerationMethods: methods}, + {Name: "models/gemini-2.5-flash-image", SupportedGenerationMethods: methods}, {Name: "models/gemini-2.5-pro", SupportedGenerationMethods: methods}, {Name: "models/gemini-3-flash-preview", SupportedGenerationMethods: methods}, {Name: "models/gemini-3-pro-preview", SupportedGenerationMethods: methods}, {Name: "models/gemini-3.1-pro-preview", SupportedGenerationMethods: methods}, + {Name: "models/gemini-3.1-flash-image", SupportedGenerationMethods: methods}, } } diff --git a/backend/internal/pkg/gemini/models_test.go b/backend/internal/pkg/gemini/models_test.go new file mode 100644 index 00000000..b80047fb --- /dev/null +++ b/backend/internal/pkg/gemini/models_test.go @@ -0,0 +1,28 @@ +package gemini + +import "testing" + +func TestDefaultModels_ContainsImageModels(t *testing.T) { + t.Parallel() + + models := DefaultModels() + byName := make(map[string]Model, len(models)) + for _, model := range models { + byName[model.Name] = model + } + + required := []string{ + "models/gemini-2.5-flash-image", + "models/gemini-3.1-flash-image", + } + + for _, name := range required { + model, ok := byName[name] + if !ok { + t.Fatalf("expected fallback model %q to exist", name) + } + if len(model.SupportedGenerationMethods) == 0 { + t.Fatalf("expected fallback model %q to advertise generation methods", name) + } + } +} diff --git a/backend/internal/pkg/geminicli/models.go b/backend/internal/pkg/geminicli/models.go index 1fc4d983..195fb06f 100644 --- a/backend/internal/pkg/geminicli/models.go +++ b/backend/internal/pkg/geminicli/models.go @@ -13,10 +13,12 @@ type Model struct { var DefaultModels = []Model{ {ID: "gemini-2.0-flash", Type: "model", DisplayName: "Gemini 2.0 Flash", CreatedAt: ""}, {ID: "gemini-2.5-flash", Type: "model", DisplayName: "Gemini 2.5 Flash", CreatedAt: ""}, + {ID: "gemini-2.5-flash-image", Type: "model", DisplayName: "Gemini 2.5 Flash Image", CreatedAt: ""}, {ID: "gemini-2.5-pro", Type: "model", DisplayName: "Gemini 2.5 Pro", CreatedAt: ""}, {ID: "gemini-3-flash-preview", Type: "model", DisplayName: "Gemini 3 Flash Preview", CreatedAt: ""}, {ID: "gemini-3-pro-preview", Type: "model", DisplayName: "Gemini 3 Pro Preview", CreatedAt: ""}, {ID: "gemini-3.1-pro-preview", Type: "model", DisplayName: "Gemini 3.1 Pro Preview", CreatedAt: ""}, + {ID: "gemini-3.1-flash-image", Type: "model", DisplayName: "Gemini 3.1 Flash Image", CreatedAt: ""}, } // DefaultTestModel is the default model to preselect in test flows. diff --git a/backend/internal/pkg/geminicli/models_test.go b/backend/internal/pkg/geminicli/models_test.go new file mode 100644 index 00000000..c1884e2e --- /dev/null +++ b/backend/internal/pkg/geminicli/models_test.go @@ -0,0 +1,23 @@ +package geminicli + +import "testing" + +func TestDefaultModels_ContainsImageModels(t *testing.T) { + t.Parallel() + + byID := make(map[string]Model, len(DefaultModels)) + for _, model := range DefaultModels { + byID[model.ID] = model + } + + required := []string{ + "gemini-2.5-flash-image", + "gemini-3.1-flash-image", + } + + for _, id := range required { + if _, ok := byID[id]; !ok { + t.Fatalf("expected curated Gemini model %q to exist", id) + } + } +} diff --git a/backend/internal/pkg/openai/constants.go b/backend/internal/pkg/openai/constants.go index 4bbc68e7..b0a31a5f 100644 --- a/backend/internal/pkg/openai/constants.go +++ b/backend/internal/pkg/openai/constants.go @@ -15,6 +15,7 @@ type Model struct { // DefaultModels OpenAI models list var DefaultModels = []Model{ + {ID: "gpt-5.4", Object: "model", Created: 1738368000, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.4"}, {ID: "gpt-5.3-codex", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3 Codex"}, {ID: "gpt-5.3-codex-spark", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3 Codex Spark"}, {ID: "gpt-5.2", Object: "model", Created: 1733875200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2"}, diff --git a/backend/internal/pkg/openai/oauth.go b/backend/internal/pkg/openai/oauth.go index 8bdcbe16..a35a5ea6 100644 --- a/backend/internal/pkg/openai/oauth.go +++ b/backend/internal/pkg/openai/oauth.go @@ -268,6 +268,7 @@ type IDTokenClaims struct { type OpenAIAuthClaims struct { ChatGPTAccountID string `json:"chatgpt_account_id"` ChatGPTUserID string `json:"chatgpt_user_id"` + ChatGPTPlanType string `json:"chatgpt_plan_type"` UserID string `json:"user_id"` Organizations []OrganizationClaim `json:"organizations"` } @@ -325,12 +326,9 @@ func (r *RefreshTokenRequest) ToFormData() string { return params.Encode() } -// ParseIDToken parses the ID Token JWT and extracts claims. -// 注意:当前仅解码 payload 并校验 exp,未验证 JWT 签名。 -// 生产环境如需用 ID Token 做授权决策,应通过 OpenAI 的 JWKS 端点验证签名: -// -// https://auth.openai.com/.well-known/jwks.json -func ParseIDToken(idToken string) (*IDTokenClaims, error) { +// DecodeIDToken decodes the ID Token JWT payload without validating expiration. +// Use this for best-effort extraction (e.g., during data import) where the token may be expired. +func DecodeIDToken(idToken string) (*IDTokenClaims, error) { parts := strings.Split(idToken, ".") if len(parts) != 3 { return nil, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts)) @@ -360,6 +358,20 @@ func ParseIDToken(idToken string) (*IDTokenClaims, error) { return nil, fmt.Errorf("failed to parse JWT claims: %w", err) } + return &claims, nil +} + +// ParseIDToken parses the ID Token JWT and extracts claims. +// 注意:当前仅解码 payload 并校验 exp,未验证 JWT 签名。 +// 生产环境如需用 ID Token 做授权决策,应通过 OpenAI 的 JWKS 端点验证签名: +// +// https://auth.openai.com/.well-known/jwks.json +func ParseIDToken(idToken string) (*IDTokenClaims, error) { + claims, err := DecodeIDToken(idToken) + if err != nil { + return nil, err + } + // 校验 ID Token 是否已过期(允许 2 分钟时钟偏差,防止因服务器时钟略有差异误判刚颁发的令牌) const clockSkewTolerance = 120 // 秒 now := time.Now().Unix() @@ -367,7 +379,7 @@ func ParseIDToken(idToken string) (*IDTokenClaims, error) { return nil, fmt.Errorf("id_token has expired (exp: %d, now: %d, skew_tolerance: %ds)", claims.Exp, now, clockSkewTolerance) } - return &claims, nil + return claims, nil } // UserInfo represents user information extracted from ID Token claims. @@ -375,6 +387,7 @@ type UserInfo struct { Email string ChatGPTAccountID string ChatGPTUserID string + PlanType string UserID string OrganizationID string Organizations []OrganizationClaim @@ -389,6 +402,7 @@ func (c *IDTokenClaims) GetUserInfo() *UserInfo { if c.OpenAIAuth != nil { info.ChatGPTAccountID = c.OpenAIAuth.ChatGPTAccountID info.ChatGPTUserID = c.OpenAIAuth.ChatGPTUserID + info.PlanType = c.OpenAIAuth.ChatGPTPlanType info.UserID = c.OpenAIAuth.UserID info.Organizations = c.OpenAIAuth.Organizations diff --git a/backend/internal/pkg/openai/request.go b/backend/internal/pkg/openai/request.go index c24d1273..dd8fe566 100644 --- a/backend/internal/pkg/openai/request.go +++ b/backend/internal/pkg/openai/request.go @@ -58,6 +58,12 @@ func IsCodexOfficialClientOriginator(originator string) bool { return matchCodexClientHeaderPrefixes(v, CodexOfficialClientOriginatorPrefixes) } +// IsCodexOfficialClientByHeaders checks whether the request headers indicate an +// official Codex client family request. +func IsCodexOfficialClientByHeaders(userAgent, originator string) bool { + return IsCodexOfficialClientRequest(userAgent) || IsCodexOfficialClientOriginator(originator) +} + func normalizeCodexClientHeader(value string) string { return strings.ToLower(strings.TrimSpace(value)) } diff --git a/backend/internal/pkg/openai/request_test.go b/backend/internal/pkg/openai/request_test.go index 508bf561..b4562a07 100644 --- a/backend/internal/pkg/openai/request_test.go +++ b/backend/internal/pkg/openai/request_test.go @@ -85,3 +85,26 @@ func TestIsCodexOfficialClientOriginator(t *testing.T) { }) } } + +func TestIsCodexOfficialClientByHeaders(t *testing.T) { + tests := []struct { + name string + ua string + originator string + want bool + }{ + {name: "仅 originator 命中 desktop", originator: "Codex Desktop", want: true}, + {name: "仅 originator 命中 vscode", originator: "codex_vscode", want: true}, + {name: "仅 ua 命中 desktop", ua: "Codex Desktop/1.2.3", want: true}, + {name: "ua 与 originator 都未命中", ua: "curl/8.0.1", originator: "my_client", want: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := IsCodexOfficialClientByHeaders(tt.ua, tt.originator) + if got != tt.want { + t.Fatalf("IsCodexOfficialClientByHeaders(%q, %q) = %v, want %v", tt.ua, tt.originator, got, tt.want) + } + }) + } +} diff --git a/backend/internal/pkg/usagestats/usage_log_types.go b/backend/internal/pkg/usagestats/usage_log_types.go index 314a6d3c..8826c048 100644 --- a/backend/internal/pkg/usagestats/usage_log_types.go +++ b/backend/internal/pkg/usagestats/usage_log_types.go @@ -57,25 +57,28 @@ type DashboardStats struct { // TrendDataPoint represents a single point in trend data type TrendDataPoint struct { - Date string `json:"date"` - Requests int64 `json:"requests"` - InputTokens int64 `json:"input_tokens"` - OutputTokens int64 `json:"output_tokens"` - CacheTokens int64 `json:"cache_tokens"` - TotalTokens int64 `json:"total_tokens"` - Cost float64 `json:"cost"` // 标准计费 - ActualCost float64 `json:"actual_cost"` // 实际扣除 + Date string `json:"date"` + Requests int64 `json:"requests"` + InputTokens int64 `json:"input_tokens"` + OutputTokens int64 `json:"output_tokens"` + CacheCreationTokens int64 `json:"cache_creation_tokens"` + CacheReadTokens int64 `json:"cache_read_tokens"` + TotalTokens int64 `json:"total_tokens"` + Cost float64 `json:"cost"` // 标准计费 + ActualCost float64 `json:"actual_cost"` // 实际扣除 } // ModelStat represents usage statistics for a single model type ModelStat struct { - Model string `json:"model"` - Requests int64 `json:"requests"` - InputTokens int64 `json:"input_tokens"` - OutputTokens int64 `json:"output_tokens"` - TotalTokens int64 `json:"total_tokens"` - Cost float64 `json:"cost"` // 标准计费 - ActualCost float64 `json:"actual_cost"` // 实际扣除 + Model string `json:"model"` + Requests int64 `json:"requests"` + InputTokens int64 `json:"input_tokens"` + OutputTokens int64 `json:"output_tokens"` + CacheCreationTokens int64 `json:"cache_creation_tokens"` + CacheReadTokens int64 `json:"cache_read_tokens"` + TotalTokens int64 `json:"total_tokens"` + Cost float64 `json:"cost"` // 标准计费 + ActualCost float64 `json:"actual_cost"` // 实际扣除 } // GroupStat represents usage statistics for a single group @@ -154,6 +157,8 @@ type UsageLogFilters struct { BillingType *int8 StartTime *time.Time EndTime *time.Time + // ExactTotal requests exact COUNT(*) for pagination. Default false for fast large-table paging. + ExactTotal bool } // UsageStats represents usage statistics diff --git a/backend/internal/repository/account_repo.go b/backend/internal/repository/account_repo.go index 0669cbbd..a9cb2cba 100644 --- a/backend/internal/repository/account_repo.go +++ b/backend/internal/repository/account_repo.go @@ -16,6 +16,7 @@ import ( "encoding/json" "errors" "strconv" + "strings" "time" dbent "github.com/Wei-Shaw/sub2api/ent" @@ -50,6 +51,18 @@ type accountRepository struct { schedulerCache service.SchedulerCache } +var schedulerNeutralExtraKeyPrefixes = []string{ + "codex_primary_", + "codex_secondary_", + "codex_5h_", + "codex_7d_", +} + +var schedulerNeutralExtraKeys = map[string]struct{}{ + "codex_usage_updated_at": {}, + "session_window_utilization": {}, +} + // NewAccountRepository 创建账户仓储实例。 // 这是对外暴露的构造函数,返回接口类型以便于依赖注入。 func NewAccountRepository(client *dbent.Client, sqlDB *sql.DB, schedulerCache service.SchedulerCache) service.AccountRepository { @@ -84,6 +97,9 @@ func (r *accountRepository) Create(ctx context.Context, account *service.Account if account.RateMultiplier != nil { builder.SetRateMultiplier(*account.RateMultiplier) } + if account.LoadFactor != nil { + builder.SetLoadFactor(*account.LoadFactor) + } if account.ProxyID != nil { builder.SetProxyID(*account.ProxyID) @@ -318,6 +334,11 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account if account.RateMultiplier != nil { builder.SetRateMultiplier(*account.RateMultiplier) } + if account.LoadFactor != nil { + builder.SetLoadFactor(*account.LoadFactor) + } else { + builder.ClearLoadFactor() + } if account.ProxyID != nil { builder.SetProxyID(*account.ProxyID) @@ -437,6 +458,14 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati switch status { case "rate_limited": q = q.Where(dbaccount.RateLimitResetAtGT(time.Now())) + case "temp_unschedulable": + q = q.Where(dbpredicate.Account(func(s *entsql.Selector) { + col := s.C("temp_unschedulable_until") + s.Where(entsql.And( + entsql.Not(entsql.IsNull(col)), + entsql.GT(col, entsql.Expr("NOW()")), + )) + })) default: q = q.Where(dbaccount.StatusEQ(status)) } @@ -640,7 +669,14 @@ func (r *accountRepository) ClearError(ctx context.Context, id int64) error { SetStatus(service.StatusActive). SetErrorMessage(""). Save(ctx) - return err + if err != nil { + return err + } + if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue clear error failed: account=%d err=%v", id, err) + } + r.syncSchedulerAccountSnapshot(ctx, id) + return nil } func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID int64, priority int) error { @@ -899,6 +935,7 @@ func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetA if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue rate limit failed: account=%d err=%v", id, err) } + r.syncSchedulerAccountSnapshot(ctx, id) return nil } @@ -1014,6 +1051,7 @@ func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue clear rate limit failed: account=%d err=%v", id, err) } + r.syncSchedulerAccountSnapshot(ctx, id) return nil } @@ -1160,12 +1198,48 @@ func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates m if affected == 0 { return service.ErrAccountNotFound } - if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { - logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue extra update failed: account=%d err=%v", id, err) + if shouldEnqueueSchedulerOutboxForExtraUpdates(updates) { + if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue extra update failed: account=%d err=%v", id, err) + } + } else { + // 观测型 extra 字段不需要触发 bucket 重建,但仍同步单账号快照, + // 让 sticky session / GetAccount 命中缓存时也能读到最新数据, + // 同时避免缓存局部 patch 覆盖掉并发写入的其它账号字段。 + r.syncSchedulerAccountSnapshot(ctx, id) } return nil } +func shouldEnqueueSchedulerOutboxForExtraUpdates(updates map[string]any) bool { + if len(updates) == 0 { + return false + } + for key := range updates { + if isSchedulerNeutralExtraKey(key) { + continue + } + return true + } + return false +} + +func isSchedulerNeutralExtraKey(key string) bool { + key = strings.TrimSpace(key) + if key == "" { + return false + } + if _, ok := schedulerNeutralExtraKeys[key]; ok { + return true + } + for _, prefix := range schedulerNeutralExtraKeyPrefixes { + if strings.HasPrefix(key, prefix) { + return true + } + } + return false +} + func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) { if len(ids) == 0 { return 0, nil @@ -1205,6 +1279,15 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates args = append(args, *updates.RateMultiplier) idx++ } + if updates.LoadFactor != nil { + if *updates.LoadFactor <= 0 { + setClauses = append(setClauses, "load_factor = NULL") + } else { + setClauses = append(setClauses, "load_factor = $"+itoa(idx)) + args = append(args, *updates.LoadFactor) + idx++ + } + } if updates.Status != nil { setClauses = append(setClauses, "status = $"+itoa(idx)) args = append(args, *updates.Status) @@ -1527,6 +1610,7 @@ func accountEntityToService(m *dbent.Account) *service.Account { Concurrency: m.Concurrency, Priority: m.Priority, RateMultiplier: &rateMultiplier, + LoadFactor: m.LoadFactor, Status: m.Status, ErrorMessage: derefString(m.ErrorMessage), LastUsedAt: m.LastUsedAt, @@ -1639,3 +1723,93 @@ func (r *accountRepository) FindByExtraField(ctx context.Context, key string, va return r.accountsToService(ctx, accounts) } + +// nowUTC is a SQL expression to generate a UTC RFC3339 timestamp string. +const nowUTC = `to_char(NOW() AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS.US"Z"')` + +// IncrementQuotaUsed 原子递增账号的配额用量(总/日/周三个维度) +// 日/周额度在周期过期时自动重置为 0 再递增。 +func (r *accountRepository) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error { + rows, err := r.sql.QueryContext(ctx, + `UPDATE accounts SET extra = ( + COALESCE(extra, '{}'::jsonb) + -- 总额度:始终递增 + || jsonb_build_object('quota_used', COALESCE((extra->>'quota_used')::numeric, 0) + $1) + -- 日额度:仅在 quota_daily_limit > 0 时处理 + || CASE WHEN COALESCE((extra->>'quota_daily_limit')::numeric, 0) > 0 THEN + jsonb_build_object( + 'quota_daily_used', + CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz) + + '24 hours'::interval <= NOW() + THEN $1 + ELSE COALESCE((extra->>'quota_daily_used')::numeric, 0) + $1 END, + 'quota_daily_start', + CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz) + + '24 hours'::interval <= NOW() + THEN `+nowUTC+` + ELSE COALESCE(extra->>'quota_daily_start', `+nowUTC+`) END + ) + ELSE '{}'::jsonb END + -- 周额度:仅在 quota_weekly_limit > 0 时处理 + || CASE WHEN COALESCE((extra->>'quota_weekly_limit')::numeric, 0) > 0 THEN + jsonb_build_object( + 'quota_weekly_used', + CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz) + + '168 hours'::interval <= NOW() + THEN $1 + ELSE COALESCE((extra->>'quota_weekly_used')::numeric, 0) + $1 END, + 'quota_weekly_start', + CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz) + + '168 hours'::interval <= NOW() + THEN `+nowUTC+` + ELSE COALESCE(extra->>'quota_weekly_start', `+nowUTC+`) END + ) + ELSE '{}'::jsonb END + ), updated_at = NOW() + WHERE id = $2 AND deleted_at IS NULL + RETURNING + COALESCE((extra->>'quota_used')::numeric, 0), + COALESCE((extra->>'quota_limit')::numeric, 0)`, + amount, id) + if err != nil { + return err + } + defer func() { _ = rows.Close() }() + + var newUsed, limit float64 + if rows.Next() { + if err := rows.Scan(&newUsed, &limit); err != nil { + return err + } + } + if err := rows.Err(); err != nil { + return err + } + + // 任一维度配额刚超限时触发调度快照刷新 + if limit > 0 && newUsed >= limit && (newUsed-amount) < limit { + if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue quota exceeded failed: account=%d err=%v", id, err) + } + } + return nil +} + +// ResetQuotaUsed 重置账号所有维度的配额用量为 0 +func (r *accountRepository) ResetQuotaUsed(ctx context.Context, id int64) error { + _, err := r.sql.ExecContext(ctx, + `UPDATE accounts SET extra = ( + COALESCE(extra, '{}'::jsonb) + || '{"quota_used": 0, "quota_daily_used": 0, "quota_weekly_used": 0}'::jsonb + ) - 'quota_daily_start' - 'quota_weekly_start', updated_at = NOW() + WHERE id = $1 AND deleted_at IS NULL`, + id) + if err != nil { + return err + } + // 重置配额后触发调度快照刷新,使账号重新参与调度 + if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue quota reset failed: account=%d err=%v", id, err) + } + return nil +} diff --git a/backend/internal/repository/account_repo_integration_test.go b/backend/internal/repository/account_repo_integration_test.go index fd48a5d4..29b699e6 100644 --- a/backend/internal/repository/account_repo_integration_test.go +++ b/backend/internal/repository/account_repo_integration_test.go @@ -23,6 +23,7 @@ type AccountRepoSuite struct { type schedulerCacheRecorder struct { setAccounts []*service.Account + accounts map[int64]*service.Account } func (s *schedulerCacheRecorder) GetSnapshot(ctx context.Context, bucket service.SchedulerBucket) ([]*service.Account, bool, error) { @@ -34,11 +35,20 @@ func (s *schedulerCacheRecorder) SetSnapshot(ctx context.Context, bucket service } func (s *schedulerCacheRecorder) GetAccount(ctx context.Context, accountID int64) (*service.Account, error) { - return nil, nil + if s.accounts == nil { + return nil, nil + } + return s.accounts[accountID], nil } func (s *schedulerCacheRecorder) SetAccount(ctx context.Context, account *service.Account) error { s.setAccounts = append(s.setAccounts, account) + if s.accounts == nil { + s.accounts = make(map[int64]*service.Account) + } + if account != nil { + s.accounts[account.ID] = account + } return nil } @@ -558,6 +568,26 @@ func (s *AccountRepoSuite) TestSetError() { s.Require().Equal("something went wrong", got.ErrorMessage) } +func (s *AccountRepoSuite) TestClearError_SyncSchedulerSnapshotOnRecovery() { + account := mustCreateAccount(s.T(), s.client, &service.Account{ + Name: "acc-clear-err", + Status: service.StatusError, + ErrorMessage: "temporary error", + }) + cacheRecorder := &schedulerCacheRecorder{} + s.repo.schedulerCache = cacheRecorder + + s.Require().NoError(s.repo.ClearError(s.ctx, account.ID)) + + got, err := s.repo.GetByID(s.ctx, account.ID) + s.Require().NoError(err) + s.Require().Equal(service.StatusActive, got.Status) + s.Require().Empty(got.ErrorMessage) + s.Require().Len(cacheRecorder.setAccounts, 1) + s.Require().Equal(account.ID, cacheRecorder.setAccounts[0].ID) + s.Require().Equal(service.StatusActive, cacheRecorder.setAccounts[0].Status) +} + // --- UpdateSessionWindow --- func (s *AccountRepoSuite) TestUpdateSessionWindow() { @@ -603,6 +633,96 @@ func (s *AccountRepoSuite) TestUpdateExtra_NilExtra() { s.Require().Equal("val", got.Extra["key"]) } +func (s *AccountRepoSuite) TestUpdateExtra_SchedulerNeutralSkipsOutboxAndSyncsFreshSnapshot() { + account := mustCreateAccount(s.T(), s.client, &service.Account{ + Name: "acc-extra-neutral", + Platform: service.PlatformOpenAI, + Extra: map[string]any{"codex_usage_updated_at": "old"}, + }) + cacheRecorder := &schedulerCacheRecorder{ + accounts: map[int64]*service.Account{ + account.ID: { + ID: account.ID, + Platform: account.Platform, + Status: service.StatusDisabled, + Extra: map[string]any{ + "codex_usage_updated_at": "old", + }, + }, + }, + } + s.repo.schedulerCache = cacheRecorder + + updates := map[string]any{ + "codex_usage_updated_at": "2026-03-11T10:00:00Z", + "codex_5h_used_percent": 88.5, + "session_window_utilization": 0.42, + } + s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, updates)) + + got, err := s.repo.GetByID(s.ctx, account.ID) + s.Require().NoError(err) + s.Require().Equal("2026-03-11T10:00:00Z", got.Extra["codex_usage_updated_at"]) + s.Require().Equal(88.5, got.Extra["codex_5h_used_percent"]) + s.Require().Equal(0.42, got.Extra["session_window_utilization"]) + + var outboxCount int + s.Require().NoError(scanSingleRow(s.ctx, s.repo.sql, "SELECT COUNT(*) FROM scheduler_outbox", nil, &outboxCount)) + s.Require().Zero(outboxCount) + s.Require().Len(cacheRecorder.setAccounts, 1) + s.Require().NotNil(cacheRecorder.accounts[account.ID]) + s.Require().Equal(service.StatusActive, cacheRecorder.accounts[account.ID].Status) + s.Require().Equal("2026-03-11T10:00:00Z", cacheRecorder.accounts[account.ID].Extra["codex_usage_updated_at"]) +} + +func (s *AccountRepoSuite) TestUpdateExtra_ExhaustedCodexSnapshotSyncsSchedulerCache() { + account := mustCreateAccount(s.T(), s.client, &service.Account{ + Name: "acc-extra-codex-exhausted", + Platform: service.PlatformOpenAI, + Type: service.AccountTypeOAuth, + Extra: map[string]any{}, + }) + cacheRecorder := &schedulerCacheRecorder{} + s.repo.schedulerCache = cacheRecorder + _, err := s.repo.sql.ExecContext(s.ctx, "TRUNCATE scheduler_outbox") + s.Require().NoError(err) + + s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{ + "codex_7d_used_percent": 100.0, + "codex_7d_reset_at": "2026-03-12T13:00:00Z", + "codex_7d_reset_after_seconds": 86400, + })) + + var count int + err = scanSingleRow(s.ctx, s.repo.sql, "SELECT COUNT(*) FROM scheduler_outbox", nil, &count) + s.Require().NoError(err) + s.Require().Equal(0, count) + s.Require().Len(cacheRecorder.setAccounts, 1) + s.Require().Equal(account.ID, cacheRecorder.setAccounts[0].ID) + s.Require().Equal(service.StatusActive, cacheRecorder.setAccounts[0].Status) + s.Require().Equal(100.0, cacheRecorder.setAccounts[0].Extra["codex_7d_used_percent"]) +} + +func (s *AccountRepoSuite) TestUpdateExtra_SchedulerRelevantStillEnqueuesOutbox() { + account := mustCreateAccount(s.T(), s.client, &service.Account{ + Name: "acc-extra-mixed", + Platform: service.PlatformAntigravity, + Extra: map[string]any{}, + }) + _, err := s.repo.sql.ExecContext(s.ctx, "TRUNCATE scheduler_outbox") + s.Require().NoError(err) + + s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{ + "mixed_scheduling": true, + "codex_usage_updated_at": "2026-03-11T10:00:00Z", + })) + + var count int + err = scanSingleRow(s.ctx, s.repo.sql, "SELECT COUNT(*) FROM scheduler_outbox", nil, &count) + s.Require().NoError(err) + s.Require().Equal(1, count) +} + // --- GetByCRSAccountID --- func (s *AccountRepoSuite) TestGetByCRSAccountID() { diff --git a/backend/internal/repository/announcement_repo.go b/backend/internal/repository/announcement_repo.go index 52029e4e..53dc335f 100644 --- a/backend/internal/repository/announcement_repo.go +++ b/backend/internal/repository/announcement_repo.go @@ -24,6 +24,7 @@ func (r *announcementRepository) Create(ctx context.Context, a *service.Announce SetTitle(a.Title). SetContent(a.Content). SetStatus(a.Status). + SetNotifyMode(a.NotifyMode). SetTargeting(a.Targeting) if a.StartsAt != nil { @@ -64,6 +65,7 @@ func (r *announcementRepository) Update(ctx context.Context, a *service.Announce SetTitle(a.Title). SetContent(a.Content). SetStatus(a.Status). + SetNotifyMode(a.NotifyMode). SetTargeting(a.Targeting) if a.StartsAt != nil { @@ -169,17 +171,18 @@ func announcementEntityToService(m *dbent.Announcement) *service.Announcement { return nil } return &service.Announcement{ - ID: m.ID, - Title: m.Title, - Content: m.Content, - Status: m.Status, - Targeting: m.Targeting, - StartsAt: m.StartsAt, - EndsAt: m.EndsAt, - CreatedBy: m.CreatedBy, - UpdatedBy: m.UpdatedBy, - CreatedAt: m.CreatedAt, - UpdatedAt: m.UpdatedAt, + ID: m.ID, + Title: m.Title, + Content: m.Content, + Status: m.Status, + NotifyMode: m.NotifyMode, + Targeting: m.Targeting, + StartsAt: m.StartsAt, + EndsAt: m.EndsAt, + CreatedBy: m.CreatedBy, + UpdatedBy: m.UpdatedBy, + CreatedAt: m.CreatedAt, + UpdatedAt: m.UpdatedAt, } } diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go index c761e8c9..a45a83a3 100644 --- a/backend/internal/repository/api_key_repo.go +++ b/backend/internal/repository/api_key_repo.go @@ -164,7 +164,10 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se group.FieldModelRoutingEnabled, group.FieldModelRouting, group.FieldMcpXMLInject, + group.FieldSimulateClaudeMaxEnabled, group.FieldSupportedModelScopes, + group.FieldAllowMessagesDispatch, + group.FieldDefaultMappedModel, ) }). Only(ctx) @@ -450,6 +453,32 @@ func (r *apiKeyRepository) IncrementQuotaUsed(ctx context.Context, id int64, amo return updated.QuotaUsed, nil } +// IncrementQuotaUsedAndGetState atomically increments quota_used, conditionally marks the key +// as quota_exhausted, and returns the latest quota state in one round trip. +func (r *apiKeyRepository) IncrementQuotaUsedAndGetState(ctx context.Context, id int64, amount float64) (*service.APIKeyQuotaUsageState, error) { + query := ` + UPDATE api_keys + SET + quota_used = quota_used + $1, + status = CASE + WHEN quota > 0 AND quota_used + $1 >= quota THEN $2 + ELSE status + END, + updated_at = NOW() + WHERE id = $3 AND deleted_at IS NULL + RETURNING quota_used, quota, key, status + ` + + state := &service.APIKeyQuotaUsageState{} + if err := scanSingleRow(ctx, r.sql, query, []any{amount, service.StatusAPIKeyQuotaExhausted, id}, &state.QuotaUsed, &state.Quota, &state.Key, &state.Status); err != nil { + if err == sql.ErrNoRows { + return nil, service.ErrAPIKeyNotFound + } + return nil, err + } + return state, nil +} + func (r *apiKeyRepository) UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error { affected, err := r.client.APIKey.Update(). Where(apikey.IDEQ(id), apikey.DeletedAtIsNil()). @@ -470,12 +499,12 @@ func (r *apiKeyRepository) UpdateLastUsed(ctx context.Context, id int64, usedAt func (r *apiKeyRepository) IncrementRateLimitUsage(ctx context.Context, id int64, cost float64) error { _, err := r.sql.ExecContext(ctx, ` UPDATE api_keys SET - usage_5h = usage_5h + $1, - usage_1d = usage_1d + $1, - usage_7d = usage_7d + $1, - window_5h_start = COALESCE(window_5h_start, NOW()), - window_1d_start = COALESCE(window_1d_start, NOW()), - window_7d_start = COALESCE(window_7d_start, NOW()), + usage_5h = CASE WHEN window_5h_start IS NOT NULL AND window_5h_start + INTERVAL '5 hours' <= NOW() THEN $1 ELSE usage_5h + $1 END, + usage_1d = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN $1 ELSE usage_1d + $1 END, + usage_7d = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN $1 ELSE usage_7d + $1 END, + window_5h_start = CASE WHEN window_5h_start IS NULL OR window_5h_start + INTERVAL '5 hours' <= NOW() THEN NOW() ELSE window_5h_start END, + window_1d_start = CASE WHEN window_1d_start IS NULL OR window_1d_start + INTERVAL '24 hours' <= NOW() THEN date_trunc('day', NOW()) ELSE window_1d_start END, + window_7d_start = CASE WHEN window_7d_start IS NULL OR window_7d_start + INTERVAL '7 days' <= NOW() THEN date_trunc('day', NOW()) ELSE window_7d_start END, updated_at = NOW() WHERE id = $2 AND deleted_at IS NULL`, cost, id) @@ -489,9 +518,9 @@ func (r *apiKeyRepository) ResetRateLimitWindows(ctx context.Context, id int64) usage_5h = CASE WHEN window_5h_start IS NOT NULL AND window_5h_start + INTERVAL '5 hours' <= NOW() THEN 0 ELSE usage_5h END, window_5h_start = CASE WHEN window_5h_start IS NOT NULL AND window_5h_start + INTERVAL '5 hours' <= NOW() THEN NOW() ELSE window_5h_start END, usage_1d = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN 0 ELSE usage_1d END, - window_1d_start = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN NOW() ELSE window_1d_start END, + window_1d_start = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN date_trunc('day', NOW()) ELSE window_1d_start END, usage_7d = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN 0 ELSE usage_7d END, - window_7d_start = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN NOW() ELSE window_7d_start END, + window_7d_start = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN date_trunc('day', NOW()) ELSE window_7d_start END, updated_at = NOW() WHERE id = $1 AND deleted_at IS NULL`, id) @@ -617,8 +646,11 @@ func groupEntityToService(g *dbent.Group) *service.Group { ModelRouting: g.ModelRouting, ModelRoutingEnabled: g.ModelRoutingEnabled, MCPXMLInject: g.McpXMLInject, + SimulateClaudeMaxEnabled: g.SimulateClaudeMaxEnabled, SupportedModelScopes: g.SupportedModelScopes, SortOrder: g.SortOrder, + AllowMessagesDispatch: g.AllowMessagesDispatch, + DefaultMappedModel: g.DefaultMappedModel, CreatedAt: g.CreatedAt, UpdatedAt: g.UpdatedAt, } diff --git a/backend/internal/repository/api_key_repo_integration_test.go b/backend/internal/repository/api_key_repo_integration_test.go index 80714614..a8989ff2 100644 --- a/backend/internal/repository/api_key_repo_integration_test.go +++ b/backend/internal/repository/api_key_repo_integration_test.go @@ -417,6 +417,27 @@ func (s *APIKeyRepoSuite) TestIncrementQuotaUsed_DeletedKey() { s.Require().ErrorIs(err, service.ErrAPIKeyNotFound, "已删除的 key 应返回 ErrAPIKeyNotFound") } +func (s *APIKeyRepoSuite) TestIncrementQuotaUsedAndGetState() { + user := s.mustCreateUser("quota-state@test.com") + key := s.mustCreateApiKey(user.ID, "sk-quota-state", "QuotaState", nil) + key.Quota = 3 + key.QuotaUsed = 1 + s.Require().NoError(s.repo.Update(s.ctx, key), "Update quota") + + state, err := s.repo.IncrementQuotaUsedAndGetState(s.ctx, key.ID, 2.5) + s.Require().NoError(err, "IncrementQuotaUsedAndGetState") + s.Require().NotNil(state) + s.Require().Equal(3.5, state.QuotaUsed) + s.Require().Equal(3.0, state.Quota) + s.Require().Equal(service.StatusAPIKeyQuotaExhausted, state.Status) + s.Require().Equal(key.Key, state.Key) + + got, err := s.repo.GetByID(s.ctx, key.ID) + s.Require().NoError(err, "GetByID") + s.Require().Equal(3.5, got.QuotaUsed) + s.Require().Equal(service.StatusAPIKeyQuotaExhausted, got.Status) +} + // TestIncrementQuotaUsed_Concurrent 使用真实数据库验证并发原子性。 // 注意:此测试使用 testEntClient(非事务隔离),数据会真正写入数据库。 func TestIncrementQuotaUsed_Concurrent(t *testing.T) { diff --git a/backend/internal/repository/claude_usage_service.go b/backend/internal/repository/claude_usage_service.go index f6054828..1264f6bb 100644 --- a/backend/internal/repository/claude_usage_service.go +++ b/backend/internal/repository/claude_usage_service.go @@ -8,6 +8,7 @@ import ( "net/http" "time" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" "github.com/Wei-Shaw/sub2api/internal/service" ) @@ -95,7 +96,8 @@ func (s *claudeUsageService) FetchUsageWithOptions(ctx context.Context, opts *se if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) - return nil, fmt.Errorf("API returned status %d: %s", resp.StatusCode, string(body)) + msg := fmt.Sprintf("API returned status %d: %s", resp.StatusCode, string(body)) + return nil, infraerrors.New(http.StatusInternalServerError, "UPSTREAM_ERROR", msg) } var usageResp service.ClaudeUsageResponse diff --git a/backend/internal/repository/concurrency_cache.go b/backend/internal/repository/concurrency_cache.go index a2552715..8732b2ce 100644 --- a/backend/internal/repository/concurrency_cache.go +++ b/backend/internal/repository/concurrency_cache.go @@ -147,17 +147,47 @@ var ( return 1 `) - // cleanupExpiredSlotsScript - remove expired slots - // KEYS[1] = concurrency:account:{accountID} - // ARGV[1] = TTL (seconds) + // cleanupExpiredSlotsScript 清理单个账号/用户有序集合中过期槽位 + // KEYS[1] = 有序集合键 + // ARGV[1] = TTL(秒) cleanupExpiredSlotsScript = redis.NewScript(` - local key = KEYS[1] - local ttl = tonumber(ARGV[1]) - local timeResult = redis.call('TIME') - local now = tonumber(timeResult[1]) - local expireBefore = now - ttl - return redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore) - `) + local key = KEYS[1] + local ttl = tonumber(ARGV[1]) + local timeResult = redis.call('TIME') + local now = tonumber(timeResult[1]) + local expireBefore = now - ttl + redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore) + if redis.call('ZCARD', key) == 0 then + redis.call('DEL', key) + else + redis.call('EXPIRE', key, ttl) + end + return 1 + `) + + // startupCleanupScript 清理非当前进程前缀的槽位成员。 + // KEYS 是有序集合键列表,ARGV[1] 是当前进程前缀,ARGV[2] 是槽位 TTL。 + // 遍历每个 KEYS[i],移除前缀不匹配的成员,清空后删 key,否则刷新 EXPIRE。 + startupCleanupScript = redis.NewScript(` + local activePrefix = ARGV[1] + local slotTTL = tonumber(ARGV[2]) + local removed = 0 + for i = 1, #KEYS do + local key = KEYS[i] + local members = redis.call('ZRANGE', key, 0, -1) + for _, member in ipairs(members) do + if string.sub(member, 1, string.len(activePrefix)) ~= activePrefix then + removed = removed + redis.call('ZREM', key, member) + end + end + if redis.call('ZCARD', key) == 0 then + redis.call('DEL', key) + else + redis.call('EXPIRE', key, slotTTL) + end + end + return removed + `) ) type concurrencyCache struct { @@ -463,3 +493,72 @@ func (c *concurrencyCache) CleanupExpiredAccountSlots(ctx context.Context, accou _, err := cleanupExpiredSlotsScript.Run(ctx, c.rdb, []string{key}, c.slotTTLSeconds).Result() return err } + +func (c *concurrencyCache) CleanupStaleProcessSlots(ctx context.Context, activeRequestPrefix string) error { + if activeRequestPrefix == "" { + return nil + } + + // 1. 清理有序集合中非当前进程前缀的成员 + slotPatterns := []string{accountSlotKeyPrefix + "*", userSlotKeyPrefix + "*"} + for _, pattern := range slotPatterns { + if err := c.cleanupSlotsByPattern(ctx, pattern, activeRequestPrefix); err != nil { + return err + } + } + + // 2. 删除所有等待队列计数器(重启后计数器失效) + waitPatterns := []string{accountWaitKeyPrefix + "*", waitQueueKeyPrefix + "*"} + for _, pattern := range waitPatterns { + if err := c.deleteKeysByPattern(ctx, pattern); err != nil { + return err + } + } + + return nil +} + +// cleanupSlotsByPattern 扫描匹配 pattern 的有序集合键,批量调用 Lua 脚本清理非当前进程成员。 +func (c *concurrencyCache) cleanupSlotsByPattern(ctx context.Context, pattern, activePrefix string) error { + const scanCount = 200 + var cursor uint64 + for { + keys, nextCursor, err := c.rdb.Scan(ctx, cursor, pattern, scanCount).Result() + if err != nil { + return fmt.Errorf("scan %s: %w", pattern, err) + } + if len(keys) > 0 { + _, err := startupCleanupScript.Run(ctx, c.rdb, keys, activePrefix, c.slotTTLSeconds).Result() + if err != nil { + return fmt.Errorf("cleanup slots %s: %w", pattern, err) + } + } + cursor = nextCursor + if cursor == 0 { + break + } + } + return nil +} + +// deleteKeysByPattern 扫描匹配 pattern 的键并删除。 +func (c *concurrencyCache) deleteKeysByPattern(ctx context.Context, pattern string) error { + const scanCount = 200 + var cursor uint64 + for { + keys, nextCursor, err := c.rdb.Scan(ctx, cursor, pattern, scanCount).Result() + if err != nil { + return fmt.Errorf("scan %s: %w", pattern, err) + } + if len(keys) > 0 { + if err := c.rdb.Del(ctx, keys...).Err(); err != nil { + return fmt.Errorf("del %s: %w", pattern, err) + } + } + cursor = nextCursor + if cursor == 0 { + break + } + } + return nil +} diff --git a/backend/internal/repository/concurrency_cache_integration_test.go b/backend/internal/repository/concurrency_cache_integration_test.go index 5983c832..5da94fc2 100644 --- a/backend/internal/repository/concurrency_cache_integration_test.go +++ b/backend/internal/repository/concurrency_cache_integration_test.go @@ -25,6 +25,10 @@ type ConcurrencyCacheSuite struct { cache service.ConcurrencyCache } +func TestConcurrencyCacheSuite(t *testing.T) { + suite.Run(t, new(ConcurrencyCacheSuite)) +} + func (s *ConcurrencyCacheSuite) SetupTest() { s.IntegrationRedisSuite.SetupTest() s.cache = NewConcurrencyCache(s.rdb, testSlotTTLMinutes, int(testSlotTTL.Seconds())) @@ -247,17 +251,41 @@ func (s *ConcurrencyCacheSuite) TestAccountWaitQueue_IncrementAndDecrement() { require.Equal(s.T(), 1, val, "expected account wait count 1") } -func (s *ConcurrencyCacheSuite) TestAccountWaitQueue_DecrementNoNegative() { - accountID := int64(301) - waitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID) +func (s *ConcurrencyCacheSuite) TestCleanupStaleProcessSlots() { + accountID := int64(901) + userID := int64(902) + accountKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID) + userKey := fmt.Sprintf("%s%d", userSlotKeyPrefix, userID) + userWaitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID) + accountWaitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID) - require.NoError(s.T(), s.cache.DecrementAccountWaitCount(s.ctx, accountID), "DecrementAccountWaitCount on non-existent key") + now := time.Now().Unix() + require.NoError(s.T(), s.rdb.ZAdd(s.ctx, accountKey, + redis.Z{Score: float64(now), Member: "oldproc-1"}, + redis.Z{Score: float64(now), Member: "keep-1"}, + ).Err()) + require.NoError(s.T(), s.rdb.ZAdd(s.ctx, userKey, + redis.Z{Score: float64(now), Member: "oldproc-2"}, + redis.Z{Score: float64(now), Member: "keep-2"}, + ).Err()) + require.NoError(s.T(), s.rdb.Set(s.ctx, userWaitKey, 3, time.Minute).Err()) + require.NoError(s.T(), s.rdb.Set(s.ctx, accountWaitKey, 2, time.Minute).Err()) - val, err := s.rdb.Get(s.ctx, waitKey).Int() - if !errors.Is(err, redis.Nil) { - require.NoError(s.T(), err, "Get waitKey") - } - require.GreaterOrEqual(s.T(), val, 0, "expected non-negative account wait count after decrement on empty") + require.NoError(s.T(), s.cache.CleanupStaleProcessSlots(s.ctx, "keep-")) + + accountMembers, err := s.rdb.ZRange(s.ctx, accountKey, 0, -1).Result() + require.NoError(s.T(), err) + require.Equal(s.T(), []string{"keep-1"}, accountMembers) + + userMembers, err := s.rdb.ZRange(s.ctx, userKey, 0, -1).Result() + require.NoError(s.T(), err) + require.Equal(s.T(), []string{"keep-2"}, userMembers) + + _, err = s.rdb.Get(s.ctx, userWaitKey).Result() + require.True(s.T(), errors.Is(err, redis.Nil)) + + _, err = s.rdb.Get(s.ctx, accountWaitKey).Result() + require.True(s.T(), errors.Is(err, redis.Nil)) } func (s *ConcurrencyCacheSuite) TestGetAccountConcurrency_Missing() { @@ -407,6 +435,53 @@ func (s *ConcurrencyCacheSuite) TestCleanupExpiredAccountSlots_NoExpired() { require.Equal(s.T(), 2, cur) } -func TestConcurrencyCacheSuite(t *testing.T) { - suite.Run(t, new(ConcurrencyCacheSuite)) +func (s *ConcurrencyCacheSuite) TestCleanupStaleProcessSlots_RemovesOldPrefixesAndWaitCounters() { + accountID := int64(901) + userID := int64(902) + accountSlotKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID) + userSlotKey := fmt.Sprintf("%s%d", userSlotKeyPrefix, userID) + userWaitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID) + accountWaitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID) + + now := float64(time.Now().Unix()) + require.NoError(s.T(), s.rdb.ZAdd(s.ctx, accountSlotKey, + redis.Z{Score: now, Member: "oldproc-1"}, + redis.Z{Score: now, Member: "activeproc-1"}, + ).Err()) + require.NoError(s.T(), s.rdb.Expire(s.ctx, accountSlotKey, testSlotTTL).Err()) + require.NoError(s.T(), s.rdb.ZAdd(s.ctx, userSlotKey, + redis.Z{Score: now, Member: "oldproc-2"}, + redis.Z{Score: now, Member: "activeproc-2"}, + ).Err()) + require.NoError(s.T(), s.rdb.Expire(s.ctx, userSlotKey, testSlotTTL).Err()) + require.NoError(s.T(), s.rdb.Set(s.ctx, userWaitKey, 3, testSlotTTL).Err()) + require.NoError(s.T(), s.rdb.Set(s.ctx, accountWaitKey, 2, testSlotTTL).Err()) + + require.NoError(s.T(), s.cache.CleanupStaleProcessSlots(s.ctx, "activeproc-")) + + accountMembers, err := s.rdb.ZRange(s.ctx, accountSlotKey, 0, -1).Result() + require.NoError(s.T(), err) + require.Equal(s.T(), []string{"activeproc-1"}, accountMembers) + + userMembers, err := s.rdb.ZRange(s.ctx, userSlotKey, 0, -1).Result() + require.NoError(s.T(), err) + require.Equal(s.T(), []string{"activeproc-2"}, userMembers) + + _, err = s.rdb.Get(s.ctx, userWaitKey).Result() + require.ErrorIs(s.T(), err, redis.Nil) + _, err = s.rdb.Get(s.ctx, accountWaitKey).Result() + require.ErrorIs(s.T(), err, redis.Nil) +} + +func (s *ConcurrencyCacheSuite) TestCleanupStaleProcessSlots_DeletesEmptySlotKeys() { + accountID := int64(903) + accountSlotKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID) + require.NoError(s.T(), s.rdb.ZAdd(s.ctx, accountSlotKey, redis.Z{Score: float64(time.Now().Unix()), Member: "oldproc-1"}).Err()) + require.NoError(s.T(), s.rdb.Expire(s.ctx, accountSlotKey, testSlotTTL).Err()) + + require.NoError(s.T(), s.cache.CleanupStaleProcessSlots(s.ctx, "activeproc-")) + + exists, err := s.rdb.Exists(s.ctx, accountSlotKey).Result() + require.NoError(s.T(), err) + require.EqualValues(s.T(), 0, exists) } diff --git a/backend/internal/repository/ent.go b/backend/internal/repository/ent.go index 5f3f5a84..64d32192 100644 --- a/backend/internal/repository/ent.go +++ b/backend/internal/repository/ent.go @@ -89,6 +89,10 @@ func InitEnt(cfg *config.Config) (*ent.Client, *sql.DB, error) { _ = client.Close() return nil, nil, err } + if err := ensureSimpleModeAdminConcurrency(seedCtx, client); err != nil { + _ = client.Close() + return nil, nil, err + } } return client, drv.DB(), nil diff --git a/backend/internal/repository/group_repo.go b/backend/internal/repository/group_repo.go index 4edc8534..27d68354 100644 --- a/backend/internal/repository/group_repo.go +++ b/backend/internal/repository/group_repo.go @@ -59,7 +59,10 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er SetNillableFallbackGroupIDOnInvalidRequest(groupIn.FallbackGroupIDOnInvalidRequest). SetModelRoutingEnabled(groupIn.ModelRoutingEnabled). SetMcpXMLInject(groupIn.MCPXMLInject). - SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes) + SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes). + SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch). + SetDefaultMappedModel(groupIn.DefaultMappedModel). + SetSimulateClaudeMaxEnabled(groupIn.SimulateClaudeMaxEnabled) // 设置模型路由配置 if groupIn.ModelRouting != nil { @@ -125,7 +128,10 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er SetClaudeCodeOnly(groupIn.ClaudeCodeOnly). SetModelRoutingEnabled(groupIn.ModelRoutingEnabled). SetMcpXMLInject(groupIn.MCPXMLInject). - SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes) + SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes). + SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch). + SetDefaultMappedModel(groupIn.DefaultMappedModel). + SetSimulateClaudeMaxEnabled(groupIn.SimulateClaudeMaxEnabled) // 显式处理可空字段:nil 需要 clear,非 nil 需要 set。 if groupIn.DailyLimitUSD != nil { diff --git a/backend/internal/repository/ops_repo.go b/backend/internal/repository/ops_repo.go index 989573f2..02ca1a3b 100644 --- a/backend/internal/repository/ops_repo.go +++ b/backend/internal/repository/ops_repo.go @@ -16,19 +16,7 @@ type opsRepository struct { db *sql.DB } -func NewOpsRepository(db *sql.DB) service.OpsRepository { - return &opsRepository{db: db} -} - -func (r *opsRepository) InsertErrorLog(ctx context.Context, input *service.OpsInsertErrorLogInput) (int64, error) { - if r == nil || r.db == nil { - return 0, fmt.Errorf("nil ops repository") - } - if input == nil { - return 0, fmt.Errorf("nil input") - } - - q := ` +const insertOpsErrorLogSQL = ` INSERT INTO ops_error_logs ( request_id, client_request_id, @@ -70,12 +58,77 @@ INSERT INTO ops_error_logs ( created_at ) VALUES ( $1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38 -) RETURNING id` +)` + +func NewOpsRepository(db *sql.DB) service.OpsRepository { + return &opsRepository{db: db} +} + +func (r *opsRepository) InsertErrorLog(ctx context.Context, input *service.OpsInsertErrorLogInput) (int64, error) { + if r == nil || r.db == nil { + return 0, fmt.Errorf("nil ops repository") + } + if input == nil { + return 0, fmt.Errorf("nil input") + } var id int64 err := r.db.QueryRowContext( ctx, - q, + insertOpsErrorLogSQL+" RETURNING id", + opsInsertErrorLogArgs(input)..., + ).Scan(&id) + if err != nil { + return 0, err + } + return id, nil +} + +func (r *opsRepository) BatchInsertErrorLogs(ctx context.Context, inputs []*service.OpsInsertErrorLogInput) (int64, error) { + if r == nil || r.db == nil { + return 0, fmt.Errorf("nil ops repository") + } + if len(inputs) == 0 { + return 0, nil + } + + tx, err := r.db.BeginTx(ctx, nil) + if err != nil { + return 0, err + } + defer func() { + if err != nil { + _ = tx.Rollback() + } + }() + + stmt, err := tx.PrepareContext(ctx, insertOpsErrorLogSQL) + if err != nil { + return 0, err + } + defer func() { + _ = stmt.Close() + }() + + var inserted int64 + for _, input := range inputs { + if input == nil { + continue + } + if _, err = stmt.ExecContext(ctx, opsInsertErrorLogArgs(input)...); err != nil { + return inserted, err + } + inserted++ + } + + if err = tx.Commit(); err != nil { + return inserted, err + } + return inserted, nil +} + +func opsInsertErrorLogArgs(input *service.OpsInsertErrorLogInput) []any { + return []any{ opsNullString(input.RequestID), opsNullString(input.ClientRequestID), opsNullInt64(input.UserID), @@ -114,11 +167,7 @@ INSERT INTO ops_error_logs ( input.IsRetryable, input.RetryCount, input.CreatedAt, - ).Scan(&id) - if err != nil { - return 0, err } - return id, nil } func (r *opsRepository) ListErrorLogs(ctx context.Context, filter *service.OpsErrorLogFilter) (*service.OpsErrorLogList, error) { diff --git a/backend/internal/repository/ops_write_pressure_integration_test.go b/backend/internal/repository/ops_write_pressure_integration_test.go new file mode 100644 index 00000000..ebb7a842 --- /dev/null +++ b/backend/internal/repository/ops_write_pressure_integration_test.go @@ -0,0 +1,79 @@ +//go:build integration + +package repository + +import ( + "context" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" +) + +func TestOpsRepositoryBatchInsertErrorLogs(t *testing.T) { + ctx := context.Background() + _, _ = integrationDB.ExecContext(ctx, "TRUNCATE ops_error_logs RESTART IDENTITY") + + repo := NewOpsRepository(integrationDB).(*opsRepository) + now := time.Now().UTC() + inserted, err := repo.BatchInsertErrorLogs(ctx, []*service.OpsInsertErrorLogInput{ + { + RequestID: "batch-ops-1", + ErrorPhase: "upstream", + ErrorType: "upstream_error", + Severity: "error", + StatusCode: 429, + ErrorMessage: "rate limited", + CreatedAt: now, + }, + { + RequestID: "batch-ops-2", + ErrorPhase: "internal", + ErrorType: "api_error", + Severity: "error", + StatusCode: 500, + ErrorMessage: "internal error", + CreatedAt: now.Add(time.Millisecond), + }, + }) + require.NoError(t, err) + require.EqualValues(t, 2, inserted) + + var count int + require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM ops_error_logs WHERE request_id IN ('batch-ops-1', 'batch-ops-2')").Scan(&count)) + require.Equal(t, 2, count) +} + +func TestEnqueueSchedulerOutbox_DeduplicatesIdempotentEvents(t *testing.T) { + ctx := context.Background() + _, _ = integrationDB.ExecContext(ctx, "TRUNCATE scheduler_outbox RESTART IDENTITY") + + accountID := int64(12345) + require.NoError(t, enqueueSchedulerOutbox(ctx, integrationDB, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil)) + require.NoError(t, enqueueSchedulerOutbox(ctx, integrationDB, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil)) + + var count int + require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM scheduler_outbox WHERE event_type = $1", service.SchedulerOutboxEventAccountChanged).Scan(&count)) + require.Equal(t, 1, count) + + time.Sleep(schedulerOutboxDedupWindow + 150*time.Millisecond) + require.NoError(t, enqueueSchedulerOutbox(ctx, integrationDB, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil)) + require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM scheduler_outbox WHERE event_type = $1", service.SchedulerOutboxEventAccountChanged).Scan(&count)) + require.Equal(t, 2, count) +} + +func TestEnqueueSchedulerOutbox_DoesNotDeduplicateLastUsed(t *testing.T) { + ctx := context.Background() + _, _ = integrationDB.ExecContext(ctx, "TRUNCATE scheduler_outbox RESTART IDENTITY") + + accountID := int64(67890) + payload1 := map[string]any{"last_used": map[string]int64{"67890": 100}} + payload2 := map[string]any{"last_used": map[string]int64{"67890": 200}} + require.NoError(t, enqueueSchedulerOutbox(ctx, integrationDB, service.SchedulerOutboxEventAccountLastUsed, &accountID, nil, payload1)) + require.NoError(t, enqueueSchedulerOutbox(ctx, integrationDB, service.SchedulerOutboxEventAccountLastUsed, &accountID, nil, payload2)) + + var count int + require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM scheduler_outbox WHERE event_type = $1", service.SchedulerOutboxEventAccountLastUsed).Scan(&count)) + require.Equal(t, 2, count) +} diff --git a/backend/internal/repository/scheduled_test_repo.go b/backend/internal/repository/scheduled_test_repo.go new file mode 100644 index 00000000..c03d1df9 --- /dev/null +++ b/backend/internal/repository/scheduled_test_repo.go @@ -0,0 +1,183 @@ +package repository + +import ( + "context" + "database/sql" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" +) + +// --- Plan Repository --- + +type scheduledTestPlanRepository struct { + db *sql.DB +} + +func NewScheduledTestPlanRepository(db *sql.DB) service.ScheduledTestPlanRepository { + return &scheduledTestPlanRepository{db: db} +} + +func (r *scheduledTestPlanRepository) Create(ctx context.Context, plan *service.ScheduledTestPlan) (*service.ScheduledTestPlan, error) { + row := r.db.QueryRowContext(ctx, ` + INSERT INTO scheduled_test_plans (account_id, model_id, cron_expression, enabled, max_results, auto_recover, next_run_at, created_at, updated_at) + VALUES ($1, $2, $3, $4, $5, $6, $7, NOW(), NOW()) + RETURNING id, account_id, model_id, cron_expression, enabled, max_results, auto_recover, last_run_at, next_run_at, created_at, updated_at + `, plan.AccountID, plan.ModelID, plan.CronExpression, plan.Enabled, plan.MaxResults, plan.AutoRecover, plan.NextRunAt) + return scanPlan(row) +} + +func (r *scheduledTestPlanRepository) GetByID(ctx context.Context, id int64) (*service.ScheduledTestPlan, error) { + row := r.db.QueryRowContext(ctx, ` + SELECT id, account_id, model_id, cron_expression, enabled, max_results, auto_recover, last_run_at, next_run_at, created_at, updated_at + FROM scheduled_test_plans WHERE id = $1 + `, id) + return scanPlan(row) +} + +func (r *scheduledTestPlanRepository) ListByAccountID(ctx context.Context, accountID int64) ([]*service.ScheduledTestPlan, error) { + rows, err := r.db.QueryContext(ctx, ` + SELECT id, account_id, model_id, cron_expression, enabled, max_results, auto_recover, last_run_at, next_run_at, created_at, updated_at + FROM scheduled_test_plans WHERE account_id = $1 + ORDER BY created_at DESC + `, accountID) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + return scanPlans(rows) +} + +func (r *scheduledTestPlanRepository) ListDue(ctx context.Context, now time.Time) ([]*service.ScheduledTestPlan, error) { + rows, err := r.db.QueryContext(ctx, ` + SELECT id, account_id, model_id, cron_expression, enabled, max_results, auto_recover, last_run_at, next_run_at, created_at, updated_at + FROM scheduled_test_plans + WHERE enabled = true AND next_run_at <= $1 + ORDER BY next_run_at ASC + `, now) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + return scanPlans(rows) +} + +func (r *scheduledTestPlanRepository) Update(ctx context.Context, plan *service.ScheduledTestPlan) (*service.ScheduledTestPlan, error) { + row := r.db.QueryRowContext(ctx, ` + UPDATE scheduled_test_plans + SET model_id = $2, cron_expression = $3, enabled = $4, max_results = $5, auto_recover = $6, next_run_at = $7, updated_at = NOW() + WHERE id = $1 + RETURNING id, account_id, model_id, cron_expression, enabled, max_results, auto_recover, last_run_at, next_run_at, created_at, updated_at + `, plan.ID, plan.ModelID, plan.CronExpression, plan.Enabled, plan.MaxResults, plan.AutoRecover, plan.NextRunAt) + return scanPlan(row) +} + +func (r *scheduledTestPlanRepository) Delete(ctx context.Context, id int64) error { + _, err := r.db.ExecContext(ctx, `DELETE FROM scheduled_test_plans WHERE id = $1`, id) + return err +} + +func (r *scheduledTestPlanRepository) UpdateAfterRun(ctx context.Context, id int64, lastRunAt time.Time, nextRunAt time.Time) error { + _, err := r.db.ExecContext(ctx, ` + UPDATE scheduled_test_plans SET last_run_at = $2, next_run_at = $3, updated_at = NOW() WHERE id = $1 + `, id, lastRunAt, nextRunAt) + return err +} + +// --- Result Repository --- + +type scheduledTestResultRepository struct { + db *sql.DB +} + +func NewScheduledTestResultRepository(db *sql.DB) service.ScheduledTestResultRepository { + return &scheduledTestResultRepository{db: db} +} + +func (r *scheduledTestResultRepository) Create(ctx context.Context, result *service.ScheduledTestResult) (*service.ScheduledTestResult, error) { + row := r.db.QueryRowContext(ctx, ` + INSERT INTO scheduled_test_results (plan_id, status, response_text, error_message, latency_ms, started_at, finished_at, created_at) + VALUES ($1, $2, $3, $4, $5, $6, $7, NOW()) + RETURNING id, plan_id, status, response_text, error_message, latency_ms, started_at, finished_at, created_at + `, result.PlanID, result.Status, result.ResponseText, result.ErrorMessage, result.LatencyMs, result.StartedAt, result.FinishedAt) + + out := &service.ScheduledTestResult{} + if err := row.Scan( + &out.ID, &out.PlanID, &out.Status, &out.ResponseText, &out.ErrorMessage, + &out.LatencyMs, &out.StartedAt, &out.FinishedAt, &out.CreatedAt, + ); err != nil { + return nil, err + } + return out, nil +} + +func (r *scheduledTestResultRepository) ListByPlanID(ctx context.Context, planID int64, limit int) ([]*service.ScheduledTestResult, error) { + rows, err := r.db.QueryContext(ctx, ` + SELECT id, plan_id, status, response_text, error_message, latency_ms, started_at, finished_at, created_at + FROM scheduled_test_results + WHERE plan_id = $1 + ORDER BY created_at DESC + LIMIT $2 + `, planID, limit) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + var results []*service.ScheduledTestResult + for rows.Next() { + r := &service.ScheduledTestResult{} + if err := rows.Scan( + &r.ID, &r.PlanID, &r.Status, &r.ResponseText, &r.ErrorMessage, + &r.LatencyMs, &r.StartedAt, &r.FinishedAt, &r.CreatedAt, + ); err != nil { + return nil, err + } + results = append(results, r) + } + return results, rows.Err() +} + +func (r *scheduledTestResultRepository) PruneOldResults(ctx context.Context, planID int64, keepCount int) error { + _, err := r.db.ExecContext(ctx, ` + DELETE FROM scheduled_test_results + WHERE id IN ( + SELECT id FROM ( + SELECT id, ROW_NUMBER() OVER (PARTITION BY plan_id ORDER BY created_at DESC) AS rn + FROM scheduled_test_results + WHERE plan_id = $1 + ) ranked + WHERE rn > $2 + ) + `, planID, keepCount) + return err +} + +// --- scan helpers --- + +type scannable interface { + Scan(dest ...any) error +} + +func scanPlan(row scannable) (*service.ScheduledTestPlan, error) { + p := &service.ScheduledTestPlan{} + if err := row.Scan( + &p.ID, &p.AccountID, &p.ModelID, &p.CronExpression, &p.Enabled, &p.MaxResults, &p.AutoRecover, + &p.LastRunAt, &p.NextRunAt, &p.CreatedAt, &p.UpdatedAt, + ); err != nil { + return nil, err + } + return p, nil +} + +func scanPlans(rows *sql.Rows) ([]*service.ScheduledTestPlan, error) { + var plans []*service.ScheduledTestPlan + for rows.Next() { + p, err := scanPlan(rows) + if err != nil { + return nil, err + } + plans = append(plans, p) + } + return plans, rows.Err() +} diff --git a/backend/internal/repository/scheduler_outbox_repo.go b/backend/internal/repository/scheduler_outbox_repo.go index d7bc97da..4b9a9f58 100644 --- a/backend/internal/repository/scheduler_outbox_repo.go +++ b/backend/internal/repository/scheduler_outbox_repo.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "encoding/json" + "time" "github.com/Wei-Shaw/sub2api/internal/service" ) @@ -12,6 +13,8 @@ type schedulerOutboxRepository struct { db *sql.DB } +const schedulerOutboxDedupWindow = time.Second + func NewSchedulerOutboxRepository(db *sql.DB) service.SchedulerOutboxRepository { return &schedulerOutboxRepository{db: db} } @@ -88,9 +91,37 @@ func enqueueSchedulerOutbox(ctx context.Context, exec sqlExecutor, eventType str } payloadArg = encoded } - _, err := exec.ExecContext(ctx, ` + query := ` INSERT INTO scheduler_outbox (event_type, account_id, group_id, payload) VALUES ($1, $2, $3, $4) - `, eventType, accountID, groupID, payloadArg) + ` + args := []any{eventType, accountID, groupID, payloadArg} + if schedulerOutboxEventSupportsDedup(eventType) { + query = ` + INSERT INTO scheduler_outbox (event_type, account_id, group_id, payload) + SELECT $1, $2, $3, $4 + WHERE NOT EXISTS ( + SELECT 1 + FROM scheduler_outbox + WHERE event_type = $1 + AND account_id IS NOT DISTINCT FROM $2 + AND group_id IS NOT DISTINCT FROM $3 + AND created_at >= NOW() - make_interval(secs => $5) + ) + ` + args = append(args, schedulerOutboxDedupWindow.Seconds()) + } + _, err := exec.ExecContext(ctx, query, args...) return err } + +func schedulerOutboxEventSupportsDedup(eventType string) bool { + switch eventType { + case service.SchedulerOutboxEventAccountChanged, + service.SchedulerOutboxEventGroupChanged, + service.SchedulerOutboxEventFullRebuild: + return true + default: + return false + } +} diff --git a/backend/internal/repository/simple_mode_admin_concurrency.go b/backend/internal/repository/simple_mode_admin_concurrency.go new file mode 100644 index 00000000..4d1db150 --- /dev/null +++ b/backend/internal/repository/simple_mode_admin_concurrency.go @@ -0,0 +1,55 @@ +package repository + +import ( + "context" + "fmt" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/setting" + dbuser "github.com/Wei-Shaw/sub2api/ent/user" + "github.com/Wei-Shaw/sub2api/internal/service" +) + +const ( + simpleModeAdminConcurrencyUpgradeKey = "simple_mode_admin_concurrency_upgraded_30" + simpleModeLegacyAdminConcurrency = 5 + simpleModeTargetAdminConcurrency = 30 +) + +func ensureSimpleModeAdminConcurrency(ctx context.Context, client *dbent.Client) error { + if client == nil { + return fmt.Errorf("nil ent client") + } + + upgraded, err := client.Setting.Query().Where(setting.KeyEQ(simpleModeAdminConcurrencyUpgradeKey)).Exist(ctx) + if err != nil { + return fmt.Errorf("check admin concurrency upgrade marker: %w", err) + } + if upgraded { + return nil + } + + if _, err := client.User.Update(). + Where( + dbuser.RoleEQ(service.RoleAdmin), + dbuser.ConcurrencyEQ(simpleModeLegacyAdminConcurrency), + ). + SetConcurrency(simpleModeTargetAdminConcurrency). + Save(ctx); err != nil { + return fmt.Errorf("upgrade simple mode admin concurrency: %w", err) + } + + now := time.Now() + if err := client.Setting.Create(). + SetKey(simpleModeAdminConcurrencyUpgradeKey). + SetValue(now.Format(time.RFC3339)). + SetUpdatedAt(now). + OnConflictColumns(setting.FieldKey). + UpdateNewValues(). + Exec(ctx); err != nil { + return fmt.Errorf("persist admin concurrency upgrade marker: %w", err) + } + + return nil +} diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go index ff40e97d..b9207f34 100644 --- a/backend/internal/repository/usage_log_repo.go +++ b/backend/internal/repository/usage_log_repo.go @@ -22,7 +22,7 @@ import ( "github.com/lib/pq" ) -const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, reasoning_effort, cache_ttl_overridden, created_at" +const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, cache_ttl_overridden, created_at" // dateFormatWhitelist 将 granularity 参数映射为 PostgreSQL TO_CHAR 格式字符串,防止外部输入直接拼入 SQL var dateFormatWhitelist = map[string]string{ @@ -135,6 +135,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) image_count, image_size, media_type, + service_tier, reasoning_effort, cache_ttl_overridden, created_at @@ -144,7 +145,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, - $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35 + $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36 ) ON CONFLICT (request_id, api_key_id) DO NOTHING RETURNING id, created_at @@ -158,6 +159,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) ipAddress := nullString(log.IPAddress) imageSize := nullString(log.ImageSize) mediaType := nullString(log.MediaType) + serviceTier := nullString(log.ServiceTier) reasoningEffort := nullString(log.ReasoningEffort) var requestIDArg any @@ -198,6 +200,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) log.ImageCount, imageSize, mediaType, + serviceTier, reasoningEffort, log.CacheTTLOverridden, createdAt, @@ -1363,7 +1366,8 @@ func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, user COUNT(*) as requests, COALESCE(SUM(input_tokens), 0) as input_tokens, COALESCE(SUM(output_tokens), 0) as output_tokens, - COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as cache_tokens, + COALESCE(SUM(cache_creation_tokens), 0) as cache_creation_tokens, + COALESCE(SUM(cache_read_tokens), 0) as cache_read_tokens, COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens, COALESCE(SUM(total_cost), 0) as cost, COALESCE(SUM(actual_cost), 0) as actual_cost @@ -1401,6 +1405,8 @@ func (r *usageLogRepository) GetUserModelStats(ctx context.Context, userID int64 COUNT(*) as requests, COALESCE(SUM(input_tokens), 0) as input_tokens, COALESCE(SUM(output_tokens), 0) as output_tokens, + COALESCE(SUM(cache_creation_tokens), 0) as cache_creation_tokens, + COALESCE(SUM(cache_read_tokens), 0) as cache_read_tokens, COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens, COALESCE(SUM(total_cost), 0) as cost, COALESCE(SUM(actual_cost), 0) as actual_cost @@ -1473,7 +1479,16 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat } whereClause := buildWhere(conditions) - logs, page, err := r.listUsageLogsWithPagination(ctx, whereClause, args, params) + var ( + logs []service.UsageLog + page *pagination.PaginationResult + err error + ) + if shouldUseFastUsageLogTotal(filters) { + logs, page, err = r.listUsageLogsWithFastPagination(ctx, whereClause, args, params) + } else { + logs, page, err = r.listUsageLogsWithPagination(ctx, whereClause, args, params) + } if err != nil { return nil, nil, err } @@ -1484,17 +1499,45 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat return logs, page, nil } +func shouldUseFastUsageLogTotal(filters UsageLogFilters) bool { + if filters.ExactTotal { + return false + } + // 强选择过滤下记录集通常较小,保留精确总数。 + return filters.UserID == 0 && filters.APIKeyID == 0 && filters.AccountID == 0 +} + // UsageStats represents usage statistics type UsageStats = usagestats.UsageStats // BatchUserUsageStats represents usage stats for a single user type BatchUserUsageStats = usagestats.BatchUserUsageStats +func normalizePositiveInt64IDs(ids []int64) []int64 { + if len(ids) == 0 { + return nil + } + seen := make(map[int64]struct{}, len(ids)) + out := make([]int64, 0, len(ids)) + for _, id := range ids { + if id <= 0 { + continue + } + if _, ok := seen[id]; ok { + continue + } + seen[id] = struct{}{} + out = append(out, id) + } + return out +} + // GetBatchUserUsageStats gets today and total actual_cost for multiple users within a time range. // If startTime is zero, defaults to 30 days ago. func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*BatchUserUsageStats, error) { result := make(map[int64]*BatchUserUsageStats) - if len(userIDs) == 0 { + normalizedUserIDs := normalizePositiveInt64IDs(userIDs) + if len(normalizedUserIDs) == 0 { return result, nil } @@ -1506,58 +1549,36 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs endTime = time.Now() } - for _, id := range userIDs { + for _, id := range normalizedUserIDs { result[id] = &BatchUserUsageStats{UserID: id} } query := ` - SELECT user_id, COALESCE(SUM(actual_cost), 0) as total_cost + SELECT + user_id, + COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $2 AND created_at < $3), 0) as total_cost, + COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $4), 0) as today_cost FROM usage_logs - WHERE user_id = ANY($1) AND created_at >= $2 AND created_at < $3 + WHERE user_id = ANY($1) + AND created_at >= LEAST($2, $4) GROUP BY user_id ` - rows, err := r.sql.QueryContext(ctx, query, pq.Array(userIDs), startTime, endTime) + today := timezone.Today() + rows, err := r.sql.QueryContext(ctx, query, pq.Array(normalizedUserIDs), startTime, endTime, today) if err != nil { return nil, err } for rows.Next() { var userID int64 var total float64 - if err := rows.Scan(&userID, &total); err != nil { + var todayTotal float64 + if err := rows.Scan(&userID, &total, &todayTotal); err != nil { _ = rows.Close() return nil, err } if stats, ok := result[userID]; ok { stats.TotalActualCost = total - } - } - if err := rows.Close(); err != nil { - return nil, err - } - if err := rows.Err(); err != nil { - return nil, err - } - - today := timezone.Today() - todayQuery := ` - SELECT user_id, COALESCE(SUM(actual_cost), 0) as today_cost - FROM usage_logs - WHERE user_id = ANY($1) AND created_at >= $2 - GROUP BY user_id - ` - rows, err = r.sql.QueryContext(ctx, todayQuery, pq.Array(userIDs), today) - if err != nil { - return nil, err - } - for rows.Next() { - var userID int64 - var total float64 - if err := rows.Scan(&userID, &total); err != nil { - _ = rows.Close() - return nil, err - } - if stats, ok := result[userID]; ok { - stats.TodayActualCost = total + stats.TodayActualCost = todayTotal } } if err := rows.Close(); err != nil { @@ -1577,7 +1598,8 @@ type BatchAPIKeyUsageStats = usagestats.BatchAPIKeyUsageStats // If startTime is zero, defaults to 30 days ago. func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*BatchAPIKeyUsageStats, error) { result := make(map[int64]*BatchAPIKeyUsageStats) - if len(apiKeyIDs) == 0 { + normalizedAPIKeyIDs := normalizePositiveInt64IDs(apiKeyIDs) + if len(normalizedAPIKeyIDs) == 0 { return result, nil } @@ -1589,58 +1611,36 @@ func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKe endTime = time.Now() } - for _, id := range apiKeyIDs { + for _, id := range normalizedAPIKeyIDs { result[id] = &BatchAPIKeyUsageStats{APIKeyID: id} } query := ` - SELECT api_key_id, COALESCE(SUM(actual_cost), 0) as total_cost + SELECT + api_key_id, + COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $2 AND created_at < $3), 0) as total_cost, + COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $4), 0) as today_cost FROM usage_logs - WHERE api_key_id = ANY($1) AND created_at >= $2 AND created_at < $3 + WHERE api_key_id = ANY($1) + AND created_at >= LEAST($2, $4) GROUP BY api_key_id ` - rows, err := r.sql.QueryContext(ctx, query, pq.Array(apiKeyIDs), startTime, endTime) + today := timezone.Today() + rows, err := r.sql.QueryContext(ctx, query, pq.Array(normalizedAPIKeyIDs), startTime, endTime, today) if err != nil { return nil, err } for rows.Next() { var apiKeyID int64 var total float64 - if err := rows.Scan(&apiKeyID, &total); err != nil { + var todayTotal float64 + if err := rows.Scan(&apiKeyID, &total, &todayTotal); err != nil { _ = rows.Close() return nil, err } if stats, ok := result[apiKeyID]; ok { stats.TotalActualCost = total - } - } - if err := rows.Close(); err != nil { - return nil, err - } - if err := rows.Err(); err != nil { - return nil, err - } - - today := timezone.Today() - todayQuery := ` - SELECT api_key_id, COALESCE(SUM(actual_cost), 0) as today_cost - FROM usage_logs - WHERE api_key_id = ANY($1) AND created_at >= $2 - GROUP BY api_key_id - ` - rows, err = r.sql.QueryContext(ctx, todayQuery, pq.Array(apiKeyIDs), today) - if err != nil { - return nil, err - } - for rows.Next() { - var apiKeyID int64 - var total float64 - if err := rows.Scan(&apiKeyID, &total); err != nil { - _ = rows.Close() - return nil, err - } - if stats, ok := result[apiKeyID]; ok { - stats.TodayActualCost = total + stats.TodayActualCost = todayTotal } } if err := rows.Close(); err != nil { @@ -1670,7 +1670,8 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start COUNT(*) as requests, COALESCE(SUM(input_tokens), 0) as input_tokens, COALESCE(SUM(output_tokens), 0) as output_tokens, - COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as cache_tokens, + COALESCE(SUM(cache_creation_tokens), 0) as cache_creation_tokens, + COALESCE(SUM(cache_read_tokens), 0) as cache_read_tokens, COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens, COALESCE(SUM(total_cost), 0) as cost, COALESCE(SUM(actual_cost), 0) as actual_cost @@ -1753,7 +1754,8 @@ func (r *usageLogRepository) getUsageTrendFromAggregates(ctx context.Context, st total_requests as requests, input_tokens, output_tokens, - (cache_creation_tokens + cache_read_tokens) as cache_tokens, + cache_creation_tokens, + cache_read_tokens, (input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens) as total_tokens, total_cost as cost, actual_cost @@ -1768,7 +1770,8 @@ func (r *usageLogRepository) getUsageTrendFromAggregates(ctx context.Context, st total_requests as requests, input_tokens, output_tokens, - (cache_creation_tokens + cache_read_tokens) as cache_tokens, + cache_creation_tokens, + cache_read_tokens, (input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens) as total_tokens, total_cost as cost, actual_cost @@ -1812,6 +1815,8 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start COUNT(*) as requests, COALESCE(SUM(input_tokens), 0) as input_tokens, COALESCE(SUM(output_tokens), 0) as output_tokens, + COALESCE(SUM(cache_creation_tokens), 0) as cache_creation_tokens, + COALESCE(SUM(cache_read_tokens), 0) as cache_read_tokens, COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens, COALESCE(SUM(total_cost), 0) as cost, %s @@ -1868,7 +1873,7 @@ func (r *usageLogRepository) GetGroupStatsWithFilters(ctx context.Context, start query := ` SELECT COALESCE(ul.group_id, 0) as group_id, - COALESCE(g.name, '') as group_name, + COALESCE(g.name, '(无分组)') as group_name, COUNT(*) as requests, COALESCE(SUM(ul.input_tokens + ul.output_tokens + ul.cache_creation_tokens + ul.cache_read_tokens), 0) as total_tokens, COALESCE(SUM(ul.total_cost), 0) as cost, @@ -2245,6 +2250,35 @@ func (r *usageLogRepository) listUsageLogsWithPagination(ctx context.Context, wh return logs, paginationResultFromTotal(total, params), nil } +func (r *usageLogRepository) listUsageLogsWithFastPagination(ctx context.Context, whereClause string, args []any, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) { + limit := params.Limit() + offset := params.Offset() + + limitPos := len(args) + 1 + offsetPos := len(args) + 2 + listArgs := append(append([]any{}, args...), limit+1, offset) + query := fmt.Sprintf("SELECT %s FROM usage_logs %s ORDER BY id DESC LIMIT $%d OFFSET $%d", usageLogSelectColumns, whereClause, limitPos, offsetPos) + + logs, err := r.queryUsageLogs(ctx, query, listArgs...) + if err != nil { + return nil, nil, err + } + + hasMore := false + if len(logs) > limit { + hasMore = true + logs = logs[:limit] + } + + total := int64(offset) + int64(len(logs)) + if hasMore { + // 只保证“还有下一页”,避免对超大表做全量 COUNT(*)。 + total = int64(offset) + int64(limit) + 1 + } + + return logs, paginationResultFromTotal(total, params), nil +} + func (r *usageLogRepository) queryUsageLogs(ctx context.Context, query string, args ...any) (logs []service.UsageLog, err error) { rows, err := r.sql.QueryContext(ctx, query, args...) if err != nil { @@ -2474,6 +2508,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e imageCount int imageSize sql.NullString mediaType sql.NullString + serviceTier sql.NullString reasoningEffort sql.NullString cacheTTLOverridden bool createdAt time.Time @@ -2513,6 +2548,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e &imageCount, &imageSize, &mediaType, + &serviceTier, &reasoningEffort, &cacheTTLOverridden, &createdAt, @@ -2583,6 +2619,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e if mediaType.Valid { log.MediaType = &mediaType.String } + if serviceTier.Valid { + log.ServiceTier = &serviceTier.String + } if reasoningEffort.Valid { log.ReasoningEffort = &reasoningEffort.String } @@ -2599,7 +2638,8 @@ func scanTrendRows(rows *sql.Rows) ([]TrendDataPoint, error) { &row.Requests, &row.InputTokens, &row.OutputTokens, - &row.CacheTokens, + &row.CacheCreationTokens, + &row.CacheReadTokens, &row.TotalTokens, &row.Cost, &row.ActualCost, @@ -2623,6 +2663,8 @@ func scanModelStatsRows(rows *sql.Rows) ([]ModelStat, error) { &row.Requests, &row.InputTokens, &row.OutputTokens, + &row.CacheCreationTokens, + &row.CacheReadTokens, &row.TotalTokens, &row.Cost, &row.ActualCost, diff --git a/backend/internal/repository/usage_log_repo_request_type_test.go b/backend/internal/repository/usage_log_repo_request_type_test.go index 95cf2a2d..7d82b4d0 100644 --- a/backend/internal/repository/usage_log_repo_request_type_test.go +++ b/backend/internal/repository/usage_log_repo_request_type_test.go @@ -71,6 +71,7 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) { log.ImageCount, sqlmock.AnyArg(), // image_size sqlmock.AnyArg(), // media_type + sqlmock.AnyArg(), // service_tier sqlmock.AnyArg(), // reasoning_effort log.CacheTTLOverridden, createdAt, @@ -81,12 +82,76 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) { require.NoError(t, err) require.True(t, inserted) require.Equal(t, int64(99), log.ID) + require.Nil(t, log.ServiceTier) require.Equal(t, service.RequestTypeWSV2, log.RequestType) require.True(t, log.Stream) require.True(t, log.OpenAIWSMode) require.NoError(t, mock.ExpectationsWereMet()) } +func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) { + db, mock := newSQLMock(t) + repo := &usageLogRepository{sql: db} + + createdAt := time.Date(2025, 1, 2, 12, 0, 0, 0, time.UTC) + serviceTier := "priority" + log := &service.UsageLog{ + UserID: 1, + APIKeyID: 2, + AccountID: 3, + RequestID: "req-service-tier", + Model: "gpt-5.4", + ServiceTier: &serviceTier, + CreatedAt: createdAt, + } + + mock.ExpectQuery("INSERT INTO usage_logs"). + WithArgs( + log.UserID, + log.APIKeyID, + log.AccountID, + log.RequestID, + log.Model, + sqlmock.AnyArg(), + sqlmock.AnyArg(), + log.InputTokens, + log.OutputTokens, + log.CacheCreationTokens, + log.CacheReadTokens, + log.CacheCreation5mTokens, + log.CacheCreation1hTokens, + log.InputCost, + log.OutputCost, + log.CacheCreationCost, + log.CacheReadCost, + log.TotalCost, + log.ActualCost, + log.RateMultiplier, + log.AccountRateMultiplier, + log.BillingType, + int16(service.RequestTypeSync), + false, + false, + sqlmock.AnyArg(), + sqlmock.AnyArg(), + sqlmock.AnyArg(), + sqlmock.AnyArg(), + log.ImageCount, + sqlmock.AnyArg(), + sqlmock.AnyArg(), + serviceTier, + sqlmock.AnyArg(), + log.CacheTTLOverridden, + createdAt, + ). + WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(100), createdAt)) + + inserted, err := repo.Create(context.Background(), log) + require.NoError(t, err) + require.True(t, inserted) + require.NoError(t, mock.ExpectationsWereMet()) +} + func TestUsageLogRepositoryListWithFiltersRequestTypePriority(t *testing.T) { db, mock := newSQLMock(t) repo := &usageLogRepository{sql: db} @@ -96,6 +161,7 @@ func TestUsageLogRepositoryListWithFiltersRequestTypePriority(t *testing.T) { filters := usagestats.UsageLogFilters{ RequestType: &requestType, Stream: &stream, + ExactTotal: true, } mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM usage_logs WHERE \\(request_type = \\$1 OR \\(request_type = 0 AND openai_ws_mode = TRUE\\)\\)"). @@ -124,7 +190,7 @@ func TestUsageLogRepositoryGetUsageTrendWithFiltersRequestTypePriority(t *testin mock.ExpectQuery("AND \\(request_type = \\$3 OR \\(request_type = 0 AND stream = TRUE AND openai_ws_mode = FALSE\\)\\)"). WithArgs(start, end, requestType). - WillReturnRows(sqlmock.NewRows([]string{"date", "requests", "input_tokens", "output_tokens", "cache_tokens", "total_tokens", "cost", "actual_cost"})) + WillReturnRows(sqlmock.NewRows([]string{"date", "requests", "input_tokens", "output_tokens", "cache_creation_tokens", "cache_read_tokens", "total_tokens", "cost", "actual_cost"})) trend, err := repo.GetUsageTrendWithFilters(context.Background(), start, end, "day", 0, 0, 0, 0, "", &requestType, &stream, nil) require.NoError(t, err) @@ -143,7 +209,7 @@ func TestUsageLogRepositoryGetModelStatsWithFiltersRequestTypePriority(t *testin mock.ExpectQuery("AND \\(request_type = \\$3 OR \\(request_type = 0 AND openai_ws_mode = TRUE\\)\\)"). WithArgs(start, end, requestType). - WillReturnRows(sqlmock.NewRows([]string{"model", "requests", "input_tokens", "output_tokens", "total_tokens", "cost", "actual_cost"})) + WillReturnRows(sqlmock.NewRows([]string{"model", "requests", "input_tokens", "output_tokens", "cache_creation_tokens", "cache_read_tokens", "total_tokens", "cost", "actual_cost"})) stats, err := repo.GetModelStatsWithFilters(context.Background(), start, end, 0, 0, 0, 0, &requestType, &stream, nil) require.NoError(t, err) @@ -279,11 +345,14 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { 0, sql.NullString{}, sql.NullString{}, + sql.NullString{Valid: true, String: "priority"}, sql.NullString{}, false, now, }}) require.NoError(t, err) + require.NotNil(t, log.ServiceTier) + require.Equal(t, "priority", *log.ServiceTier) require.Equal(t, service.RequestTypeWSV2, log.RequestType) require.True(t, log.Stream) require.True(t, log.OpenAIWSMode) @@ -315,13 +384,53 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { 0, sql.NullString{}, sql.NullString{}, + sql.NullString{Valid: true, String: "flex"}, sql.NullString{}, false, now, }}) require.NoError(t, err) + require.NotNil(t, log.ServiceTier) + require.Equal(t, "flex", *log.ServiceTier) require.Equal(t, service.RequestTypeStream, log.RequestType) require.True(t, log.Stream) require.False(t, log.OpenAIWSMode) }) + + t.Run("service_tier_is_scanned", func(t *testing.T) { + now := time.Now().UTC() + log, err := scanUsageLog(usageLogScannerStub{values: []any{ + int64(3), + int64(12), + int64(22), + int64(32), + sql.NullString{Valid: true, String: "req-3"}, + "gpt-5.4", + sql.NullInt64{}, + sql.NullInt64{}, + 1, 2, 3, 4, 5, 6, + 0.1, 0.2, 0.3, 0.4, 1.0, 0.9, + 1.0, + sql.NullFloat64{}, + int16(service.BillingTypeBalance), + int16(service.RequestTypeSync), + false, + false, + sql.NullInt64{}, + sql.NullInt64{}, + sql.NullString{}, + sql.NullString{}, + 0, + sql.NullString{}, + sql.NullString{}, + sql.NullString{Valid: true, String: "priority"}, + sql.NullString{}, + false, + now, + }}) + require.NoError(t, err) + require.NotNil(t, log.ServiceTier) + require.Equal(t, "priority", *log.ServiceTier) + }) + } diff --git a/backend/internal/repository/user_group_rate_repo.go b/backend/internal/repository/user_group_rate_repo.go index e3b11096..e794635d 100644 --- a/backend/internal/repository/user_group_rate_repo.go +++ b/backend/internal/repository/user_group_rate_repo.go @@ -95,6 +95,35 @@ func (r *userGroupRateRepository) GetByUserIDs(ctx context.Context, userIDs []in return result, nil } +// GetByGroupID 获取指定分组下所有用户的专属倍率 +func (r *userGroupRateRepository) GetByGroupID(ctx context.Context, groupID int64) ([]service.UserGroupRateEntry, error) { + query := ` + SELECT ugr.user_id, u.email, ugr.rate_multiplier + FROM user_group_rate_multipliers ugr + JOIN users u ON u.id = ugr.user_id AND u.deleted_at IS NULL + WHERE ugr.group_id = $1 + ORDER BY ugr.user_id + ` + rows, err := r.sql.QueryContext(ctx, query, groupID) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + var result []service.UserGroupRateEntry + for rows.Next() { + var entry service.UserGroupRateEntry + if err := rows.Scan(&entry.UserID, &entry.UserEmail, &entry.RateMultiplier); err != nil { + return nil, err + } + result = append(result, entry) + } + if err := rows.Err(); err != nil { + return nil, err + } + return result, nil +} + // GetByUserAndGroup 获取用户在特定分组的专属倍率 func (r *userGroupRateRepository) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) { query := `SELECT rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2` diff --git a/backend/internal/repository/user_repo.go b/backend/internal/repository/user_repo.go index 05b68968..b56aaaf9 100644 --- a/backend/internal/repository/user_repo.go +++ b/backend/internal/repository/user_repo.go @@ -243,21 +243,24 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination. userMap[u.ID] = &outUsers[len(outUsers)-1] } - // Batch load active subscriptions with groups to avoid N+1. - subs, err := r.client.UserSubscription.Query(). - Where( - usersubscription.UserIDIn(userIDs...), - usersubscription.StatusEQ(service.SubscriptionStatusActive), - ). - WithGroup(). - All(ctx) - if err != nil { - return nil, nil, err - } + shouldLoadSubscriptions := filters.IncludeSubscriptions == nil || *filters.IncludeSubscriptions + if shouldLoadSubscriptions { + // Batch load active subscriptions with groups to avoid N+1. + subs, err := r.client.UserSubscription.Query(). + Where( + usersubscription.UserIDIn(userIDs...), + usersubscription.StatusEQ(service.SubscriptionStatusActive), + ). + WithGroup(). + All(ctx) + if err != nil { + return nil, nil, err + } - for i := range subs { - if u, ok := userMap[subs[i].UserID]; ok { - u.Subscriptions = append(u.Subscriptions, *userSubscriptionEntityToService(subs[i])) + for i := range subs { + if u, ok := userMap[subs[i].UserID]; ok { + u.Subscriptions = append(u.Subscriptions, *userSubscriptionEntityToService(subs[i])) + } } } diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go index 2e35e0a0..5fe7a98e 100644 --- a/backend/internal/repository/wire.go +++ b/backend/internal/repository/wire.go @@ -53,7 +53,9 @@ var ProviderSet = wire.NewSet( NewAPIKeyRepository, NewGroupRepository, NewAccountRepository, - NewSoraAccountRepository, // Sora 账号扩展表仓储 + NewSoraAccountRepository, // Sora 账号扩展表仓储 + NewScheduledTestPlanRepository, // 定时测试计划仓储 + NewScheduledTestResultRepository, // 定时测试结果仓储 NewProxyRepository, NewRedeemCodeRepository, NewPromoCodeRepository, diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index c7eb646c..0b36bf66 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -210,8 +210,10 @@ func TestAPIContracts(t *testing.T) { "sora_video_price_per_request": null, "sora_video_price_per_request_hd": null, "claude_code_only": false, + "allow_messages_dispatch": false, "fallback_group_id": null, "fallback_group_id_on_invalid_request": null, + "allow_messages_dispatch": false, "created_at": "2025-01-02T03:04:05Z", "updated_at": "2025-01-02T03:04:05Z" } @@ -446,9 +448,10 @@ func TestAPIContracts(t *testing.T) { setup: func(t *testing.T, deps *contractDeps) { t.Helper() deps.settingRepo.SetAll(map[string]string{ - service.SettingKeyRegistrationEnabled: "true", - service.SettingKeyEmailVerifyEnabled: "false", - service.SettingKeyPromoCodeEnabled: "true", + service.SettingKeyRegistrationEnabled: "true", + service.SettingKeyEmailVerifyEnabled: "false", + service.SettingKeyRegistrationEmailSuffixWhitelist: "[]", + service.SettingKeyPromoCodeEnabled: "true", service.SettingKeySMTPHost: "smtp.example.com", service.SettingKeySMTPPort: "587", @@ -487,6 +490,7 @@ func TestAPIContracts(t *testing.T) { "data": { "registration_enabled": true, "email_verify_enabled": false, + "registration_email_suffix_whitelist": [], "promo_code_enabled": true, "password_reset_enabled": false, "totp_enabled": false, @@ -641,7 +645,7 @@ func newContractDeps(t *testing.T) *contractDeps { settingRepo := newStubSettingRepo() settingService := service.NewSettingService(settingRepo, cfg) - adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil) + adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil) authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) usageHandler := handler.NewUsageHandler(usageService, apiKeyService) @@ -1094,6 +1098,14 @@ func (s *stubAccountRepo) UpdateExtra(ctx context.Context, id int64, updates map return errors.New("not implemented") } +func (s *stubAccountRepo) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error { + return errors.New("not implemented") +} + +func (s *stubAccountRepo) ResetQuotaUsed(ctx context.Context, id int64) error { + return errors.New("not implemented") +} + func (s *stubAccountRepo) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) { s.bulkUpdateIDs = append([]int64{}, ids...) return int64(len(ids)), nil diff --git a/backend/internal/server/middleware/admin_auth_test.go b/backend/internal/server/middleware/admin_auth_test.go index 033a5b77..138663c4 100644 --- a/backend/internal/server/middleware/admin_auth_test.go +++ b/backend/internal/server/middleware/admin_auth_test.go @@ -19,7 +19,7 @@ func TestAdminAuthJWTValidatesTokenVersion(t *testing.T) { gin.SetMode(gin.TestMode) cfg := &config.Config{JWT: config.JWTConfig{Secret: "test-secret", ExpireHour: 1}} - authService := service.NewAuthService(nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil) + authService := service.NewAuthService(nil, nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil) admin := &service.User{ ID: 1, diff --git a/backend/internal/server/middleware/jwt_auth_test.go b/backend/internal/server/middleware/jwt_auth_test.go index f8839cfe..ad9c1b5b 100644 --- a/backend/internal/server/middleware/jwt_auth_test.go +++ b/backend/internal/server/middleware/jwt_auth_test.go @@ -40,7 +40,7 @@ func newJWTTestEnv(users map[int64]*service.User) (*gin.Engine, *service.AuthSer cfg.JWT.AccessTokenExpireMinutes = 60 userRepo := &stubJWTUserRepo{users: users} - authSvc := service.NewAuthService(userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil) + authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil) userSvc := service.NewUserService(userRepo, nil, nil) mw := NewJWTAuthMiddleware(authSvc, userSvc) diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index c36c36a0..dc5e8269 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -78,6 +78,9 @@ func RegisterAdminRoutes( // API Key 管理 registerAdminAPIKeyRoutes(admin, h) + + // 定时测试计划 + registerScheduledTestRoutes(admin, h) } } @@ -168,6 +171,7 @@ func registerOpsRoutes(admin *gin.RouterGroup, h *handler.Handlers) { ops.GET("/system-logs/health", h.Admin.Ops.GetSystemLogIngestionHealth) // Dashboard (vNext - raw path for MVP) + ops.GET("/dashboard/snapshot-v2", h.Admin.Ops.GetDashboardSnapshotV2) ops.GET("/dashboard/overview", h.Admin.Ops.GetDashboardOverview) ops.GET("/dashboard/throughput-trend", h.Admin.Ops.GetDashboardThroughputTrend) ops.GET("/dashboard/latency-histogram", h.Admin.Ops.GetDashboardLatencyHistogram) @@ -180,6 +184,7 @@ func registerOpsRoutes(admin *gin.RouterGroup, h *handler.Handlers) { func registerDashboardRoutes(admin *gin.RouterGroup, h *handler.Handlers) { dashboard := admin.Group("/dashboard") { + dashboard.GET("/snapshot-v2", h.Admin.Dashboard.GetSnapshotV2) dashboard.GET("/stats", h.Admin.Dashboard.GetStats) dashboard.GET("/realtime", h.Admin.Dashboard.GetRealtimeMetrics) dashboard.GET("/trend", h.Admin.Dashboard.GetUsageTrend) @@ -223,6 +228,7 @@ func registerGroupRoutes(admin *gin.RouterGroup, h *handler.Handlers) { groups.PUT("/:id", h.Admin.Group.Update) groups.DELETE("/:id", h.Admin.Group.Delete) groups.GET("/:id/stats", h.Admin.Group.GetStats) + groups.GET("/:id/rate-multipliers", h.Admin.Group.GetGroupRateMultipliers) groups.GET("/:id/api-keys", h.Admin.Group.GetGroupAPIKeys) } } @@ -239,6 +245,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) { accounts.PUT("/:id", h.Admin.Account.Update) accounts.DELETE("/:id", h.Admin.Account.Delete) accounts.POST("/:id/test", h.Admin.Account.Test) + accounts.POST("/:id/recover-state", h.Admin.Account.RecoverState) accounts.POST("/:id/refresh", h.Admin.Account.Refresh) accounts.POST("/:id/refresh-tier", h.Admin.Account.RefreshTier) accounts.GET("/:id/stats", h.Admin.Account.GetStats) @@ -247,6 +254,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) { accounts.GET("/:id/today-stats", h.Admin.Account.GetTodayStats) accounts.POST("/today-stats/batch", h.Admin.Account.GetBatchTodayStats) accounts.POST("/:id/clear-rate-limit", h.Admin.Account.ClearRateLimit) + accounts.POST("/:id/reset-quota", h.Admin.Account.ResetQuota) accounts.GET("/:id/temp-unschedulable", h.Admin.Account.GetTempUnschedulable) accounts.DELETE("/:id/temp-unschedulable", h.Admin.Account.ClearTempUnschedulable) accounts.POST("/:id/schedulable", h.Admin.Account.SetSchedulable) @@ -257,6 +265,8 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) { accounts.POST("/batch-update-credentials", h.Admin.Account.BatchUpdateCredentials) accounts.POST("/batch-refresh-tier", h.Admin.Account.BatchRefreshTier) accounts.POST("/bulk-update", h.Admin.Account.BulkUpdate) + accounts.POST("/batch-clear-error", h.Admin.Account.BatchClearError) + accounts.POST("/batch-refresh", h.Admin.Account.BatchRefresh) // Antigravity 默认模型映射 accounts.GET("/antigravity/default-model-mapping", h.Admin.Account.GetAntigravityDefaultModelMapping) @@ -386,6 +396,12 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) { // 流超时处理配置 adminSettings.GET("/stream-timeout", h.Admin.Setting.GetStreamTimeoutSettings) adminSettings.PUT("/stream-timeout", h.Admin.Setting.UpdateStreamTimeoutSettings) + // 请求整流器配置 + adminSettings.GET("/rectifier", h.Admin.Setting.GetRectifierSettings) + adminSettings.PUT("/rectifier", h.Admin.Setting.UpdateRectifierSettings) + // Beta 策略配置 + adminSettings.GET("/beta-policy", h.Admin.Setting.GetBetaPolicySettings) + adminSettings.PUT("/beta-policy", h.Admin.Setting.UpdateBetaPolicySettings) // Sora S3 存储配置 adminSettings.GET("/sora-s3", h.Admin.Setting.GetSoraS3Settings) adminSettings.PUT("/sora-s3", h.Admin.Setting.UpdateSoraS3Settings) @@ -441,6 +457,7 @@ func registerSubscriptionRoutes(admin *gin.RouterGroup, h *handler.Handlers) { subscriptions.POST("/assign", h.Admin.Subscription.Assign) subscriptions.POST("/bulk-assign", h.Admin.Subscription.BulkAssign) subscriptions.POST("/:id/extend", h.Admin.Subscription.Extend) + subscriptions.POST("/:id/reset-quota", h.Admin.Subscription.ResetQuota) subscriptions.DELETE("/:id", h.Admin.Subscription.Revoke) } @@ -476,6 +493,18 @@ func registerUserAttributeRoutes(admin *gin.RouterGroup, h *handler.Handlers) { } } +func registerScheduledTestRoutes(admin *gin.RouterGroup, h *handler.Handlers) { + plans := admin.Group("/scheduled-test-plans") + { + plans.POST("", h.Admin.ScheduledTest.Create) + plans.PUT("/:id", h.Admin.ScheduledTest.Update) + plans.DELETE("/:id", h.Admin.ScheduledTest.Delete) + plans.GET("/:id/results", h.Admin.ScheduledTest.ListResults) + } + // Nested under accounts + admin.GET("/accounts/:id/scheduled-test-plans", h.Admin.ScheduledTest.ListByAccount) +} + func registerErrorPassthroughRoutes(admin *gin.RouterGroup, h *handler.Handlers) { rules := admin.Group("/error-passthrough-rules") { diff --git a/backend/internal/server/routes/auth.go b/backend/internal/server/routes/auth.go index c168820c..0efc9560 100644 --- a/backend/internal/server/routes/auth.go +++ b/backend/internal/server/routes/auth.go @@ -61,6 +61,12 @@ func RegisterAuthRoutes( }), h.Auth.ResetPassword) auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart) auth.GET("/oauth/linuxdo/callback", h.Auth.LinuxDoOAuthCallback) + auth.POST("/oauth/linuxdo/complete-registration", + rateLimiter.LimitWithOptions("oauth-linuxdo-complete", 10, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), + h.Auth.CompleteLinuxDoOAuthRegistration, + ) } // 公开设置(无需认证) diff --git a/backend/internal/server/routes/gateway.go b/backend/internal/server/routes/gateway.go index 13f13320..ea40f2f1 100644 --- a/backend/internal/server/routes/gateway.go +++ b/backend/internal/server/routes/gateway.go @@ -43,22 +43,36 @@ func RegisterGatewayRoutes( gateway.Use(gin.HandlerFunc(apiKeyAuth)) gateway.Use(requireGroupAnthropic) { - gateway.POST("/messages", h.Gateway.Messages) - gateway.POST("/messages/count_tokens", h.Gateway.CountTokens) + // /v1/messages: auto-route based on group platform + gateway.POST("/messages", func(c *gin.Context) { + if getGroupPlatform(c) == service.PlatformOpenAI { + h.OpenAIGateway.Messages(c) + return + } + h.Gateway.Messages(c) + }) + // /v1/messages/count_tokens: OpenAI groups get 404 + gateway.POST("/messages/count_tokens", func(c *gin.Context) { + if getGroupPlatform(c) == service.PlatformOpenAI { + c.JSON(http.StatusNotFound, gin.H{ + "type": "error", + "error": gin.H{ + "type": "not_found_error", + "message": "Token counting is not supported for this platform", + }, + }) + return + } + h.Gateway.CountTokens(c) + }) gateway.GET("/models", h.Gateway.Models) gateway.GET("/usage", h.Gateway.Usage) // OpenAI Responses API gateway.POST("/responses", h.OpenAIGateway.Responses) + gateway.POST("/responses/*subpath", h.OpenAIGateway.Responses) gateway.GET("/responses", h.OpenAIGateway.ResponsesWebSocket) - // 明确阻止旧协议入口:OpenAI 仅支持 Responses API,避免客户端误解为会自动路由到其它平台。 - gateway.POST("/chat/completions", func(c *gin.Context) { - c.JSON(http.StatusBadRequest, gin.H{ - "error": gin.H{ - "type": "invalid_request_error", - "message": "Unsupported legacy protocol: /v1/chat/completions is not supported. Please use /v1/responses.", - }, - }) - }) + // OpenAI Chat Completions API + gateway.POST("/chat/completions", h.OpenAIGateway.ChatCompletions) } // Gemini 原生 API 兼容层(Gemini SDK/CLI 直连) @@ -77,7 +91,10 @@ func RegisterGatewayRoutes( // OpenAI Responses API(不带v1前缀的别名) r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses) + r.POST("/responses/*subpath", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses) r.GET("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ResponsesWebSocket) + // OpenAI Chat Completions API(不带v1前缀的别名) + r.POST("/chat/completions", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ChatCompletions) // Antigravity 模型列表 r.GET("/antigravity/models", gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.Gateway.AntigravityModels) @@ -132,3 +149,12 @@ func RegisterGatewayRoutes( // Sora 媒体代理(签名 URL,无需 API Key) r.GET("/sora/media-signed/*filepath", h.SoraGateway.MediaProxySigned) } + +// getGroupPlatform extracts the group platform from the API Key stored in context. +func getGroupPlatform(c *gin.Context) string { + apiKey, ok := middleware.GetAPIKeyFromContext(c) + if !ok || apiKey.Group == nil { + return "" + } + return apiKey.Group.Platform +} diff --git a/backend/internal/server/routes/gateway_test.go b/backend/internal/server/routes/gateway_test.go new file mode 100644 index 00000000..00edd31b --- /dev/null +++ b/backend/internal/server/routes/gateway_test.go @@ -0,0 +1,51 @@ +package routes + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/handler" + servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func newGatewayRoutesTestRouter() *gin.Engine { + gin.SetMode(gin.TestMode) + router := gin.New() + + RegisterGatewayRoutes( + router, + &handler.Handlers{ + Gateway: &handler.GatewayHandler{}, + OpenAIGateway: &handler.OpenAIGatewayHandler{}, + SoraGateway: &handler.SoraGatewayHandler{}, + }, + servermiddleware.APIKeyAuthMiddleware(func(c *gin.Context) { + c.Next() + }), + nil, + nil, + nil, + nil, + &config.Config{}, + ) + + return router +} + +func TestGatewayRoutesOpenAIResponsesCompactPathIsRegistered(t *testing.T) { + router := newGatewayRoutesTestRouter() + + for _, path := range []string{"/v1/responses/compact", "/responses/compact"} { + req := httptest.NewRequest(http.MethodPost, path, strings.NewReader(`{"model":"gpt-5"}`)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + router.ServeHTTP(w, req) + require.NotEqual(t, http.StatusNotFound, w.Code, "path=%s should hit OpenAI responses handler", path) + } +} diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go index 81e91aeb..9d4f73d4 100644 --- a/backend/internal/service/account.go +++ b/backend/internal/service/account.go @@ -28,6 +28,7 @@ type Account struct { // RateMultiplier 账号计费倍率(>=0,允许 0 表示该账号计费为 0)。 // 使用指针用于兼容旧版本调度缓存(Redis)中缺字段的情况:nil 表示按 1.0 处理。 RateMultiplier *float64 + LoadFactor *int // 调度负载因子;nil 表示使用 Concurrency Status string ErrorMessage string LastUsedAt *time.Time @@ -88,6 +89,19 @@ func (a *Account) BillingRateMultiplier() float64 { return *a.RateMultiplier } +func (a *Account) EffectiveLoadFactor() int { + if a == nil { + return 1 + } + if a.LoadFactor != nil && *a.LoadFactor > 0 { + return *a.LoadFactor + } + if a.Concurrency > 0 { + return a.Concurrency + } + return 1 +} + func (a *Account) IsSchedulable() bool { if !a.IsActive() || !a.Schedulable { return false @@ -633,6 +647,75 @@ func (a *Account) IsCustomErrorCodesEnabled() bool { return false } +// IsPoolMode 检查 API Key 账号是否启用池模式。 +// 池模式下,上游错误不标记本地账号状态,而是在同一账号上重试。 +func (a *Account) IsPoolMode() bool { + if a.Type != AccountTypeAPIKey || a.Credentials == nil { + return false + } + if v, ok := a.Credentials["pool_mode"]; ok { + if enabled, ok := v.(bool); ok { + return enabled + } + } + return false +} + +const ( + defaultPoolModeRetryCount = 3 + maxPoolModeRetryCount = 10 +) + +// GetPoolModeRetryCount 返回池模式同账号重试次数。 +// 未配置或配置非法时回退为默认值 3;小于 0 按 0 处理;过大则截断到 10。 +func (a *Account) GetPoolModeRetryCount() int { + if a == nil || !a.IsPoolMode() || a.Credentials == nil { + return defaultPoolModeRetryCount + } + raw, ok := a.Credentials["pool_mode_retry_count"] + if !ok || raw == nil { + return defaultPoolModeRetryCount + } + count := parsePoolModeRetryCount(raw) + if count < 0 { + return 0 + } + if count > maxPoolModeRetryCount { + return maxPoolModeRetryCount + } + return count +} + +func parsePoolModeRetryCount(value any) int { + switch v := value.(type) { + case int: + return v + case int64: + return int(v) + case float64: + return int(v) + case json.Number: + if i, err := v.Int64(); err == nil { + return int(i) + } + case string: + if i, err := strconv.Atoi(strings.TrimSpace(v)); err == nil { + return i + } + } + return defaultPoolModeRetryCount +} + +// isPoolModeRetryableStatus 池模式下应触发同账号重试的状态码 +func isPoolModeRetryableStatus(statusCode int) bool { + switch statusCode { + case 401, 403, 429: + return true + default: + return false + } +} + func (a *Account) GetCustomErrorCodes() []int { if a.Credentials == nil { return nil @@ -853,15 +936,21 @@ func (a *Account) IsOpenAIResponsesWebSocketV2Enabled() bool { } const ( - OpenAIWSIngressModeOff = "off" - OpenAIWSIngressModeShared = "shared" - OpenAIWSIngressModeDedicated = "dedicated" + OpenAIWSIngressModeOff = "off" + OpenAIWSIngressModeShared = "shared" + OpenAIWSIngressModeDedicated = "dedicated" + OpenAIWSIngressModeCtxPool = "ctx_pool" + OpenAIWSIngressModePassthrough = "passthrough" ) func normalizeOpenAIWSIngressMode(mode string) string { switch strings.ToLower(strings.TrimSpace(mode)) { case OpenAIWSIngressModeOff: return OpenAIWSIngressModeOff + case OpenAIWSIngressModeCtxPool: + return OpenAIWSIngressModeCtxPool + case OpenAIWSIngressModePassthrough: + return OpenAIWSIngressModePassthrough case OpenAIWSIngressModeShared: return OpenAIWSIngressModeShared case OpenAIWSIngressModeDedicated: @@ -873,18 +962,21 @@ func normalizeOpenAIWSIngressMode(mode string) string { func normalizeOpenAIWSIngressDefaultMode(mode string) string { if normalized := normalizeOpenAIWSIngressMode(mode); normalized != "" { + if normalized == OpenAIWSIngressModeShared || normalized == OpenAIWSIngressModeDedicated { + return OpenAIWSIngressModeCtxPool + } return normalized } - return OpenAIWSIngressModeShared + return OpenAIWSIngressModeCtxPool } -// ResolveOpenAIResponsesWebSocketV2Mode 返回账号在 WSv2 ingress 下的有效模式(off/shared/dedicated)。 +// ResolveOpenAIResponsesWebSocketV2Mode 返回账号在 WSv2 ingress 下的有效模式(off/ctx_pool/passthrough)。 // // 优先级: // 1. 分类型 mode 新字段(string) // 2. 分类型 enabled 旧字段(bool) // 3. 兼容 enabled 旧字段(bool) -// 4. defaultMode(非法时回退 shared) +// 4. defaultMode(非法时回退 ctx_pool) func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) string { resolvedDefault := normalizeOpenAIWSIngressDefaultMode(defaultMode) if a == nil || !a.IsOpenAI() { @@ -919,7 +1011,7 @@ func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) stri return "", false } if enabled { - return OpenAIWSIngressModeShared, true + return OpenAIWSIngressModeCtxPool, true } return OpenAIWSIngressModeOff, true } @@ -946,6 +1038,10 @@ func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) stri if mode, ok := resolveBoolMode("openai_ws_enabled"); ok { return mode } + // 兼容旧值:shared/dedicated 语义都归并到 ctx_pool。 + if resolvedDefault == OpenAIWSIngressModeShared || resolvedDefault == OpenAIWSIngressModeDedicated { + return OpenAIWSIngressModeCtxPool + } return resolvedDefault } @@ -1104,6 +1200,102 @@ func (a *Account) GetCacheTTLOverrideTarget() string { return "5m" } +// GetQuotaLimit 获取 API Key 账号的配额限制(美元) +// 返回 0 表示未启用 +func (a *Account) GetQuotaLimit() float64 { + return a.getExtraFloat64("quota_limit") +} + +// GetQuotaUsed 获取 API Key 账号的已用配额(美元) +func (a *Account) GetQuotaUsed() float64 { + return a.getExtraFloat64("quota_used") +} + +// GetQuotaDailyLimit 获取日额度限制(美元),0 表示未启用 +func (a *Account) GetQuotaDailyLimit() float64 { + return a.getExtraFloat64("quota_daily_limit") +} + +// GetQuotaDailyUsed 获取当日已用额度(美元) +func (a *Account) GetQuotaDailyUsed() float64 { + return a.getExtraFloat64("quota_daily_used") +} + +// GetQuotaWeeklyLimit 获取周额度限制(美元),0 表示未启用 +func (a *Account) GetQuotaWeeklyLimit() float64 { + return a.getExtraFloat64("quota_weekly_limit") +} + +// GetQuotaWeeklyUsed 获取本周已用额度(美元) +func (a *Account) GetQuotaWeeklyUsed() float64 { + return a.getExtraFloat64("quota_weekly_used") +} + +// getExtraFloat64 从 Extra 中读取指定 key 的 float64 值 +func (a *Account) getExtraFloat64(key string) float64 { + if a.Extra == nil { + return 0 + } + if v, ok := a.Extra[key]; ok { + return parseExtraFloat64(v) + } + return 0 +} + +// getExtraTime 从 Extra 中读取 RFC3339 时间戳 +func (a *Account) getExtraTime(key string) time.Time { + if a.Extra == nil { + return time.Time{} + } + if v, ok := a.Extra[key]; ok { + if s, ok := v.(string); ok { + if t, err := time.Parse(time.RFC3339Nano, s); err == nil { + return t + } + if t, err := time.Parse(time.RFC3339, s); err == nil { + return t + } + } + } + return time.Time{} +} + +// HasAnyQuotaLimit 检查是否配置了任一维度的配额限制 +func (a *Account) HasAnyQuotaLimit() bool { + return a.GetQuotaLimit() > 0 || a.GetQuotaDailyLimit() > 0 || a.GetQuotaWeeklyLimit() > 0 +} + +// isPeriodExpired 检查指定周期(自 periodStart 起经过 dur)是否已过期 +func isPeriodExpired(periodStart time.Time, dur time.Duration) bool { + if periodStart.IsZero() { + return true // 从未使用过,视为过期(下次 increment 会初始化) + } + return time.Since(periodStart) >= dur +} + +// IsQuotaExceeded 检查 API Key 账号配额是否已超限(任一维度超限即返回 true) +func (a *Account) IsQuotaExceeded() bool { + // 总额度 + if limit := a.GetQuotaLimit(); limit > 0 && a.GetQuotaUsed() >= limit { + return true + } + // 日额度(周期过期视为未超限,下次 increment 会重置) + if limit := a.GetQuotaDailyLimit(); limit > 0 { + start := a.getExtraTime("quota_daily_start") + if !isPeriodExpired(start, 24*time.Hour) && a.GetQuotaDailyUsed() >= limit { + return true + } + } + // 周额度 + if limit := a.GetQuotaWeeklyLimit(); limit > 0 { + start := a.getExtraTime("quota_weekly_start") + if !isPeriodExpired(start, 7*24*time.Hour) && a.GetQuotaWeeklyUsed() >= limit { + return true + } + } + return false +} + // GetWindowCostLimit 获取 5h 窗口费用阈值(美元) // 返回 0 表示未启用 func (a *Account) GetWindowCostLimit() float64 { diff --git a/backend/internal/service/account_load_factor_test.go b/backend/internal/service/account_load_factor_test.go new file mode 100644 index 00000000..a4d78a4b --- /dev/null +++ b/backend/internal/service/account_load_factor_test.go @@ -0,0 +1,46 @@ +//go:build unit + +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func intPtrHelper(v int) *int { return &v } + +func TestEffectiveLoadFactor_NilAccount(t *testing.T) { + var a *Account + require.Equal(t, 1, a.EffectiveLoadFactor()) +} + +func TestEffectiveLoadFactor_NilLoadFactor_PositiveConcurrency(t *testing.T) { + a := &Account{Concurrency: 5} + require.Equal(t, 5, a.EffectiveLoadFactor()) +} + +func TestEffectiveLoadFactor_NilLoadFactor_ZeroConcurrency(t *testing.T) { + a := &Account{Concurrency: 0} + require.Equal(t, 1, a.EffectiveLoadFactor()) +} + +func TestEffectiveLoadFactor_PositiveLoadFactor(t *testing.T) { + a := &Account{Concurrency: 5, LoadFactor: intPtrHelper(20)} + require.Equal(t, 20, a.EffectiveLoadFactor()) +} + +func TestEffectiveLoadFactor_ZeroLoadFactor_FallbackToConcurrency(t *testing.T) { + a := &Account{Concurrency: 5, LoadFactor: intPtrHelper(0)} + require.Equal(t, 5, a.EffectiveLoadFactor()) +} + +func TestEffectiveLoadFactor_NegativeLoadFactor_FallbackToConcurrency(t *testing.T) { + a := &Account{Concurrency: 3, LoadFactor: intPtrHelper(-1)} + require.Equal(t, 3, a.EffectiveLoadFactor()) +} + +func TestEffectiveLoadFactor_ZeroLoadFactor_ZeroConcurrency(t *testing.T) { + a := &Account{Concurrency: 0, LoadFactor: intPtrHelper(0)} + require.Equal(t, 1, a.EffectiveLoadFactor()) +} diff --git a/backend/internal/service/account_openai_passthrough_test.go b/backend/internal/service/account_openai_passthrough_test.go index a85c68ec..50c2b7cb 100644 --- a/backend/internal/service/account_openai_passthrough_test.go +++ b/backend/internal/service/account_openai_passthrough_test.go @@ -206,14 +206,14 @@ func TestAccount_IsOpenAIResponsesWebSocketV2Enabled(t *testing.T) { } func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) { - t.Run("default fallback to shared", func(t *testing.T) { + t.Run("default fallback to ctx_pool", func(t *testing.T) { account := &Account{ Platform: PlatformOpenAI, Type: AccountTypeOAuth, Extra: map[string]any{}, } - require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode("")) - require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode("invalid")) + require.Equal(t, OpenAIWSIngressModeCtxPool, account.ResolveOpenAIResponsesWebSocketV2Mode("")) + require.Equal(t, OpenAIWSIngressModeCtxPool, account.ResolveOpenAIResponsesWebSocketV2Mode("invalid")) }) t.Run("oauth mode field has highest priority", func(t *testing.T) { @@ -221,15 +221,15 @@ func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) { Platform: PlatformOpenAI, Type: AccountTypeOAuth, Extra: map[string]any{ - "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated, + "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModePassthrough, "openai_oauth_responses_websockets_v2_enabled": false, "responses_websockets_v2_enabled": false, }, } - require.Equal(t, OpenAIWSIngressModeDedicated, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeShared)) + require.Equal(t, OpenAIWSIngressModePassthrough, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeCtxPool)) }) - t.Run("legacy enabled maps to shared", func(t *testing.T) { + t.Run("legacy enabled maps to ctx_pool", func(t *testing.T) { account := &Account{ Platform: PlatformOpenAI, Type: AccountTypeAPIKey, @@ -237,7 +237,28 @@ func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) { "responses_websockets_v2_enabled": true, }, } - require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff)) + require.Equal(t, OpenAIWSIngressModeCtxPool, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff)) + }) + + t.Run("shared/dedicated mode strings are compatible with ctx_pool", func(t *testing.T) { + shared := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeShared, + }, + } + dedicated := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated, + }, + } + require.Equal(t, OpenAIWSIngressModeShared, shared.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff)) + require.Equal(t, OpenAIWSIngressModeDedicated, dedicated.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff)) + require.Equal(t, OpenAIWSIngressModeCtxPool, normalizeOpenAIWSIngressDefaultMode(OpenAIWSIngressModeShared)) + require.Equal(t, OpenAIWSIngressModeCtxPool, normalizeOpenAIWSIngressDefaultMode(OpenAIWSIngressModeDedicated)) }) t.Run("legacy disabled maps to off", func(t *testing.T) { @@ -249,7 +270,7 @@ func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) { "responses_websockets_v2_enabled": true, }, } - require.Equal(t, OpenAIWSIngressModeOff, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeShared)) + require.Equal(t, OpenAIWSIngressModeOff, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeCtxPool)) }) t.Run("non openai always off", func(t *testing.T) { diff --git a/backend/internal/service/account_pool_mode_test.go b/backend/internal/service/account_pool_mode_test.go new file mode 100644 index 00000000..98429bb1 --- /dev/null +++ b/backend/internal/service/account_pool_mode_test.go @@ -0,0 +1,117 @@ +//go:build unit + +package service + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestGetPoolModeRetryCount(t *testing.T) { + tests := []struct { + name string + account *Account + expected int + }{ + { + name: "default_when_not_pool_mode", + account: &Account{ + Type: AccountTypeAPIKey, + Platform: PlatformOpenAI, + Credentials: map[string]any{}, + }, + expected: defaultPoolModeRetryCount, + }, + { + name: "default_when_missing_retry_count", + account: &Account{ + Type: AccountTypeAPIKey, + Platform: PlatformOpenAI, + Credentials: map[string]any{ + "pool_mode": true, + }, + }, + expected: defaultPoolModeRetryCount, + }, + { + name: "supports_float64_from_json_credentials", + account: &Account{ + Type: AccountTypeAPIKey, + Platform: PlatformOpenAI, + Credentials: map[string]any{ + "pool_mode": true, + "pool_mode_retry_count": float64(5), + }, + }, + expected: 5, + }, + { + name: "supports_json_number", + account: &Account{ + Type: AccountTypeAPIKey, + Platform: PlatformOpenAI, + Credentials: map[string]any{ + "pool_mode": true, + "pool_mode_retry_count": json.Number("4"), + }, + }, + expected: 4, + }, + { + name: "supports_string_value", + account: &Account{ + Type: AccountTypeAPIKey, + Platform: PlatformOpenAI, + Credentials: map[string]any{ + "pool_mode": true, + "pool_mode_retry_count": "2", + }, + }, + expected: 2, + }, + { + name: "negative_value_is_clamped_to_zero", + account: &Account{ + Type: AccountTypeAPIKey, + Platform: PlatformOpenAI, + Credentials: map[string]any{ + "pool_mode": true, + "pool_mode_retry_count": -1, + }, + }, + expected: 0, + }, + { + name: "oversized_value_is_clamped_to_max", + account: &Account{ + Type: AccountTypeAPIKey, + Platform: PlatformOpenAI, + Credentials: map[string]any{ + "pool_mode": true, + "pool_mode_retry_count": 99, + }, + }, + expected: maxPoolModeRetryCount, + }, + { + name: "invalid_value_falls_back_to_default", + account: &Account{ + Type: AccountTypeAPIKey, + Platform: PlatformOpenAI, + Credentials: map[string]any{ + "pool_mode": true, + "pool_mode_retry_count": "oops", + }, + }, + expected: defaultPoolModeRetryCount, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.expected, tt.account.GetPoolModeRetryCount()) + }) + } +} diff --git a/backend/internal/service/account_service.go b/backend/internal/service/account_service.go index 18a70c5c..a06d8048 100644 --- a/backend/internal/service/account_service.go +++ b/backend/internal/service/account_service.go @@ -68,6 +68,10 @@ type AccountRepository interface { UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error UpdateExtra(ctx context.Context, id int64, updates map[string]any) error BulkUpdate(ctx context.Context, ids []int64, updates AccountBulkUpdate) (int64, error) + // IncrementQuotaUsed 原子递增 API Key 账号的配额用量(总/日/周) + IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error + // ResetQuotaUsed 重置 API Key 账号所有维度的配额用量为 0 + ResetQuotaUsed(ctx context.Context, id int64) error } // AccountBulkUpdate describes the fields that can be updated in a bulk operation. @@ -78,6 +82,7 @@ type AccountBulkUpdate struct { Concurrency *int Priority *int RateMultiplier *float64 + LoadFactor *int Status *string Schedulable *bool Credentials map[string]any diff --git a/backend/internal/service/account_service_delete_test.go b/backend/internal/service/account_service_delete_test.go index 768cf7b7..c96b436f 100644 --- a/backend/internal/service/account_service_delete_test.go +++ b/backend/internal/service/account_service_delete_test.go @@ -199,6 +199,14 @@ func (s *accountRepoStub) BulkUpdate(ctx context.Context, ids []int64, updates A panic("unexpected BulkUpdate call") } +func (s *accountRepoStub) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error { + return nil +} + +func (s *accountRepoStub) ResetQuotaUsed(ctx context.Context, id int64) error { + return nil +} + // TestAccountService_Delete_NotFound 测试删除不存在的账号时返回正确的错误。 // 预期行为: // - ExistsByID 返回 false(账号不存在) diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go index c55e418d..472551cf 100644 --- a/backend/internal/service/account_test_service.go +++ b/backend/internal/service/account_test_service.go @@ -12,6 +12,7 @@ import ( "io" "log" "net/http" + "net/http/httptest" "net/url" "regexp" "strings" @@ -33,7 +34,7 @@ import ( var sseDataPrefix = regexp.MustCompile(`^data:\s*`) const ( - testClaudeAPIURL = "https://api.anthropic.com/v1/messages" + testClaudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true" chatgptCodexAPIURL = "https://chatgpt.com/backend-api/codex/responses" soraMeAPIURL = "https://sora.chatgpt.com/backend/me" // Sora 用户信息接口,用于测试连接 soraBillingAPIURL = "https://sora.chatgpt.com/backend/billing/subscriptions" @@ -44,16 +45,23 @@ const ( // TestEvent represents a SSE event for account testing type TestEvent struct { - Type string `json:"type"` - Text string `json:"text,omitempty"` - Model string `json:"model,omitempty"` - Status string `json:"status,omitempty"` - Code string `json:"code,omitempty"` - Data any `json:"data,omitempty"` - Success bool `json:"success,omitempty"` - Error string `json:"error,omitempty"` + Type string `json:"type"` + Text string `json:"text,omitempty"` + Model string `json:"model,omitempty"` + Status string `json:"status,omitempty"` + Code string `json:"code,omitempty"` + ImageURL string `json:"image_url,omitempty"` + MimeType string `json:"mime_type,omitempty"` + Data any `json:"data,omitempty"` + Success bool `json:"success,omitempty"` + Error string `json:"error,omitempty"` } +const ( + defaultGeminiTextTestPrompt = "hi" + defaultGeminiImageTestPrompt = "Generate a cute orange cat astronaut sticker on a clean pastel background." +) + // AccountTestService handles account testing operations type AccountTestService struct { accountRepo AccountRepository @@ -160,7 +168,7 @@ func createTestPayload(modelID string) (map[string]any, error) { // TestAccountConnection tests an account's connection by sending a test request // All account types use full Claude Code client characteristics, only auth header differs // modelID is optional - if empty, defaults to claude.DefaultTestModel -func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int64, modelID string) error { +func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int64, modelID string, prompt string) error { ctx := c.Request.Context() // Get account @@ -175,11 +183,11 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int } if account.IsGemini() { - return s.testGeminiAccountConnection(c, account, modelID) + return s.testGeminiAccountConnection(c, account, modelID, prompt) } if account.Platform == PlatformAntigravity { - return s.testAntigravityAccountConnection(c, account, modelID) + return s.routeAntigravityTest(c, account, modelID, prompt) } if account.Platform == PlatformSora { @@ -238,7 +246,7 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account if err != nil { return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid base URL: %s", err.Error())) } - apiURL = strings.TrimSuffix(normalizedBaseURL, "/") + "/v1/messages" + apiURL = strings.TrimSuffix(normalizedBaseURL, "/") + "/v1/messages?beta=true" } else { return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type)) } @@ -405,8 +413,27 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account } defer func() { _ = resp.Body.Close() }() + if isOAuth && s.accountRepo != nil { + if updates, err := extractOpenAICodexProbeUpdates(resp); err == nil && len(updates) > 0 { + _ = s.accountRepo.UpdateExtra(ctx, account.ID, updates) + mergeAccountExtra(account, updates) + } + if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil { + if resetAt := codexRateLimitResetAtFromSnapshot(snapshot, time.Now()); resetAt != nil { + _ = s.accountRepo.SetRateLimited(ctx, account.ID, *resetAt) + account.RateLimitResetAt = resetAt + } + } + } + if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) + if isOAuth && s.accountRepo != nil { + if resetAt := (&RateLimitService{}).calculateOpenAI429ResetTime(resp.Header); resetAt != nil { + _ = s.accountRepo.SetRateLimited(ctx, account.ID, *resetAt) + account.RateLimitResetAt = resetAt + } + } return s.sendErrorAndEnd(c, fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body))) } @@ -415,7 +442,7 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account } // testGeminiAccountConnection tests a Gemini account's connection -func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account *Account, modelID string) error { +func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account *Account, modelID string, prompt string) error { ctx := c.Request.Context() // Determine the model to use @@ -442,7 +469,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account c.Writer.Flush() // Create test payload (Gemini format) - payload := createGeminiTestPayload() + payload := createGeminiTestPayload(testModelID, prompt) // Build request based on account type var req *http.Request @@ -1176,6 +1203,18 @@ func truncateSoraErrorBody(body []byte, max int) string { return soraerror.TruncateBody(body, max) } +// routeAntigravityTest 路由 Antigravity 账号的测试请求。 +// APIKey 类型走原生协议(与 gateway_handler 路由一致),OAuth/Upstream 走 CRS 中转。 +func (s *AccountTestService) routeAntigravityTest(c *gin.Context, account *Account, modelID string, prompt string) error { + if account.Type == AccountTypeAPIKey { + if strings.HasPrefix(modelID, "gemini-") { + return s.testGeminiAccountConnection(c, account, modelID, prompt) + } + return s.testClaudeAccountConnection(c, account, modelID) + } + return s.testAntigravityAccountConnection(c, account, modelID) +} + // testAntigravityAccountConnection tests an Antigravity account's connection // 支持 Claude 和 Gemini 两种协议,使用非流式请求 func (s *AccountTestService) testAntigravityAccountConnection(c *gin.Context, account *Account, modelID string) error { @@ -1317,14 +1356,46 @@ func (s *AccountTestService) buildCodeAssistRequest(ctx context.Context, accessT return req, nil } -// createGeminiTestPayload creates a minimal test payload for Gemini API -func createGeminiTestPayload() []byte { +// createGeminiTestPayload creates a minimal test payload for Gemini API. +// Image models use the image-generation path so the frontend can preview the returned image. +func createGeminiTestPayload(modelID string, prompt string) []byte { + if isImageGenerationModel(modelID) { + imagePrompt := strings.TrimSpace(prompt) + if imagePrompt == "" { + imagePrompt = defaultGeminiImageTestPrompt + } + + payload := map[string]any{ + "contents": []map[string]any{ + { + "role": "user", + "parts": []map[string]any{ + {"text": imagePrompt}, + }, + }, + }, + "generationConfig": map[string]any{ + "responseModalities": []string{"TEXT", "IMAGE"}, + "imageConfig": map[string]any{ + "aspectRatio": "1:1", + }, + }, + } + bytes, _ := json.Marshal(payload) + return bytes + } + + textPrompt := strings.TrimSpace(prompt) + if textPrompt == "" { + textPrompt = defaultGeminiTextTestPrompt + } + payload := map[string]any{ "contents": []map[string]any{ { "role": "user", "parts": []map[string]any{ - {"text": "hi"}, + {"text": textPrompt}, }, }, }, @@ -1384,6 +1455,17 @@ func (s *AccountTestService) processGeminiStream(c *gin.Context, body io.Reader) if text, ok := partMap["text"].(string); ok && text != "" { s.sendEvent(c, TestEvent{Type: "content", Text: text}) } + if inlineData, ok := partMap["inlineData"].(map[string]any); ok { + mimeType, _ := inlineData["mimeType"].(string) + data, _ := inlineData["data"].(string) + if strings.HasPrefix(strings.ToLower(mimeType), "image/") && data != "" { + s.sendEvent(c, TestEvent{ + Type: "image", + ImageURL: fmt.Sprintf("data:%s;base64,%s", mimeType, data), + MimeType: mimeType, + }) + } + } } } } @@ -1560,3 +1642,62 @@ func (s *AccountTestService) sendErrorAndEnd(c *gin.Context, errorMsg string) er s.sendEvent(c, TestEvent{Type: "error", Error: errorMsg}) return fmt.Errorf("%s", errorMsg) } + +// RunTestBackground executes an account test in-memory (no real HTTP client), +// capturing SSE output via httptest.NewRecorder, then parses the result. +func (s *AccountTestService) RunTestBackground(ctx context.Context, accountID int64, modelID string) (*ScheduledTestResult, error) { + startedAt := time.Now() + + w := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(w) + ginCtx.Request = (&http.Request{}).WithContext(ctx) + + testErr := s.TestAccountConnection(ginCtx, accountID, modelID, "") + + finishedAt := time.Now() + body := w.Body.String() + responseText, errMsg := parseTestSSEOutput(body) + + status := "success" + if testErr != nil || errMsg != "" { + status = "failed" + if errMsg == "" && testErr != nil { + errMsg = testErr.Error() + } + } + + return &ScheduledTestResult{ + Status: status, + ResponseText: responseText, + ErrorMessage: errMsg, + LatencyMs: finishedAt.Sub(startedAt).Milliseconds(), + StartedAt: startedAt, + FinishedAt: finishedAt, + }, nil +} + +// parseTestSSEOutput extracts response text and error message from captured SSE output. +func parseTestSSEOutput(body string) (responseText, errMsg string) { + var texts []string + for _, line := range strings.Split(body, "\n") { + line = strings.TrimSpace(line) + if !strings.HasPrefix(line, "data: ") { + continue + } + jsonStr := strings.TrimPrefix(line, "data: ") + var event TestEvent + if err := json.Unmarshal([]byte(jsonStr), &event); err != nil { + continue + } + switch event.Type { + case "content": + if event.Text != "" { + texts = append(texts, event.Text) + } + case "error": + errMsg = event.Error + } + } + responseText = strings.Join(texts, "") + return +} diff --git a/backend/internal/service/account_test_service_gemini_test.go b/backend/internal/service/account_test_service_gemini_test.go new file mode 100644 index 00000000..5ba04c69 --- /dev/null +++ b/backend/internal/service/account_test_service_gemini_test.go @@ -0,0 +1,59 @@ +//go:build unit + +package service + +import ( + "encoding/json" + "strings" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestCreateGeminiTestPayload_ImageModel(t *testing.T) { + t.Parallel() + + payload := createGeminiTestPayload("gemini-2.5-flash-image", "draw a tiny robot") + + var parsed struct { + Contents []struct { + Parts []struct { + Text string `json:"text"` + } `json:"parts"` + } `json:"contents"` + GenerationConfig struct { + ResponseModalities []string `json:"responseModalities"` + ImageConfig struct { + AspectRatio string `json:"aspectRatio"` + } `json:"imageConfig"` + } `json:"generationConfig"` + } + + require.NoError(t, json.Unmarshal(payload, &parsed)) + require.Len(t, parsed.Contents, 1) + require.Len(t, parsed.Contents[0].Parts, 1) + require.Equal(t, "draw a tiny robot", parsed.Contents[0].Parts[0].Text) + require.Equal(t, []string{"TEXT", "IMAGE"}, parsed.GenerationConfig.ResponseModalities) + require.Equal(t, "1:1", parsed.GenerationConfig.ImageConfig.AspectRatio) +} + +func TestProcessGeminiStream_EmitsImageEvent(t *testing.T) { + t.Parallel() + gin.SetMode(gin.TestMode) + + ctx, recorder := newSoraTestContext() + svc := &AccountTestService{} + + stream := strings.NewReader("data: {\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"},{\"inlineData\":{\"mimeType\":\"image/png\",\"data\":\"QUJD\"}}]}}]}\n\ndata: [DONE]\n\n") + + err := svc.processGeminiStream(ctx, stream) + require.NoError(t, err) + + body := recorder.Body.String() + require.Contains(t, body, "\"type\":\"content\"") + require.Contains(t, body, "\"text\":\"ok\"") + require.Contains(t, body, "\"type\":\"image\"") + require.Contains(t, body, "\"image_url\":\"data:image/png;base64,QUJD\"") + require.Contains(t, body, "\"mime_type\":\"image/png\"") +} diff --git a/backend/internal/service/account_test_service_openai_test.go b/backend/internal/service/account_test_service_openai_test.go new file mode 100644 index 00000000..efa6f7da --- /dev/null +++ b/backend/internal/service/account_test_service_openai_test.go @@ -0,0 +1,102 @@ +//go:build unit + +package service + +import ( + "context" + "io" + "net/http" + "strings" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type openAIAccountTestRepo struct { + mockAccountRepoForGemini + updatedExtra map[string]any + rateLimitedID int64 + rateLimitedAt *time.Time +} + +func (r *openAIAccountTestRepo) UpdateExtra(_ context.Context, _ int64, updates map[string]any) error { + r.updatedExtra = updates + return nil +} + +func (r *openAIAccountTestRepo) SetRateLimited(_ context.Context, id int64, resetAt time.Time) error { + r.rateLimitedID = id + r.rateLimitedAt = &resetAt + return nil +} + +func TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders(t *testing.T) { + gin.SetMode(gin.TestMode) + ctx, recorder := newSoraTestContext() + + resp := newJSONResponse(http.StatusOK, "") + resp.Body = io.NopCloser(strings.NewReader(`data: {"type":"response.completed"} + +`)) + resp.Header.Set("x-codex-primary-used-percent", "88") + resp.Header.Set("x-codex-primary-reset-after-seconds", "604800") + resp.Header.Set("x-codex-primary-window-minutes", "10080") + resp.Header.Set("x-codex-secondary-used-percent", "42") + resp.Header.Set("x-codex-secondary-reset-after-seconds", "18000") + resp.Header.Set("x-codex-secondary-window-minutes", "300") + + repo := &openAIAccountTestRepo{} + upstream := &queuedHTTPUpstream{responses: []*http.Response{resp}} + svc := &AccountTestService{accountRepo: repo, httpUpstream: upstream} + account := &Account{ + ID: 89, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{"access_token": "test-token"}, + } + + err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4") + require.NoError(t, err) + require.NotEmpty(t, repo.updatedExtra) + require.Equal(t, 42.0, repo.updatedExtra["codex_5h_used_percent"]) + require.Equal(t, 88.0, repo.updatedExtra["codex_7d_used_percent"]) + require.Contains(t, recorder.Body.String(), "test_complete") +} + +func TestAccountTestService_OpenAI429PersistsSnapshotAndRateLimit(t *testing.T) { + gin.SetMode(gin.TestMode) + ctx, _ := newSoraTestContext() + + resp := newJSONResponse(http.StatusTooManyRequests, `{"error":{"type":"usage_limit_reached","message":"limit reached"}}`) + resp.Header.Set("x-codex-primary-used-percent", "100") + resp.Header.Set("x-codex-primary-reset-after-seconds", "604800") + resp.Header.Set("x-codex-primary-window-minutes", "10080") + resp.Header.Set("x-codex-secondary-used-percent", "100") + resp.Header.Set("x-codex-secondary-reset-after-seconds", "18000") + resp.Header.Set("x-codex-secondary-window-minutes", "300") + + repo := &openAIAccountTestRepo{} + upstream := &queuedHTTPUpstream{responses: []*http.Response{resp}} + svc := &AccountTestService{accountRepo: repo, httpUpstream: upstream} + account := &Account{ + ID: 88, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{"access_token": "test-token"}, + } + + err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4") + require.Error(t, err) + require.NotEmpty(t, repo.updatedExtra) + require.Equal(t, 100.0, repo.updatedExtra["codex_5h_used_percent"]) + require.Equal(t, int64(88), repo.rateLimitedID) + require.NotNil(t, repo.rateLimitedAt) + require.NotNil(t, account.RateLimitResetAt) + if account.RateLimitResetAt != nil && repo.rateLimitedAt != nil { + require.WithinDuration(t, *repo.rateLimitedAt, *account.RateLimitResetAt, time.Second) + } +} diff --git a/backend/internal/service/account_usage_service.go b/backend/internal/service/account_usage_service.go index 6dee6c13..e4245133 100644 --- a/backend/internal/service/account_usage_service.go +++ b/backend/internal/service/account_usage_service.go @@ -1,17 +1,24 @@ package service import ( + "bytes" "context" + "encoding/json" "fmt" "log" + "math/rand/v2" + "net/http" "strings" "sync" "time" + httppool "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" + openaipkg "github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" "golang.org/x/sync/errgroup" + "golang.org/x/sync/singleflight" ) type UsageLogRepository interface { @@ -70,8 +77,10 @@ type accountWindowStatsBatchReader interface { } // apiUsageCache 缓存从 Anthropic API 获取的使用率数据(utilization, resets_at) +// 同时支持缓存错误响应(负缓存),防止 429 等错误导致的重试风暴 type apiUsageCache struct { response *ClaudeUsageResponse + err error // 非 nil 表示缓存的错误(负缓存) timestamp time.Time } @@ -88,15 +97,21 @@ type antigravityUsageCache struct { } const ( - apiCacheTTL = 3 * time.Minute - windowStatsCacheTTL = 1 * time.Minute + apiCacheTTL = 3 * time.Minute + apiErrorCacheTTL = 1 * time.Minute // 负缓存 TTL:429 等错误缓存 1 分钟 + apiQueryMaxJitter = 800 * time.Millisecond // 用量查询最大随机延迟 + windowStatsCacheTTL = 1 * time.Minute + openAIProbeCacheTTL = 10 * time.Minute + openAICodexProbeVersion = "0.104.0" ) // UsageCache 封装账户使用量相关的缓存 type UsageCache struct { - apiCache sync.Map // accountID -> *apiUsageCache - windowStatsCache sync.Map // accountID -> *windowStatsCache - antigravityCache sync.Map // accountID -> *antigravityUsageCache + apiCache sync.Map // accountID -> *apiUsageCache + windowStatsCache sync.Map // accountID -> *windowStatsCache + antigravityCache sync.Map // accountID -> *antigravityUsageCache + apiFlight singleflight.Group // 防止同一账号的并发请求击穿缓存 + openAIProbeCache sync.Map // accountID -> time.Time } // NewUsageCache 创建 UsageCache 实例 @@ -224,6 +239,14 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U return nil, fmt.Errorf("get account failed: %w", err) } + if account.Platform == PlatformOpenAI && account.Type == AccountTypeOAuth { + usage, err := s.getOpenAIUsage(ctx, account) + if err == nil { + s.tryClearRecoverableAccountError(ctx, account) + } + return usage, err + } + if account.Platform == PlatformGemini { usage, err := s.getGeminiUsage(ctx, account) if err == nil { @@ -245,24 +268,65 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U if account.CanGetUsage() { var apiResp *ClaudeUsageResponse - // 1. 检查 API 缓存(10 分钟) + // 1. 检查缓存(成功响应 3 分钟 / 错误响应 1 分钟) if cached, ok := s.cache.apiCache.Load(accountID); ok { - if cache, ok := cached.(*apiUsageCache); ok && time.Since(cache.timestamp) < apiCacheTTL { - apiResp = cache.response + if cache, ok := cached.(*apiUsageCache); ok { + age := time.Since(cache.timestamp) + if cache.err != nil && age < apiErrorCacheTTL { + // 负缓存命中:返回缓存的错误,避免重试风暴 + return nil, cache.err + } + if cache.response != nil && age < apiCacheTTL { + apiResp = cache.response + } } } - // 2. 如果没有缓存,从 API 获取 + // 2. 如果没有有效缓存,通过 singleflight 从 API 获取(防止并发击穿) if apiResp == nil { - apiResp, err = s.fetchOAuthUsageRaw(ctx, account) - if err != nil { - return nil, err + // 随机延迟:打散多账号并发请求,避免同一时刻大量相同 TLS 指纹请求 + // 触发上游反滥用检测。延迟范围 0~800ms,仅在缓存未命中时生效。 + jitter := time.Duration(rand.Int64N(int64(apiQueryMaxJitter))) + select { + case <-time.After(jitter): + case <-ctx.Done(): + return nil, ctx.Err() } - // 缓存 API 响应 - s.cache.apiCache.Store(accountID, &apiUsageCache{ - response: apiResp, - timestamp: time.Now(), + + flightKey := fmt.Sprintf("usage:%d", accountID) + result, flightErr, _ := s.cache.apiFlight.Do(flightKey, func() (any, error) { + // 再次检查缓存(可能在等待 singleflight 期间被其他请求填充) + if cached, ok := s.cache.apiCache.Load(accountID); ok { + if cache, ok := cached.(*apiUsageCache); ok { + age := time.Since(cache.timestamp) + if cache.err != nil && age < apiErrorCacheTTL { + return nil, cache.err + } + if cache.response != nil && age < apiCacheTTL { + return cache.response, nil + } + } + } + resp, fetchErr := s.fetchOAuthUsageRaw(ctx, account) + if fetchErr != nil { + // 负缓存:缓存错误响应,防止后续请求重复触发 429 + s.cache.apiCache.Store(accountID, &apiUsageCache{ + err: fetchErr, + timestamp: time.Now(), + }) + return nil, fetchErr + } + // 缓存成功响应 + s.cache.apiCache.Store(accountID, &apiUsageCache{ + response: resp, + timestamp: time.Now(), + }) + return resp, nil }) + if flightErr != nil { + return nil, flightErr + } + apiResp, _ = result.(*ClaudeUsageResponse) } // 3. 构建 UsageInfo(每次都重新计算 RemainingSeconds) @@ -288,6 +352,237 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U return nil, fmt.Errorf("account type %s does not support usage query", account.Type) } +func (s *AccountUsageService) getOpenAIUsage(ctx context.Context, account *Account) (*UsageInfo, error) { + now := time.Now() + usage := &UsageInfo{UpdatedAt: &now} + + if account == nil { + return usage, nil + } + syncOpenAICodexRateLimitFromExtra(ctx, s.accountRepo, account, now) + + if progress := buildCodexUsageProgressFromExtra(account.Extra, "5h", now); progress != nil { + usage.FiveHour = progress + } + if progress := buildCodexUsageProgressFromExtra(account.Extra, "7d", now); progress != nil { + usage.SevenDay = progress + } + + if shouldRefreshOpenAICodexSnapshot(account, usage, now) && s.shouldProbeOpenAICodexSnapshot(account.ID, now) { + if updates, resetAt, err := s.probeOpenAICodexSnapshot(ctx, account); err == nil && (len(updates) > 0 || resetAt != nil) { + mergeAccountExtra(account, updates) + if resetAt != nil { + account.RateLimitResetAt = resetAt + } + if usage.UpdatedAt == nil { + usage.UpdatedAt = &now + } + if progress := buildCodexUsageProgressFromExtra(account.Extra, "5h", now); progress != nil { + usage.FiveHour = progress + } + if progress := buildCodexUsageProgressFromExtra(account.Extra, "7d", now); progress != nil { + usage.SevenDay = progress + } + } + } + + if s.usageLogRepo == nil { + return usage, nil + } + + if stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, now.Add(-5*time.Hour)); err == nil { + windowStats := windowStatsFromAccountStats(stats) + if hasMeaningfulWindowStats(windowStats) { + if usage.FiveHour == nil { + usage.FiveHour = &UsageProgress{Utilization: 0} + } + usage.FiveHour.WindowStats = windowStats + } + } + + if stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, now.Add(-7*24*time.Hour)); err == nil { + windowStats := windowStatsFromAccountStats(stats) + if hasMeaningfulWindowStats(windowStats) { + if usage.SevenDay == nil { + usage.SevenDay = &UsageProgress{Utilization: 0} + } + usage.SevenDay.WindowStats = windowStats + } + } + + return usage, nil +} + +func shouldRefreshOpenAICodexSnapshot(account *Account, usage *UsageInfo, now time.Time) bool { + if account == nil { + return false + } + if usage == nil { + return true + } + if usage.FiveHour == nil || usage.SevenDay == nil { + return true + } + if account.IsRateLimited() { + return true + } + return isOpenAICodexSnapshotStale(account, now) +} + +func isOpenAICodexSnapshotStale(account *Account, now time.Time) bool { + if account == nil || !account.IsOpenAIOAuth() || !account.IsOpenAIResponsesWebSocketV2Enabled() { + return false + } + if account.Extra == nil { + return true + } + raw, ok := account.Extra["codex_usage_updated_at"] + if !ok { + return true + } + ts, err := parseTime(fmt.Sprint(raw)) + if err != nil { + return true + } + return now.Sub(ts) >= openAIProbeCacheTTL +} + +func (s *AccountUsageService) shouldProbeOpenAICodexSnapshot(accountID int64, now time.Time) bool { + if s == nil || s.cache == nil || accountID <= 0 { + return true + } + if cached, ok := s.cache.openAIProbeCache.Load(accountID); ok { + if ts, ok := cached.(time.Time); ok && now.Sub(ts) < openAIProbeCacheTTL { + return false + } + } + s.cache.openAIProbeCache.Store(accountID, now) + return true +} + +func (s *AccountUsageService) probeOpenAICodexSnapshot(ctx context.Context, account *Account) (map[string]any, *time.Time, error) { + if account == nil || !account.IsOAuth() { + return nil, nil, nil + } + accessToken := account.GetOpenAIAccessToken() + if accessToken == "" { + return nil, nil, fmt.Errorf("no access token available") + } + modelID := openaipkg.DefaultTestModel + payload := createOpenAITestPayload(modelID, true) + payloadBytes, err := json.Marshal(payload) + if err != nil { + return nil, nil, fmt.Errorf("marshal openai probe payload: %w", err) + } + + reqCtx, cancel := context.WithTimeout(ctx, 15*time.Second) + defer cancel() + req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, chatgptCodexURL, bytes.NewReader(payloadBytes)) + if err != nil { + return nil, nil, fmt.Errorf("create openai probe request: %w", err) + } + req.Host = "chatgpt.com" + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("Accept", "text/event-stream") + req.Header.Set("OpenAI-Beta", "responses=experimental") + req.Header.Set("Originator", "codex_cli_rs") + req.Header.Set("Version", openAICodexProbeVersion) + req.Header.Set("User-Agent", codexCLIUserAgent) + if s.identityCache != nil { + if fp, fpErr := s.identityCache.GetFingerprint(reqCtx, account.ID); fpErr == nil && fp != nil && strings.TrimSpace(fp.UserAgent) != "" { + req.Header.Set("User-Agent", strings.TrimSpace(fp.UserAgent)) + } + } + if chatgptAccountID := account.GetChatGPTAccountID(); chatgptAccountID != "" { + req.Header.Set("chatgpt-account-id", chatgptAccountID) + } + + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + client, err := httppool.GetClient(httppool.Options{ + ProxyURL: proxyURL, + Timeout: 15 * time.Second, + ResponseHeaderTimeout: 10 * time.Second, + }) + if err != nil { + return nil, nil, fmt.Errorf("build openai probe client: %w", err) + } + resp, err := client.Do(req) + if err != nil { + return nil, nil, fmt.Errorf("openai codex probe request failed: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + updates, resetAt, err := extractOpenAICodexProbeSnapshot(resp) + if err != nil { + return nil, nil, err + } + if len(updates) > 0 || resetAt != nil { + s.persistOpenAICodexProbeSnapshot(account.ID, updates, resetAt) + return updates, resetAt, nil + } + return nil, nil, nil +} + +func (s *AccountUsageService) persistOpenAICodexProbeSnapshot(accountID int64, updates map[string]any, resetAt *time.Time) { + if s == nil || s.accountRepo == nil || accountID <= 0 { + return + } + if len(updates) == 0 && resetAt == nil { + return + } + + go func() { + updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer updateCancel() + if len(updates) > 0 { + _ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates) + } + if resetAt != nil { + _ = s.accountRepo.SetRateLimited(updateCtx, accountID, *resetAt) + } + }() +} + +func extractOpenAICodexProbeSnapshot(resp *http.Response) (map[string]any, *time.Time, error) { + if resp == nil { + return nil, nil, nil + } + if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil { + baseTime := time.Now() + updates := buildCodexUsageExtraUpdates(snapshot, baseTime) + resetAt := codexRateLimitResetAtFromSnapshot(snapshot, baseTime) + if len(updates) > 0 { + return updates, resetAt, nil + } + return nil, resetAt, nil + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return nil, nil, fmt.Errorf("openai codex probe returned status %d", resp.StatusCode) + } + return nil, nil, nil +} + +func extractOpenAICodexProbeUpdates(resp *http.Response) (map[string]any, error) { + updates, _, err := extractOpenAICodexProbeSnapshot(resp) + return updates, err +} + +func mergeAccountExtra(account *Account, updates map[string]any) { + if account == nil || len(updates) == 0 { + return + } + if account.Extra == nil { + account.Extra = make(map[string]any, len(updates)) + } + for k, v := range updates { + account.Extra[k] = v + } +} + func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Account) (*UsageInfo, error) { now := time.Now() usage := &UsageInfo{ @@ -519,6 +814,72 @@ func windowStatsFromAccountStats(stats *usagestats.AccountStats) *WindowStats { } } +func hasMeaningfulWindowStats(stats *WindowStats) bool { + if stats == nil { + return false + } + return stats.Requests > 0 || stats.Tokens > 0 || stats.Cost > 0 || stats.StandardCost > 0 || stats.UserCost > 0 +} + +func buildCodexUsageProgressFromExtra(extra map[string]any, window string, now time.Time) *UsageProgress { + if len(extra) == 0 { + return nil + } + + var ( + usedPercentKey string + resetAfterKey string + resetAtKey string + ) + + switch window { + case "5h": + usedPercentKey = "codex_5h_used_percent" + resetAfterKey = "codex_5h_reset_after_seconds" + resetAtKey = "codex_5h_reset_at" + case "7d": + usedPercentKey = "codex_7d_used_percent" + resetAfterKey = "codex_7d_reset_after_seconds" + resetAtKey = "codex_7d_reset_at" + default: + return nil + } + + usedRaw, ok := extra[usedPercentKey] + if !ok { + return nil + } + + progress := &UsageProgress{Utilization: parseExtraFloat64(usedRaw)} + if resetAtRaw, ok := extra[resetAtKey]; ok { + if resetAt, err := parseTime(fmt.Sprint(resetAtRaw)); err == nil { + progress.ResetsAt = &resetAt + progress.RemainingSeconds = int(time.Until(resetAt).Seconds()) + if progress.RemainingSeconds < 0 { + progress.RemainingSeconds = 0 + } + } + } + if progress.ResetsAt == nil { + if resetAfterSeconds := parseExtraInt(extra[resetAfterKey]); resetAfterSeconds > 0 { + base := now + if updatedAtRaw, ok := extra["codex_usage_updated_at"]; ok { + if updatedAt, err := parseTime(fmt.Sprint(updatedAtRaw)); err == nil { + base = updatedAt + } + } + resetAt := base.Add(time.Duration(resetAfterSeconds) * time.Second) + progress.ResetsAt = &resetAt + progress.RemainingSeconds = int(time.Until(resetAt).Seconds()) + if progress.RemainingSeconds < 0 { + progress.RemainingSeconds = 0 + } + } + } + + return progress +} + func (s *AccountUsageService) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.AccountUsageStatsResponse, error) { stats, err := s.usageLogRepo.GetAccountUsageStats(ctx, accountID, startTime, endTime) if err != nil { @@ -666,15 +1027,30 @@ func (s *AccountUsageService) estimateSetupTokenUsage(account *Account) *UsageIn remaining = 0 } - // 根据状态估算使用率 (百分比形式,100 = 100%) + // 优先使用响应头中存储的真实 utilization 值(0-1 小数,转为 0-100 百分比) var utilization float64 - switch account.SessionWindowStatus { - case "rejected": - utilization = 100.0 - case "allowed_warning": - utilization = 80.0 - default: - utilization = 0.0 + var found bool + if stored, ok := account.Extra["session_window_utilization"]; ok { + switch v := stored.(type) { + case float64: + utilization = v * 100 + found = true + case json.Number: + if f, err := v.Float64(); err == nil { + utilization = f * 100 + found = true + } + } + } + + // 如果没有存储的 utilization,回退到状态估算 + if !found { + switch account.SessionWindowStatus { + case "rejected": + utilization = 100.0 + case "allowed_warning": + utilization = 80.0 + } } info.FiveHour = &UsageProgress{ diff --git a/backend/internal/service/account_usage_service_test.go b/backend/internal/service/account_usage_service_test.go new file mode 100644 index 00000000..a063fe26 --- /dev/null +++ b/backend/internal/service/account_usage_service_test.go @@ -0,0 +1,150 @@ +package service + +import ( + "context" + "net/http" + "testing" + "time" +) + +type accountUsageCodexProbeRepo struct { + stubOpenAIAccountRepo + updateExtraCh chan map[string]any + rateLimitCh chan time.Time +} + +func (r *accountUsageCodexProbeRepo) UpdateExtra(_ context.Context, _ int64, updates map[string]any) error { + if r.updateExtraCh != nil { + copied := make(map[string]any, len(updates)) + for k, v := range updates { + copied[k] = v + } + r.updateExtraCh <- copied + } + return nil +} + +func (r *accountUsageCodexProbeRepo) SetRateLimited(_ context.Context, _ int64, resetAt time.Time) error { + if r.rateLimitCh != nil { + r.rateLimitCh <- resetAt + } + return nil +} + +func TestShouldRefreshOpenAICodexSnapshot(t *testing.T) { + t.Parallel() + + rateLimitedUntil := time.Now().Add(5 * time.Minute) + now := time.Now() + usage := &UsageInfo{ + FiveHour: &UsageProgress{Utilization: 0}, + SevenDay: &UsageProgress{Utilization: 0}, + } + + if !shouldRefreshOpenAICodexSnapshot(&Account{RateLimitResetAt: &rateLimitedUntil}, usage, now) { + t.Fatal("expected rate-limited account to force codex snapshot refresh") + } + + if shouldRefreshOpenAICodexSnapshot(&Account{}, usage, now) { + t.Fatal("expected complete non-rate-limited usage to skip codex snapshot refresh") + } + + if !shouldRefreshOpenAICodexSnapshot(&Account{}, &UsageInfo{FiveHour: nil, SevenDay: &UsageProgress{}}, now) { + t.Fatal("expected missing 5h snapshot to require refresh") + } + + staleAt := now.Add(-(openAIProbeCacheTTL + time.Minute)).Format(time.RFC3339) + if !shouldRefreshOpenAICodexSnapshot(&Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "openai_oauth_responses_websockets_v2_enabled": true, + "codex_usage_updated_at": staleAt, + }, + }, usage, now) { + t.Fatal("expected stale ws snapshot to trigger refresh") + } +} + +func TestExtractOpenAICodexProbeUpdatesAccepts429WithCodexHeaders(t *testing.T) { + t.Parallel() + + headers := make(http.Header) + headers.Set("x-codex-primary-used-percent", "100") + headers.Set("x-codex-primary-reset-after-seconds", "604800") + headers.Set("x-codex-primary-window-minutes", "10080") + headers.Set("x-codex-secondary-used-percent", "100") + headers.Set("x-codex-secondary-reset-after-seconds", "18000") + headers.Set("x-codex-secondary-window-minutes", "300") + + updates, err := extractOpenAICodexProbeUpdates(&http.Response{StatusCode: http.StatusTooManyRequests, Header: headers}) + if err != nil { + t.Fatalf("extractOpenAICodexProbeUpdates() error = %v", err) + } + if len(updates) == 0 { + t.Fatal("expected codex probe updates from 429 headers") + } + if got := updates["codex_5h_used_percent"]; got != 100.0 { + t.Fatalf("codex_5h_used_percent = %v, want 100", got) + } + if got := updates["codex_7d_used_percent"]; got != 100.0 { + t.Fatalf("codex_7d_used_percent = %v, want 100", got) + } +} + +func TestExtractOpenAICodexProbeSnapshotAccepts429WithResetAt(t *testing.T) { + t.Parallel() + + headers := make(http.Header) + headers.Set("x-codex-primary-used-percent", "100") + headers.Set("x-codex-primary-reset-after-seconds", "604800") + headers.Set("x-codex-primary-window-minutes", "10080") + headers.Set("x-codex-secondary-used-percent", "100") + headers.Set("x-codex-secondary-reset-after-seconds", "18000") + headers.Set("x-codex-secondary-window-minutes", "300") + + updates, resetAt, err := extractOpenAICodexProbeSnapshot(&http.Response{StatusCode: http.StatusTooManyRequests, Header: headers}) + if err != nil { + t.Fatalf("extractOpenAICodexProbeSnapshot() error = %v", err) + } + if len(updates) == 0 { + t.Fatal("expected codex probe updates from 429 headers") + } + if resetAt == nil { + t.Fatal("expected resetAt from exhausted codex headers") + } +} + +func TestAccountUsageService_PersistOpenAICodexProbeSnapshotSetsRateLimit(t *testing.T) { + t.Parallel() + + repo := &accountUsageCodexProbeRepo{ + updateExtraCh: make(chan map[string]any, 1), + rateLimitCh: make(chan time.Time, 1), + } + svc := &AccountUsageService{accountRepo: repo} + resetAt := time.Now().Add(2 * time.Hour).UTC().Truncate(time.Second) + + svc.persistOpenAICodexProbeSnapshot(321, map[string]any{ + "codex_7d_used_percent": 100.0, + "codex_7d_reset_at": resetAt.Format(time.RFC3339), + }, &resetAt) + + select { + case updates := <-repo.updateExtraCh: + if got := updates["codex_7d_used_percent"]; got != 100.0 { + t.Fatalf("codex_7d_used_percent = %v, want 100", got) + } + case <-time.After(2 * time.Second): + t.Fatal("waiting for codex probe extra persistence timed out") + } + + select { + case got := <-repo.rateLimitCh: + if got.Before(resetAt.Add(-time.Second)) || got.After(resetAt.Add(time.Second)) { + t.Fatalf("rate limit resetAt = %v, want around %v", got, resetAt) + } + case <-time.After(2 * time.Second): + t.Fatal("waiting for codex probe rate limit persistence timed out") + } +} diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index 67e7c783..f2e7bd9b 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -42,6 +42,7 @@ type AdminService interface { UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) DeleteGroup(ctx context.Context, id int64) error GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]APIKey, int64, error) + GetGroupRateMultipliers(ctx context.Context, groupID int64) ([]UserGroupRateEntry, error) UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error // API Key management (admin) @@ -84,6 +85,7 @@ type AdminService interface { DeleteRedeemCode(ctx context.Context, id int64) error BatchDeleteRedeemCodes(ctx context.Context, ids []int64) (int64, error) ExpireRedeemCode(ctx context.Context, id int64) (*RedeemCode, error) + ResetAccountQuota(ctx context.Context, id int64) error } // CreateUserInput represents input for creating a new user via admin operations. @@ -137,13 +139,17 @@ type CreateGroupInput struct { // 无效请求兜底分组 ID(仅 anthropic 平台使用) FallbackGroupIDOnInvalidRequest *int64 // 模型路由配置(仅 anthropic 平台使用) - ModelRouting map[string][]int64 - ModelRoutingEnabled bool // 是否启用模型路由 - MCPXMLInject *bool + ModelRouting map[string][]int64 + ModelRoutingEnabled bool // 是否启用模型路由 + MCPXMLInject *bool + SimulateClaudeMaxEnabled *bool // 支持的模型系列(仅 antigravity 平台使用) SupportedModelScopes []string // Sora 存储配额 SoraStorageQuotaBytes int64 + // OpenAI Messages 调度配置(仅 openai 平台使用) + AllowMessagesDispatch bool + DefaultMappedModel string // 从指定分组复制账号(创建分组后在同一事务内绑定) CopyAccountsFromGroupIDs []int64 } @@ -173,13 +179,17 @@ type UpdateGroupInput struct { // 无效请求兜底分组 ID(仅 anthropic 平台使用) FallbackGroupIDOnInvalidRequest *int64 // 模型路由配置(仅 anthropic 平台使用) - ModelRouting map[string][]int64 - ModelRoutingEnabled *bool // 是否启用模型路由 - MCPXMLInject *bool + ModelRouting map[string][]int64 + ModelRoutingEnabled *bool // 是否启用模型路由 + MCPXMLInject *bool + SimulateClaudeMaxEnabled *bool // 支持的模型系列(仅 antigravity 平台使用) SupportedModelScopes *[]string // Sora 存储配额 SoraStorageQuotaBytes *int64 + // OpenAI Messages 调度配置(仅 openai 平台使用) + AllowMessagesDispatch *bool + DefaultMappedModel *string // 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号) CopyAccountsFromGroupIDs []int64 } @@ -195,6 +205,7 @@ type CreateAccountInput struct { Concurrency int Priority int RateMultiplier *float64 // 账号计费倍率(>=0,允许 0) + LoadFactor *int GroupIDs []int64 ExpiresAt *int64 AutoPauseOnExpired *bool @@ -215,6 +226,7 @@ type UpdateAccountInput struct { Concurrency *int // 使用指针区分"未提供"和"设置为0" Priority *int // 使用指针区分"未提供"和"设置为0" RateMultiplier *float64 // 账号计费倍率(>=0,允许 0) + LoadFactor *int Status string GroupIDs *[]int64 ExpiresAt *int64 @@ -230,6 +242,7 @@ type BulkUpdateAccountsInput struct { Concurrency *int Priority *int RateMultiplier *float64 // 账号计费倍率(>=0,允许 0) + LoadFactor *int Status string Schedulable *bool GroupIDs *[]int64 @@ -353,6 +366,10 @@ type ProxyExitInfoProber interface { ProbeProxy(ctx context.Context, proxyURL string) (*ProxyExitInfo, int64, error) } +type groupExistenceBatchReader interface { + ExistsByIDs(ctx context.Context, ids []int64) (map[int64]bool, error) +} + type proxyQualityTarget struct { Target string URL string @@ -422,16 +439,13 @@ type adminServiceImpl struct { entClient *dbent.Client // 用于开启数据库事务 settingService *SettingService defaultSubAssigner DefaultSubscriptionAssigner + userSubRepo UserSubscriptionRepository } type userGroupRateBatchReader interface { GetByUserIDs(ctx context.Context, userIDs []int64) (map[int64]map[int64]float64, error) } -type groupExistenceBatchReader interface { - ExistsByIDs(ctx context.Context, ids []int64) (map[int64]bool, error) -} - // NewAdminService creates a new AdminService func NewAdminService( userRepo UserRepository, @@ -449,6 +463,7 @@ func NewAdminService( entClient *dbent.Client, settingService *SettingService, defaultSubAssigner DefaultSubscriptionAssigner, + userSubRepo UserSubscriptionRepository, ) AdminService { return &adminServiceImpl{ userRepo: userRepo, @@ -466,6 +481,7 @@ func NewAdminService( entClient: entClient, settingService: settingService, defaultSubAssigner: defaultSubAssigner, + userSubRepo: userSubRepo, } } @@ -847,6 +863,13 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn if input.MCPXMLInject != nil { mcpXMLInject = *input.MCPXMLInject } + simulateClaudeMaxEnabled := false + if input.SimulateClaudeMaxEnabled != nil { + if platform != PlatformAnthropic && *input.SimulateClaudeMaxEnabled { + return nil, fmt.Errorf("simulate_claude_max_enabled only supported for anthropic groups") + } + simulateClaudeMaxEnabled = *input.SimulateClaudeMaxEnabled + } // 如果指定了复制账号的源分组,先获取账号 ID 列表 var accountIDsToCopy []int64 @@ -903,8 +926,11 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn FallbackGroupIDOnInvalidRequest: fallbackOnInvalidRequest, ModelRouting: input.ModelRouting, MCPXMLInject: mcpXMLInject, + SimulateClaudeMaxEnabled: simulateClaudeMaxEnabled, SupportedModelScopes: input.SupportedModelScopes, SoraStorageQuotaBytes: input.SoraStorageQuotaBytes, + AllowMessagesDispatch: input.AllowMessagesDispatch, + DefaultMappedModel: input.DefaultMappedModel, } if err := s.groupRepo.Create(ctx, group); err != nil { return nil, err @@ -1112,12 +1138,29 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd if input.MCPXMLInject != nil { group.MCPXMLInject = *input.MCPXMLInject } + if input.SimulateClaudeMaxEnabled != nil { + if group.Platform != PlatformAnthropic && *input.SimulateClaudeMaxEnabled { + return nil, fmt.Errorf("simulate_claude_max_enabled only supported for anthropic groups") + } + group.SimulateClaudeMaxEnabled = *input.SimulateClaudeMaxEnabled + } + if group.Platform != PlatformAnthropic { + group.SimulateClaudeMaxEnabled = false + } // 支持的模型系列(仅 antigravity 平台使用) if input.SupportedModelScopes != nil { group.SupportedModelScopes = *input.SupportedModelScopes } + // OpenAI Messages 调度配置 + if input.AllowMessagesDispatch != nil { + group.AllowMessagesDispatch = *input.AllowMessagesDispatch + } + if input.DefaultMappedModel != nil { + group.DefaultMappedModel = *input.DefaultMappedModel + } + if err := s.groupRepo.Update(ctx, group); err != nil { return nil, err } @@ -1221,6 +1264,13 @@ func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, p return keys, result.Total, nil } +func (s *adminServiceImpl) GetGroupRateMultipliers(ctx context.Context, groupID int64) ([]UserGroupRateEntry, error) { + if s.userGroupRateRepo == nil { + return nil, nil + } + return s.userGroupRateRepo.GetByGroupID(ctx, groupID) +} + func (s *adminServiceImpl) UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error { return s.groupRepo.UpdateSortOrders(ctx, updates) } @@ -1257,9 +1307,17 @@ func (s *adminServiceImpl) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID i if group.Status != StatusActive { return nil, infraerrors.BadRequest("GROUP_NOT_ACTIVE", "target group is not active") } - // 订阅类型分组:不允许通过此 API 直接绑定,需通过订阅管理流程 + // 订阅类型分组:用户须持有该分组的有效订阅才可绑定 if group.IsSubscriptionType() { - return nil, infraerrors.BadRequest("SUBSCRIPTION_GROUP_NOT_ALLOWED", "subscription groups must be managed through the subscription workflow") + if s.userSubRepo == nil { + return nil, infraerrors.InternalServer("SUBSCRIPTION_REPOSITORY_UNAVAILABLE", "subscription repository is not configured") + } + if _, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, apiKey.UserID, *groupID); err != nil { + if errors.Is(err, ErrSubscriptionNotFound) { + return nil, infraerrors.BadRequest("SUBSCRIPTION_REQUIRED", "user does not have an active subscription for this group") + } + return nil, err + } } gid := *groupID @@ -1267,7 +1325,7 @@ func (s *adminServiceImpl) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID i apiKey.Group = group // 专属标准分组:使用事务保证「添加分组权限」与「更新 API Key」的原子性 - if group.IsExclusive { + if group.IsExclusive && !group.IsSubscriptionType() { opCtx := ctx var tx *dbent.Tx if s.entClient == nil { @@ -1329,6 +1387,10 @@ func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, if err != nil { return nil, 0, err } + now := time.Now() + for i := range accounts { + syncOpenAICodexRateLimitFromExtra(ctx, s.accountRepo, &accounts[i], now) + } return accounts, result.Total, nil } @@ -1413,6 +1475,12 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou } account.RateMultiplier = input.RateMultiplier } + if input.LoadFactor != nil && *input.LoadFactor > 0 { + if *input.LoadFactor > 10000 { + return nil, errors.New("load_factor must be <= 10000") + } + account.LoadFactor = input.LoadFactor + } if err := s.accountRepo.Create(ctx, account); err != nil { return nil, err } @@ -1458,6 +1526,12 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U account.Credentials = input.Credentials } if len(input.Extra) > 0 { + // 保留配额用量字段,防止编辑账号时意外重置 + for _, key := range []string{"quota_used", "quota_daily_used", "quota_daily_start", "quota_weekly_used", "quota_weekly_start"} { + if v, ok := account.Extra[key]; ok { + input.Extra[key] = v + } + } account.Extra = input.Extra } if input.ProxyID != nil { @@ -1483,6 +1557,15 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U } account.RateMultiplier = input.RateMultiplier } + if input.LoadFactor != nil { + if *input.LoadFactor <= 0 { + account.LoadFactor = nil // 0 或负数表示清除 + } else if *input.LoadFactor > 10000 { + return nil, errors.New("load_factor must be <= 10000") + } else { + account.LoadFactor = input.LoadFactor + } + } if input.Status != "" { account.Status = input.Status } @@ -1616,6 +1699,15 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp if input.RateMultiplier != nil { repoUpdates.RateMultiplier = input.RateMultiplier } + if input.LoadFactor != nil { + if *input.LoadFactor <= 0 { + repoUpdates.LoadFactor = nil // 0 或负数表示清除 + } else if *input.LoadFactor > 10000 { + return nil, errors.New("load_factor must be <= 10000") + } else { + repoUpdates.LoadFactor = input.LoadFactor + } + } if input.Status != "" { repoUpdates.Status = &input.Status } @@ -1669,16 +1761,10 @@ func (s *adminServiceImpl) RefreshAccountCredentials(ctx context.Context, id int } func (s *adminServiceImpl) ClearAccountError(ctx context.Context, id int64) (*Account, error) { - account, err := s.accountRepo.GetByID(ctx, id) - if err != nil { + if err := s.accountRepo.ClearError(ctx, id); err != nil { return nil, err } - account.Status = StatusActive - account.ErrorMessage = "" - if err := s.accountRepo.Update(ctx, account); err != nil { - return nil, err - } - return account, nil + return s.accountRepo.GetByID(ctx, id) } func (s *adminServiceImpl) SetAccountError(ctx context.Context, id int64, errorMsg string) error { @@ -2439,3 +2525,7 @@ func (e *MixedChannelError) Error() string { return fmt.Sprintf("mixed_channel_warning: Group '%s' contains both %s and %s accounts. Using mixed channels in the same context may cause thinking block signature validation issues, which will fallback to non-thinking mode for historical messages.", e.GroupName, e.CurrentPlatform, e.OtherPlatform) } + +func (s *adminServiceImpl) ResetAccountQuota(ctx context.Context, id int64) error { + return s.accountRepo.ResetQuotaUsed(ctx, id) +} diff --git a/backend/internal/service/admin_service_apikey_test.go b/backend/internal/service/admin_service_apikey_test.go index a6d12f97..88d2f492 100644 --- a/backend/internal/service/admin_service_apikey_test.go +++ b/backend/internal/service/admin_service_apikey_test.go @@ -32,28 +32,44 @@ func (s *userRepoStubForGroupUpdate) AddGroupToAllowedGroups(_ context.Context, return s.addGroupErr } -func (s *userRepoStubForGroupUpdate) Create(context.Context, *User) error { panic("unexpected") } -func (s *userRepoStubForGroupUpdate) GetByID(context.Context, int64) (*User, error) { panic("unexpected") } -func (s *userRepoStubForGroupUpdate) GetByEmail(context.Context, string) (*User, error) { panic("unexpected") } -func (s *userRepoStubForGroupUpdate) GetFirstAdmin(context.Context) (*User, error) { panic("unexpected") } -func (s *userRepoStubForGroupUpdate) Update(context.Context, *User) error { panic("unexpected") } -func (s *userRepoStubForGroupUpdate) Delete(context.Context, int64) error { panic("unexpected") } +func (s *userRepoStubForGroupUpdate) Create(context.Context, *User) error { panic("unexpected") } +func (s *userRepoStubForGroupUpdate) GetByID(context.Context, int64) (*User, error) { + panic("unexpected") +} +func (s *userRepoStubForGroupUpdate) GetByEmail(context.Context, string) (*User, error) { + panic("unexpected") +} +func (s *userRepoStubForGroupUpdate) GetFirstAdmin(context.Context) (*User, error) { + panic("unexpected") +} +func (s *userRepoStubForGroupUpdate) Update(context.Context, *User) error { panic("unexpected") } +func (s *userRepoStubForGroupUpdate) Delete(context.Context, int64) error { panic("unexpected") } func (s *userRepoStubForGroupUpdate) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) { panic("unexpected") } func (s *userRepoStubForGroupUpdate) ListWithFilters(context.Context, pagination.PaginationParams, UserListFilters) ([]User, *pagination.PaginationResult, error) { panic("unexpected") } -func (s *userRepoStubForGroupUpdate) UpdateBalance(context.Context, int64, float64) error { panic("unexpected") } -func (s *userRepoStubForGroupUpdate) DeductBalance(context.Context, int64, float64) error { panic("unexpected") } -func (s *userRepoStubForGroupUpdate) UpdateConcurrency(context.Context, int64, int) error { panic("unexpected") } -func (s *userRepoStubForGroupUpdate) ExistsByEmail(context.Context, string) (bool, error) { panic("unexpected") } +func (s *userRepoStubForGroupUpdate) UpdateBalance(context.Context, int64, float64) error { + panic("unexpected") +} +func (s *userRepoStubForGroupUpdate) DeductBalance(context.Context, int64, float64) error { + panic("unexpected") +} +func (s *userRepoStubForGroupUpdate) UpdateConcurrency(context.Context, int64, int) error { + panic("unexpected") +} +func (s *userRepoStubForGroupUpdate) ExistsByEmail(context.Context, string) (bool, error) { + panic("unexpected") +} func (s *userRepoStubForGroupUpdate) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) { panic("unexpected") } -func (s *userRepoStubForGroupUpdate) UpdateTotpSecret(context.Context, int64, *string) error { panic("unexpected") } -func (s *userRepoStubForGroupUpdate) EnableTotp(context.Context, int64) error { panic("unexpected") } -func (s *userRepoStubForGroupUpdate) DisableTotp(context.Context, int64) error { panic("unexpected") } +func (s *userRepoStubForGroupUpdate) UpdateTotpSecret(context.Context, int64, *string) error { + panic("unexpected") +} +func (s *userRepoStubForGroupUpdate) EnableTotp(context.Context, int64) error { panic("unexpected") } +func (s *userRepoStubForGroupUpdate) DisableTotp(context.Context, int64) error { panic("unexpected") } // apiKeyRepoStubForGroupUpdate implements APIKeyRepository for AdminUpdateAPIKeyGroupID tests. type apiKeyRepoStubForGroupUpdate struct { @@ -194,6 +210,29 @@ func (s *groupRepoStubForGroupUpdate) UpdateSortOrders(context.Context, []GroupS panic("unexpected") } +type userSubRepoStubForGroupUpdate struct { + userSubRepoNoop + getActiveSub *UserSubscription + getActiveErr error + called bool + calledUserID int64 + calledGroupID int64 +} + +func (s *userSubRepoStubForGroupUpdate) GetActiveByUserIDAndGroupID(_ context.Context, userID, groupID int64) (*UserSubscription, error) { + s.called = true + s.calledUserID = userID + s.calledGroupID = groupID + if s.getActiveErr != nil { + return nil, s.getActiveErr + } + if s.getActiveSub == nil { + return nil, ErrSubscriptionNotFound + } + clone := *s.getActiveSub + return &clone, nil +} + // --------------------------------------------------------------------------- // Tests // --------------------------------------------------------------------------- @@ -386,14 +425,49 @@ func TestAdminService_AdminUpdateAPIKeyGroupID_NonExclusiveGroup_NoAllowedGroupU func TestAdminService_AdminUpdateAPIKeyGroupID_SubscriptionGroup_Blocked(t *testing.T) { existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil} apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing} - groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Sub", Status: StatusActive, IsExclusive: true, SubscriptionType: SubscriptionTypeSubscription}} + groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Sub", Status: StatusActive, IsExclusive: false, SubscriptionType: SubscriptionTypeSubscription}} + userRepo := &userRepoStubForGroupUpdate{} + userSubRepo := &userSubRepoStubForGroupUpdate{getActiveErr: ErrSubscriptionNotFound} + svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo, userSubRepo: userSubRepo} + + // 无有效订阅时应拒绝绑定 + _, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10)) + require.Error(t, err) + require.Equal(t, "SUBSCRIPTION_REQUIRED", infraerrors.Reason(err)) + require.True(t, userSubRepo.called) + require.Equal(t, int64(42), userSubRepo.calledUserID) + require.Equal(t, int64(10), userSubRepo.calledGroupID) + require.False(t, userRepo.addGroupCalled) +} + +func TestAdminService_AdminUpdateAPIKeyGroupID_SubscriptionGroup_RequiresRepo(t *testing.T) { + existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil} + apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing} + groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Sub", Status: StatusActive, IsExclusive: false, SubscriptionType: SubscriptionTypeSubscription}} userRepo := &userRepoStubForGroupUpdate{} svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo} - // 订阅类型分组应被阻止绑定 _, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10)) require.Error(t, err) - require.Equal(t, "SUBSCRIPTION_GROUP_NOT_ALLOWED", infraerrors.Reason(err)) + require.Equal(t, "SUBSCRIPTION_REPOSITORY_UNAVAILABLE", infraerrors.Reason(err)) + require.False(t, userRepo.addGroupCalled) +} + +func TestAdminService_AdminUpdateAPIKeyGroupID_SubscriptionGroup_AllowsActiveSubscription(t *testing.T) { + existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil} + apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing} + groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Sub", Status: StatusActive, IsExclusive: true, SubscriptionType: SubscriptionTypeSubscription}} + userRepo := &userRepoStubForGroupUpdate{} + userSubRepo := &userSubRepoStubForGroupUpdate{ + getActiveSub: &UserSubscription{ID: 99, UserID: 42, GroupID: 10}, + } + svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo, userSubRepo: userSubRepo} + + got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10)) + require.NoError(t, err) + require.True(t, userSubRepo.called) + require.NotNil(t, got.APIKey.GroupID) + require.Equal(t, int64(10), *got.APIKey.GroupID) require.False(t, userRepo.addGroupCalled) } diff --git a/backend/internal/service/admin_service_bulk_update_test.go b/backend/internal/service/admin_service_bulk_update_test.go index 4845d87c..e90ec93a 100644 --- a/backend/internal/service/admin_service_bulk_update_test.go +++ b/backend/internal/service/admin_service_bulk_update_test.go @@ -43,6 +43,16 @@ func (s *accountRepoStubForBulkUpdate) BindGroups(_ context.Context, accountID i return nil } +func (s *accountRepoStubForBulkUpdate) ListByGroup(_ context.Context, groupID int64) ([]Account, error) { + if err, ok := s.listByGroupErr[groupID]; ok { + return nil, err + } + if rows, ok := s.listByGroupData[groupID]; ok { + return rows, nil + } + return nil, nil +} + func (s *accountRepoStubForBulkUpdate) GetByIDs(_ context.Context, ids []int64) ([]*Account, error) { s.getByIDsCalled = true s.getByIDsIDs = append([]int64{}, ids...) @@ -63,16 +73,6 @@ func (s *accountRepoStubForBulkUpdate) GetByID(_ context.Context, id int64) (*Ac return nil, errors.New("account not found") } -func (s *accountRepoStubForBulkUpdate) ListByGroup(_ context.Context, groupID int64) ([]Account, error) { - if err, ok := s.listByGroupErr[groupID]; ok { - return nil, err - } - if rows, ok := s.listByGroupData[groupID]; ok { - return rows, nil - } - return nil, nil -} - // TestAdminService_BulkUpdateAccounts_AllSuccessIDs 验证批量更新成功时返回 success_ids/failed_ids。 func TestAdminService_BulkUpdateAccounts_AllSuccessIDs(t *testing.T) { repo := &accountRepoStubForBulkUpdate{} diff --git a/backend/internal/service/admin_service_group_test.go b/backend/internal/service/admin_service_group_test.go index ef77a980..0e6fe084 100644 --- a/backend/internal/service/admin_service_group_test.go +++ b/backend/internal/service/admin_service_group_test.go @@ -785,3 +785,57 @@ func TestAdminService_UpdateGroup_InvalidRequestFallbackAllowsAntigravity(t *tes require.NotNil(t, repo.updated) require.Equal(t, fallbackID, *repo.updated.FallbackGroupIDOnInvalidRequest) } + +func TestAdminService_CreateGroup_SimulateClaudeMaxRequiresAnthropic(t *testing.T) { + repo := &groupRepoStubForAdmin{} + svc := &adminServiceImpl{groupRepo: repo} + + enabled := true + _, err := svc.CreateGroup(context.Background(), &CreateGroupInput{ + Name: "openai-group", + Platform: PlatformOpenAI, + SimulateClaudeMaxEnabled: &enabled, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "simulate_claude_max_enabled only supported for anthropic groups") + require.Nil(t, repo.created) +} + +func TestAdminService_UpdateGroup_SimulateClaudeMaxRequiresAnthropic(t *testing.T) { + existingGroup := &Group{ + ID: 1, + Name: "openai-group", + Platform: PlatformOpenAI, + Status: StatusActive, + } + repo := &groupRepoStubForAdmin{getByID: existingGroup} + svc := &adminServiceImpl{groupRepo: repo} + + enabled := true + _, err := svc.UpdateGroup(context.Background(), 1, &UpdateGroupInput{ + SimulateClaudeMaxEnabled: &enabled, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "simulate_claude_max_enabled only supported for anthropic groups") + require.Nil(t, repo.updated) +} + +func TestAdminService_UpdateGroup_ClearsSimulateClaudeMaxWhenPlatformChanges(t *testing.T) { + existingGroup := &Group{ + ID: 1, + Name: "anthropic-group", + Platform: PlatformAnthropic, + Status: StatusActive, + SimulateClaudeMaxEnabled: true, + } + repo := &groupRepoStubForAdmin{getByID: existingGroup} + svc := &adminServiceImpl{groupRepo: repo} + + group, err := svc.UpdateGroup(context.Background(), 1, &UpdateGroupInput{ + Platform: PlatformOpenAI, + }) + require.NoError(t, err) + require.NotNil(t, group) + require.NotNil(t, repo.updated) + require.False(t, repo.updated.SimulateClaudeMaxEnabled) +} diff --git a/backend/internal/service/admin_service_list_users_test.go b/backend/internal/service/admin_service_list_users_test.go index 8b50530a..579fa981 100644 --- a/backend/internal/service/admin_service_list_users_test.go +++ b/backend/internal/service/admin_service_list_users_test.go @@ -68,6 +68,10 @@ func (s *userGroupRateRepoStubForListUsers) SyncUserGroupRates(_ context.Context panic("unexpected SyncUserGroupRates call") } +func (s *userGroupRateRepoStubForListUsers) GetByGroupID(_ context.Context, groupID int64) ([]UserGroupRateEntry, error) { + panic("unexpected GetByGroupID call") +} + func (s *userGroupRateRepoStubForListUsers) DeleteByGroupID(_ context.Context, groupID int64) error { panic("unexpected DeleteByGroupID call") } diff --git a/backend/internal/service/announcement.go b/backend/internal/service/announcement.go index 2ba5af5d..25c66eb4 100644 --- a/backend/internal/service/announcement.go +++ b/backend/internal/service/announcement.go @@ -14,6 +14,11 @@ const ( AnnouncementStatusArchived = domain.AnnouncementStatusArchived ) +const ( + AnnouncementNotifyModeSilent = domain.AnnouncementNotifyModeSilent + AnnouncementNotifyModePopup = domain.AnnouncementNotifyModePopup +) + const ( AnnouncementConditionTypeSubscription = domain.AnnouncementConditionTypeSubscription AnnouncementConditionTypeBalance = domain.AnnouncementConditionTypeBalance diff --git a/backend/internal/service/announcement_service.go b/backend/internal/service/announcement_service.go index c2588e6c..c0a0681a 100644 --- a/backend/internal/service/announcement_service.go +++ b/backend/internal/service/announcement_service.go @@ -33,23 +33,25 @@ func NewAnnouncementService( } type CreateAnnouncementInput struct { - Title string - Content string - Status string - Targeting AnnouncementTargeting - StartsAt *time.Time - EndsAt *time.Time - ActorID *int64 // 管理员用户ID + Title string + Content string + Status string + NotifyMode string + Targeting AnnouncementTargeting + StartsAt *time.Time + EndsAt *time.Time + ActorID *int64 // 管理员用户ID } type UpdateAnnouncementInput struct { - Title *string - Content *string - Status *string - Targeting *AnnouncementTargeting - StartsAt **time.Time - EndsAt **time.Time - ActorID *int64 // 管理员用户ID + Title *string + Content *string + Status *string + NotifyMode *string + Targeting *AnnouncementTargeting + StartsAt **time.Time + EndsAt **time.Time + ActorID *int64 // 管理员用户ID } type UserAnnouncement struct { @@ -93,6 +95,14 @@ func (s *AnnouncementService) Create(ctx context.Context, input *CreateAnnouncem return nil, err } + notifyMode := strings.TrimSpace(input.NotifyMode) + if notifyMode == "" { + notifyMode = AnnouncementNotifyModeSilent + } + if !isValidAnnouncementNotifyMode(notifyMode) { + return nil, fmt.Errorf("create announcement: invalid notify_mode") + } + if input.StartsAt != nil && input.EndsAt != nil { if !input.StartsAt.Before(*input.EndsAt) { return nil, fmt.Errorf("create announcement: starts_at must be before ends_at") @@ -100,12 +110,13 @@ func (s *AnnouncementService) Create(ctx context.Context, input *CreateAnnouncem } a := &Announcement{ - Title: title, - Content: content, - Status: status, - Targeting: targeting, - StartsAt: input.StartsAt, - EndsAt: input.EndsAt, + Title: title, + Content: content, + Status: status, + NotifyMode: notifyMode, + Targeting: targeting, + StartsAt: input.StartsAt, + EndsAt: input.EndsAt, } if input.ActorID != nil && *input.ActorID > 0 { a.CreatedBy = input.ActorID @@ -150,6 +161,14 @@ func (s *AnnouncementService) Update(ctx context.Context, id int64, input *Updat a.Status = status } + if input.NotifyMode != nil { + notifyMode := strings.TrimSpace(*input.NotifyMode) + if !isValidAnnouncementNotifyMode(notifyMode) { + return nil, fmt.Errorf("update announcement: invalid notify_mode") + } + a.NotifyMode = notifyMode + } + if input.Targeting != nil { targeting, err := domain.AnnouncementTargeting(*input.Targeting).NormalizeAndValidate() if err != nil { @@ -376,3 +395,12 @@ func isValidAnnouncementStatus(status string) bool { return false } } + +func isValidAnnouncementNotifyMode(mode string) bool { + switch mode { + case AnnouncementNotifyModeSilent, AnnouncementNotifyModePopup: + return true + default: + return false + } +} diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index 96ff3354..bfca7a82 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -1384,7 +1384,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, // 优先检测 thinking block 的 signature 相关错误(400)并重试一次: // Antigravity /v1internal 链路在部分场景会对 thought/thinking signature 做严格校验, // 当历史消息携带的 signature 不合法时会直接 400;去除 thinking 后可继续完成请求。 - if resp.StatusCode == http.StatusBadRequest && isSignatureRelatedError(respBody) { + if resp.StatusCode == http.StatusBadRequest && isSignatureRelatedError(respBody) && s.settingService.IsSignatureRectifierEnabled(ctx) { upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody)) upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) logBody, maxBytes := s.getLogConfig() @@ -1517,6 +1517,80 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, } } + // Budget 整流:检测 budget_tokens 约束错误并自动修正重试 + if resp.StatusCode == http.StatusBadRequest && respBody != nil && !isSignatureRelatedError(respBody) { + errMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody)) + if isThinkingBudgetConstraintError(errMsg) && s.settingService.IsBudgetRectifierEnabled(ctx) { + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: "budget_constraint_error", + Message: errMsg, + Detail: s.getUpstreamErrorDetail(respBody), + }) + + // 修正 claudeReq 的 thinking 参数(adaptive 模式不修正) + if claudeReq.Thinking == nil || claudeReq.Thinking.Type != "adaptive" { + retryClaudeReq := claudeReq + retryClaudeReq.Messages = append([]antigravity.ClaudeMessage(nil), claudeReq.Messages...) + // 创建新的 ThinkingConfig 避免修改原始 claudeReq.Thinking 指针 + retryClaudeReq.Thinking = &antigravity.ThinkingConfig{ + Type: "enabled", + BudgetTokens: BudgetRectifyBudgetTokens, + } + if retryClaudeReq.MaxTokens < BudgetRectifyMinMaxTokens { + retryClaudeReq.MaxTokens = BudgetRectifyMaxTokens + } + + logger.LegacyPrintf("service.antigravity_gateway", "Antigravity account %d: detected budget_tokens constraint error, retrying with rectified budget (budget_tokens=%d, max_tokens=%d)", account.ID, BudgetRectifyBudgetTokens, BudgetRectifyMaxTokens) + + retryGeminiBody, txErr := antigravity.TransformClaudeToGeminiWithOptions(&retryClaudeReq, projectID, mappedModel, transformOpts) + if txErr == nil { + retryResult, retryErr := s.antigravityRetryLoop(antigravityRetryLoopParams{ + ctx: ctx, + prefix: prefix, + account: account, + proxyURL: proxyURL, + accessToken: accessToken, + action: action, + body: retryGeminiBody, + c: c, + httpUpstream: s.httpUpstream, + settingService: s.settingService, + accountRepo: s.accountRepo, + handleError: s.handleUpstreamError, + requestedModel: originalModel, + isStickySession: isStickySession, + groupID: 0, + sessionHash: "", + }) + if retryErr == nil { + retryResp := retryResult.resp + if retryResp.StatusCode < 400 { + _ = resp.Body.Close() + resp = retryResp + respBody = nil + } else { + retryBody, _ := io.ReadAll(io.LimitReader(retryResp.Body, 2<<20)) + _ = retryResp.Body.Close() + respBody = retryBody + resp = &http.Response{ + StatusCode: retryResp.StatusCode, + Header: retryResp.Header.Clone(), + Body: io.NopCloser(bytes.NewReader(retryBody)), + } + } + } else { + logger.LegacyPrintf("service.antigravity_gateway", "Antigravity account %d: budget rectifier retry failed: %v", account.ID, retryErr) + } + } + } + } + } + // 处理错误响应(重试后仍失败或不触发重试) if resp.StatusCode >= 400 { // 检测 prompt too long 错误,返回特殊错误类型供上层 fallback @@ -1599,7 +1673,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, var clientDisconnect bool if claudeReq.Stream { // 客户端要求流式,直接透传转换 - streamRes, err := s.handleClaudeStreamingResponse(c, resp, startTime, originalModel) + streamRes, err := s.handleClaudeStreamingResponse(c, resp, startTime, originalModel, account.ID) if err != nil { logger.LegacyPrintf("service.antigravity_gateway", "%s status=stream_error error=%v", prefix, err) return nil, err @@ -1609,7 +1683,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, clientDisconnect = streamRes.clientDisconnect } else { // 客户端要求非流式,收集流式响应后转换返回 - streamRes, err := s.handleClaudeStreamToNonStreaming(c, resp, startTime, originalModel) + streamRes, err := s.handleClaudeStreamToNonStreaming(c, resp, startTime, originalModel, account.ID) if err != nil { logger.LegacyPrintf("service.antigravity_gateway", "%s status=stream_collect_error error=%v", prefix, err) return nil, err @@ -1618,6 +1692,9 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, firstTokenMs = streamRes.firstTokenMs } + // Claude Max cache billing: 同步 ForwardResult.Usage 与客户端响应体一致 + applyClaudeMaxCacheBillingPolicyToUsage(usage, parsedRequestFromGinContext(c), claudeMaxGroupFromGinContext(c), originalModel, account.ID) + return &ForwardResult{ RequestID: requestID, Usage: *usage, @@ -2090,6 +2167,112 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co } } + // Gemini 原生请求中的 thoughtSignature 可能来自旧上下文/旧账号,触发上游严格校验后返回 + // "Corrupted thought signature."。检测到此类 400 时,将 thoughtSignature 清理为 dummy 值后重试一次。 + signatureCheckBody := respBody + if unwrapped, unwrapErr := s.unwrapV1InternalResponse(respBody); unwrapErr == nil && len(unwrapped) > 0 { + signatureCheckBody = unwrapped + } + if resp.StatusCode == http.StatusBadRequest && + s.settingService != nil && + s.settingService.IsSignatureRectifierEnabled(ctx) && + isSignatureRelatedError(signatureCheckBody) && + bytes.Contains(injectedBody, []byte(`"thoughtSignature"`)) { + upstreamMsg := sanitizeUpstreamErrorMessage(strings.TrimSpace(extractAntigravityErrorMessage(signatureCheckBody))) + upstreamDetail := s.getUpstreamErrorDetail(signatureCheckBody) + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: "signature_error", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + + logger.LegacyPrintf("service.antigravity_gateway", "Antigravity Gemini account %d: detected signature-related 400, retrying with cleaned thought signatures", account.ID) + + cleanedInjectedBody := CleanGeminiNativeThoughtSignatures(injectedBody) + retryWrappedBody, wrapErr := s.wrapV1InternalRequest(projectID, mappedModel, cleanedInjectedBody) + if wrapErr == nil { + retryResult, retryErr := s.antigravityRetryLoop(antigravityRetryLoopParams{ + ctx: ctx, + prefix: prefix, + account: account, + proxyURL: proxyURL, + accessToken: accessToken, + action: upstreamAction, + body: retryWrappedBody, + c: c, + httpUpstream: s.httpUpstream, + settingService: s.settingService, + accountRepo: s.accountRepo, + handleError: s.handleUpstreamError, + requestedModel: originalModel, + isStickySession: isStickySession, + groupID: 0, + sessionHash: "", + }) + if retryErr == nil { + retryResp := retryResult.resp + if retryResp.StatusCode < 400 { + resp = retryResp + } else { + retryRespBody, _ := io.ReadAll(io.LimitReader(retryResp.Body, 2<<20)) + _ = retryResp.Body.Close() + retryOpsBody := retryRespBody + if retryUnwrapped, unwrapErr := s.unwrapV1InternalResponse(retryRespBody); unwrapErr == nil && len(retryUnwrapped) > 0 { + retryOpsBody = retryUnwrapped + } + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: retryResp.StatusCode, + UpstreamRequestID: retryResp.Header.Get("x-request-id"), + Kind: "signature_retry", + Message: sanitizeUpstreamErrorMessage(strings.TrimSpace(extractAntigravityErrorMessage(retryOpsBody))), + Detail: s.getUpstreamErrorDetail(retryOpsBody), + }) + respBody = retryRespBody + resp = &http.Response{ + StatusCode: retryResp.StatusCode, + Header: retryResp.Header.Clone(), + Body: io.NopCloser(bytes.NewReader(retryRespBody)), + } + contentType = resp.Header.Get("Content-Type") + } + } else { + if switchErr, ok := IsAntigravityAccountSwitchError(retryErr); ok { + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: http.StatusServiceUnavailable, + Kind: "failover", + Message: sanitizeUpstreamErrorMessage(retryErr.Error()), + }) + return nil, &UpstreamFailoverError{ + StatusCode: http.StatusServiceUnavailable, + ForceCacheBilling: switchErr.IsStickySession, + } + } + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: 0, + Kind: "signature_retry_request_error", + Message: sanitizeUpstreamErrorMessage(retryErr.Error()), + }) + logger.LegacyPrintf("service.antigravity_gateway", "Antigravity Gemini account %d: signature retry request failed: %v", account.ID, retryErr) + } + } else { + logger.LegacyPrintf("service.antigravity_gateway", "Antigravity Gemini account %d: signature retry wrap failed: %v", account.ID, wrapErr) + } + } + // fallback 成功:继续按正常响应处理 if resp.StatusCode < 400 { goto handleSuccess @@ -3415,7 +3598,7 @@ func (s *AntigravityGatewayService) writeGoogleError(c *gin.Context, status int, // handleClaudeStreamToNonStreaming 收集上游流式响应,转换为 Claude 非流式格式返回 // 用于处理客户端非流式请求但上游只支持流式的情况 -func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Context, resp *http.Response, startTime time.Time, originalModel string) (*antigravityStreamResult, error) { +func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Context, resp *http.Response, startTime time.Time, originalModel string, accountID int64) (*antigravityStreamResult, error) { scanner := bufio.NewScanner(resp.Body) maxLineSize := defaultMaxLineSize if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 { @@ -3573,6 +3756,9 @@ returnResponse: return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to parse upstream response") } + // Claude Max cache billing simulation (non-streaming) + claudeResp = applyClaudeMaxNonStreamingRewrite(c, claudeResp, agUsage, originalModel, accountID) + c.Data(http.StatusOK, "application/json", claudeResp) // 转换为 service.ClaudeUsage @@ -3587,7 +3773,7 @@ returnResponse: } // handleClaudeStreamingResponse 处理 Claude 流式响应(Gemini SSE → Claude SSE 转换) -func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time, originalModel string) (*antigravityStreamResult, error) { +func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time, originalModel string, accountID int64) (*antigravityStreamResult, error) { c.Header("Content-Type", "text/event-stream") c.Header("Cache-Control", "no-cache") c.Header("Connection", "keep-alive") @@ -3600,6 +3786,8 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context } processor := antigravity.NewStreamingProcessor(originalModel) + setupClaudeMaxStreamingHook(c, processor, originalModel, accountID) + var firstTokenMs *int // 使用 Scanner 并限制单行大小,避免 ReadString 无上限导致 OOM scanner := bufio.NewScanner(resp.Body) @@ -3696,6 +3884,15 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context finalEvents, agUsage := processor.Finish() if len(finalEvents) > 0 { cw.Write(finalEvents) + } else if !processor.MessageStartSent() && !cw.Disconnected() { + // 整个流未收到任何可解析的上游数据(全部 SSE 行均无法被 JSON 解析), + // 触发 failover 在同账号重试,避免向客户端发出缺少 message_start 的残缺流 + logger.LegacyPrintf("service.antigravity_gateway", "[antigravity-Claude-Stream] empty stream response (no valid events parsed), triggering failover") + return nil, &UpstreamFailoverError{ + StatusCode: http.StatusBadGateway, + ResponseBody: []byte(`{"error":"empty stream response from upstream"}`), + RetryableOnSameAccount: true, + } } return &antigravityStreamResult{usage: convertUsage(agUsage), firstTokenMs: firstTokenMs, clientDisconnect: cw.Disconnected()}, nil } diff --git a/backend/internal/service/antigravity_gateway_service_test.go b/backend/internal/service/antigravity_gateway_service_test.go index 84b65adc..b2e2fc38 100644 --- a/backend/internal/service/antigravity_gateway_service_test.go +++ b/backend/internal/service/antigravity_gateway_service_test.go @@ -134,6 +134,47 @@ func (s *httpUpstreamStub) DoWithTLS(_ *http.Request, _ string, _ int64, _ int, return s.resp, s.err } +type queuedHTTPUpstreamStub struct { + responses []*http.Response + errors []error + requestBodies [][]byte + callCount int + onCall func(*http.Request, *queuedHTTPUpstreamStub) +} + +func (s *queuedHTTPUpstreamStub) Do(req *http.Request, _ string, _ int64, _ int) (*http.Response, error) { + if req != nil && req.Body != nil { + body, _ := io.ReadAll(req.Body) + s.requestBodies = append(s.requestBodies, body) + req.Body = io.NopCloser(bytes.NewReader(body)) + } else { + s.requestBodies = append(s.requestBodies, nil) + } + + idx := s.callCount + s.callCount++ + if s.onCall != nil { + s.onCall(req, s) + } + + var resp *http.Response + if idx < len(s.responses) { + resp = s.responses[idx] + } + var err error + if idx < len(s.errors) { + err = s.errors[idx] + } + if resp == nil && err == nil { + return nil, errors.New("unexpected upstream call") + } + return resp, err +} + +func (s *queuedHTTPUpstreamStub) DoWithTLS(req *http.Request, proxyURL string, accountID int64, concurrency int, _ bool) (*http.Response, error) { + return s.Do(req, proxyURL, accountID, concurrency) +} + type antigravitySettingRepoStub struct{} func (s *antigravitySettingRepoStub) Get(ctx context.Context, key string) (*Setting, error) { @@ -556,6 +597,177 @@ func TestAntigravityGatewayService_ForwardGemini_BillsWithMappedModel(t *testing require.Equal(t, mappedModel, result.Model) } +func TestAntigravityGatewayService_ForwardGemini_RetriesCorruptedThoughtSignature(t *testing.T) { + gin.SetMode(gin.TestMode) + writer := httptest.NewRecorder() + c, _ := gin.CreateTestContext(writer) + + body, err := json.Marshal(map[string]any{ + "contents": []map[string]any{ + {"role": "user", "parts": []map[string]any{{"text": "hello"}}}, + {"role": "model", "parts": []map[string]any{{"text": "thinking", "thought": true, "thoughtSignature": "sig_bad_1"}}}, + {"role": "model", "parts": []map[string]any{{"functionCall": map[string]any{"name": "toolA", "args": map[string]any{"x": 1}}, "thoughtSignature": "sig_bad_2"}}}, + }, + }) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/antigravity/v1beta/models/gemini-3.1-pro-preview:streamGenerateContent", bytes.NewReader(body)) + c.Request = req + + firstRespBody := []byte(`{"response":{"error":{"code":400,"message":"Corrupted thought signature.","status":"INVALID_ARGUMENT"}}}`) + secondRespBody := []byte("data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"}]},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":8,\"candidatesTokenCount\":3}}}\n\n") + + upstream := &queuedHTTPUpstreamStub{ + responses: []*http.Response{ + { + StatusCode: http.StatusBadRequest, + Header: http.Header{ + "Content-Type": []string{"application/json"}, + "X-Request-Id": []string{"req-sig-1"}, + }, + Body: io.NopCloser(bytes.NewReader(firstRespBody)), + }, + { + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"text/event-stream"}, + "X-Request-Id": []string{"req-sig-2"}, + }, + Body: io.NopCloser(bytes.NewReader(secondRespBody)), + }, + }, + } + + svc := &AntigravityGatewayService{ + settingService: NewSettingService(&antigravitySettingRepoStub{}, &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}), + tokenProvider: &AntigravityTokenProvider{}, + httpUpstream: upstream, + } + + const originalModel = "gemini-3.1-pro-preview" + const mappedModel = "gemini-3.1-pro-high" + account := &Account{ + ID: 7, + Name: "acc-gemini-signature", + Platform: PlatformAntigravity, + Type: AccountTypeOAuth, + Status: StatusActive, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "token", + "model_mapping": map[string]any{ + originalModel: mappedModel, + }, + }, + } + + result, err := svc.ForwardGemini(context.Background(), c, account, originalModel, "streamGenerateContent", true, body, false) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, mappedModel, result.Model) + require.Len(t, upstream.requestBodies, 2, "signature error should trigger exactly one retry") + + firstReq := string(upstream.requestBodies[0]) + secondReq := string(upstream.requestBodies[1]) + require.Contains(t, firstReq, `"thoughtSignature":"sig_bad_1"`) + require.Contains(t, firstReq, `"thoughtSignature":"sig_bad_2"`) + require.Contains(t, secondReq, `"thoughtSignature":"skip_thought_signature_validator"`) + require.NotContains(t, secondReq, `"thoughtSignature":"sig_bad_1"`) + require.NotContains(t, secondReq, `"thoughtSignature":"sig_bad_2"`) + + raw, ok := c.Get(OpsUpstreamErrorsKey) + require.True(t, ok) + events, ok := raw.([]*OpsUpstreamErrorEvent) + require.True(t, ok) + require.NotEmpty(t, events) + require.Equal(t, "signature_error", events[0].Kind) +} + +func TestAntigravityGatewayService_ForwardGemini_SignatureRetryPropagatesFailover(t *testing.T) { + gin.SetMode(gin.TestMode) + writer := httptest.NewRecorder() + c, _ := gin.CreateTestContext(writer) + + body, err := json.Marshal(map[string]any{ + "contents": []map[string]any{ + {"role": "user", "parts": []map[string]any{{"text": "hello"}}}, + {"role": "model", "parts": []map[string]any{{"text": "thinking", "thought": true, "thoughtSignature": "sig_bad_1"}}}, + }, + }) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/antigravity/v1beta/models/gemini-3.1-pro-preview:streamGenerateContent", bytes.NewReader(body)) + c.Request = req + + firstRespBody := []byte(`{"response":{"error":{"code":400,"message":"Corrupted thought signature.","status":"INVALID_ARGUMENT"}}}`) + + const originalModel = "gemini-3.1-pro-preview" + const mappedModel = "gemini-3.1-pro-high" + account := &Account{ + ID: 8, + Name: "acc-gemini-signature-failover", + Platform: PlatformAntigravity, + Type: AccountTypeOAuth, + Status: StatusActive, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "token", + "model_mapping": map[string]any{ + originalModel: mappedModel, + }, + }, + } + + upstream := &queuedHTTPUpstreamStub{ + responses: []*http.Response{ + { + StatusCode: http.StatusBadRequest, + Header: http.Header{ + "Content-Type": []string{"application/json"}, + "X-Request-Id": []string{"req-sig-failover-1"}, + }, + Body: io.NopCloser(bytes.NewReader(firstRespBody)), + }, + }, + onCall: func(_ *http.Request, stub *queuedHTTPUpstreamStub) { + if stub.callCount != 1 { + return + } + futureResetAt := time.Now().Add(30 * time.Second).Format(time.RFC3339) + account.Extra = map[string]any{ + modelRateLimitsKey: map[string]any{ + mappedModel: map[string]any{ + "rate_limit_reset_at": futureResetAt, + }, + }, + } + }, + } + + svc := &AntigravityGatewayService{ + settingService: NewSettingService(&antigravitySettingRepoStub{}, &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}), + tokenProvider: &AntigravityTokenProvider{}, + httpUpstream: upstream, + } + + result, err := svc.ForwardGemini(context.Background(), c, account, originalModel, "streamGenerateContent", true, body, true) + require.Nil(t, result) + + var failoverErr *UpstreamFailoverError + require.ErrorAs(t, err, &failoverErr, "signature retry should propagate failover instead of falling back to the original 400") + require.Equal(t, http.StatusServiceUnavailable, failoverErr.StatusCode) + require.True(t, failoverErr.ForceCacheBilling) + require.Len(t, upstream.requestBodies, 1, "retry should stop at preflight failover and not issue a second upstream request") + + raw, ok := c.Get(OpsUpstreamErrorsKey) + require.True(t, ok) + events, ok := raw.([]*OpsUpstreamErrorEvent) + require.True(t, ok) + require.Len(t, events, 2) + require.Equal(t, "signature_error", events[0].Kind) + require.Equal(t, "failover", events[1].Kind) +} + // TestStreamUpstreamResponse_UsageAndFirstToken // 验证:usage 字段可被累积/覆盖更新,并且能记录首 token 时间 func TestStreamUpstreamResponse_UsageAndFirstToken(t *testing.T) { @@ -710,7 +922,7 @@ func TestHandleClaudeStreamingResponse_NormalComplete(t *testing.T) { fmt.Fprintln(pw, "") }() - result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5") + result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5", 0) _ = pr.Close() require.NoError(t, err) @@ -787,7 +999,7 @@ func TestHandleClaudeStreamingResponse_ThoughtsTokenCount(t *testing.T) { fmt.Fprintln(pw, "") }() - result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "gemini-2.5-pro") + result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "gemini-2.5-pro", 0) _ = pr.Close() require.NoError(t, err) @@ -990,7 +1202,7 @@ func TestHandleClaudeStreamingResponse_ClientDisconnect(t *testing.T) { fmt.Fprintln(pw, "") }() - result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5") + result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5", 0) _ = pr.Close() require.NoError(t, err) @@ -998,6 +1210,46 @@ func TestHandleClaudeStreamingResponse_ClientDisconnect(t *testing.T) { require.True(t, result.clientDisconnect) } +// TestHandleClaudeStreamingResponse_EmptyStream +// 验证:上游只返回无法解析的 SSE 行时,触发 UpstreamFailoverError 而不是向客户端发出残缺流 +func TestHandleClaudeStreamingResponse_EmptyStream(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := newAntigravityTestService(&config.Config{ + Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}, + }) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + pr, pw := io.Pipe() + resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}} + + go func() { + defer func() { _ = pw.Close() }() + // 所有行均为无法 JSON 解析的内容,ProcessLine 全部返回 nil + fmt.Fprintln(pw, "data: not-valid-json") + fmt.Fprintln(pw, "") + fmt.Fprintln(pw, "data: also-invalid") + fmt.Fprintln(pw, "") + }() + + _, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5", 0) + _ = pr.Close() + + // 应当返回 UpstreamFailoverError 而非 nil,以便上层触发 failover + require.Error(t, err) + var failoverErr *UpstreamFailoverError + require.ErrorAs(t, err, &failoverErr) + require.True(t, failoverErr.RetryableOnSameAccount) + + // 客户端不应收到任何 SSE 事件(既无 message_start 也无 message_stop) + body := rec.Body.String() + require.NotContains(t, body, "event: message_start") + require.NotContains(t, body, "event: message_stop") + require.NotContains(t, body, "event: message_delta") +} + // TestHandleClaudeStreamingResponse_ContextCanceled // 验证:context 取消时不注入错误事件 func TestHandleClaudeStreamingResponse_ContextCanceled(t *testing.T) { @@ -1014,7 +1266,7 @@ func TestHandleClaudeStreamingResponse_ContextCanceled(t *testing.T) { resp := &http.Response{StatusCode: http.StatusOK, Body: cancelReadCloser{}, Header: http.Header{}} - result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5") + result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5", 0) require.NoError(t, err) require.NotNil(t, result) diff --git a/backend/internal/service/api_key.go b/backend/internal/service/api_key.go index 4c565495..eb9f2b15 100644 --- a/backend/internal/service/api_key.go +++ b/backend/internal/service/api_key.go @@ -14,6 +14,18 @@ const ( StatusAPIKeyExpired = "expired" ) +// Rate limit window durations +const ( + RateLimitWindow5h = 5 * time.Hour + RateLimitWindow1d = 24 * time.Hour + RateLimitWindow7d = 7 * 24 * time.Hour +) + +// IsWindowExpired returns true if the window starting at windowStart has exceeded the given duration. +func IsWindowExpired(windowStart *time.Time, duration time.Duration) bool { + return windowStart != nil && time.Since(*windowStart) >= duration +} + type APIKey struct { ID int64 UserID int64 @@ -98,6 +110,30 @@ func (k *APIKey) GetDaysUntilExpiry() int { return int(duration.Hours() / 24) } +// EffectiveUsage5h returns the 5h window usage, or 0 if the window has expired. +func (k *APIKey) EffectiveUsage5h() float64 { + if IsWindowExpired(k.Window5hStart, RateLimitWindow5h) { + return 0 + } + return k.Usage5h +} + +// EffectiveUsage1d returns the 1d window usage, or 0 if the window has expired. +func (k *APIKey) EffectiveUsage1d() float64 { + if IsWindowExpired(k.Window1dStart, RateLimitWindow1d) { + return 0 + } + return k.Usage1d +} + +// EffectiveUsage7d returns the 7d window usage, or 0 if the window has expired. +func (k *APIKey) EffectiveUsage7d() float64 { + if IsWindowExpired(k.Window7dStart, RateLimitWindow7d) { + return 0 + } + return k.Usage7d +} + // APIKeyListFilters holds optional filtering parameters for listing API keys. type APIKeyListFilters struct { Search string diff --git a/backend/internal/service/api_key_auth_cache.go b/backend/internal/service/api_key_auth_cache.go index 83933f42..258b842b 100644 --- a/backend/internal/service/api_key_auth_cache.go +++ b/backend/internal/service/api_key_auth_cache.go @@ -59,12 +59,17 @@ type APIKeyAuthGroupSnapshot struct { // Model routing is used by gateway account selection, so it must be part of auth cache snapshot. // Only anthropic groups use these fields; others may leave them empty. - ModelRouting map[string][]int64 `json:"model_routing,omitempty"` - ModelRoutingEnabled bool `json:"model_routing_enabled"` - MCPXMLInject bool `json:"mcp_xml_inject"` + ModelRouting map[string][]int64 `json:"model_routing,omitempty"` + ModelRoutingEnabled bool `json:"model_routing_enabled"` + MCPXMLInject bool `json:"mcp_xml_inject"` + SimulateClaudeMaxEnabled bool `json:"simulate_claude_max_enabled"` // 支持的模型系列(仅 antigravity 平台使用) SupportedModelScopes []string `json:"supported_model_scopes,omitempty"` + + // OpenAI Messages 调度配置(仅 openai 平台使用) + AllowMessagesDispatch bool `json:"allow_messages_dispatch"` + DefaultMappedModel string `json:"default_mapped_model,omitempty"` } // APIKeyAuthCacheEntry 缓存条目,支持负缓存 diff --git a/backend/internal/service/api_key_auth_cache_impl.go b/backend/internal/service/api_key_auth_cache_impl.go index 0ca694af..d874ccf2 100644 --- a/backend/internal/service/api_key_auth_cache_impl.go +++ b/backend/internal/service/api_key_auth_cache_impl.go @@ -244,7 +244,10 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot { ModelRouting: apiKey.Group.ModelRouting, ModelRoutingEnabled: apiKey.Group.ModelRoutingEnabled, MCPXMLInject: apiKey.Group.MCPXMLInject, + SimulateClaudeMaxEnabled: apiKey.Group.SimulateClaudeMaxEnabled, SupportedModelScopes: apiKey.Group.SupportedModelScopes, + AllowMessagesDispatch: apiKey.Group.AllowMessagesDispatch, + DefaultMappedModel: apiKey.Group.DefaultMappedModel, } } return snapshot @@ -301,7 +304,10 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho ModelRouting: snapshot.Group.ModelRouting, ModelRoutingEnabled: snapshot.Group.ModelRoutingEnabled, MCPXMLInject: snapshot.Group.MCPXMLInject, + SimulateClaudeMaxEnabled: snapshot.Group.SimulateClaudeMaxEnabled, SupportedModelScopes: snapshot.Group.SupportedModelScopes, + AllowMessagesDispatch: snapshot.Group.AllowMessagesDispatch, + DefaultMappedModel: snapshot.Group.DefaultMappedModel, } } s.compileAPIKeyIPRules(apiKey) diff --git a/backend/internal/service/api_key_rate_limit_test.go b/backend/internal/service/api_key_rate_limit_test.go new file mode 100644 index 00000000..7fadf270 --- /dev/null +++ b/backend/internal/service/api_key_rate_limit_test.go @@ -0,0 +1,245 @@ +package service + +import ( + "testing" + "time" +) + +func TestIsWindowExpired(t *testing.T) { + now := time.Now() + + tests := []struct { + name string + start *time.Time + duration time.Duration + want bool + }{ + { + name: "nil window start", + start: nil, + duration: RateLimitWindow5h, + want: false, + }, + { + name: "active window (started 1h ago, 5h window)", + start: rateLimitTimePtr(now.Add(-1 * time.Hour)), + duration: RateLimitWindow5h, + want: false, + }, + { + name: "expired window (started 6h ago, 5h window)", + start: rateLimitTimePtr(now.Add(-6 * time.Hour)), + duration: RateLimitWindow5h, + want: true, + }, + { + name: "exactly at boundary (started 5h ago, 5h window)", + start: rateLimitTimePtr(now.Add(-5 * time.Hour)), + duration: RateLimitWindow5h, + want: true, + }, + { + name: "active 1d window (started 12h ago)", + start: rateLimitTimePtr(now.Add(-12 * time.Hour)), + duration: RateLimitWindow1d, + want: false, + }, + { + name: "expired 1d window (started 25h ago)", + start: rateLimitTimePtr(now.Add(-25 * time.Hour)), + duration: RateLimitWindow1d, + want: true, + }, + { + name: "active 7d window (started 3d ago)", + start: rateLimitTimePtr(now.Add(-3 * 24 * time.Hour)), + duration: RateLimitWindow7d, + want: false, + }, + { + name: "expired 7d window (started 8d ago)", + start: rateLimitTimePtr(now.Add(-8 * 24 * time.Hour)), + duration: RateLimitWindow7d, + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := IsWindowExpired(tt.start, tt.duration) + if got != tt.want { + t.Errorf("IsWindowExpired() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestAPIKey_EffectiveUsage(t *testing.T) { + now := time.Now() + + tests := []struct { + name string + key APIKey + want5h float64 + want1d float64 + want7d float64 + }{ + { + name: "all windows active", + key: APIKey{ + Usage5h: 5.0, + Usage1d: 10.0, + Usage7d: 50.0, + Window5hStart: rateLimitTimePtr(now.Add(-1 * time.Hour)), + Window1dStart: rateLimitTimePtr(now.Add(-12 * time.Hour)), + Window7dStart: rateLimitTimePtr(now.Add(-3 * 24 * time.Hour)), + }, + want5h: 5.0, + want1d: 10.0, + want7d: 50.0, + }, + { + name: "all windows expired", + key: APIKey{ + Usage5h: 5.0, + Usage1d: 10.0, + Usage7d: 50.0, + Window5hStart: rateLimitTimePtr(now.Add(-6 * time.Hour)), + Window1dStart: rateLimitTimePtr(now.Add(-25 * time.Hour)), + Window7dStart: rateLimitTimePtr(now.Add(-8 * 24 * time.Hour)), + }, + want5h: 0, + want1d: 0, + want7d: 0, + }, + { + name: "nil window starts return raw usage", + key: APIKey{ + Usage5h: 5.0, + Usage1d: 10.0, + Usage7d: 50.0, + Window5hStart: nil, + Window1dStart: nil, + Window7dStart: nil, + }, + want5h: 5.0, + want1d: 10.0, + want7d: 50.0, + }, + { + name: "mixed: 5h expired, 1d active, 7d nil", + key: APIKey{ + Usage5h: 5.0, + Usage1d: 10.0, + Usage7d: 50.0, + Window5hStart: rateLimitTimePtr(now.Add(-6 * time.Hour)), + Window1dStart: rateLimitTimePtr(now.Add(-12 * time.Hour)), + Window7dStart: nil, + }, + want5h: 0, + want1d: 10.0, + want7d: 50.0, + }, + { + name: "zero usage with active windows", + key: APIKey{ + Usage5h: 0, + Usage1d: 0, + Usage7d: 0, + Window5hStart: rateLimitTimePtr(now.Add(-1 * time.Hour)), + Window1dStart: rateLimitTimePtr(now.Add(-1 * time.Hour)), + Window7dStart: rateLimitTimePtr(now.Add(-1 * time.Hour)), + }, + want5h: 0, + want1d: 0, + want7d: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.key.EffectiveUsage5h(); got != tt.want5h { + t.Errorf("EffectiveUsage5h() = %v, want %v", got, tt.want5h) + } + if got := tt.key.EffectiveUsage1d(); got != tt.want1d { + t.Errorf("EffectiveUsage1d() = %v, want %v", got, tt.want1d) + } + if got := tt.key.EffectiveUsage7d(); got != tt.want7d { + t.Errorf("EffectiveUsage7d() = %v, want %v", got, tt.want7d) + } + }) + } +} + +func TestAPIKeyRateLimitData_EffectiveUsage(t *testing.T) { + now := time.Now() + + tests := []struct { + name string + data APIKeyRateLimitData + want5h float64 + want1d float64 + want7d float64 + }{ + { + name: "all windows active", + data: APIKeyRateLimitData{ + Usage5h: 3.0, + Usage1d: 8.0, + Usage7d: 40.0, + Window5hStart: rateLimitTimePtr(now.Add(-2 * time.Hour)), + Window1dStart: rateLimitTimePtr(now.Add(-10 * time.Hour)), + Window7dStart: rateLimitTimePtr(now.Add(-2 * 24 * time.Hour)), + }, + want5h: 3.0, + want1d: 8.0, + want7d: 40.0, + }, + { + name: "all windows expired", + data: APIKeyRateLimitData{ + Usage5h: 3.0, + Usage1d: 8.0, + Usage7d: 40.0, + Window5hStart: rateLimitTimePtr(now.Add(-10 * time.Hour)), + Window1dStart: rateLimitTimePtr(now.Add(-48 * time.Hour)), + Window7dStart: rateLimitTimePtr(now.Add(-10 * 24 * time.Hour)), + }, + want5h: 0, + want1d: 0, + want7d: 0, + }, + { + name: "nil window starts return raw usage", + data: APIKeyRateLimitData{ + Usage5h: 3.0, + Usage1d: 8.0, + Usage7d: 40.0, + Window5hStart: nil, + Window1dStart: nil, + Window7dStart: nil, + }, + want5h: 3.0, + want1d: 8.0, + want7d: 40.0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.data.EffectiveUsage5h(); got != tt.want5h { + t.Errorf("EffectiveUsage5h() = %v, want %v", got, tt.want5h) + } + if got := tt.data.EffectiveUsage1d(); got != tt.want1d { + t.Errorf("EffectiveUsage1d() = %v, want %v", got, tt.want1d) + } + if got := tt.data.EffectiveUsage7d(); got != tt.want7d { + t.Errorf("EffectiveUsage7d() = %v, want %v", got, tt.want7d) + } + }) + } +} + +func rateLimitTimePtr(t time.Time) *time.Time { + return &t +} diff --git a/backend/internal/service/api_key_service.go b/backend/internal/service/api_key_service.go index b32a1d67..18e9ff7a 100644 --- a/backend/internal/service/api_key_service.go +++ b/backend/internal/service/api_key_service.go @@ -6,6 +6,7 @@ import ( "encoding/hex" "fmt" "strconv" + "strings" "sync" "time" @@ -86,6 +87,39 @@ type APIKeyRateLimitData struct { Window7dStart *time.Time } +// EffectiveUsage5h returns the 5h window usage, or 0 if the window has expired. +func (d *APIKeyRateLimitData) EffectiveUsage5h() float64 { + if IsWindowExpired(d.Window5hStart, RateLimitWindow5h) { + return 0 + } + return d.Usage5h +} + +// EffectiveUsage1d returns the 1d window usage, or 0 if the window has expired. +func (d *APIKeyRateLimitData) EffectiveUsage1d() float64 { + if IsWindowExpired(d.Window1dStart, RateLimitWindow1d) { + return 0 + } + return d.Usage1d +} + +// EffectiveUsage7d returns the 7d window usage, or 0 if the window has expired. +func (d *APIKeyRateLimitData) EffectiveUsage7d() float64 { + if IsWindowExpired(d.Window7dStart, RateLimitWindow7d) { + return 0 + } + return d.Usage7d +} + +// APIKeyQuotaUsageState captures the latest quota fields after an atomic quota update. +// It is intentionally small so repositories can return it from a single SQL statement. +type APIKeyQuotaUsageState struct { + QuotaUsed float64 + Quota float64 + Key string + Status string +} + // APIKeyCache defines cache operations for API key service type APIKeyCache interface { GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) @@ -793,6 +827,21 @@ func (s *APIKeyService) UpdateQuotaUsed(ctx context.Context, apiKeyID int64, cos return nil } + type quotaStateReader interface { + IncrementQuotaUsedAndGetState(ctx context.Context, id int64, amount float64) (*APIKeyQuotaUsageState, error) + } + + if repo, ok := s.apiKeyRepo.(quotaStateReader); ok { + state, err := repo.IncrementQuotaUsedAndGetState(ctx, apiKeyID, cost) + if err != nil { + return fmt.Errorf("increment quota used: %w", err) + } + if state != nil && state.Status == StatusAPIKeyQuotaExhausted && strings.TrimSpace(state.Key) != "" { + s.InvalidateAuthCacheByKey(ctx, state.Key) + } + return nil + } + // Use repository to atomically increment quota_used newQuotaUsed, err := s.apiKeyRepo.IncrementQuotaUsed(ctx, apiKeyID, cost) if err != nil { diff --git a/backend/internal/service/api_key_service_quota_test.go b/backend/internal/service/api_key_service_quota_test.go new file mode 100644 index 00000000..2e2f6f78 --- /dev/null +++ b/backend/internal/service/api_key_service_quota_test.go @@ -0,0 +1,170 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/stretchr/testify/require" +) + +type quotaStateRepoStub struct { + quotaBaseAPIKeyRepoStub + stateCalls int + state *APIKeyQuotaUsageState + stateErr error +} + +func (s *quotaStateRepoStub) IncrementQuotaUsedAndGetState(ctx context.Context, id int64, amount float64) (*APIKeyQuotaUsageState, error) { + s.stateCalls++ + if s.stateErr != nil { + return nil, s.stateErr + } + if s.state == nil { + return nil, nil + } + out := *s.state + return &out, nil +} + +type quotaStateCacheStub struct { + deleteAuthKeys []string +} + +func (s *quotaStateCacheStub) GetCreateAttemptCount(context.Context, int64) (int, error) { + return 0, nil +} + +func (s *quotaStateCacheStub) IncrementCreateAttemptCount(context.Context, int64) error { + return nil +} + +func (s *quotaStateCacheStub) DeleteCreateAttemptCount(context.Context, int64) error { + return nil +} + +func (s *quotaStateCacheStub) IncrementDailyUsage(context.Context, string) error { + return nil +} + +func (s *quotaStateCacheStub) SetDailyUsageExpiry(context.Context, string, time.Duration) error { + return nil +} + +func (s *quotaStateCacheStub) GetAuthCache(context.Context, string) (*APIKeyAuthCacheEntry, error) { + return nil, nil +} + +func (s *quotaStateCacheStub) SetAuthCache(context.Context, string, *APIKeyAuthCacheEntry, time.Duration) error { + return nil +} + +func (s *quotaStateCacheStub) DeleteAuthCache(_ context.Context, key string) error { + s.deleteAuthKeys = append(s.deleteAuthKeys, key) + return nil +} + +func (s *quotaStateCacheStub) PublishAuthCacheInvalidation(context.Context, string) error { + return nil +} + +func (s *quotaStateCacheStub) SubscribeAuthCacheInvalidation(context.Context, func(string)) error { + return nil +} + +type quotaBaseAPIKeyRepoStub struct { + getByIDCalls int +} + +func (s *quotaBaseAPIKeyRepoStub) Create(context.Context, *APIKey) error { + panic("unexpected Create call") +} +func (s *quotaBaseAPIKeyRepoStub) GetByID(context.Context, int64) (*APIKey, error) { + s.getByIDCalls++ + return nil, nil +} +func (s *quotaBaseAPIKeyRepoStub) GetKeyAndOwnerID(context.Context, int64) (string, int64, error) { + panic("unexpected GetKeyAndOwnerID call") +} +func (s *quotaBaseAPIKeyRepoStub) GetByKey(context.Context, string) (*APIKey, error) { + panic("unexpected GetByKey call") +} +func (s *quotaBaseAPIKeyRepoStub) GetByKeyForAuth(context.Context, string) (*APIKey, error) { + panic("unexpected GetByKeyForAuth call") +} +func (s *quotaBaseAPIKeyRepoStub) Update(context.Context, *APIKey) error { + panic("unexpected Update call") +} +func (s *quotaBaseAPIKeyRepoStub) Delete(context.Context, int64) error { + panic("unexpected Delete call") +} +func (s *quotaBaseAPIKeyRepoStub) ListByUserID(context.Context, int64, pagination.PaginationParams, APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error) { + panic("unexpected ListByUserID call") +} +func (s *quotaBaseAPIKeyRepoStub) VerifyOwnership(context.Context, int64, []int64) ([]int64, error) { + panic("unexpected VerifyOwnership call") +} +func (s *quotaBaseAPIKeyRepoStub) CountByUserID(context.Context, int64) (int64, error) { + panic("unexpected CountByUserID call") +} +func (s *quotaBaseAPIKeyRepoStub) ExistsByKey(context.Context, string) (bool, error) { + panic("unexpected ExistsByKey call") +} +func (s *quotaBaseAPIKeyRepoStub) ListByGroupID(context.Context, int64, pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) { + panic("unexpected ListByGroupID call") +} +func (s *quotaBaseAPIKeyRepoStub) SearchAPIKeys(context.Context, int64, string, int) ([]APIKey, error) { + panic("unexpected SearchAPIKeys call") +} +func (s *quotaBaseAPIKeyRepoStub) ClearGroupIDByGroupID(context.Context, int64) (int64, error) { + panic("unexpected ClearGroupIDByGroupID call") +} +func (s *quotaBaseAPIKeyRepoStub) CountByGroupID(context.Context, int64) (int64, error) { + panic("unexpected CountByGroupID call") +} +func (s *quotaBaseAPIKeyRepoStub) ListKeysByUserID(context.Context, int64) ([]string, error) { + panic("unexpected ListKeysByUserID call") +} +func (s *quotaBaseAPIKeyRepoStub) ListKeysByGroupID(context.Context, int64) ([]string, error) { + panic("unexpected ListKeysByGroupID call") +} +func (s *quotaBaseAPIKeyRepoStub) IncrementQuotaUsed(context.Context, int64, float64) (float64, error) { + panic("unexpected IncrementQuotaUsed call") +} +func (s *quotaBaseAPIKeyRepoStub) UpdateLastUsed(context.Context, int64, time.Time) error { + panic("unexpected UpdateLastUsed call") +} +func (s *quotaBaseAPIKeyRepoStub) IncrementRateLimitUsage(context.Context, int64, float64) error { + panic("unexpected IncrementRateLimitUsage call") +} +func (s *quotaBaseAPIKeyRepoStub) ResetRateLimitWindows(context.Context, int64) error { + panic("unexpected ResetRateLimitWindows call") +} +func (s *quotaBaseAPIKeyRepoStub) GetRateLimitData(context.Context, int64) (*APIKeyRateLimitData, error) { + panic("unexpected GetRateLimitData call") +} + +func TestAPIKeyService_UpdateQuotaUsed_UsesAtomicStatePath(t *testing.T) { + repo := "aStateRepoStub{ + state: &APIKeyQuotaUsageState{ + QuotaUsed: 12, + Quota: 10, + Key: "sk-test-quota", + Status: StatusAPIKeyQuotaExhausted, + }, + } + cache := "aStateCacheStub{} + svc := &APIKeyService{ + apiKeyRepo: repo, + cache: cache, + } + + err := svc.UpdateQuotaUsed(context.Background(), 101, 2) + require.NoError(t, err) + require.Equal(t, 1, repo.stateCalls) + require.Equal(t, 0, repo.getByIDCalls, "fast path should not re-read API key by id") + require.Equal(t, []string{svc.authCacheKey("sk-test-quota")}, cache.deleteAuthKeys) +} diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go index fe3a0f25..28607e9f 100644 --- a/backend/internal/service/auth_service.go +++ b/backend/internal/service/auth_service.go @@ -8,9 +8,11 @@ import ( "errors" "fmt" "net/mail" + "strconv" "strings" "time" + dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/internal/config" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" @@ -20,23 +22,25 @@ import ( ) var ( - ErrInvalidCredentials = infraerrors.Unauthorized("INVALID_CREDENTIALS", "invalid email or password") - ErrUserNotActive = infraerrors.Forbidden("USER_NOT_ACTIVE", "user is not active") - ErrEmailExists = infraerrors.Conflict("EMAIL_EXISTS", "email already exists") - ErrEmailReserved = infraerrors.BadRequest("EMAIL_RESERVED", "email is reserved") - ErrInvalidToken = infraerrors.Unauthorized("INVALID_TOKEN", "invalid token") - ErrTokenExpired = infraerrors.Unauthorized("TOKEN_EXPIRED", "token has expired") - ErrAccessTokenExpired = infraerrors.Unauthorized("ACCESS_TOKEN_EXPIRED", "access token has expired") - ErrTokenTooLarge = infraerrors.BadRequest("TOKEN_TOO_LARGE", "token too large") - ErrTokenRevoked = infraerrors.Unauthorized("TOKEN_REVOKED", "token has been revoked") - ErrRefreshTokenInvalid = infraerrors.Unauthorized("REFRESH_TOKEN_INVALID", "invalid refresh token") - ErrRefreshTokenExpired = infraerrors.Unauthorized("REFRESH_TOKEN_EXPIRED", "refresh token has expired") - ErrRefreshTokenReused = infraerrors.Unauthorized("REFRESH_TOKEN_REUSED", "refresh token has been reused") - ErrEmailVerifyRequired = infraerrors.BadRequest("EMAIL_VERIFY_REQUIRED", "email verification is required") - ErrRegDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled") - ErrServiceUnavailable = infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable") - ErrInvitationCodeRequired = infraerrors.BadRequest("INVITATION_CODE_REQUIRED", "invitation code is required") - ErrInvitationCodeInvalid = infraerrors.BadRequest("INVITATION_CODE_INVALID", "invalid or used invitation code") + ErrInvalidCredentials = infraerrors.Unauthorized("INVALID_CREDENTIALS", "invalid email or password") + ErrUserNotActive = infraerrors.Forbidden("USER_NOT_ACTIVE", "user is not active") + ErrEmailExists = infraerrors.Conflict("EMAIL_EXISTS", "email already exists") + ErrEmailReserved = infraerrors.BadRequest("EMAIL_RESERVED", "email is reserved") + ErrInvalidToken = infraerrors.Unauthorized("INVALID_TOKEN", "invalid token") + ErrTokenExpired = infraerrors.Unauthorized("TOKEN_EXPIRED", "token has expired") + ErrAccessTokenExpired = infraerrors.Unauthorized("ACCESS_TOKEN_EXPIRED", "access token has expired") + ErrTokenTooLarge = infraerrors.BadRequest("TOKEN_TOO_LARGE", "token too large") + ErrTokenRevoked = infraerrors.Unauthorized("TOKEN_REVOKED", "token has been revoked") + ErrRefreshTokenInvalid = infraerrors.Unauthorized("REFRESH_TOKEN_INVALID", "invalid refresh token") + ErrRefreshTokenExpired = infraerrors.Unauthorized("REFRESH_TOKEN_EXPIRED", "refresh token has expired") + ErrRefreshTokenReused = infraerrors.Unauthorized("REFRESH_TOKEN_REUSED", "refresh token has been reused") + ErrEmailVerifyRequired = infraerrors.BadRequest("EMAIL_VERIFY_REQUIRED", "email verification is required") + ErrEmailSuffixNotAllowed = infraerrors.BadRequest("EMAIL_SUFFIX_NOT_ALLOWED", "email suffix is not allowed") + ErrRegDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled") + ErrServiceUnavailable = infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable") + ErrInvitationCodeRequired = infraerrors.BadRequest("INVITATION_CODE_REQUIRED", "invitation code is required") + ErrInvitationCodeInvalid = infraerrors.BadRequest("INVITATION_CODE_INVALID", "invalid or used invitation code") + ErrOAuthInvitationRequired = infraerrors.Forbidden("OAUTH_INVITATION_REQUIRED", "invitation code required to complete oauth registration") ) // maxTokenLength 限制 token 大小,避免超长 header 触发解析时的异常内存分配。 @@ -56,6 +60,7 @@ type JWTClaims struct { // AuthService 认证服务 type AuthService struct { + entClient *dbent.Client userRepo UserRepository redeemRepo RedeemCodeRepository refreshTokenCache RefreshTokenCache @@ -74,6 +79,7 @@ type DefaultSubscriptionAssigner interface { // NewAuthService 创建认证服务实例 func NewAuthService( + entClient *dbent.Client, userRepo UserRepository, redeemRepo RedeemCodeRepository, refreshTokenCache RefreshTokenCache, @@ -86,6 +92,7 @@ func NewAuthService( defaultSubAssigner DefaultSubscriptionAssigner, ) *AuthService { return &AuthService{ + entClient: entClient, userRepo: userRepo, redeemRepo: redeemRepo, refreshTokenCache: refreshTokenCache, @@ -115,6 +122,9 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw if isReservedEmail(email) { return "", nil, ErrEmailReserved } + if err := s.validateRegistrationEmailPolicy(ctx, email); err != nil { + return "", nil, err + } // 检查是否需要邀请码 var invitationRedeemCode *RedeemCode @@ -241,6 +251,9 @@ func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error { if isReservedEmail(email) { return ErrEmailReserved } + if err := s.validateRegistrationEmailPolicy(ctx, email); err != nil { + return err + } // 检查邮箱是否已存在 existsEmail, err := s.userRepo.ExistsByEmail(ctx, email) @@ -279,6 +292,9 @@ func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*S if isReservedEmail(email) { return nil, ErrEmailReserved } + if err := s.validateRegistrationEmailPolicy(ctx, email); err != nil { + return nil, err + } // 检查邮箱是否已存在 existsEmail, err := s.userRepo.ExistsByEmail(ctx, email) @@ -512,9 +528,10 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username return token, user, nil } -// LoginOrRegisterOAuthWithTokenPair 用于第三方 OAuth/SSO 登录,返回完整的 TokenPair -// 与 LoginOrRegisterOAuth 功能相同,但返回 TokenPair 而非单个 token -func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, email, username string) (*TokenPair, *User, error) { +// LoginOrRegisterOAuthWithTokenPair 用于第三方 OAuth/SSO 登录,返回完整的 TokenPair。 +// 与 LoginOrRegisterOAuth 功能相同,但返回 TokenPair 而非单个 token。 +// invitationCode 仅在邀请码注册模式下新用户注册时使用;已有账号登录时忽略。 +func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, email, username, invitationCode string) (*TokenPair, *User, error) { // 检查 refreshTokenCache 是否可用 if s.refreshTokenCache == nil { return nil, nil, errors.New("refresh token cache not configured") @@ -541,6 +558,22 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema return nil, nil, ErrRegDisabled } + // 检查是否需要邀请码 + var invitationRedeemCode *RedeemCode + if s.settingService != nil && s.settingService.IsInvitationCodeEnabled(ctx) { + if invitationCode == "" { + return nil, nil, ErrOAuthInvitationRequired + } + redeemCode, err := s.redeemRepo.GetByCode(ctx, invitationCode) + if err != nil { + return nil, nil, ErrInvitationCodeInvalid + } + if redeemCode.Type != RedeemTypeInvitation || redeemCode.Status != StatusUnused { + return nil, nil, ErrInvitationCodeInvalid + } + invitationRedeemCode = redeemCode + } + randomPassword, err := randomHexString(32) if err != nil { logger.LegacyPrintf("service.auth", "[Auth] Failed to generate random password for oauth signup: %v", err) @@ -568,20 +601,58 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema Status: StatusActive, } - if err := s.userRepo.Create(ctx, newUser); err != nil { - if errors.Is(err, ErrEmailExists) { - user, err = s.userRepo.GetByEmail(ctx, email) - if err != nil { - logger.LegacyPrintf("service.auth", "[Auth] Database error getting user after conflict: %v", err) + if s.entClient != nil && invitationRedeemCode != nil { + tx, err := s.entClient.Tx(ctx) + if err != nil { + logger.LegacyPrintf("service.auth", "[Auth] Failed to begin transaction for oauth registration: %v", err) + return nil, nil, ErrServiceUnavailable + } + defer func() { _ = tx.Rollback() }() + txCtx := dbent.NewTxContext(ctx, tx) + + if err := s.userRepo.Create(txCtx, newUser); err != nil { + if errors.Is(err, ErrEmailExists) { + user, err = s.userRepo.GetByEmail(ctx, email) + if err != nil { + logger.LegacyPrintf("service.auth", "[Auth] Database error getting user after conflict: %v", err) + return nil, nil, ErrServiceUnavailable + } + } else { + logger.LegacyPrintf("service.auth", "[Auth] Database error creating oauth user: %v", err) return nil, nil, ErrServiceUnavailable } } else { - logger.LegacyPrintf("service.auth", "[Auth] Database error creating oauth user: %v", err) - return nil, nil, ErrServiceUnavailable + if err := s.redeemRepo.Use(txCtx, invitationRedeemCode.ID, newUser.ID); err != nil { + return nil, nil, ErrInvitationCodeInvalid + } + if err := tx.Commit(); err != nil { + logger.LegacyPrintf("service.auth", "[Auth] Failed to commit oauth registration transaction: %v", err) + return nil, nil, ErrServiceUnavailable + } + user = newUser + s.assignDefaultSubscriptions(ctx, user.ID) } } else { - user = newUser - s.assignDefaultSubscriptions(ctx, user.ID) + if err := s.userRepo.Create(ctx, newUser); err != nil { + if errors.Is(err, ErrEmailExists) { + user, err = s.userRepo.GetByEmail(ctx, email) + if err != nil { + logger.LegacyPrintf("service.auth", "[Auth] Database error getting user after conflict: %v", err) + return nil, nil, ErrServiceUnavailable + } + } else { + logger.LegacyPrintf("service.auth", "[Auth] Database error creating oauth user: %v", err) + return nil, nil, ErrServiceUnavailable + } + } else { + user = newUser + s.assignDefaultSubscriptions(ctx, user.ID) + if invitationRedeemCode != nil { + if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil { + return nil, nil, ErrInvitationCodeInvalid + } + } + } } } else { logger.LegacyPrintf("service.auth", "[Auth] Database error during oauth login: %v", err) @@ -607,6 +678,63 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema return tokenPair, user, nil } +// pendingOAuthTokenTTL is the validity period for pending OAuth tokens. +const pendingOAuthTokenTTL = 10 * time.Minute + +// pendingOAuthPurpose is the purpose claim value for pending OAuth registration tokens. +const pendingOAuthPurpose = "pending_oauth_registration" + +type pendingOAuthClaims struct { + Email string `json:"email"` + Username string `json:"username"` + Purpose string `json:"purpose"` + jwt.RegisteredClaims +} + +// CreatePendingOAuthToken generates a short-lived JWT that carries the OAuth identity +// while waiting for the user to supply an invitation code. +func (s *AuthService) CreatePendingOAuthToken(email, username string) (string, error) { + now := time.Now() + claims := &pendingOAuthClaims{ + Email: email, + Username: username, + Purpose: pendingOAuthPurpose, + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(now.Add(pendingOAuthTokenTTL)), + IssuedAt: jwt.NewNumericDate(now), + NotBefore: jwt.NewNumericDate(now), + }, + } + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + return token.SignedString([]byte(s.cfg.JWT.Secret)) +} + +// VerifyPendingOAuthToken validates a pending OAuth token and returns the embedded identity. +// Returns ErrInvalidToken when the token is invalid or expired. +func (s *AuthService) VerifyPendingOAuthToken(tokenStr string) (email, username string, err error) { + if len(tokenStr) > maxTokenLength { + return "", "", ErrInvalidToken + } + parser := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name})) + token, parseErr := parser.ParseWithClaims(tokenStr, &pendingOAuthClaims{}, func(t *jwt.Token) (any, error) { + if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"]) + } + return []byte(s.cfg.JWT.Secret), nil + }) + if parseErr != nil { + return "", "", ErrInvalidToken + } + claims, ok := token.Claims.(*pendingOAuthClaims) + if !ok || !token.Valid { + return "", "", ErrInvalidToken + } + if claims.Purpose != pendingOAuthPurpose { + return "", "", ErrInvalidToken + } + return claims.Email, claims.Username, nil +} + func (s *AuthService) assignDefaultSubscriptions(ctx context.Context, userID int64) { if s.settingService == nil || s.defaultSubAssigner == nil || userID <= 0 { return @@ -624,6 +752,32 @@ func (s *AuthService) assignDefaultSubscriptions(ctx context.Context, userID int } } +func (s *AuthService) validateRegistrationEmailPolicy(ctx context.Context, email string) error { + if s.settingService == nil { + return nil + } + whitelist := s.settingService.GetRegistrationEmailSuffixWhitelist(ctx) + if !IsRegistrationEmailSuffixAllowed(email, whitelist) { + return buildEmailSuffixNotAllowedError(whitelist) + } + return nil +} + +func buildEmailSuffixNotAllowedError(whitelist []string) error { + if len(whitelist) == 0 { + return ErrEmailSuffixNotAllowed + } + + allowed := strings.Join(whitelist, ", ") + return infraerrors.BadRequest( + "EMAIL_SUFFIX_NOT_ALLOWED", + fmt.Sprintf("email suffix is not allowed, allowed suffixes: %s", allowed), + ).WithMetadata(map[string]string{ + "allowed_suffixes": strings.Join(whitelist, ","), + "allowed_suffix_count": strconv.Itoa(len(whitelist)), + }) +} + // ValidateToken 验证JWT token并返回用户声明 func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) { // 先做长度校验,尽早拒绝异常超长 token,降低 DoS 风险。 diff --git a/backend/internal/service/auth_service_pending_oauth_test.go b/backend/internal/service/auth_service_pending_oauth_test.go new file mode 100644 index 00000000..0472e06c --- /dev/null +++ b/backend/internal/service/auth_service_pending_oauth_test.go @@ -0,0 +1,146 @@ +//go:build unit + +package service + +import ( + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/require" +) + +func newAuthServiceForPendingOAuthTest() *AuthService { + cfg := &config.Config{ + JWT: config.JWTConfig{ + Secret: "test-secret-pending-oauth", + ExpireHour: 1, + }, + } + return NewAuthService(nil, nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil) +} + +// TestVerifyPendingOAuthToken_ValidToken 验证正常签发的 pending token 可以被成功解析。 +func TestVerifyPendingOAuthToken_ValidToken(t *testing.T) { + svc := newAuthServiceForPendingOAuthTest() + + token, err := svc.CreatePendingOAuthToken("user@example.com", "alice") + require.NoError(t, err) + require.NotEmpty(t, token) + + email, username, err := svc.VerifyPendingOAuthToken(token) + require.NoError(t, err) + require.Equal(t, "user@example.com", email) + require.Equal(t, "alice", username) +} + +// TestVerifyPendingOAuthToken_RegularJWTRejected 用普通 access token 尝试验证,应返回 ErrInvalidToken。 +func TestVerifyPendingOAuthToken_RegularJWTRejected(t *testing.T) { + svc := newAuthServiceForPendingOAuthTest() + + // 签发一个普通 access token(JWTClaims,无 Purpose 字段) + accessToken, err := svc.GenerateToken(&User{ + ID: 1, + Email: "user@example.com", + Role: RoleUser, + }) + require.NoError(t, err) + + _, _, err = svc.VerifyPendingOAuthToken(accessToken) + require.ErrorIs(t, err, ErrInvalidToken) +} + +// TestVerifyPendingOAuthToken_WrongPurpose 手动构造 purpose 字段不匹配的 JWT,应返回 ErrInvalidToken。 +func TestVerifyPendingOAuthToken_WrongPurpose(t *testing.T) { + svc := newAuthServiceForPendingOAuthTest() + + now := time.Now() + claims := &pendingOAuthClaims{ + Email: "user@example.com", + Username: "alice", + Purpose: "some_other_purpose", + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(now.Add(10 * time.Minute)), + IssuedAt: jwt.NewNumericDate(now), + NotBefore: jwt.NewNumericDate(now), + }, + } + tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret)) + require.NoError(t, err) + + _, _, err = svc.VerifyPendingOAuthToken(tokenStr) + require.ErrorIs(t, err, ErrInvalidToken) +} + +// TestVerifyPendingOAuthToken_MissingPurpose 手动构造无 purpose 字段的 JWT(模拟旧 token),应返回 ErrInvalidToken。 +func TestVerifyPendingOAuthToken_MissingPurpose(t *testing.T) { + svc := newAuthServiceForPendingOAuthTest() + + now := time.Now() + claims := &pendingOAuthClaims{ + Email: "user@example.com", + Username: "alice", + Purpose: "", // 旧 token 无此字段,反序列化后为零值 + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(now.Add(10 * time.Minute)), + IssuedAt: jwt.NewNumericDate(now), + NotBefore: jwt.NewNumericDate(now), + }, + } + tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret)) + require.NoError(t, err) + + _, _, err = svc.VerifyPendingOAuthToken(tokenStr) + require.ErrorIs(t, err, ErrInvalidToken) +} + +// TestVerifyPendingOAuthToken_ExpiredToken 过期 token 应返回 ErrInvalidToken。 +func TestVerifyPendingOAuthToken_ExpiredToken(t *testing.T) { + svc := newAuthServiceForPendingOAuthTest() + + past := time.Now().Add(-1 * time.Hour) + claims := &pendingOAuthClaims{ + Email: "user@example.com", + Username: "alice", + Purpose: pendingOAuthPurpose, + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(past), + IssuedAt: jwt.NewNumericDate(past.Add(-10 * time.Minute)), + NotBefore: jwt.NewNumericDate(past.Add(-10 * time.Minute)), + }, + } + tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret)) + require.NoError(t, err) + + _, _, err = svc.VerifyPendingOAuthToken(tokenStr) + require.ErrorIs(t, err, ErrInvalidToken) +} + +// TestVerifyPendingOAuthToken_WrongSecret 不同密钥签发的 token 应返回 ErrInvalidToken。 +func TestVerifyPendingOAuthToken_WrongSecret(t *testing.T) { + other := NewAuthService(nil, nil, nil, nil, &config.Config{ + JWT: config.JWTConfig{Secret: "other-secret"}, + }, nil, nil, nil, nil, nil, nil) + + token, err := other.CreatePendingOAuthToken("user@example.com", "alice") + require.NoError(t, err) + + svc := newAuthServiceForPendingOAuthTest() + _, _, err = svc.VerifyPendingOAuthToken(token) + require.ErrorIs(t, err, ErrInvalidToken) +} + +// TestVerifyPendingOAuthToken_TooLong 超长 token 应返回 ErrInvalidToken。 +func TestVerifyPendingOAuthToken_TooLong(t *testing.T) { + svc := newAuthServiceForPendingOAuthTest() + giant := make([]byte, maxTokenLength+1) + for i := range giant { + giant[i] = 'a' + } + _, _, err := svc.VerifyPendingOAuthToken(string(giant)) + require.ErrorIs(t, err, ErrInvalidToken) +} diff --git a/backend/internal/service/auth_service_register_test.go b/backend/internal/service/auth_service_register_test.go index 1999e759..7b50e90d 100644 --- a/backend/internal/service/auth_service_register_test.go +++ b/backend/internal/service/auth_service_register_test.go @@ -9,6 +9,7 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/config" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/stretchr/testify/require" ) @@ -129,6 +130,7 @@ func newAuthService(repo *userRepoStub, settings map[string]string, emailCache E } return NewAuthService( + nil, // entClient repo, nil, // redeemRepo nil, // refreshTokenCache @@ -231,6 +233,51 @@ func TestAuthService_Register_ReservedEmail(t *testing.T) { require.ErrorIs(t, err, ErrEmailReserved) } +func TestAuthService_Register_EmailSuffixNotAllowed(t *testing.T) { + repo := &userRepoStub{} + service := newAuthService(repo, map[string]string{ + SettingKeyRegistrationEnabled: "true", + SettingKeyRegistrationEmailSuffixWhitelist: `["@example.com","@company.com"]`, + }, nil) + + _, _, err := service.Register(context.Background(), "user@other.com", "password") + require.ErrorIs(t, err, ErrEmailSuffixNotAllowed) + appErr := infraerrors.FromError(err) + require.Contains(t, appErr.Message, "@example.com") + require.Contains(t, appErr.Message, "@company.com") + require.Equal(t, "EMAIL_SUFFIX_NOT_ALLOWED", appErr.Reason) + require.Equal(t, "2", appErr.Metadata["allowed_suffix_count"]) + require.Equal(t, "@example.com,@company.com", appErr.Metadata["allowed_suffixes"]) +} + +func TestAuthService_Register_EmailSuffixAllowed(t *testing.T) { + repo := &userRepoStub{nextID: 8} + service := newAuthService(repo, map[string]string{ + SettingKeyRegistrationEnabled: "true", + SettingKeyRegistrationEmailSuffixWhitelist: `["example.com"]`, + }, nil) + + _, user, err := service.Register(context.Background(), "user@example.com", "password") + require.NoError(t, err) + require.NotNil(t, user) + require.Equal(t, int64(8), user.ID) +} + +func TestAuthService_SendVerifyCode_EmailSuffixNotAllowed(t *testing.T) { + repo := &userRepoStub{} + service := newAuthService(repo, map[string]string{ + SettingKeyRegistrationEnabled: "true", + SettingKeyRegistrationEmailSuffixWhitelist: `["@example.com","@company.com"]`, + }, nil) + + err := service.SendVerifyCode(context.Background(), "user@other.com") + require.ErrorIs(t, err, ErrEmailSuffixNotAllowed) + appErr := infraerrors.FromError(err) + require.Contains(t, appErr.Message, "@example.com") + require.Contains(t, appErr.Message, "@company.com") + require.Equal(t, "2", appErr.Metadata["allowed_suffix_count"]) +} + func TestAuthService_Register_CreateError(t *testing.T) { repo := &userRepoStub{createErr: errors.New("create failed")} service := newAuthService(repo, map[string]string{ @@ -402,7 +449,7 @@ func TestAuthService_Register_AssignsDefaultSubscriptions(t *testing.T) { repo := &userRepoStub{nextID: 42} assigner := &defaultSubscriptionAssignerStub{} service := newAuthService(repo, map[string]string{ - SettingKeyRegistrationEnabled: "true", + SettingKeyRegistrationEnabled: "true", SettingKeyDefaultSubscriptions: `[{"group_id":11,"validity_days":30},{"group_id":12,"validity_days":7}]`, }, nil) service.defaultSubAssigner = assigner diff --git a/backend/internal/service/auth_service_turnstile_register_test.go b/backend/internal/service/auth_service_turnstile_register_test.go index 36cb1e06..477ba1b2 100644 --- a/backend/internal/service/auth_service_turnstile_register_test.go +++ b/backend/internal/service/auth_service_turnstile_register_test.go @@ -43,6 +43,7 @@ func newAuthServiceForRegisterTurnstileTest(settings map[string]string, verifier turnstileService := NewTurnstileService(settingService, verifier) return NewAuthService( + nil, // entClient &userRepoStub{}, nil, // redeemRepo nil, // refreshTokenCache diff --git a/backend/internal/service/billing_cache_service.go b/backend/internal/service/billing_cache_service.go index e055c0f7..f2ad0a3d 100644 --- a/backend/internal/service/billing_cache_service.go +++ b/backend/internal/service/billing_cache_service.go @@ -565,15 +565,15 @@ func (s *BillingCacheService) evaluateRateLimits(ctx context.Context, apiKey *AP needsReset := false // Reset expired windows in-memory for check purposes - if w5h != nil && time.Since(*w5h) >= 5*time.Hour { + if IsWindowExpired(w5h, RateLimitWindow5h) { usage5h = 0 needsReset = true } - if w1d != nil && time.Since(*w1d) >= 24*time.Hour { + if IsWindowExpired(w1d, RateLimitWindow1d) { usage1d = 0 needsReset = true } - if w7d != nil && time.Since(*w7d) >= 7*24*time.Hour { + if IsWindowExpired(w7d, RateLimitWindow7d) { usage7d = 0 needsReset = true } @@ -589,12 +589,16 @@ func (s *BillingCacheService) evaluateRateLimits(ctx context.Context, apiKey *AP if loader, ok := s.apiKeyRateLimitLoader.(interface { ResetRateLimitWindows(ctx context.Context, id int64) error }); ok { - _ = loader.ResetRateLimitWindows(resetCtx, keyID) + if err := loader.ResetRateLimitWindows(resetCtx, keyID); err != nil { + logger.LegacyPrintf("service.billing_cache", "Warning: reset rate limit windows failed for api key %d: %v", keyID, err) + } } } // Invalidate cache so next request loads fresh data if s.cache != nil { - _ = s.cache.InvalidateAPIKeyRateLimit(resetCtx, keyID) + if err := s.cache.InvalidateAPIKeyRateLimit(resetCtx, keyID); err != nil { + logger.LegacyPrintf("service.billing_cache", "Warning: invalidate rate limit cache failed for api key %d: %v", keyID, err) + } } }() } diff --git a/backend/internal/service/billing_service.go b/backend/internal/service/billing_service.go index 5d67c808..68d7a8f9 100644 --- a/backend/internal/service/billing_service.go +++ b/backend/internal/service/billing_service.go @@ -43,13 +43,47 @@ type BillingCache interface { // ModelPricing 模型价格配置(per-token价格,与LiteLLM格式一致) type ModelPricing struct { - InputPricePerToken float64 // 每token输入价格 (USD) - OutputPricePerToken float64 // 每token输出价格 (USD) - CacheCreationPricePerToken float64 // 缓存创建每token价格 (USD) - CacheReadPricePerToken float64 // 缓存读取每token价格 (USD) - CacheCreation5mPrice float64 // 5分钟缓存创建每token价格 (USD) - CacheCreation1hPrice float64 // 1小时缓存创建每token价格 (USD) - SupportsCacheBreakdown bool // 是否支持详细的缓存分类 + InputPricePerToken float64 // 每token输入价格 (USD) + InputPricePerTokenPriority float64 // priority service tier 下每token输入价格 (USD) + OutputPricePerToken float64 // 每token输出价格 (USD) + OutputPricePerTokenPriority float64 // priority service tier 下每token输出价格 (USD) + CacheCreationPricePerToken float64 // 缓存创建每token价格 (USD) + CacheReadPricePerToken float64 // 缓存读取每token价格 (USD) + CacheReadPricePerTokenPriority float64 // priority service tier 下缓存读取每token价格 (USD) + CacheCreation5mPrice float64 // 5分钟缓存创建每token价格 (USD) + CacheCreation1hPrice float64 // 1小时缓存创建每token价格 (USD) + SupportsCacheBreakdown bool // 是否支持详细的缓存分类 + LongContextInputThreshold int // 超过阈值后按整次会话提升输入价格 + LongContextInputMultiplier float64 // 长上下文整次会话输入倍率 + LongContextOutputMultiplier float64 // 长上下文整次会话输出倍率 +} + +const ( + openAIGPT54LongContextInputThreshold = 272000 + openAIGPT54LongContextInputMultiplier = 2.0 + openAIGPT54LongContextOutputMultiplier = 1.5 +) + +func normalizeBillingServiceTier(serviceTier string) string { + return strings.ToLower(strings.TrimSpace(serviceTier)) +} + +func usePriorityServiceTierPricing(serviceTier string, pricing *ModelPricing) bool { + if pricing == nil || normalizeBillingServiceTier(serviceTier) != "priority" { + return false + } + return pricing.InputPricePerTokenPriority > 0 || pricing.OutputPricePerTokenPriority > 0 || pricing.CacheReadPricePerTokenPriority > 0 +} + +func serviceTierCostMultiplier(serviceTier string) float64 { + switch normalizeBillingServiceTier(serviceTier) { + case "priority": + return 2.0 + case "flex": + return 0.5 + default: + return 1.0 + } } // UsageTokens 使用的token数量 @@ -161,6 +195,65 @@ func (s *BillingService) initFallbackPricing() { CacheReadPricePerToken: 0.2e-6, // $0.20 per MTok SupportsCacheBreakdown: false, } + + // OpenAI GPT-5.1(本地兜底,防止动态定价不可用时拒绝计费) + s.fallbackPrices["gpt-5.1"] = &ModelPricing{ + InputPricePerToken: 1.25e-6, // $1.25 per MTok + InputPricePerTokenPriority: 2.5e-6, // $2.5 per MTok + OutputPricePerToken: 10e-6, // $10 per MTok + OutputPricePerTokenPriority: 20e-6, // $20 per MTok + CacheCreationPricePerToken: 1.25e-6, // $1.25 per MTok + CacheReadPricePerToken: 0.125e-6, + CacheReadPricePerTokenPriority: 0.25e-6, + SupportsCacheBreakdown: false, + } + // OpenAI GPT-5.4(业务指定价格) + s.fallbackPrices["gpt-5.4"] = &ModelPricing{ + InputPricePerToken: 2.5e-6, // $2.5 per MTok + InputPricePerTokenPriority: 5e-6, // $5 per MTok + OutputPricePerToken: 15e-6, // $15 per MTok + OutputPricePerTokenPriority: 30e-6, // $30 per MTok + CacheCreationPricePerToken: 2.5e-6, // $2.5 per MTok + CacheReadPricePerToken: 0.25e-6, // $0.25 per MTok + CacheReadPricePerTokenPriority: 0.5e-6, // $0.5 per MTok + SupportsCacheBreakdown: false, + LongContextInputThreshold: openAIGPT54LongContextInputThreshold, + LongContextInputMultiplier: openAIGPT54LongContextInputMultiplier, + LongContextOutputMultiplier: openAIGPT54LongContextOutputMultiplier, + } + // OpenAI GPT-5.2(本地兜底) + s.fallbackPrices["gpt-5.2"] = &ModelPricing{ + InputPricePerToken: 1.75e-6, + InputPricePerTokenPriority: 3.5e-6, + OutputPricePerToken: 14e-6, + OutputPricePerTokenPriority: 28e-6, + CacheCreationPricePerToken: 1.75e-6, + CacheReadPricePerToken: 0.175e-6, + CacheReadPricePerTokenPriority: 0.35e-6, + SupportsCacheBreakdown: false, + } + // Codex 族兜底统一按 GPT-5.1 Codex 价格计费 + s.fallbackPrices["gpt-5.1-codex"] = &ModelPricing{ + InputPricePerToken: 1.5e-6, // $1.5 per MTok + InputPricePerTokenPriority: 3e-6, // $3 per MTok + OutputPricePerToken: 12e-6, // $12 per MTok + OutputPricePerTokenPriority: 24e-6, // $24 per MTok + CacheCreationPricePerToken: 1.5e-6, // $1.5 per MTok + CacheReadPricePerToken: 0.15e-6, + CacheReadPricePerTokenPriority: 0.3e-6, + SupportsCacheBreakdown: false, + } + s.fallbackPrices["gpt-5.2-codex"] = &ModelPricing{ + InputPricePerToken: 1.75e-6, + InputPricePerTokenPriority: 3.5e-6, + OutputPricePerToken: 14e-6, + OutputPricePerTokenPriority: 28e-6, + CacheCreationPricePerToken: 1.75e-6, + CacheReadPricePerToken: 0.175e-6, + CacheReadPricePerTokenPriority: 0.35e-6, + SupportsCacheBreakdown: false, + } + s.fallbackPrices["gpt-5.3-codex"] = s.fallbackPrices["gpt-5.1-codex"] } // getFallbackPricing 根据模型系列获取回退价格 @@ -189,12 +282,34 @@ func (s *BillingService) getFallbackPricing(model string) *ModelPricing { } return s.fallbackPrices["claude-3-haiku"] } + // Claude 未知型号统一回退到 Sonnet,避免计费中断。 + if strings.Contains(modelLower, "claude") { + return s.fallbackPrices["claude-sonnet-4"] + } if strings.Contains(modelLower, "gemini-3.1-pro") || strings.Contains(modelLower, "gemini-3-1-pro") { return s.fallbackPrices["gemini-3.1-pro"] } - // 默认使用Sonnet价格 - return s.fallbackPrices["claude-sonnet-4"] + // OpenAI 仅匹配已知 GPT-5/Codex 族,避免未知 OpenAI 型号误计价。 + if strings.Contains(modelLower, "gpt-5") || strings.Contains(modelLower, "codex") { + normalized := normalizeCodexModel(modelLower) + switch normalized { + case "gpt-5.4": + return s.fallbackPrices["gpt-5.4"] + case "gpt-5.2": + return s.fallbackPrices["gpt-5.2"] + case "gpt-5.2-codex": + return s.fallbackPrices["gpt-5.2-codex"] + case "gpt-5.3-codex": + return s.fallbackPrices["gpt-5.3-codex"] + case "gpt-5.1-codex", "gpt-5.1-codex-max", "gpt-5.1-codex-mini", "codex-mini-latest": + return s.fallbackPrices["gpt-5.1-codex"] + case "gpt-5.1": + return s.fallbackPrices["gpt-5.1"] + } + } + + return nil } // GetModelPricing 获取模型价格配置 @@ -212,15 +327,21 @@ func (s *BillingService) GetModelPricing(model string) (*ModelPricing, error) { price5m := litellmPricing.CacheCreationInputTokenCost price1h := litellmPricing.CacheCreationInputTokenCostAbove1hr enableBreakdown := price1h > 0 && price1h > price5m - return &ModelPricing{ - InputPricePerToken: litellmPricing.InputCostPerToken, - OutputPricePerToken: litellmPricing.OutputCostPerToken, - CacheCreationPricePerToken: litellmPricing.CacheCreationInputTokenCost, - CacheReadPricePerToken: litellmPricing.CacheReadInputTokenCost, - CacheCreation5mPrice: price5m, - CacheCreation1hPrice: price1h, - SupportsCacheBreakdown: enableBreakdown, - }, nil + return s.applyModelSpecificPricingPolicy(model, &ModelPricing{ + InputPricePerToken: litellmPricing.InputCostPerToken, + InputPricePerTokenPriority: litellmPricing.InputCostPerTokenPriority, + OutputPricePerToken: litellmPricing.OutputCostPerToken, + OutputPricePerTokenPriority: litellmPricing.OutputCostPerTokenPriority, + CacheCreationPricePerToken: litellmPricing.CacheCreationInputTokenCost, + CacheReadPricePerToken: litellmPricing.CacheReadInputTokenCost, + CacheReadPricePerTokenPriority: litellmPricing.CacheReadInputTokenCostPriority, + CacheCreation5mPrice: price5m, + CacheCreation1hPrice: price1h, + SupportsCacheBreakdown: enableBreakdown, + LongContextInputThreshold: litellmPricing.LongContextInputTokenThreshold, + LongContextInputMultiplier: litellmPricing.LongContextInputCostMultiplier, + LongContextOutputMultiplier: litellmPricing.LongContextOutputCostMultiplier, + }), nil } } @@ -228,7 +349,7 @@ func (s *BillingService) GetModelPricing(model string) (*ModelPricing, error) { fallback := s.getFallbackPricing(model) if fallback != nil { log.Printf("[Billing] Using fallback pricing for model: %s", model) - return fallback, nil + return s.applyModelSpecificPricingPolicy(model, fallback), nil } return nil, fmt.Errorf("pricing not found for model: %s", model) @@ -236,18 +357,43 @@ func (s *BillingService) GetModelPricing(model string) (*ModelPricing, error) { // CalculateCost 计算使用费用 func (s *BillingService) CalculateCost(model string, tokens UsageTokens, rateMultiplier float64) (*CostBreakdown, error) { + return s.CalculateCostWithServiceTier(model, tokens, rateMultiplier, "") +} + +func (s *BillingService) CalculateCostWithServiceTier(model string, tokens UsageTokens, rateMultiplier float64, serviceTier string) (*CostBreakdown, error) { pricing, err := s.GetModelPricing(model) if err != nil { return nil, err } breakdown := &CostBreakdown{} + inputPricePerToken := pricing.InputPricePerToken + outputPricePerToken := pricing.OutputPricePerToken + cacheReadPricePerToken := pricing.CacheReadPricePerToken + tierMultiplier := 1.0 + if usePriorityServiceTierPricing(serviceTier, pricing) { + if pricing.InputPricePerTokenPriority > 0 { + inputPricePerToken = pricing.InputPricePerTokenPriority + } + if pricing.OutputPricePerTokenPriority > 0 { + outputPricePerToken = pricing.OutputPricePerTokenPriority + } + if pricing.CacheReadPricePerTokenPriority > 0 { + cacheReadPricePerToken = pricing.CacheReadPricePerTokenPriority + } + } else { + tierMultiplier = serviceTierCostMultiplier(serviceTier) + } + if s.shouldApplySessionLongContextPricing(tokens, pricing) { + inputPricePerToken *= pricing.LongContextInputMultiplier + outputPricePerToken *= pricing.LongContextOutputMultiplier + } // 计算输入token费用(使用per-token价格) - breakdown.InputCost = float64(tokens.InputTokens) * pricing.InputPricePerToken + breakdown.InputCost = float64(tokens.InputTokens) * inputPricePerToken // 计算输出token费用 - breakdown.OutputCost = float64(tokens.OutputTokens) * pricing.OutputPricePerToken + breakdown.OutputCost = float64(tokens.OutputTokens) * outputPricePerToken // 计算缓存费用 if pricing.SupportsCacheBreakdown && (pricing.CacheCreation5mPrice > 0 || pricing.CacheCreation1hPrice > 0) { @@ -264,7 +410,14 @@ func (s *BillingService) CalculateCost(model string, tokens UsageTokens, rateMul breakdown.CacheCreationCost = float64(tokens.CacheCreationTokens) * pricing.CacheCreationPricePerToken } - breakdown.CacheReadCost = float64(tokens.CacheReadTokens) * pricing.CacheReadPricePerToken + breakdown.CacheReadCost = float64(tokens.CacheReadTokens) * cacheReadPricePerToken + + if tierMultiplier != 1.0 { + breakdown.InputCost *= tierMultiplier + breakdown.OutputCost *= tierMultiplier + breakdown.CacheCreationCost *= tierMultiplier + breakdown.CacheReadCost *= tierMultiplier + } // 计算总费用 breakdown.TotalCost = breakdown.InputCost + breakdown.OutputCost + @@ -279,6 +432,45 @@ func (s *BillingService) CalculateCost(model string, tokens UsageTokens, rateMul return breakdown, nil } +func (s *BillingService) applyModelSpecificPricingPolicy(model string, pricing *ModelPricing) *ModelPricing { + if pricing == nil { + return nil + } + if !isOpenAIGPT54Model(model) { + return pricing + } + if pricing.LongContextInputThreshold > 0 && pricing.LongContextInputMultiplier > 0 && pricing.LongContextOutputMultiplier > 0 { + return pricing + } + cloned := *pricing + if cloned.LongContextInputThreshold <= 0 { + cloned.LongContextInputThreshold = openAIGPT54LongContextInputThreshold + } + if cloned.LongContextInputMultiplier <= 0 { + cloned.LongContextInputMultiplier = openAIGPT54LongContextInputMultiplier + } + if cloned.LongContextOutputMultiplier <= 0 { + cloned.LongContextOutputMultiplier = openAIGPT54LongContextOutputMultiplier + } + return &cloned +} + +func (s *BillingService) shouldApplySessionLongContextPricing(tokens UsageTokens, pricing *ModelPricing) bool { + if pricing == nil || pricing.LongContextInputThreshold <= 0 { + return false + } + if pricing.LongContextInputMultiplier <= 1 && pricing.LongContextOutputMultiplier <= 1 { + return false + } + totalInputTokens := tokens.InputTokens + tokens.CacheReadTokens + return totalInputTokens > pricing.LongContextInputThreshold +} + +func isOpenAIGPT54Model(model string) bool { + normalized := normalizeCodexModel(strings.TrimSpace(strings.ToLower(model))) + return normalized == "gpt-5.4" +} + // CalculateCostWithConfig 使用配置中的默认倍率计算费用 func (s *BillingService) CalculateCostWithConfig(model string, tokens UsageTokens) (*CostBreakdown, error) { multiplier := s.cfg.Default.RateMultiplier diff --git a/backend/internal/service/billing_service_test.go b/backend/internal/service/billing_service_test.go index 5eb278f6..45bbdcee 100644 --- a/backend/internal/service/billing_service_test.go +++ b/backend/internal/service/billing_service_test.go @@ -133,7 +133,7 @@ func TestGetModelPricing_CaseInsensitive(t *testing.T) { require.Equal(t, p1.InputPricePerToken, p2.InputPricePerToken) } -func TestGetModelPricing_UnknownModelFallsBackToSonnet(t *testing.T) { +func TestGetModelPricing_UnknownClaudeModelFallsBackToSonnet(t *testing.T) { svc := newTestBillingService() // 不包含 opus/sonnet/haiku 关键词的 Claude 模型会走默认 Sonnet 价格 @@ -142,6 +142,93 @@ func TestGetModelPricing_UnknownModelFallsBackToSonnet(t *testing.T) { require.InDelta(t, 3e-6, pricing.InputPricePerToken, 1e-12) } +func TestGetModelPricing_UnknownOpenAIModelReturnsError(t *testing.T) { + svc := newTestBillingService() + + pricing, err := svc.GetModelPricing("gpt-unknown-model") + require.Error(t, err) + require.Nil(t, pricing) + require.Contains(t, err.Error(), "pricing not found") +} + +func TestGetModelPricing_OpenAIGPT51Fallback(t *testing.T) { + svc := newTestBillingService() + + pricing, err := svc.GetModelPricing("gpt-5.1") + require.NoError(t, err) + require.NotNil(t, pricing) + require.InDelta(t, 1.25e-6, pricing.InputPricePerToken, 1e-12) +} + +func TestGetModelPricing_OpenAIGPT54Fallback(t *testing.T) { + svc := newTestBillingService() + + pricing, err := svc.GetModelPricing("gpt-5.4") + require.NoError(t, err) + require.NotNil(t, pricing) + require.InDelta(t, 2.5e-6, pricing.InputPricePerToken, 1e-12) + require.InDelta(t, 15e-6, pricing.OutputPricePerToken, 1e-12) + require.InDelta(t, 0.25e-6, pricing.CacheReadPricePerToken, 1e-12) + require.Equal(t, 272000, pricing.LongContextInputThreshold) + require.InDelta(t, 2.0, pricing.LongContextInputMultiplier, 1e-12) + require.InDelta(t, 1.5, pricing.LongContextOutputMultiplier, 1e-12) +} + +func TestCalculateCost_OpenAIGPT54LongContextAppliesWholeSessionMultipliers(t *testing.T) { + svc := newTestBillingService() + + tokens := UsageTokens{ + InputTokens: 300000, + OutputTokens: 4000, + } + + cost, err := svc.CalculateCost("gpt-5.4-2026-03-05", tokens, 1.0) + require.NoError(t, err) + + expectedInput := float64(tokens.InputTokens) * 2.5e-6 * 2.0 + expectedOutput := float64(tokens.OutputTokens) * 15e-6 * 1.5 + require.InDelta(t, expectedInput, cost.InputCost, 1e-10) + require.InDelta(t, expectedOutput, cost.OutputCost, 1e-10) + require.InDelta(t, expectedInput+expectedOutput, cost.TotalCost, 1e-10) + require.InDelta(t, expectedInput+expectedOutput, cost.ActualCost, 1e-10) +} + +func TestGetFallbackPricing_FamilyMatching(t *testing.T) { + svc := newTestBillingService() + + tests := []struct { + name string + model string + expectedInput float64 + expectNilPricing bool + }{ + {name: "empty model", model: " ", expectNilPricing: true}, + {name: "claude opus 4.6", model: "claude-opus-4.6-20260201", expectedInput: 5e-6}, + {name: "claude opus 4.5 alt separator", model: "claude-opus-4-5-20260101", expectedInput: 5e-6}, + {name: "claude generic model fallback sonnet", model: "claude-foo-bar", expectedInput: 3e-6}, + {name: "gemini explicit fallback", model: "gemini-3-1-pro", expectedInput: 2e-6}, + {name: "gemini unknown no fallback", model: "gemini-2.0-pro", expectNilPricing: true}, + {name: "openai gpt5.1", model: "gpt-5.1", expectedInput: 1.25e-6}, + {name: "openai gpt5.4", model: "gpt-5.4", expectedInput: 2.5e-6}, + {name: "openai gpt5.3 codex", model: "gpt-5.3-codex", expectedInput: 1.5e-6}, + {name: "openai gpt5.1 codex max alias", model: "gpt-5.1-codex-max", expectedInput: 1.5e-6}, + {name: "openai codex mini latest alias", model: "codex-mini-latest", expectedInput: 1.5e-6}, + {name: "openai unknown no fallback", model: "gpt-unknown-model", expectNilPricing: true}, + {name: "non supported family", model: "qwen-max", expectNilPricing: true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pricing := svc.getFallbackPricing(tt.model) + if tt.expectNilPricing { + require.Nil(t, pricing) + return + } + require.NotNil(t, pricing) + require.InDelta(t, tt.expectedInput, pricing.InputPricePerToken, 1e-12) + }) + } +} func TestCalculateCostWithLongContext_BelowThreshold(t *testing.T) { svc := newTestBillingService() @@ -435,3 +522,189 @@ func TestCalculateCost_LargeTokenCount(t *testing.T) { require.False(t, math.IsNaN(cost.TotalCost)) require.False(t, math.IsInf(cost.TotalCost, 0)) } + +func TestServiceTierCostMultiplier(t *testing.T) { + require.InDelta(t, 2.0, serviceTierCostMultiplier("priority"), 1e-12) + require.InDelta(t, 2.0, serviceTierCostMultiplier(" Priority "), 1e-12) + require.InDelta(t, 0.5, serviceTierCostMultiplier("flex"), 1e-12) + require.InDelta(t, 1.0, serviceTierCostMultiplier(""), 1e-12) + require.InDelta(t, 1.0, serviceTierCostMultiplier("default"), 1e-12) +} + +func TestCalculateCostWithServiceTier_OpenAIPriorityUsesPriorityPricing(t *testing.T) { + svc := newTestBillingService() + tokens := UsageTokens{InputTokens: 100, OutputTokens: 50, CacheReadTokens: 20} + + baseCost, err := svc.CalculateCost("gpt-5.1-codex", tokens, 1.0) + require.NoError(t, err) + + priorityCost, err := svc.CalculateCostWithServiceTier("gpt-5.1-codex", tokens, 1.0, "priority") + require.NoError(t, err) + + require.InDelta(t, baseCost.InputCost*2, priorityCost.InputCost, 1e-10) + require.InDelta(t, baseCost.OutputCost*2, priorityCost.OutputCost, 1e-10) + require.InDelta(t, baseCost.CacheReadCost*2, priorityCost.CacheReadCost, 1e-10) + require.InDelta(t, baseCost.TotalCost*2, priorityCost.TotalCost, 1e-10) +} + +func TestCalculateCostWithServiceTier_FlexAppliesHalfMultiplier(t *testing.T) { + svc := newTestBillingService() + tokens := UsageTokens{InputTokens: 100, OutputTokens: 50, CacheCreationTokens: 40, CacheReadTokens: 20} + + baseCost, err := svc.CalculateCost("gpt-5.4", tokens, 1.0) + require.NoError(t, err) + + flexCost, err := svc.CalculateCostWithServiceTier("gpt-5.4", tokens, 1.0, "flex") + require.NoError(t, err) + + require.InDelta(t, baseCost.InputCost*0.5, flexCost.InputCost, 1e-10) + require.InDelta(t, baseCost.OutputCost*0.5, flexCost.OutputCost, 1e-10) + require.InDelta(t, baseCost.CacheCreationCost*0.5, flexCost.CacheCreationCost, 1e-10) + require.InDelta(t, baseCost.CacheReadCost*0.5, flexCost.CacheReadCost, 1e-10) + require.InDelta(t, baseCost.TotalCost*0.5, flexCost.TotalCost, 1e-10) +} + +func TestCalculateCostWithServiceTier_PriorityFallsBackToTierMultiplierWithoutExplicitPriorityPrice(t *testing.T) { + svc := newTestBillingService() + tokens := UsageTokens{InputTokens: 120, OutputTokens: 30, CacheCreationTokens: 12, CacheReadTokens: 8} + + baseCost, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0) + require.NoError(t, err) + + priorityCost, err := svc.CalculateCostWithServiceTier("claude-sonnet-4", tokens, 1.0, "priority") + require.NoError(t, err) + + require.InDelta(t, baseCost.InputCost*2, priorityCost.InputCost, 1e-10) + require.InDelta(t, baseCost.OutputCost*2, priorityCost.OutputCost, 1e-10) + require.InDelta(t, baseCost.CacheCreationCost*2, priorityCost.CacheCreationCost, 1e-10) + require.InDelta(t, baseCost.CacheReadCost*2, priorityCost.CacheReadCost, 1e-10) + require.InDelta(t, baseCost.TotalCost*2, priorityCost.TotalCost, 1e-10) +} + +func TestBillingServiceGetModelPricing_UsesDynamicPriorityFields(t *testing.T) { + pricingSvc := &PricingService{ + pricingData: map[string]*LiteLLMModelPricing{ + "gpt-5.4": { + InputCostPerToken: 2.5e-6, + InputCostPerTokenPriority: 5e-6, + OutputCostPerToken: 15e-6, + OutputCostPerTokenPriority: 30e-6, + CacheCreationInputTokenCost: 2.5e-6, + CacheReadInputTokenCost: 0.25e-6, + CacheReadInputTokenCostPriority: 0.5e-6, + LongContextInputTokenThreshold: 272000, + LongContextInputCostMultiplier: 2.0, + LongContextOutputCostMultiplier: 1.5, + }, + }, + } + svc := NewBillingService(&config.Config{}, pricingSvc) + + pricing, err := svc.GetModelPricing("gpt-5.4") + require.NoError(t, err) + require.InDelta(t, 2.5e-6, pricing.InputPricePerToken, 1e-12) + require.InDelta(t, 5e-6, pricing.InputPricePerTokenPriority, 1e-12) + require.InDelta(t, 15e-6, pricing.OutputPricePerToken, 1e-12) + require.InDelta(t, 30e-6, pricing.OutputPricePerTokenPriority, 1e-12) + require.InDelta(t, 0.25e-6, pricing.CacheReadPricePerToken, 1e-12) + require.InDelta(t, 0.5e-6, pricing.CacheReadPricePerTokenPriority, 1e-12) + require.Equal(t, 272000, pricing.LongContextInputThreshold) + require.InDelta(t, 2.0, pricing.LongContextInputMultiplier, 1e-12) + require.InDelta(t, 1.5, pricing.LongContextOutputMultiplier, 1e-12) +} + +func TestBillingServiceGetModelPricing_OpenAIFallbackGpt52Variants(t *testing.T) { + svc := newTestBillingService() + + gpt52, err := svc.GetModelPricing("gpt-5.2") + require.NoError(t, err) + require.NotNil(t, gpt52) + require.InDelta(t, 1.75e-6, gpt52.InputPricePerToken, 1e-12) + require.InDelta(t, 3.5e-6, gpt52.InputPricePerTokenPriority, 1e-12) + + gpt52Codex, err := svc.GetModelPricing("gpt-5.2-codex") + require.NoError(t, err) + require.NotNil(t, gpt52Codex) + require.InDelta(t, 1.75e-6, gpt52Codex.InputPricePerToken, 1e-12) + require.InDelta(t, 3.5e-6, gpt52Codex.InputPricePerTokenPriority, 1e-12) + require.InDelta(t, 28e-6, gpt52Codex.OutputPricePerTokenPriority, 1e-12) +} + +func TestCalculateCostWithServiceTier_PriorityFallsBackToTierMultiplierWhenExplicitPriceMissing(t *testing.T) { + svc := NewBillingService(&config.Config{}, &PricingService{ + pricingData: map[string]*LiteLLMModelPricing{ + "custom-no-priority": { + InputCostPerToken: 1e-6, + OutputCostPerToken: 2e-6, + CacheCreationInputTokenCost: 0.5e-6, + CacheReadInputTokenCost: 0.25e-6, + }, + }, + }) + tokens := UsageTokens{InputTokens: 100, OutputTokens: 50, CacheCreationTokens: 40, CacheReadTokens: 20} + + baseCost, err := svc.CalculateCost("custom-no-priority", tokens, 1.0) + require.NoError(t, err) + + priorityCost, err := svc.CalculateCostWithServiceTier("custom-no-priority", tokens, 1.0, "priority") + require.NoError(t, err) + + require.InDelta(t, baseCost.InputCost*2, priorityCost.InputCost, 1e-10) + require.InDelta(t, baseCost.OutputCost*2, priorityCost.OutputCost, 1e-10) + require.InDelta(t, baseCost.CacheCreationCost*2, priorityCost.CacheCreationCost, 1e-10) + require.InDelta(t, baseCost.CacheReadCost*2, priorityCost.CacheReadCost, 1e-10) + require.InDelta(t, baseCost.TotalCost*2, priorityCost.TotalCost, 1e-10) +} + +func TestGetModelPricing_OpenAIGpt52FallbacksExposePriorityPrices(t *testing.T) { + svc := newTestBillingService() + + gpt52, err := svc.GetModelPricing("gpt-5.2") + require.NoError(t, err) + require.InDelta(t, 1.75e-6, gpt52.InputPricePerToken, 1e-12) + require.InDelta(t, 3.5e-6, gpt52.InputPricePerTokenPriority, 1e-12) + require.InDelta(t, 14e-6, gpt52.OutputPricePerToken, 1e-12) + require.InDelta(t, 28e-6, gpt52.OutputPricePerTokenPriority, 1e-12) + + gpt52Codex, err := svc.GetModelPricing("gpt-5.2-codex") + require.NoError(t, err) + require.InDelta(t, 1.75e-6, gpt52Codex.InputPricePerToken, 1e-12) + require.InDelta(t, 3.5e-6, gpt52Codex.InputPricePerTokenPriority, 1e-12) + require.InDelta(t, 14e-6, gpt52Codex.OutputPricePerToken, 1e-12) + require.InDelta(t, 28e-6, gpt52Codex.OutputPricePerTokenPriority, 1e-12) +} + +func TestGetModelPricing_MapsDynamicPriorityFieldsIntoBillingPricing(t *testing.T) { + svc := NewBillingService(&config.Config{}, &PricingService{ + pricingData: map[string]*LiteLLMModelPricing{ + "dynamic-tier-model": { + InputCostPerToken: 1e-6, + InputCostPerTokenPriority: 2e-6, + OutputCostPerToken: 3e-6, + OutputCostPerTokenPriority: 6e-6, + CacheCreationInputTokenCost: 4e-6, + CacheCreationInputTokenCostAbove1hr: 5e-6, + CacheReadInputTokenCost: 7e-7, + CacheReadInputTokenCostPriority: 8e-7, + LongContextInputTokenThreshold: 999, + LongContextInputCostMultiplier: 1.5, + LongContextOutputCostMultiplier: 1.25, + }, + }, + }) + + pricing, err := svc.GetModelPricing("dynamic-tier-model") + require.NoError(t, err) + require.InDelta(t, 1e-6, pricing.InputPricePerToken, 1e-12) + require.InDelta(t, 2e-6, pricing.InputPricePerTokenPriority, 1e-12) + require.InDelta(t, 3e-6, pricing.OutputPricePerToken, 1e-12) + require.InDelta(t, 6e-6, pricing.OutputPricePerTokenPriority, 1e-12) + require.InDelta(t, 4e-6, pricing.CacheCreation5mPrice, 1e-12) + require.InDelta(t, 5e-6, pricing.CacheCreation1hPrice, 1e-12) + require.True(t, pricing.SupportsCacheBreakdown) + require.InDelta(t, 7e-7, pricing.CacheReadPricePerToken, 1e-12) + require.InDelta(t, 8e-7, pricing.CacheReadPricePerTokenPriority, 1e-12) + require.Equal(t, 999, pricing.LongContextInputThreshold) + require.InDelta(t, 1.5, pricing.LongContextInputMultiplier, 1e-12) + require.InDelta(t, 1.25, pricing.LongContextOutputMultiplier, 1e-12) +} diff --git a/backend/internal/service/claude_max_cache_billing_policy.go b/backend/internal/service/claude_max_cache_billing_policy.go new file mode 100644 index 00000000..2381915e --- /dev/null +++ b/backend/internal/service/claude_max_cache_billing_policy.go @@ -0,0 +1,450 @@ +package service + +import ( + "encoding/json" + "strings" + + "github.com/Wei-Shaw/sub2api/internal/pkg/claude" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/tidwall/gjson" +) + +type claudeMaxCacheBillingOutcome struct { + Simulated bool +} + +func applyClaudeMaxCacheBillingPolicyToUsage(usage *ClaudeUsage, parsed *ParsedRequest, group *Group, model string, accountID int64) claudeMaxCacheBillingOutcome { + var out claudeMaxCacheBillingOutcome + if usage == nil || !shouldApplyClaudeMaxBillingRulesForUsage(group, model, parsed) { + return out + } + + resolvedModel := strings.TrimSpace(model) + if resolvedModel == "" && parsed != nil { + resolvedModel = strings.TrimSpace(parsed.Model) + } + + if hasCacheCreationTokens(*usage) { + // Upstream already returned cache creation usage; keep original usage. + return out + } + + if !shouldSimulateClaudeMaxUsageForUsage(*usage, parsed) { + return out + } + beforeInputTokens := usage.InputTokens + out.Simulated = safelyProjectUsageToClaudeMax1H(usage, parsed) + if out.Simulated { + logger.LegacyPrintf("service.gateway", "simulate_claude_max_usage: model=%s account=%d input_tokens:%d->%d cache_creation_1h=%d", + resolvedModel, + accountID, + beforeInputTokens, + usage.InputTokens, + usage.CacheCreation1hTokens, + ) + } + return out +} + +func isClaudeFamilyModel(model string) bool { + normalized := strings.ToLower(strings.TrimSpace(claude.NormalizeModelID(model))) + if normalized == "" { + return false + } + return strings.Contains(normalized, "claude-") +} + +func shouldApplyClaudeMaxBillingRules(input *RecordUsageInput) bool { + if input == nil || input.Result == nil || input.APIKey == nil || input.APIKey.Group == nil { + return false + } + return shouldApplyClaudeMaxBillingRulesForUsage(input.APIKey.Group, input.Result.Model, input.ParsedRequest) +} + +func shouldApplyClaudeMaxBillingRulesForUsage(group *Group, model string, parsed *ParsedRequest) bool { + if group == nil { + return false + } + if !group.SimulateClaudeMaxEnabled || group.Platform != PlatformAnthropic { + return false + } + + resolvedModel := model + if resolvedModel == "" && parsed != nil { + resolvedModel = parsed.Model + } + if !isClaudeFamilyModel(resolvedModel) { + return false + } + return true +} + +func hasCacheCreationTokens(usage ClaudeUsage) bool { + return usage.CacheCreationInputTokens > 0 || usage.CacheCreation5mTokens > 0 || usage.CacheCreation1hTokens > 0 +} + +func shouldSimulateClaudeMaxUsage(input *RecordUsageInput) bool { + if input == nil || input.Result == nil { + return false + } + if !shouldApplyClaudeMaxBillingRules(input) { + return false + } + return shouldSimulateClaudeMaxUsageForUsage(input.Result.Usage, input.ParsedRequest) +} + +func shouldSimulateClaudeMaxUsageForUsage(usage ClaudeUsage, parsed *ParsedRequest) bool { + if usage.InputTokens <= 0 { + return false + } + if hasCacheCreationTokens(usage) { + return false + } + if !hasClaudeCacheSignals(parsed) { + return false + } + return true +} + +func safelyProjectUsageToClaudeMax1H(usage *ClaudeUsage, parsed *ParsedRequest) (changed bool) { + defer func() { + if r := recover(); r != nil { + logger.LegacyPrintf("service.gateway", "simulate_claude_max_usage skipped: panic=%v", r) + changed = false + } + }() + return projectUsageToClaudeMax1H(usage, parsed) +} + +func projectUsageToClaudeMax1H(usage *ClaudeUsage, parsed *ParsedRequest) bool { + if usage == nil { + return false + } + totalWindowTokens := usage.InputTokens + usage.CacheCreation5mTokens + usage.CacheCreation1hTokens + if totalWindowTokens <= 1 { + return false + } + + simulatedInputTokens := computeClaudeMaxProjectedInputTokens(totalWindowTokens, parsed) + if simulatedInputTokens <= 0 { + simulatedInputTokens = 1 + } + if simulatedInputTokens >= totalWindowTokens { + simulatedInputTokens = totalWindowTokens - 1 + } + + cacheCreation1hTokens := totalWindowTokens - simulatedInputTokens + if usage.InputTokens == simulatedInputTokens && + usage.CacheCreation5mTokens == 0 && + usage.CacheCreation1hTokens == cacheCreation1hTokens && + usage.CacheCreationInputTokens == cacheCreation1hTokens { + return false + } + + usage.InputTokens = simulatedInputTokens + usage.CacheCreation5mTokens = 0 + usage.CacheCreation1hTokens = cacheCreation1hTokens + usage.CacheCreationInputTokens = cacheCreation1hTokens + return true +} + +type claudeCacheProjection struct { + HasBreakpoint bool + BreakpointCount int + TotalEstimatedTokens int + TailEstimatedTokens int +} + +func computeClaudeMaxProjectedInputTokens(totalWindowTokens int, parsed *ParsedRequest) int { + if totalWindowTokens <= 1 { + return totalWindowTokens + } + + projection := analyzeClaudeCacheProjection(parsed) + if !projection.HasBreakpoint || projection.TotalEstimatedTokens <= 0 || projection.TailEstimatedTokens <= 0 { + return totalWindowTokens + } + + totalEstimate := int64(projection.TotalEstimatedTokens) + tailEstimate := int64(projection.TailEstimatedTokens) + if tailEstimate > totalEstimate { + tailEstimate = totalEstimate + } + + scaled := (int64(totalWindowTokens)*tailEstimate + totalEstimate/2) / totalEstimate + if scaled <= 0 { + scaled = 1 + } + if scaled >= int64(totalWindowTokens) { + scaled = int64(totalWindowTokens - 1) + } + return int(scaled) +} + +func hasClaudeCacheSignals(parsed *ParsedRequest) bool { + if parsed == nil { + return false + } + if hasTopLevelEphemeralCacheControl(parsed) { + return true + } + return countExplicitCacheBreakpoints(parsed) > 0 +} + +func hasTopLevelEphemeralCacheControl(parsed *ParsedRequest) bool { + if parsed == nil || len(parsed.Body) == 0 { + return false + } + cacheType := strings.TrimSpace(gjson.GetBytes(parsed.Body, "cache_control.type").String()) + return strings.EqualFold(cacheType, "ephemeral") +} + +func analyzeClaudeCacheProjection(parsed *ParsedRequest) claudeCacheProjection { + var projection claudeCacheProjection + if parsed == nil { + return projection + } + + total := 0 + lastBreakpointAt := -1 + + switch system := parsed.System.(type) { + case string: + total += claudeMaxMessageOverheadTokens + estimateClaudeTextTokens(system) + case []any: + for _, raw := range system { + block, ok := raw.(map[string]any) + if !ok { + total += claudeMaxUnknownContentTokens + continue + } + total += estimateClaudeBlockTokens(block) + if hasEphemeralCacheControl(block) { + lastBreakpointAt = total + projection.BreakpointCount++ + projection.HasBreakpoint = true + } + } + } + + for _, rawMsg := range parsed.Messages { + total += claudeMaxMessageOverheadTokens + msg, ok := rawMsg.(map[string]any) + if !ok { + total += claudeMaxUnknownContentTokens + continue + } + content, exists := msg["content"] + if !exists { + continue + } + msgTokens, msgLastBreak, msgBreakCount := estimateClaudeContentTokens(content) + total += msgTokens + if msgBreakCount > 0 { + lastBreakpointAt = total - msgTokens + msgLastBreak + projection.BreakpointCount += msgBreakCount + projection.HasBreakpoint = true + } + } + + if total <= 0 { + total = 1 + } + projection.TotalEstimatedTokens = total + + if projection.HasBreakpoint && lastBreakpointAt >= 0 { + tail := total - lastBreakpointAt + if tail <= 0 { + tail = 1 + } + projection.TailEstimatedTokens = tail + return projection + } + + if hasTopLevelEphemeralCacheControl(parsed) { + tail := estimateLastUserMessageTokens(parsed) + if tail <= 0 { + tail = 1 + } + projection.HasBreakpoint = true + projection.BreakpointCount = 1 + projection.TailEstimatedTokens = tail + } + return projection +} + +func countExplicitCacheBreakpoints(parsed *ParsedRequest) int { + if parsed == nil { + return 0 + } + total := 0 + if system, ok := parsed.System.([]any); ok { + for _, raw := range system { + if block, ok := raw.(map[string]any); ok && hasEphemeralCacheControl(block) { + total++ + } + } + } + for _, rawMsg := range parsed.Messages { + msg, ok := rawMsg.(map[string]any) + if !ok { + continue + } + content, ok := msg["content"].([]any) + if !ok { + continue + } + for _, raw := range content { + if block, ok := raw.(map[string]any); ok && hasEphemeralCacheControl(block) { + total++ + } + } + } + return total +} + +func hasEphemeralCacheControl(block map[string]any) bool { + if block == nil { + return false + } + raw, ok := block["cache_control"] + if !ok || raw == nil { + return false + } + switch cc := raw.(type) { + case map[string]any: + cacheType, _ := cc["type"].(string) + return strings.EqualFold(strings.TrimSpace(cacheType), "ephemeral") + case map[string]string: + return strings.EqualFold(strings.TrimSpace(cc["type"]), "ephemeral") + default: + return false + } +} + +func estimateClaudeContentTokens(content any) (tokens int, lastBreakAt int, breakpointCount int) { + switch value := content.(type) { + case string: + return estimateClaudeTextTokens(value), -1, 0 + case []any: + total := 0 + lastBreak := -1 + breaks := 0 + for _, raw := range value { + block, ok := raw.(map[string]any) + if !ok { + total += claudeMaxUnknownContentTokens + continue + } + total += estimateClaudeBlockTokens(block) + if hasEphemeralCacheControl(block) { + lastBreak = total + breaks++ + } + } + return total, lastBreak, breaks + default: + return estimateStructuredTokens(value), -1, 0 + } +} + +func estimateClaudeBlockTokens(block map[string]any) int { + if block == nil { + return claudeMaxUnknownContentTokens + } + tokens := claudeMaxBlockOverheadTokens + blockType, _ := block["type"].(string) + switch blockType { + case "text": + if text, ok := block["text"].(string); ok { + tokens += estimateClaudeTextTokens(text) + } + case "tool_result": + if content, ok := block["content"]; ok { + nested, _, _ := estimateClaudeContentTokens(content) + tokens += nested + } + case "tool_use": + if name, ok := block["name"].(string); ok { + tokens += estimateClaudeTextTokens(name) + } + if input, ok := block["input"]; ok { + tokens += estimateStructuredTokens(input) + } + default: + if text, ok := block["text"].(string); ok { + tokens += estimateClaudeTextTokens(text) + } else if content, ok := block["content"]; ok { + nested, _, _ := estimateClaudeContentTokens(content) + tokens += nested + } + } + if tokens <= claudeMaxBlockOverheadTokens { + tokens += claudeMaxUnknownContentTokens + } + return tokens +} + +func estimateLastUserMessageTokens(parsed *ParsedRequest) int { + if parsed == nil || len(parsed.Messages) == 0 { + return 0 + } + for i := len(parsed.Messages) - 1; i >= 0; i-- { + msg, ok := parsed.Messages[i].(map[string]any) + if !ok { + continue + } + role, _ := msg["role"].(string) + if !strings.EqualFold(strings.TrimSpace(role), "user") { + continue + } + tokens, _, _ := estimateClaudeContentTokens(msg["content"]) + return claudeMaxMessageOverheadTokens + tokens + } + return 0 +} + +func estimateStructuredTokens(v any) int { + if v == nil { + return 0 + } + raw, err := json.Marshal(v) + if err != nil { + return claudeMaxUnknownContentTokens + } + return estimateClaudeTextTokens(string(raw)) +} + +func estimateClaudeTextTokens(text string) int { + if tokens, ok := estimateTokensByThirdPartyTokenizer(text); ok { + return tokens + } + return estimateClaudeTextTokensHeuristic(text) +} + +func estimateClaudeTextTokensHeuristic(text string) int { + normalized := strings.Join(strings.Fields(strings.TrimSpace(text)), " ") + if normalized == "" { + return 0 + } + asciiChars := 0 + nonASCIIChars := 0 + for _, r := range normalized { + if r <= 127 { + asciiChars++ + } else { + nonASCIIChars++ + } + } + tokens := nonASCIIChars + if asciiChars > 0 { + tokens += (asciiChars + 3) / 4 + } + if words := len(strings.Fields(normalized)); words > tokens { + tokens = words + } + if tokens <= 0 { + return 1 + } + return tokens +} diff --git a/backend/internal/service/claude_max_simulation_test.go b/backend/internal/service/claude_max_simulation_test.go new file mode 100644 index 00000000..3d2ae2e6 --- /dev/null +++ b/backend/internal/service/claude_max_simulation_test.go @@ -0,0 +1,156 @@ +package service + +import ( + "strings" + "testing" +) + +func TestProjectUsageToClaudeMax1H_Conservation(t *testing.T) { + usage := &ClaudeUsage{ + InputTokens: 1200, + CacheCreationInputTokens: 0, + CacheCreation5mTokens: 0, + CacheCreation1hTokens: 0, + } + parsed := &ParsedRequest{ + Model: "claude-sonnet-4-5", + Messages: []any{ + map[string]any{ + "role": "user", + "content": []any{ + map[string]any{ + "type": "text", + "text": strings.Repeat("cached context ", 200), + "cache_control": map[string]any{"type": "ephemeral"}, + }, + map[string]any{ + "type": "text", + "text": "summarize quickly", + }, + }, + }, + }, + } + + changed := projectUsageToClaudeMax1H(usage, parsed) + if !changed { + t.Fatalf("expected usage to be projected") + } + + total := usage.InputTokens + usage.CacheCreation5mTokens + usage.CacheCreation1hTokens + if total != 1200 { + t.Fatalf("total tokens changed: got=%d want=%d", total, 1200) + } + if usage.CacheCreation5mTokens != 0 { + t.Fatalf("cache_creation_5m should be 0, got=%d", usage.CacheCreation5mTokens) + } + if usage.InputTokens <= 0 || usage.InputTokens >= 1200 { + t.Fatalf("simulated input out of range, got=%d", usage.InputTokens) + } + if usage.InputTokens > 100 { + t.Fatalf("simulated input should stay near cache breakpoint tail, got=%d", usage.InputTokens) + } + if usage.CacheCreation1hTokens <= 0 { + t.Fatalf("cache_creation_1h should be > 0, got=%d", usage.CacheCreation1hTokens) + } + if usage.CacheCreationInputTokens != usage.CacheCreation1hTokens { + t.Fatalf("cache_creation_input_tokens mismatch: got=%d want=%d", usage.CacheCreationInputTokens, usage.CacheCreation1hTokens) + } +} + +func TestComputeClaudeMaxProjectedInputTokens_Deterministic(t *testing.T) { + parsed := &ParsedRequest{ + Model: "claude-opus-4-5", + Messages: []any{ + map[string]any{ + "role": "user", + "content": []any{ + map[string]any{ + "type": "text", + "text": "build context", + "cache_control": map[string]any{"type": "ephemeral"}, + }, + map[string]any{ + "type": "text", + "text": "what is failing now", + }, + }, + }, + }, + } + + got1 := computeClaudeMaxProjectedInputTokens(4096, parsed) + got2 := computeClaudeMaxProjectedInputTokens(4096, parsed) + if got1 != got2 { + t.Fatalf("non-deterministic input tokens: %d != %d", got1, got2) + } +} + +func TestShouldSimulateClaudeMaxUsage(t *testing.T) { + group := &Group{ + Platform: PlatformAnthropic, + SimulateClaudeMaxEnabled: true, + } + input := &RecordUsageInput{ + Result: &ForwardResult{ + Model: "claude-sonnet-4-5", + Usage: ClaudeUsage{ + InputTokens: 3000, + CacheCreationInputTokens: 0, + CacheCreation5mTokens: 0, + CacheCreation1hTokens: 0, + }, + }, + ParsedRequest: &ParsedRequest{ + Messages: []any{ + map[string]any{ + "role": "user", + "content": []any{ + map[string]any{ + "type": "text", + "text": "cached", + "cache_control": map[string]any{"type": "ephemeral"}, + }, + map[string]any{ + "type": "text", + "text": "tail", + }, + }, + }, + }, + }, + APIKey: &APIKey{Group: group}, + } + + if !shouldSimulateClaudeMaxUsage(input) { + t.Fatalf("expected simulate=true for claude group with cache signal") + } + + input.ParsedRequest = &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "no cache signal"}, + }, + } + if shouldSimulateClaudeMaxUsage(input) { + t.Fatalf("expected simulate=false when request has no cache signal") + } + + input.ParsedRequest = &ParsedRequest{ + Messages: []any{ + map[string]any{ + "role": "user", + "content": []any{ + map[string]any{ + "type": "text", + "text": "cached", + "cache_control": map[string]any{"type": "ephemeral"}, + }, + }, + }, + }, + } + input.Result.Usage.CacheCreationInputTokens = 100 + if shouldSimulateClaudeMaxUsage(input) { + t.Fatalf("expected simulate=false when cache creation already exists") + } +} diff --git a/backend/internal/service/claude_tokenizer.go b/backend/internal/service/claude_tokenizer.go new file mode 100644 index 00000000..61f5e961 --- /dev/null +++ b/backend/internal/service/claude_tokenizer.go @@ -0,0 +1,41 @@ +package service + +import ( + "sync" + + tiktoken "github.com/pkoukk/tiktoken-go" + tiktokenloader "github.com/pkoukk/tiktoken-go-loader" +) + +var ( + claudeTokenizerOnce sync.Once + claudeTokenizer *tiktoken.Tiktoken +) + +func getClaudeTokenizer() *tiktoken.Tiktoken { + claudeTokenizerOnce.Do(func() { + // Use offline loader to avoid runtime dictionary download. + tiktoken.SetBpeLoader(tiktokenloader.NewOfflineLoader()) + // Use a high-capacity tokenizer as the default approximation for Claude payloads. + enc, err := tiktoken.GetEncoding(tiktoken.MODEL_O200K_BASE) + if err != nil { + enc, err = tiktoken.GetEncoding(tiktoken.MODEL_CL100K_BASE) + } + if err == nil { + claudeTokenizer = enc + } + }) + return claudeTokenizer +} + +func estimateTokensByThirdPartyTokenizer(text string) (int, bool) { + enc := getClaudeTokenizer() + if enc == nil { + return 0, false + } + tokens := len(enc.EncodeOrdinary(text)) + if tokens <= 0 { + return 0, false + } + return tokens, true +} diff --git a/backend/internal/service/concurrency_service.go b/backend/internal/service/concurrency_service.go index 4dcf84e0..386d5ed0 100644 --- a/backend/internal/service/concurrency_service.go +++ b/backend/internal/service/concurrency_service.go @@ -43,6 +43,9 @@ type ConcurrencyCache interface { // 清理过期槽位(后台任务) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error + + // 启动时清理旧进程遗留槽位与等待计数 + CleanupStaleProcessSlots(ctx context.Context, activeRequestPrefix string) error } var ( @@ -59,13 +62,22 @@ func initRequestIDPrefix() string { return "r" + strconv.FormatUint(fallback, 36) } -// generateRequestID generates a unique request ID for concurrency slot tracking. -// Format: {process_random_prefix}-{base36_counter} +func RequestIDPrefix() string { + return requestIDPrefix +} + func generateRequestID() string { seq := requestIDCounter.Add(1) return requestIDPrefix + "-" + strconv.FormatUint(seq, 36) } +func (s *ConcurrencyService) CleanupStaleProcessSlots(ctx context.Context) error { + if s == nil || s.cache == nil { + return nil + } + return s.cache.CleanupStaleProcessSlots(ctx, RequestIDPrefix()) +} + const ( // Default extra wait slots beyond concurrency limit defaultExtraWaitSlots = 20 @@ -331,8 +343,9 @@ func (s *ConcurrencyService) StartSlotCleanupWorker(accountRepo AccountRepositor }() } -// GetAccountConcurrencyBatch gets current concurrency counts for multiple accounts -// Returns a map of accountID -> current concurrency count +// GetAccountConcurrencyBatch gets current concurrency counts for multiple accounts. +// Uses a detached context with timeout to prevent HTTP request cancellation from +// causing the entire batch to fail (which would show all concurrency as 0). func (s *ConcurrencyService) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) { if len(accountIDs) == 0 { return map[int64]int{}, nil @@ -344,5 +357,11 @@ func (s *ConcurrencyService) GetAccountConcurrencyBatch(ctx context.Context, acc } return result, nil } - return s.cache.GetAccountConcurrencyBatch(ctx, accountIDs) + + // Use a detached context so that a cancelled HTTP request doesn't cause + // the Redis pipeline to fail and return all-zero concurrency counts. + redisCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + return s.cache.GetAccountConcurrencyBatch(redisCtx, accountIDs) } diff --git a/backend/internal/service/concurrency_service_test.go b/backend/internal/service/concurrency_service_test.go index 9ba43d93..078ba0dc 100644 --- a/backend/internal/service/concurrency_service_test.go +++ b/backend/internal/service/concurrency_service_test.go @@ -91,6 +91,32 @@ func (c *stubConcurrencyCacheForTest) CleanupExpiredAccountSlots(_ context.Conte return c.cleanupErr } +func (c *stubConcurrencyCacheForTest) CleanupStaleProcessSlots(_ context.Context, _ string) error { + return c.cleanupErr +} + +type trackingConcurrencyCache struct { + stubConcurrencyCacheForTest + cleanupPrefix string +} + +func (c *trackingConcurrencyCache) CleanupStaleProcessSlots(_ context.Context, prefix string) error { + c.cleanupPrefix = prefix + return c.cleanupErr +} + +func TestCleanupStaleProcessSlots_NilCache(t *testing.T) { + svc := &ConcurrencyService{cache: nil} + require.NoError(t, svc.CleanupStaleProcessSlots(context.Background())) +} + +func TestCleanupStaleProcessSlots_DelegatesPrefix(t *testing.T) { + cache := &trackingConcurrencyCache{} + svc := NewConcurrencyService(cache) + require.NoError(t, svc.CleanupStaleProcessSlots(context.Background())) + require.Equal(t, RequestIDPrefix(), cache.cleanupPrefix) +} + func TestAcquireAccountSlot_Success(t *testing.T) { cache := &stubConcurrencyCacheForTest{acquireResult: true} svc := NewConcurrencyService(cache) diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go index 277683a0..304c09f4 100644 --- a/backend/internal/service/domain_constants.go +++ b/backend/internal/service/domain_constants.go @@ -74,11 +74,12 @@ const LinuxDoConnectSyntheticEmailDomain = "@linuxdo-connect.invalid" // Setting keys const ( // 注册设置 - SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册 - SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证 - SettingKeyPromoCodeEnabled = "promo_code_enabled" // 是否启用优惠码功能 - SettingKeyPasswordResetEnabled = "password_reset_enabled" // 是否启用忘记密码功能(需要先开启邮件验证) - SettingKeyInvitationCodeEnabled = "invitation_code_enabled" // 是否启用邀请码注册 + SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册 + SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证 + SettingKeyRegistrationEmailSuffixWhitelist = "registration_email_suffix_whitelist" // 注册邮箱后缀白名单(JSON 数组) + SettingKeyPromoCodeEnabled = "promo_code_enabled" // 是否启用优惠码功能 + SettingKeyPasswordResetEnabled = "password_reset_enabled" // 是否启用忘记密码功能(需要先开启邮件验证) + SettingKeyInvitationCodeEnabled = "invitation_code_enabled" // 是否启用邀请码注册 // 邮件服务设置 SettingKeySMTPHost = "smtp_host" // SMTP服务器地址 @@ -174,6 +175,20 @@ const ( // SettingKeyStreamTimeoutSettings stores JSON config for stream timeout handling. SettingKeyStreamTimeoutSettings = "stream_timeout_settings" + // ========================= + // Request Rectifier (请求整流器) + // ========================= + + // SettingKeyRectifierSettings stores JSON config for rectifier settings (thinking signature + budget). + SettingKeyRectifierSettings = "rectifier_settings" + + // ========================= + // Beta Policy Settings + // ========================= + + // SettingKeyBetaPolicySettings stores JSON config for beta policy rules. + SettingKeyBetaPolicySettings = "beta_policy_settings" + // ========================= // Sora S3 存储配置 // ========================= diff --git a/backend/internal/service/error_passthrough_runtime_test.go b/backend/internal/service/error_passthrough_runtime_test.go index 7032d15b..2b7bbf60 100644 --- a/backend/internal/service/error_passthrough_runtime_test.go +++ b/backend/internal/service/error_passthrough_runtime_test.go @@ -220,7 +220,7 @@ func TestApplyErrorPassthroughRule_SkipMonitoringSetsContextKey(t *testing.T) { v, exists := c.Get(OpsSkipPassthroughKey) assert.True(t, exists, "OpsSkipPassthroughKey should be set when skip_monitoring=true") boolVal, ok := v.(bool) - assert.True(t, ok, "value should be bool") + assert.True(t, ok, "value should be a bool") assert.True(t, boolVal) } diff --git a/backend/internal/service/error_policy_test.go b/backend/internal/service/error_policy_test.go index 9d7d025e..dd9850bd 100644 --- a/backend/internal/service/error_policy_test.go +++ b/backend/internal/service/error_policy_test.go @@ -88,6 +88,49 @@ func TestCheckErrorPolicy(t *testing.T) { body: []byte(`overloaded service`), expected: ErrorPolicyTempUnscheduled, }, + { + name: "temp_unschedulable_401_first_hit_returns_temp_unscheduled", + account: &Account{ + ID: 14, + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + Credentials: map[string]any{ + "temp_unschedulable_enabled": true, + "temp_unschedulable_rules": []any{ + map[string]any{ + "error_code": float64(401), + "keywords": []any{"unauthorized"}, + "duration_minutes": float64(10), + }, + }, + }, + }, + statusCode: 401, + body: []byte(`unauthorized`), + expected: ErrorPolicyTempUnscheduled, + }, + { + name: "temp_unschedulable_401_second_hit_upgrades_to_none", + account: &Account{ + ID: 15, + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + TempUnschedulableReason: `{"status_code":401,"until_unix":1735689600}`, + Credentials: map[string]any{ + "temp_unschedulable_enabled": true, + "temp_unschedulable_rules": []any{ + map[string]any{ + "error_code": float64(401), + "keywords": []any{"unauthorized"}, + "duration_minutes": float64(10), + }, + }, + }, + }, + statusCode: 401, + body: []byte(`unauthorized`), + expected: ErrorPolicyNone, + }, { name: "temp_unschedulable_body_miss_returns_none", account: &Account{ @@ -134,6 +177,36 @@ func TestCheckErrorPolicy(t *testing.T) { body: []byte(`overloaded`), expected: ErrorPolicyMatched, // custom codes take precedence }, + { + name: "pool_mode_custom_error_codes_hit_returns_matched", + account: &Account{ + ID: 7, + Type: AccountTypeAPIKey, + Platform: PlatformOpenAI, + Credentials: map[string]any{ + "pool_mode": true, + "custom_error_codes_enabled": true, + "custom_error_codes": []any{float64(401), float64(403)}, + }, + }, + statusCode: 401, + body: []byte(`unauthorized`), + expected: ErrorPolicyMatched, + }, + { + name: "pool_mode_without_custom_error_codes_returns_skipped", + account: &Account{ + ID: 8, + Type: AccountTypeAPIKey, + Platform: PlatformOpenAI, + Credentials: map[string]any{ + "pool_mode": true, + }, + }, + statusCode: 401, + body: []byte(`unauthorized`), + expected: ErrorPolicySkipped, + }, } for _, tt := range tests { @@ -147,6 +220,48 @@ func TestCheckErrorPolicy(t *testing.T) { } } +func TestHandleUpstreamError_PoolModeCustomErrorCodesOverride(t *testing.T) { + t.Run("pool_mode_without_custom_error_codes_still_skips", func(t *testing.T) { + repo := &errorPolicyRepoStub{} + svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + account := &Account{ + ID: 30, + Type: AccountTypeAPIKey, + Platform: PlatformOpenAI, + Credentials: map[string]any{ + "pool_mode": true, + }, + } + + shouldDisable := svc.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized")) + + require.False(t, shouldDisable) + require.Equal(t, 0, repo.setErrCalls) + require.Equal(t, 0, repo.tempCalls) + }) + + t.Run("pool_mode_with_custom_error_codes_uses_local_error_policy", func(t *testing.T) { + repo := &errorPolicyRepoStub{} + svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + account := &Account{ + ID: 31, + Type: AccountTypeAPIKey, + Platform: PlatformOpenAI, + Credentials: map[string]any{ + "pool_mode": true, + "custom_error_codes_enabled": true, + "custom_error_codes": []any{float64(401)}, + }, + } + + shouldDisable := svc.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized")) + + require.True(t, shouldDisable) + require.Equal(t, 1, repo.setErrCalls) + require.Equal(t, 0, repo.tempCalls) + }) +} + // --------------------------------------------------------------------------- // TestApplyErrorPolicy — 4 table-driven cases for the wrapper method // --------------------------------------------------------------------------- diff --git a/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go b/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go index f8c0ecda..5dcda1de 100644 --- a/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go +++ b/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go @@ -171,8 +171,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardStreamPreservesBodyAnd require.NotNil(t, result) require.True(t, result.Stream) - require.Equal(t, body, upstream.lastBody, "透传模式不应改写上游请求体") - require.Equal(t, "claude-3-7-sonnet-20250219", gjson.GetBytes(upstream.lastBody, "model").String()) + require.Equal(t, "claude-3-haiku-20240307", gjson.GetBytes(upstream.lastBody, "model").String(), "透传模式应应用账号级模型映射") require.Equal(t, "upstream-anthropic-key", upstream.lastReq.Header.Get("x-api-key")) require.Empty(t, upstream.lastReq.Header.Get("authorization")) @@ -190,7 +189,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardStreamPreservesBodyAnd require.True(t, ok) bodyBytes, ok := rawBody.([]byte) require.True(t, ok, "应以 []byte 形式缓存上游请求体,避免重复 string 拷贝") - require.Equal(t, body, bodyBytes) + require.Equal(t, "claude-3-haiku-20240307", gjson.GetBytes(bodyBytes, "model").String(), "缓存的上游请求体应包含映射后的模型") } func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardCountTokensPreservesBody(t *testing.T) { @@ -253,8 +252,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardCountTokensPreservesBo err := svc.ForwardCountTokens(context.Background(), c, account, parsed) require.NoError(t, err) - require.Equal(t, body, upstream.lastBody, "count_tokens 透传模式不应改写请求体") - require.Equal(t, "claude-3-5-sonnet-latest", gjson.GetBytes(upstream.lastBody, "model").String()) + require.Equal(t, "claude-3-opus-20240229", gjson.GetBytes(upstream.lastBody, "model").String(), "count_tokens 透传模式应应用账号级模型映射") require.Equal(t, "upstream-anthropic-key", upstream.lastReq.Header.Get("x-api-key")) require.Empty(t, upstream.lastReq.Header.Get("authorization")) require.Empty(t, upstream.lastReq.Header.Get("cookie")) @@ -263,6 +261,273 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardCountTokensPreservesBo require.Empty(t, rec.Header().Get("Set-Cookie")) } +// TestGatewayService_AnthropicAPIKeyPassthrough_ModelMappingEdgeCases 覆盖透传模式下模型映射的各种边界情况 +func TestGatewayService_AnthropicAPIKeyPassthrough_ModelMappingEdgeCases(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + model string + modelMapping map[string]any // nil = 不配置映射 + expectedModel string + endpoint string // "messages" or "count_tokens" + }{ + { + name: "Forward: 无映射配置时不改写模型", + model: "claude-sonnet-4-20250514", + modelMapping: nil, + expectedModel: "claude-sonnet-4-20250514", + endpoint: "messages", + }, + { + name: "Forward: 空映射配置时不改写模型", + model: "claude-sonnet-4-20250514", + modelMapping: map[string]any{}, + expectedModel: "claude-sonnet-4-20250514", + endpoint: "messages", + }, + { + name: "Forward: 模型不在映射表中时不改写", + model: "claude-sonnet-4-20250514", + modelMapping: map[string]any{"claude-3-haiku-20240307": "claude-3-opus-20240229"}, + expectedModel: "claude-sonnet-4-20250514", + endpoint: "messages", + }, + { + name: "Forward: 精确匹配映射应改写模型", + model: "claude-sonnet-4-20250514", + modelMapping: map[string]any{"claude-sonnet-4-20250514": "claude-sonnet-4-5-20241022"}, + expectedModel: "claude-sonnet-4-5-20241022", + endpoint: "messages", + }, + { + name: "Forward: 通配符映射应改写模型", + model: "claude-sonnet-4-20250514", + modelMapping: map[string]any{"claude-sonnet-4-*": "claude-sonnet-4-5-20241022"}, + expectedModel: "claude-sonnet-4-5-20241022", + endpoint: "messages", + }, + { + name: "CountTokens: 无映射配置时不改写模型", + model: "claude-sonnet-4-20250514", + modelMapping: nil, + expectedModel: "claude-sonnet-4-20250514", + endpoint: "count_tokens", + }, + { + name: "CountTokens: 模型不在映射表中时不改写", + model: "claude-sonnet-4-20250514", + modelMapping: map[string]any{"claude-3-haiku-20240307": "claude-3-opus-20240229"}, + expectedModel: "claude-sonnet-4-20250514", + endpoint: "count_tokens", + }, + { + name: "CountTokens: 精确匹配映射应改写模型", + model: "claude-sonnet-4-20250514", + modelMapping: map[string]any{"claude-sonnet-4-20250514": "claude-sonnet-4-5-20241022"}, + expectedModel: "claude-sonnet-4-5-20241022", + endpoint: "count_tokens", + }, + { + name: "CountTokens: 通配符映射应改写模型", + model: "claude-sonnet-4-20250514", + modelMapping: map[string]any{"claude-sonnet-4-*": "claude-sonnet-4-5-20241022"}, + expectedModel: "claude-sonnet-4-5-20241022", + endpoint: "count_tokens", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + body := []byte(`{"model":"` + tt.model + `","messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`) + parsed := &ParsedRequest{ + Body: body, + Model: tt.model, + } + + credentials := map[string]any{ + "api_key": "upstream-key", + "base_url": "https://api.anthropic.com", + } + if tt.modelMapping != nil { + credentials["model_mapping"] = tt.modelMapping + } + + account := &Account{ + ID: 300, + Name: "edge-case-test", + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: credentials, + Extra: map[string]any{"anthropic_passthrough": true}, + Status: StatusActive, + Schedulable: true, + } + + if tt.endpoint == "messages" { + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + parsed.Stream = false + + upstreamJSON := `{"id":"msg_1","type":"message","usage":{"input_tokens":5,"output_tokens":3}}` + upstream := &anthropicHTTPUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(upstreamJSON)), + }, + } + svc := &GatewayService{ + cfg: &config.Config{}, + httpUpstream: upstream, + rateLimitService: &RateLimitService{}, + } + + result, err := svc.Forward(context.Background(), c, account, parsed) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, tt.expectedModel, gjson.GetBytes(upstream.lastBody, "model").String(), + "Forward 上游请求体中的模型应为: %s", tt.expectedModel) + } else { + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages/count_tokens", nil) + + upstreamRespBody := `{"input_tokens":42}` + upstream := &anthropicHTTPUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(upstreamRespBody)), + }, + } + svc := &GatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}, + httpUpstream: upstream, + rateLimitService: &RateLimitService{}, + } + + err := svc.ForwardCountTokens(context.Background(), c, account, parsed) + require.NoError(t, err) + require.Equal(t, tt.expectedModel, gjson.GetBytes(upstream.lastBody, "model").String(), + "CountTokens 上游请求体中的模型应为: %s", tt.expectedModel) + } + }) + } +} + +// TestGatewayService_AnthropicAPIKeyPassthrough_ModelMappingPreservesOtherFields +// 确保模型映射只替换 model 字段,不影响请求体中的其他字段 +func TestGatewayService_AnthropicAPIKeyPassthrough_ModelMappingPreservesOtherFields(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages/count_tokens", nil) + + // 包含复杂字段的请求体:system、thinking、messages + body := []byte(`{"model":"claude-sonnet-4-20250514","system":[{"type":"text","text":"You are a helpful assistant."}],"messages":[{"role":"user","content":[{"type":"text","text":"hello world"}]}],"thinking":{"type":"enabled","budget_tokens":5000},"max_tokens":1024}`) + parsed := &ParsedRequest{ + Body: body, + Model: "claude-sonnet-4-20250514", + } + + upstreamRespBody := `{"input_tokens":42}` + upstream := &anthropicHTTPUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(upstreamRespBody)), + }, + } + + svc := &GatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}, + httpUpstream: upstream, + rateLimitService: &RateLimitService{}, + } + + account := &Account{ + ID: 301, + Name: "preserve-fields-test", + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "upstream-key", + "base_url": "https://api.anthropic.com", + "model_mapping": map[string]any{"claude-sonnet-4-20250514": "claude-sonnet-4-5-20241022"}, + }, + Extra: map[string]any{"anthropic_passthrough": true}, + Status: StatusActive, + Schedulable: true, + } + + err := svc.ForwardCountTokens(context.Background(), c, account, parsed) + require.NoError(t, err) + + sentBody := upstream.lastBody + require.Equal(t, "claude-sonnet-4-5-20241022", gjson.GetBytes(sentBody, "model").String(), "model 应被映射") + require.Equal(t, "You are a helpful assistant.", gjson.GetBytes(sentBody, "system.0.text").String(), "system 字段不应被修改") + require.Equal(t, "hello world", gjson.GetBytes(sentBody, "messages.0.content.0.text").String(), "messages 字段不应被修改") + require.Equal(t, "enabled", gjson.GetBytes(sentBody, "thinking.type").String(), "thinking 字段不应被修改") + require.Equal(t, int64(5000), gjson.GetBytes(sentBody, "thinking.budget_tokens").Int(), "thinking.budget_tokens 不应被修改") + require.Equal(t, int64(1024), gjson.GetBytes(sentBody, "max_tokens").Int(), "max_tokens 不应被修改") +} + +// TestGatewayService_AnthropicAPIKeyPassthrough_EmptyModelSkipsMapping +// 确保空模型名不会触发映射逻辑 +func TestGatewayService_AnthropicAPIKeyPassthrough_EmptyModelSkipsMapping(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages/count_tokens", nil) + + body := []byte(`{"messages":[{"role":"user","content":"hello"}]}`) + parsed := &ParsedRequest{ + Body: body, + Model: "", // 空模型 + } + + upstreamRespBody := `{"input_tokens":10}` + upstream := &anthropicHTTPUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(upstreamRespBody)), + }, + } + + svc := &GatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}, + httpUpstream: upstream, + rateLimitService: &RateLimitService{}, + } + + account := &Account{ + ID: 302, + Name: "empty-model-test", + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "upstream-key", + "base_url": "https://api.anthropic.com", + "model_mapping": map[string]any{"*": "claude-3-opus-20240229"}, + }, + Extra: map[string]any{"anthropic_passthrough": true}, + Status: StatusActive, + Schedulable: true, + } + + err := svc.ForwardCountTokens(context.Background(), c, account, parsed) + require.NoError(t, err) + // 空模型名时,body 应原样透传,不应触发映射 + require.Equal(t, body, upstream.lastBody, "空模型名时请求体不应被修改") +} + func TestGatewayService_AnthropicAPIKeyPassthrough_CountTokens404PassthroughNotError(t *testing.T) { gin.SetMode(gin.TestMode) diff --git a/backend/internal/service/gateway_beta_test.go b/backend/internal/service/gateway_beta_test.go index 21a1faa4..ecaffe21 100644 --- a/backend/internal/service/gateway_beta_test.go +++ b/backend/internal/service/gateway_beta_test.go @@ -86,10 +86,10 @@ func TestStripBetaTokens(t *testing.T) { want: "oauth-2025-04-20,interleaved-thinking-2025-05-14", }, { - name: "DroppedBetas removes both context-1m and fast-mode", + name: "DroppedBetas is empty (filtering moved to configurable beta policy)", header: "oauth-2025-04-20,context-1m-2025-08-07,fast-mode-2026-02-01,interleaved-thinking-2025-05-14", tokens: claude.DroppedBetas, - want: "oauth-2025-04-20,interleaved-thinking-2025-05-14", + want: "oauth-2025-04-20,context-1m-2025-08-07,fast-mode-2026-02-01,interleaved-thinking-2025-05-14", }, } @@ -114,25 +114,23 @@ func TestMergeAnthropicBetaDropping_Context1M(t *testing.T) { func TestMergeAnthropicBetaDropping_DroppedBetas(t *testing.T) { required := []string{"oauth-2025-04-20", "interleaved-thinking-2025-05-14"} incoming := "context-1m-2025-08-07,fast-mode-2026-02-01,foo-beta,oauth-2025-04-20" + // DroppedBetas is now empty — filtering moved to configurable beta policy. + // Without a policy filter set, nothing gets dropped from the static set. drop := droppedBetaSet() got := mergeAnthropicBetaDropping(required, incoming, drop) - require.Equal(t, "oauth-2025-04-20,interleaved-thinking-2025-05-14,foo-beta", got) - require.NotContains(t, got, "context-1m-2025-08-07") - require.NotContains(t, got, "fast-mode-2026-02-01") + require.Equal(t, "oauth-2025-04-20,interleaved-thinking-2025-05-14,context-1m-2025-08-07,fast-mode-2026-02-01,foo-beta", got) + require.Contains(t, got, "context-1m-2025-08-07") + require.Contains(t, got, "fast-mode-2026-02-01") } func TestDroppedBetaSet(t *testing.T) { - // Base set contains DroppedBetas + // Base set contains DroppedBetas (now empty — filtering moved to configurable beta policy) base := droppedBetaSet() - require.Contains(t, base, claude.BetaContext1M) - require.Contains(t, base, claude.BetaFastMode) require.Len(t, base, len(claude.DroppedBetas)) // With extra tokens extended := droppedBetaSet(claude.BetaClaudeCode) - require.Contains(t, extended, claude.BetaContext1M) - require.Contains(t, extended, claude.BetaFastMode) require.Contains(t, extended, claude.BetaClaudeCode) require.Len(t, extended, len(claude.DroppedBetas)+1) } @@ -148,6 +146,32 @@ func TestBuildBetaTokenSet(t *testing.T) { require.Empty(t, empty) } +func TestContainsBetaToken(t *testing.T) { + tests := []struct { + name string + header string + token string + want bool + }{ + {"present in middle", "oauth-2025-04-20,fast-mode-2026-02-01,interleaved-thinking-2025-05-14", "fast-mode-2026-02-01", true}, + {"present at start", "fast-mode-2026-02-01,oauth-2025-04-20", "fast-mode-2026-02-01", true}, + {"present at end", "oauth-2025-04-20,fast-mode-2026-02-01", "fast-mode-2026-02-01", true}, + {"only token", "fast-mode-2026-02-01", "fast-mode-2026-02-01", true}, + {"not present", "oauth-2025-04-20,interleaved-thinking-2025-05-14", "fast-mode-2026-02-01", false}, + {"with spaces", "oauth-2025-04-20, fast-mode-2026-02-01 , interleaved-thinking-2025-05-14", "fast-mode-2026-02-01", true}, + {"empty header", "", "fast-mode-2026-02-01", false}, + {"empty token", "fast-mode-2026-02-01", "", false}, + {"partial match", "fast-mode-2026-02-01-extra", "fast-mode-2026-02-01", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := containsBetaToken(tt.header, tt.token) + require.Equal(t, tt.want, got) + }) + } +} + func TestStripBetaTokensWithSet_EmptyDropSet(t *testing.T) { header := "oauth-2025-04-20,interleaved-thinking-2025-05-14" got := stripBetaTokensWithSet(header, map[string]struct{}{}) diff --git a/backend/internal/service/gateway_claude_max_response_helpers.go b/backend/internal/service/gateway_claude_max_response_helpers.go new file mode 100644 index 00000000..a5f5f3d2 --- /dev/null +++ b/backend/internal/service/gateway_claude_max_response_helpers.go @@ -0,0 +1,196 @@ +package service + +import ( + "context" + "encoding/json" + + "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" + "github.com/gin-gonic/gin" + "github.com/tidwall/sjson" +) + +type claudeMaxResponseRewriteContext struct { + Parsed *ParsedRequest + Group *Group +} + +type claudeMaxResponseRewriteContextKeyType struct{} + +var claudeMaxResponseRewriteContextKey = claudeMaxResponseRewriteContextKeyType{} + +func withClaudeMaxResponseRewriteContext(ctx context.Context, c *gin.Context, parsed *ParsedRequest) context.Context { + if ctx == nil { + ctx = context.Background() + } + value := claudeMaxResponseRewriteContext{ + Parsed: parsed, + Group: claudeMaxGroupFromGinContext(c), + } + return context.WithValue(ctx, claudeMaxResponseRewriteContextKey, value) +} + +func claudeMaxResponseRewriteContextFromContext(ctx context.Context) claudeMaxResponseRewriteContext { + if ctx == nil { + return claudeMaxResponseRewriteContext{} + } + value, _ := ctx.Value(claudeMaxResponseRewriteContextKey).(claudeMaxResponseRewriteContext) + return value +} + +func claudeMaxGroupFromGinContext(c *gin.Context) *Group { + if c == nil { + return nil + } + raw, exists := c.Get("api_key") + if !exists { + return nil + } + apiKey, ok := raw.(*APIKey) + if !ok || apiKey == nil { + return nil + } + return apiKey.Group +} + +func parsedRequestFromGinContext(c *gin.Context) *ParsedRequest { + if c == nil { + return nil + } + raw, exists := c.Get("parsed_request") + if !exists { + return nil + } + parsed, _ := raw.(*ParsedRequest) + return parsed +} + +func applyClaudeMaxSimulationToUsage(ctx context.Context, usage *ClaudeUsage, model string, accountID int64) claudeMaxCacheBillingOutcome { + var out claudeMaxCacheBillingOutcome + if usage == nil { + return out + } + rewriteCtx := claudeMaxResponseRewriteContextFromContext(ctx) + return applyClaudeMaxCacheBillingPolicyToUsage(usage, rewriteCtx.Parsed, rewriteCtx.Group, model, accountID) +} + +func applyClaudeMaxSimulationToUsageJSONMap(ctx context.Context, usageObj map[string]any, model string, accountID int64) claudeMaxCacheBillingOutcome { + var out claudeMaxCacheBillingOutcome + if usageObj == nil { + return out + } + usage := claudeUsageFromJSONMap(usageObj) + out = applyClaudeMaxSimulationToUsage(ctx, &usage, model, accountID) + if out.Simulated { + rewriteClaudeUsageJSONMap(usageObj, usage) + } + return out +} + +func rewriteClaudeUsageJSONBytes(body []byte, usage ClaudeUsage) []byte { + updated := body + var err error + + updated, err = sjson.SetBytes(updated, "usage.input_tokens", usage.InputTokens) + if err != nil { + return body + } + updated, err = sjson.SetBytes(updated, "usage.cache_creation_input_tokens", usage.CacheCreationInputTokens) + if err != nil { + return body + } + updated, err = sjson.SetBytes(updated, "usage.cache_creation.ephemeral_5m_input_tokens", usage.CacheCreation5mTokens) + if err != nil { + return body + } + updated, err = sjson.SetBytes(updated, "usage.cache_creation.ephemeral_1h_input_tokens", usage.CacheCreation1hTokens) + if err != nil { + return body + } + return updated +} + +func claudeUsageFromJSONMap(usageObj map[string]any) ClaudeUsage { + var usage ClaudeUsage + if usageObj == nil { + return usage + } + + usage.InputTokens = usageIntFromAny(usageObj["input_tokens"]) + usage.OutputTokens = usageIntFromAny(usageObj["output_tokens"]) + usage.CacheCreationInputTokens = usageIntFromAny(usageObj["cache_creation_input_tokens"]) + usage.CacheReadInputTokens = usageIntFromAny(usageObj["cache_read_input_tokens"]) + + if ccObj, ok := usageObj["cache_creation"].(map[string]any); ok { + usage.CacheCreation5mTokens = usageIntFromAny(ccObj["ephemeral_5m_input_tokens"]) + usage.CacheCreation1hTokens = usageIntFromAny(ccObj["ephemeral_1h_input_tokens"]) + } + return usage +} + +func rewriteClaudeUsageJSONMap(usageObj map[string]any, usage ClaudeUsage) { + if usageObj == nil { + return + } + usageObj["input_tokens"] = usage.InputTokens + usageObj["cache_creation_input_tokens"] = usage.CacheCreationInputTokens + + ccObj, _ := usageObj["cache_creation"].(map[string]any) + if ccObj == nil { + ccObj = make(map[string]any, 2) + usageObj["cache_creation"] = ccObj + } + ccObj["ephemeral_5m_input_tokens"] = usage.CacheCreation5mTokens + ccObj["ephemeral_1h_input_tokens"] = usage.CacheCreation1hTokens +} + +func usageIntFromAny(v any) int { + switch value := v.(type) { + case int: + return value + case int64: + return int(value) + case float64: + return int(value) + case json.Number: + if n, err := value.Int64(); err == nil { + return int(n) + } + } + return 0 +} + +// setupClaudeMaxStreamingHook 为 Antigravity 流式路径设置 SSE usage 改写 hook。 +func setupClaudeMaxStreamingHook(c *gin.Context, processor *antigravity.StreamingProcessor, originalModel string, accountID int64) { + group := claudeMaxGroupFromGinContext(c) + parsed := parsedRequestFromGinContext(c) + if !shouldApplyClaudeMaxBillingRulesForUsage(group, originalModel, parsed) { + return + } + processor.SetUsageMapHook(func(usageMap map[string]any) { + svcUsage := claudeUsageFromJSONMap(usageMap) + outcome := applyClaudeMaxCacheBillingPolicyToUsage(&svcUsage, parsed, group, originalModel, accountID) + if outcome.Simulated { + rewriteClaudeUsageJSONMap(usageMap, svcUsage) + } + }) +} + +// applyClaudeMaxNonStreamingRewrite 为 Antigravity 非流式路径改写响应体中的 usage。 +func applyClaudeMaxNonStreamingRewrite(c *gin.Context, claudeResp []byte, agUsage *antigravity.ClaudeUsage, originalModel string, accountID int64) []byte { + group := claudeMaxGroupFromGinContext(c) + parsed := parsedRequestFromGinContext(c) + if !shouldApplyClaudeMaxBillingRulesForUsage(group, originalModel, parsed) { + return claudeResp + } + svcUsage := &ClaudeUsage{ + InputTokens: agUsage.InputTokens, + OutputTokens: agUsage.OutputTokens, + CacheCreationInputTokens: agUsage.CacheCreationInputTokens, + CacheReadInputTokens: agUsage.CacheReadInputTokens, + } + outcome := applyClaudeMaxCacheBillingPolicyToUsage(svcUsage, parsed, group, originalModel, accountID) + if outcome.Simulated { + return rewriteClaudeUsageJSONBytes(claudeResp, *svcUsage) + } + return claudeResp +} diff --git a/backend/internal/service/gateway_multiplatform_test.go b/backend/internal/service/gateway_multiplatform_test.go index 1cb3c61e..f947a8ee 100644 --- a/backend/internal/service/gateway_multiplatform_test.go +++ b/backend/internal/service/gateway_multiplatform_test.go @@ -187,6 +187,14 @@ func (m *mockAccountRepoForPlatform) BulkUpdate(ctx context.Context, ids []int64 return 0, nil } +func (m *mockAccountRepoForPlatform) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error { + return nil +} + +func (m *mockAccountRepoForPlatform) ResetQuotaUsed(ctx context.Context, id int64) error { + return nil +} + // Verify interface implementation var _ AccountRepository = (*mockAccountRepoForPlatform)(nil) @@ -1978,6 +1986,10 @@ func (m *mockConcurrencyCache) CleanupExpiredAccountSlots(ctx context.Context, a return nil } +func (m *mockConcurrencyCache) CleanupStaleProcessSlots(ctx context.Context, activeRequestPrefix string) error { + return nil +} + func (m *mockConcurrencyCache) GetUsersLoadBatch(ctx context.Context, users []UserWithConcurrency) (map[int64]*UserLoadInfo, error) { result := make(map[int64]*UserLoadInfo, len(users)) for _, user := range users { diff --git a/backend/internal/service/gateway_record_usage_claude_max_test.go b/backend/internal/service/gateway_record_usage_claude_max_test.go new file mode 100644 index 00000000..3cd86938 --- /dev/null +++ b/backend/internal/service/gateway_record_usage_claude_max_test.go @@ -0,0 +1,199 @@ +package service + +import ( + "context" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +type usageLogRepoRecordUsageStub struct { + UsageLogRepository + + last *UsageLog + inserted bool + err error +} + +func (s *usageLogRepoRecordUsageStub) Create(_ context.Context, log *UsageLog) (bool, error) { + copied := *log + s.last = &copied + return s.inserted, s.err +} + +func newGatewayServiceForRecordUsageTest(repo UsageLogRepository) *GatewayService { + return &GatewayService{ + usageLogRepo: repo, + billingService: NewBillingService(&config.Config{}, nil), + cfg: &config.Config{RunMode: config.RunModeSimple}, + deferredService: &DeferredService{}, + } +} + +func TestRecordUsage_SimulateClaudeMaxEnabled_ProjectsUsageAndSkipsTTLOverride(t *testing.T) { + repo := &usageLogRepoRecordUsageStub{inserted: true} + svc := newGatewayServiceForRecordUsageTest(repo) + + groupID := int64(11) + input := &RecordUsageInput{ + Result: &ForwardResult{ + RequestID: "req-sim-1", + Model: "claude-sonnet-4", + Duration: time.Second, + Usage: ClaudeUsage{ + InputTokens: 160, + }, + }, + ParsedRequest: &ParsedRequest{ + Model: "claude-sonnet-4", + Messages: []any{ + map[string]any{ + "role": "user", + "content": []any{ + map[string]any{ + "type": "text", + "text": "long cached context for prior turns", + "cache_control": map[string]any{"type": "ephemeral"}, + }, + map[string]any{ + "type": "text", + "text": "please summarize the logs and provide root cause analysis", + }, + }, + }, + }, + }, + APIKey: &APIKey{ + ID: 1, + GroupID: &groupID, + Group: &Group{ + ID: groupID, + Platform: PlatformAnthropic, + RateMultiplier: 1, + SimulateClaudeMaxEnabled: true, + }, + }, + User: &User{ID: 2}, + Account: &Account{ + ID: 3, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "cache_ttl_override_enabled": true, + "cache_ttl_override_target": "5m", + }, + }, + } + + err := svc.RecordUsage(context.Background(), input) + require.NoError(t, err) + require.NotNil(t, repo.last) + + log := repo.last + require.Equal(t, 80, log.InputTokens) + require.Equal(t, 80, log.CacheCreationTokens) + require.Equal(t, 0, log.CacheCreation5mTokens) + require.Equal(t, 80, log.CacheCreation1hTokens) + require.False(t, log.CacheTTLOverridden, "simulate outcome should skip account ttl override") +} + +func TestRecordUsage_SimulateClaudeMaxDisabled_AppliesTTLOverride(t *testing.T) { + repo := &usageLogRepoRecordUsageStub{inserted: true} + svc := newGatewayServiceForRecordUsageTest(repo) + + groupID := int64(12) + input := &RecordUsageInput{ + Result: &ForwardResult{ + RequestID: "req-sim-2", + Model: "claude-sonnet-4", + Duration: time.Second, + Usage: ClaudeUsage{ + InputTokens: 40, + CacheCreationInputTokens: 120, + CacheCreation1hTokens: 120, + }, + }, + APIKey: &APIKey{ + ID: 2, + GroupID: &groupID, + Group: &Group{ + ID: groupID, + Platform: PlatformAnthropic, + RateMultiplier: 1, + SimulateClaudeMaxEnabled: false, + }, + }, + User: &User{ID: 3}, + Account: &Account{ + ID: 4, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "cache_ttl_override_enabled": true, + "cache_ttl_override_target": "5m", + }, + }, + } + + err := svc.RecordUsage(context.Background(), input) + require.NoError(t, err) + require.NotNil(t, repo.last) + + log := repo.last + require.Equal(t, 120, log.CacheCreationTokens) + require.Equal(t, 120, log.CacheCreation5mTokens) + require.Equal(t, 0, log.CacheCreation1hTokens) + require.True(t, log.CacheTTLOverridden) +} + +func TestRecordUsage_SimulateClaudeMaxEnabled_ExistingCacheCreationBypassesSimulation(t *testing.T) { + repo := &usageLogRepoRecordUsageStub{inserted: true} + svc := newGatewayServiceForRecordUsageTest(repo) + + groupID := int64(13) + input := &RecordUsageInput{ + Result: &ForwardResult{ + RequestID: "req-sim-3", + Model: "claude-sonnet-4", + Duration: time.Second, + Usage: ClaudeUsage{ + InputTokens: 20, + CacheCreationInputTokens: 120, + CacheCreation5mTokens: 120, + }, + }, + APIKey: &APIKey{ + ID: 3, + GroupID: &groupID, + Group: &Group{ + ID: groupID, + Platform: PlatformAnthropic, + RateMultiplier: 1, + SimulateClaudeMaxEnabled: true, + }, + }, + User: &User{ID: 4}, + Account: &Account{ + ID: 5, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "cache_ttl_override_enabled": true, + "cache_ttl_override_target": "5m", + }, + }, + } + + err := svc.RecordUsage(context.Background(), input) + require.NoError(t, err) + require.NotNil(t, repo.last) + + log := repo.last + require.Equal(t, 20, log.InputTokens) + require.Equal(t, 120, log.CacheCreation5mTokens) + require.Equal(t, 0, log.CacheCreation1hTokens) + require.Equal(t, 120, log.CacheCreationTokens) + require.False(t, log.CacheTTLOverridden, "existing cache_creation with SimulateClaudeMax enabled should skip account ttl override") +} diff --git a/backend/internal/service/gateway_request.go b/backend/internal/service/gateway_request.go index b546fe85..f7bc57ac 100644 --- a/backend/internal/service/gateway_request.go +++ b/backend/internal/service/gateway_request.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "math" + "strings" "unsafe" "github.com/Wei-Shaw/sub2api/internal/domain" @@ -258,6 +259,7 @@ func FilterThinkingBlocksForRetry(body []byte) []byte { if !hasEmptyContent && !containsThinkingBlocks { if topThinking := gjson.Get(jsonStr, "thinking"); topThinking.Exists() { if out, err := sjson.DeleteBytes(body, "thinking"); err == nil { + out = removeThinkingDependentContextStrategies(out) return out } return body @@ -395,6 +397,10 @@ func FilterThinkingBlocksForRetry(body []byte) []byte { } else { return body } + // Removing "thinking" makes any context_management strategy that requires it invalid + // (e.g. clear_thinking_20251015). Strip those entries so the retry request does not + // receive a 400 "strategy requires thinking to be enabled or adaptive". + out = removeThinkingDependentContextStrategies(out) } if modified { msgsBytes, err := json.Marshal(messages) @@ -409,6 +415,49 @@ func FilterThinkingBlocksForRetry(body []byte) []byte { return out } +// removeThinkingDependentContextStrategies 从 context_management.edits 中移除 +// 需要 thinking 启用的策略(如 clear_thinking_20251015)。 +// 当顶层 "thinking" 字段被禁用时必须调用,否则上游会返回 +// "strategy requires thinking to be enabled or adaptive"。 +func removeThinkingDependentContextStrategies(body []byte) []byte { + jsonStr := *(*string)(unsafe.Pointer(&body)) + editsRes := gjson.Get(jsonStr, "context_management.edits") + if !editsRes.Exists() || !editsRes.IsArray() { + return body + } + + var filtered []json.RawMessage + hasRemoved := false + editsRes.ForEach(func(_, v gjson.Result) bool { + if v.Get("type").String() == "clear_thinking_20251015" { + hasRemoved = true + return true + } + filtered = append(filtered, json.RawMessage(v.Raw)) + return true + }) + + if !hasRemoved { + return body + } + + if len(filtered) == 0 { + if b, err := sjson.DeleteBytes(body, "context_management.edits"); err == nil { + return b + } + return body + } + + filteredBytes, err := json.Marshal(filtered) + if err != nil { + return body + } + if b, err := sjson.SetRawBytes(body, "context_management.edits", filteredBytes); err == nil { + return b + } + return body +} + // FilterSignatureSensitiveBlocksForRetry is a stronger retry filter for cases where upstream errors indicate // signature/thought_signature validation issues involving tool blocks. // @@ -444,6 +493,28 @@ func FilterSignatureSensitiveBlocksForRetry(body []byte) []byte { if _, exists := req["thinking"]; exists { delete(req, "thinking") modified = true + // Remove context_management strategies that require thinking to be enabled + // (e.g. clear_thinking_20251015), otherwise upstream returns 400. + if cm, ok := req["context_management"].(map[string]any); ok { + if edits, ok := cm["edits"].([]any); ok { + filtered := make([]any, 0, len(edits)) + for _, edit := range edits { + if editMap, ok := edit.(map[string]any); ok { + if editMap["type"] == "clear_thinking_20251015" { + continue + } + } + filtered = append(filtered, edit) + } + if len(filtered) != len(edits) { + if len(filtered) == 0 { + delete(cm, "edits") + } else { + cm["edits"] = filtered + } + } + } + } } messages, ok := req["messages"].([]any) @@ -675,3 +746,90 @@ func filterThinkingBlocksInternal(body []byte, _ bool) []byte { } return newBody } + +// ========================= +// Thinking Budget Rectifier +// ========================= + +const ( + // BudgetRectifyBudgetTokens is the budget_tokens value to set when rectifying. + BudgetRectifyBudgetTokens = 32000 + // BudgetRectifyMaxTokens is the max_tokens value to set when rectifying. + BudgetRectifyMaxTokens = 64000 + // BudgetRectifyMinMaxTokens is the minimum max_tokens that must exceed budget_tokens. + BudgetRectifyMinMaxTokens = 32001 +) + +// isThinkingBudgetConstraintError detects whether an upstream error message indicates +// a budget_tokens constraint violation (e.g. "budget_tokens >= 1024"). +// Matches three conditions (all must be true): +// 1. Contains "budget_tokens" or "budget tokens" +// 2. Contains "thinking" +// 3. Contains ">= 1024" or "greater than or equal to 1024" or ("1024" + "input should be") +func isThinkingBudgetConstraintError(errMsg string) bool { + m := strings.ToLower(errMsg) + + // Condition 1: budget_tokens or budget tokens + hasBudget := strings.Contains(m, "budget_tokens") || strings.Contains(m, "budget tokens") + if !hasBudget { + return false + } + + // Condition 2: thinking + if !strings.Contains(m, "thinking") { + return false + } + + // Condition 3: constraint indicator + if strings.Contains(m, ">= 1024") || strings.Contains(m, "greater than or equal to 1024") { + return true + } + if strings.Contains(m, "1024") && strings.Contains(m, "input should be") { + return true + } + + return false +} + +// RectifyThinkingBudget modifies the request body to fix budget_tokens constraint errors. +// It sets thinking.budget_tokens = 32000, thinking.type = "enabled" (unless adaptive), +// and ensures max_tokens >= 32001. +// Returns (modified body, true) if changes were applied, or (original body, false) if not. +func RectifyThinkingBudget(body []byte) ([]byte, bool) { + // If thinking type is "adaptive", skip rectification entirely + thinkingType := gjson.GetBytes(body, "thinking.type").String() + if thinkingType == "adaptive" { + return body, false + } + + modified := body + changed := false + + // Set thinking.type = "enabled" + if thinkingType != "enabled" { + if result, err := sjson.SetBytes(modified, "thinking.type", "enabled"); err == nil { + modified = result + changed = true + } + } + + // Set thinking.budget_tokens = 32000 + currentBudget := gjson.GetBytes(modified, "thinking.budget_tokens").Int() + if currentBudget != BudgetRectifyBudgetTokens { + if result, err := sjson.SetBytes(modified, "thinking.budget_tokens", BudgetRectifyBudgetTokens); err == nil { + modified = result + changed = true + } + } + + // Ensure max_tokens >= BudgetRectifyMinMaxTokens + maxTokens := gjson.GetBytes(modified, "max_tokens").Int() + if maxTokens < int64(BudgetRectifyMinMaxTokens) { + if result, err := sjson.SetBytes(modified, "max_tokens", BudgetRectifyMaxTokens); err == nil { + modified = result + changed = true + } + } + + return modified, changed +} diff --git a/backend/internal/service/gateway_request_test.go b/backend/internal/service/gateway_request_test.go index 2a9b4017..42b61e3f 100644 --- a/backend/internal/service/gateway_request_test.go +++ b/backend/internal/service/gateway_request_test.go @@ -439,6 +439,210 @@ func TestFilterSignatureSensitiveBlocksForRetry_DowngradesTools(t *testing.T) { require.Contains(t, content1["text"], "tool_result") } +// ============ Group 6b: context_management.edits 清理测试 ============ + +// removeThinkingDependentContextStrategies — 边界用例 + +func TestRemoveThinkingDependentContextStrategies_NoContextManagement(t *testing.T) { + input := []byte(`{"thinking":{"type":"enabled"},"messages":[]}`) + out := removeThinkingDependentContextStrategies(input) + require.Equal(t, input, out, "无 context_management 字段时应原样返回") +} + +func TestRemoveThinkingDependentContextStrategies_EmptyEdits(t *testing.T) { + input := []byte(`{"context_management":{"edits":[]},"messages":[]}`) + out := removeThinkingDependentContextStrategies(input) + require.Equal(t, input, out, "edits 为空数组时应原样返回") +} + +func TestRemoveThinkingDependentContextStrategies_NoClearThinkingEntry(t *testing.T) { + input := []byte(`{"context_management":{"edits":[{"type":"other_strategy"}]},"messages":[]}`) + out := removeThinkingDependentContextStrategies(input) + require.Equal(t, input, out, "edits 中无 clear_thinking_20251015 时应原样返回") +} + +func TestRemoveThinkingDependentContextStrategies_RemovesSingleEntry(t *testing.T) { + input := []byte(`{"context_management":{"edits":[{"type":"clear_thinking_20251015"}]},"messages":[]}`) + out := removeThinkingDependentContextStrategies(input) + + var req map[string]any + require.NoError(t, json.Unmarshal(out, &req)) + cm, ok := req["context_management"].(map[string]any) + require.True(t, ok) + _, hasEdits := cm["edits"] + require.False(t, hasEdits, "所有 edits 均为 clear_thinking_20251015 时应删除 edits 键") +} + +func TestRemoveThinkingDependentContextStrategies_MixedEntries(t *testing.T) { + input := []byte(`{"context_management":{"edits":[{"type":"clear_thinking_20251015"},{"type":"other_strategy","param":1}]},"messages":[]}`) + out := removeThinkingDependentContextStrategies(input) + + var req map[string]any + require.NoError(t, json.Unmarshal(out, &req)) + cm, ok := req["context_management"].(map[string]any) + require.True(t, ok) + edits, ok := cm["edits"].([]any) + require.True(t, ok) + require.Len(t, edits, 1, "仅移除 clear_thinking_20251015,保留其他条目") + edit0, ok := edits[0].(map[string]any) + require.True(t, ok) + require.Equal(t, "other_strategy", edit0["type"]) +} + +// FilterThinkingBlocksForRetry — 包含 context_management 的场景 + +func TestFilterThinkingBlocksForRetry_RemovesClearThinkingStrategy_FastPath(t *testing.T) { + // 快速路径:messages 中无 thinking 块,仅有顶层 thinking 字段 + // 这条路径曾因提前 return 跳过 removeThinkingDependentContextStrategies 而存在 bug + input := []byte(`{ + "thinking":{"type":"enabled","budget_tokens":1024}, + "context_management":{"edits":[{"type":"clear_thinking_20251015"}]}, + "messages":[ + {"role":"user","content":[{"type":"text","text":"Hello"}]} + ] + }`) + + out := FilterThinkingBlocksForRetry(input) + + var req map[string]any + require.NoError(t, json.Unmarshal(out, &req)) + _, hasThinking := req["thinking"] + require.False(t, hasThinking, "顶层 thinking 应被移除") + + cm, ok := req["context_management"].(map[string]any) + require.True(t, ok) + _, hasEdits := cm["edits"] + require.False(t, hasEdits, "fast path 下 clear_thinking_20251015 应被移除,edits 键应被删除") +} + +func TestFilterThinkingBlocksForRetry_RemovesClearThinkingStrategy_WithThinkingBlocks(t *testing.T) { + // 完整路径:messages 中有 thinking 块(非 fast path) + input := []byte(`{ + "thinking":{"type":"enabled","budget_tokens":1024}, + "context_management":{"edits":[{"type":"clear_thinking_20251015"},{"type":"keep_this"}]}, + "messages":[ + {"role":"assistant","content":[ + {"type":"thinking","thinking":"some thought","signature":"sig"}, + {"type":"text","text":"Answer"} + ]} + ] + }`) + + out := FilterThinkingBlocksForRetry(input) + + var req map[string]any + require.NoError(t, json.Unmarshal(out, &req)) + _, hasThinking := req["thinking"] + require.False(t, hasThinking, "顶层 thinking 应被移除") + + cm, ok := req["context_management"].(map[string]any) + require.True(t, ok) + edits, ok := cm["edits"].([]any) + require.True(t, ok) + require.Len(t, edits, 1, "仅移除 clear_thinking_20251015,保留 keep_this") + edit0, ok := edits[0].(map[string]any) + require.True(t, ok) + require.Equal(t, "keep_this", edit0["type"]) +} + +func TestFilterThinkingBlocksForRetry_NoContextManagement_Unaffected(t *testing.T) { + // 无 context_management 时不应报错,且 thinking 正常被移除 + input := []byte(`{ + "thinking":{"type":"enabled"}, + "messages":[{"role":"user","content":[{"type":"text","text":"Hi"}]}] + }`) + + out := FilterThinkingBlocksForRetry(input) + + var req map[string]any + require.NoError(t, json.Unmarshal(out, &req)) + _, hasThinking := req["thinking"] + require.False(t, hasThinking) + _, hasCM := req["context_management"] + require.False(t, hasCM) +} + +// FilterSignatureSensitiveBlocksForRetry — 包含 context_management 的场景 + +func TestFilterSignatureSensitiveBlocksForRetry_RemovesClearThinkingStrategy(t *testing.T) { + input := []byte(`{ + "thinking":{"type":"enabled","budget_tokens":1024}, + "context_management":{"edits":[{"type":"clear_thinking_20251015"}]}, + "messages":[ + {"role":"assistant","content":[ + {"type":"thinking","thinking":"thought","signature":"sig"} + ]} + ] + }`) + + out := FilterSignatureSensitiveBlocksForRetry(input) + + var req map[string]any + require.NoError(t, json.Unmarshal(out, &req)) + _, hasThinking := req["thinking"] + require.False(t, hasThinking, "顶层 thinking 应被移除") + + cm, ok := req["context_management"].(map[string]any) + require.True(t, ok) + if rawEdits, hasEdits := cm["edits"]; hasEdits { + edits, ok := rawEdits.([]any) + require.True(t, ok) + for _, e := range edits { + em, ok := e.(map[string]any) + require.True(t, ok) + require.NotEqual(t, "clear_thinking_20251015", em["type"], "clear_thinking_20251015 应被移除") + } + } +} + +func TestFilterSignatureSensitiveBlocksForRetry_PreservesNonThinkingStrategies(t *testing.T) { + input := []byte(`{ + "thinking":{"type":"enabled"}, + "context_management":{"edits":[{"type":"clear_thinking_20251015"},{"type":"other_edit"}]}, + "messages":[ + {"role":"assistant","content":[ + {"type":"thinking","thinking":"t","signature":"s"} + ]} + ] + }`) + + out := FilterSignatureSensitiveBlocksForRetry(input) + + var req map[string]any + require.NoError(t, json.Unmarshal(out, &req)) + + cm, ok := req["context_management"].(map[string]any) + require.True(t, ok) + edits, ok := cm["edits"].([]any) + require.True(t, ok) + require.Len(t, edits, 1, "仅移除 clear_thinking_20251015,保留 other_edit") + edit0, ok := edits[0].(map[string]any) + require.True(t, ok) + require.Equal(t, "other_edit", edit0["type"]) +} + +func TestFilterSignatureSensitiveBlocksForRetry_NoThinkingField_ContextManagementUntouched(t *testing.T) { + // 没有顶层 thinking 字段时,context_management 不应被修改 + input := []byte(`{ + "context_management":{"edits":[{"type":"clear_thinking_20251015"}]}, + "messages":[ + {"role":"assistant","content":[ + {"type":"thinking","thinking":"t","signature":"s"} + ]} + ] + }`) + + out := FilterSignatureSensitiveBlocksForRetry(input) + + var req map[string]any + require.NoError(t, json.Unmarshal(out, &req)) + cm, ok := req["context_management"].(map[string]any) + require.True(t, ok) + edits, ok := cm["edits"].([]any) + require.True(t, ok) + require.Len(t, edits, 1, "无顶层 thinking 时 context_management 不应被修改") +} + // ============ Group 7: ParseGatewayRequest 补充单元测试 ============ // Task 7.1 — 类型校验边界测试 diff --git a/backend/internal/service/gateway_response_usage_sync_test.go b/backend/internal/service/gateway_response_usage_sync_test.go new file mode 100644 index 00000000..445ee8ad --- /dev/null +++ b/backend/internal/service/gateway_response_usage_sync_test.go @@ -0,0 +1,170 @@ +package service + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +func TestHandleNonStreamingResponse_UsageAlignedWithClaudeMaxSimulation(t *testing.T) { + gin.SetMode(gin.TestMode) + + svc := &GatewayService{ + cfg: &config.Config{}, + rateLimitService: &RateLimitService{}, + } + + account := &Account{ + ID: 11, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "cache_ttl_override_enabled": true, + "cache_ttl_override_target": "5m", + }, + } + group := &Group{ + ID: 99, + Platform: PlatformAnthropic, + SimulateClaudeMaxEnabled: true, + } + parsed := &ParsedRequest{ + Model: "claude-sonnet-4", + Messages: []any{ + map[string]any{ + "role": "user", + "content": []any{ + map[string]any{ + "type": "text", + "text": "long cached context", + "cache_control": map[string]any{"type": "ephemeral"}, + }, + map[string]any{ + "type": "text", + "text": "new user question", + }, + }, + }, + }, + } + + upstreamBody := []byte(`{"id":"msg_1","model":"claude-sonnet-4","usage":{"input_tokens":120,"output_tokens":8}}`) + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: ioNopCloserBytes(upstreamBody), + } + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(nil)) + c.Set("api_key", &APIKey{Group: group}) + requestCtx := withClaudeMaxResponseRewriteContext(context.Background(), c, parsed) + + usage, err := svc.handleNonStreamingResponse(requestCtx, resp, c, account, "claude-sonnet-4", "claude-sonnet-4") + require.NoError(t, err) + require.NotNil(t, usage) + + var rendered struct { + Usage ClaudeUsage `json:"usage"` + } + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &rendered)) + rendered.Usage.CacheCreation5mTokens = int(gjson.GetBytes(rec.Body.Bytes(), "usage.cache_creation.ephemeral_5m_input_tokens").Int()) + rendered.Usage.CacheCreation1hTokens = int(gjson.GetBytes(rec.Body.Bytes(), "usage.cache_creation.ephemeral_1h_input_tokens").Int()) + + require.Equal(t, rendered.Usage.InputTokens, usage.InputTokens) + require.Equal(t, rendered.Usage.OutputTokens, usage.OutputTokens) + require.Equal(t, rendered.Usage.CacheCreationInputTokens, usage.CacheCreationInputTokens) + require.Equal(t, rendered.Usage.CacheCreation5mTokens, usage.CacheCreation5mTokens) + require.Equal(t, rendered.Usage.CacheCreation1hTokens, usage.CacheCreation1hTokens) + require.Equal(t, rendered.Usage.CacheReadInputTokens, usage.CacheReadInputTokens) + + require.Greater(t, usage.CacheCreation1hTokens, 0) + require.Equal(t, 0, usage.CacheCreation5mTokens) + require.Less(t, usage.InputTokens, 120) +} + +func TestHandleNonStreamingResponse_ClaudeMaxDisabled_NoSimulationIntercept(t *testing.T) { + gin.SetMode(gin.TestMode) + + svc := &GatewayService{ + cfg: &config.Config{}, + rateLimitService: &RateLimitService{}, + } + + account := &Account{ + ID: 12, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "cache_ttl_override_enabled": true, + "cache_ttl_override_target": "5m", + }, + } + group := &Group{ + ID: 100, + Platform: PlatformAnthropic, + SimulateClaudeMaxEnabled: false, + } + parsed := &ParsedRequest{ + Model: "claude-sonnet-4", + Messages: []any{ + map[string]any{ + "role": "user", + "content": []any{ + map[string]any{ + "type": "text", + "text": "long cached context", + "cache_control": map[string]any{"type": "ephemeral"}, + }, + map[string]any{ + "type": "text", + "text": "new user question", + }, + }, + }, + }, + } + + upstreamBody := []byte(`{"id":"msg_2","model":"claude-sonnet-4","usage":{"input_tokens":120,"output_tokens":8}}`) + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: ioNopCloserBytes(upstreamBody), + } + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(nil)) + c.Set("api_key", &APIKey{Group: group}) + requestCtx := withClaudeMaxResponseRewriteContext(context.Background(), c, parsed) + + usage, err := svc.handleNonStreamingResponse(requestCtx, resp, c, account, "claude-sonnet-4", "claude-sonnet-4") + require.NoError(t, err) + require.NotNil(t, usage) + + require.Equal(t, 120, usage.InputTokens) + require.Equal(t, 0, usage.CacheCreationInputTokens) + require.Equal(t, 0, usage.CacheCreation5mTokens) + require.Equal(t, 0, usage.CacheCreation1hTokens) +} + +func ioNopCloserBytes(b []byte) *readCloserFromBytes { + return &readCloserFromBytes{Reader: bytes.NewReader(b)} +} + +type readCloserFromBytes struct { + *bytes.Reader +} + +func (r *readCloserFromBytes) Close() error { + return nil +} diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 02f9a6a3..85a8e2b7 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -41,7 +41,7 @@ const ( claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true" claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true" stickySessionTTL = time.Hour // 粘性会话TTL - defaultMaxLineSize = 40 * 1024 * 1024 + defaultMaxLineSize = 500 * 1024 * 1024 // Canonical Claude Code banner. Keep it EXACT (no trailing whitespace/newlines) // to match real Claude CLI traffic as closely as possible. When we need a visual // separator between system blocks, we add "\n\n" at concatenation time. @@ -56,6 +56,12 @@ const ( claudeMimicDebugInfoKey = "claude_mimic_debug_info" ) +const ( + claudeMaxMessageOverheadTokens = 3 + claudeMaxBlockOverheadTokens = 1 + claudeMaxUnknownContentTokens = 4 +) + // ForceCacheBillingContextKey 强制缓存计费上下文键 // 用于粘性会话切换时,将 input_tokens 转为 cache_read_input_tokens 计费 type forceCacheBillingKeyType struct{} @@ -501,33 +507,35 @@ func (s *GatewayService) TempUnscheduleRetryableError(ctx context.Context, accou // GatewayService handles API gateway operations type GatewayService struct { - accountRepo AccountRepository - groupRepo GroupRepository - usageLogRepo UsageLogRepository - userRepo UserRepository - userSubRepo UserSubscriptionRepository - userGroupRateRepo UserGroupRateRepository - cache GatewayCache - digestStore *DigestSessionStore - cfg *config.Config - schedulerSnapshot *SchedulerSnapshotService - billingService *BillingService - rateLimitService *RateLimitService - billingCacheService *BillingCacheService - identityService *IdentityService - httpUpstream HTTPUpstream - deferredService *DeferredService - concurrencyService *ConcurrencyService - claudeTokenProvider *ClaudeTokenProvider - sessionLimitCache SessionLimitCache // 会话数量限制缓存(仅 Anthropic OAuth/SetupToken) - rpmCache RPMCache // RPM 计数缓存(仅 Anthropic OAuth/SetupToken) - userGroupRateCache *gocache.Cache - userGroupRateSF singleflight.Group - modelsListCache *gocache.Cache - modelsListCacheTTL time.Duration - responseHeaderFilter *responseheaders.CompiledHeaderFilter - debugModelRouting atomic.Bool - debugClaudeMimic atomic.Bool + accountRepo AccountRepository + groupRepo GroupRepository + usageLogRepo UsageLogRepository + userRepo UserRepository + userSubRepo UserSubscriptionRepository + userGroupRateRepo UserGroupRateRepository + cache GatewayCache + digestStore *DigestSessionStore + cfg *config.Config + schedulerSnapshot *SchedulerSnapshotService + billingService *BillingService + rateLimitService *RateLimitService + billingCacheService *BillingCacheService + identityService *IdentityService + httpUpstream HTTPUpstream + deferredService *DeferredService + concurrencyService *ConcurrencyService + claudeTokenProvider *ClaudeTokenProvider + sessionLimitCache SessionLimitCache // 会话数量限制缓存(仅 Anthropic OAuth/SetupToken) + rpmCache RPMCache // RPM 计数缓存(仅 Anthropic OAuth/SetupToken) + userGroupRateResolver *userGroupRateResolver + userGroupRateCache *gocache.Cache + userGroupRateSF singleflight.Group + modelsListCache *gocache.Cache + modelsListCacheTTL time.Duration + settingService *SettingService + responseHeaderFilter *responseheaders.CompiledHeaderFilter + debugModelRouting atomic.Bool + debugClaudeMimic atomic.Bool } // NewGatewayService creates a new GatewayService @@ -552,6 +560,7 @@ func NewGatewayService( sessionLimitCache SessionLimitCache, rpmCache RPMCache, digestStore *DigestSessionStore, + settingService *SettingService, ) *GatewayService { userGroupRateTTL := resolveUserGroupRateCacheTTL(cfg) modelsListTTL := resolveModelsListCacheTTL(cfg) @@ -578,10 +587,18 @@ func NewGatewayService( sessionLimitCache: sessionLimitCache, rpmCache: rpmCache, userGroupRateCache: gocache.New(userGroupRateTTL, time.Minute), + settingService: settingService, modelsListCache: gocache.New(modelsListTTL, time.Minute), modelsListCacheTTL: modelsListTTL, responseHeaderFilter: compileResponseHeaderFilter(cfg), } + svc.userGroupRateResolver = newUserGroupRateResolver( + userGroupRateRepo, + svc.userGroupRateCache, + userGroupRateTTL, + &svc.userGroupRateSF, + "service.gateway", + ) svc.debugModelRouting.Store(parseDebugEnvBool(os.Getenv("SUB2API_DEBUG_MODEL_ROUTING"))) svc.debugClaudeMimic.Store(parseDebugEnvBool(os.Getenv("SUB2API_DEBUG_CLAUDE_MIMIC"))) return svc @@ -986,6 +1003,11 @@ func (s *GatewayService) buildOAuthMetadataUserID(parsed *ParsedRequest, account return fmt.Sprintf("user_%s_account__session_%s", userID, sessionID) } +// GenerateSessionUUID creates a deterministic UUID4 from a seed string. +func GenerateSessionUUID(seed string) string { + return generateSessionUUID(seed) +} + func generateSessionUUID(seed string) string { if seed == "" { return uuid.NewString() @@ -1228,6 +1250,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro modelScopeSkippedIDs = append(modelScopeSkippedIDs, account.ID) continue } + // 配额检查 + if !s.isAccountSchedulableForQuota(account) { + continue + } // 窗口费用检查(非粘性会话路径) if !s.isAccountSchedulableForWindowCost(ctx, account, false) { filteredWindowCost++ @@ -1260,6 +1286,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro s.isAccountAllowedForPlatform(stickyAccount, platform, useMixed) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, stickyAccount, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, stickyAccount, requestedModel) && + s.isAccountSchedulableForQuota(stickyAccount) && s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true) && s.isAccountSchedulableForRPM(ctx, stickyAccount, true) { // 粘性会话窗口费用+RPM 检查 @@ -1311,7 +1338,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro for _, acc := range routingCandidates { routingLoads = append(routingLoads, AccountWithConcurrency{ ID: acc.ID, - MaxConcurrency: acc.Concurrency, + MaxConcurrency: acc.EffectiveLoadFactor(), }) } routingLoadMap, _ := s.concurrencyService.GetAccountsLoadBatch(ctx, routingLoads) @@ -1416,6 +1443,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro s.isAccountAllowedForPlatform(account, platform, useMixed) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && + s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) { // 粘性会话窗口费用+RPM 检查 @@ -1480,6 +1508,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) { continue } + // 配额检查 + if !s.isAccountSchedulableForQuota(acc) { + continue + } // 窗口费用检查(非粘性会话路径) if !s.isAccountSchedulableForWindowCost(ctx, acc, false) { continue @@ -1499,7 +1531,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro for _, acc := range candidates { accountLoads = append(accountLoads, AccountWithConcurrency{ ID: acc.ID, - MaxConcurrency: acc.Concurrency, + MaxConcurrency: acc.EffectiveLoadFactor(), }) } @@ -2113,6 +2145,15 @@ func (s *GatewayService) withWindowCostPrefetch(ctx context.Context, accounts [] return context.WithValue(ctx, windowCostPrefetchContextKey, costs) } +// isAccountSchedulableForQuota 检查 API Key 账号是否在配额限制内 +// 仅适用于配置了 quota_limit 的 apikey 类型账号 +func (s *GatewayService) isAccountSchedulableForQuota(account *Account) bool { + if account.Type != AccountTypeAPIKey { + return true + } + return !account.IsQuotaExceeded() +} + // isAccountSchedulableForWindowCost 检查账号是否可根据窗口费用进行调度 // 仅适用于 Anthropic OAuth/SetupToken 账号 // 返回 true 表示可调度,false 表示不可调度 @@ -2590,7 +2631,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, if clearSticky { _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) } - if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) { + if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) { if s.debugModelRoutingEnabled() { logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID) } @@ -2644,6 +2685,9 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) { continue } + if !s.isAccountSchedulableForQuota(acc) { + continue + } if !s.isAccountSchedulableForWindowCost(ctx, acc, false) { continue } @@ -2700,7 +2744,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, if clearSticky { _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) } - if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) { + if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) { return account, nil } } @@ -2743,6 +2787,9 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) { continue } + if !s.isAccountSchedulableForQuota(acc) { + continue + } if !s.isAccountSchedulableForWindowCost(ctx, acc, false) { continue } @@ -2818,7 +2865,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g if clearSticky { _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) } - if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) { + if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) { if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) { if s.debugModelRoutingEnabled() { logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy mixed routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID) @@ -2874,6 +2921,9 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) { continue } + if !s.isAccountSchedulableForQuota(acc) { + continue + } if !s.isAccountSchedulableForWindowCost(ctx, acc, false) { continue } @@ -2930,7 +2980,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g if clearSticky { _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) } - if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) { + if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) { if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) { return account, nil } @@ -2975,6 +3025,9 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) { continue } + if !s.isAccountSchedulableForQuota(acc) { + continue + } if !s.isAccountSchedulableForWindowCost(ctx, acc, false) { continue } @@ -3889,7 +3942,30 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A } if account != nil && account.IsAnthropicAPIKeyPassthroughEnabled() { - return s.forwardAnthropicAPIKeyPassthrough(ctx, c, account, parsed.Body, parsed.Model, parsed.Stream, startTime) + passthroughBody := parsed.Body + passthroughModel := parsed.Model + if passthroughModel != "" { + if mappedModel := account.GetMappedModel(passthroughModel); mappedModel != passthroughModel { + passthroughBody = s.replaceModelInBody(passthroughBody, mappedModel) + logger.LegacyPrintf("service.gateway", "Passthrough model mapping: %s -> %s (account: %s)", parsed.Model, mappedModel, account.Name) + passthroughModel = mappedModel + } + } + return s.forwardAnthropicAPIKeyPassthrough(ctx, c, account, passthroughBody, passthroughModel, parsed.Stream, startTime) + } + + // Beta policy: evaluate once; block check + cache filter set for buildUpstreamRequest. + // Always overwrite the cache to prevent stale values from a previous retry with a different account. + if account.Platform == PlatformAnthropic && c != nil { + policy := s.evaluateBetaPolicy(ctx, c.GetHeader("anthropic-beta"), account) + if policy.blockErr != nil { + return nil, policy.blockErr + } + filterSet := policy.filterSet + if filterSet == nil { + filterSet = map[string]struct{}{} + } + c.Set(betaPolicyFilterSetKey, filterSet) } body := parsed.Body @@ -4017,7 +4093,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A if readErr == nil { _ = resp.Body.Close() - if s.isThinkingBlockSignatureError(respBody) { + if s.isThinkingBlockSignatureError(respBody) && s.settingService.IsSignatureRectifierEnabled(ctx) { appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ Platform: account.Platform, AccountID: account.ID, @@ -4134,7 +4210,45 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A resp.Body = io.NopCloser(bytes.NewReader(respBody)) break } - // 不是thinking签名错误,恢复响应体 + // 不是签名错误(或整流器已关闭),继续检查 budget 约束 + errMsg := extractUpstreamErrorMessage(respBody) + if isThinkingBudgetConstraintError(errMsg) && s.settingService.IsBudgetRectifierEnabled(ctx) { + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: "budget_constraint_error", + Message: errMsg, + Detail: func() string { + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + return truncateString(string(respBody), s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes) + } + return "" + }(), + }) + + rectifiedBody, applied := RectifyThinkingBudget(body) + if applied && time.Since(retryStart) < maxRetryElapsed { + logger.LegacyPrintf("service.gateway", "Account %d: detected budget_tokens constraint error, retrying with rectified budget (budget_tokens=%d, max_tokens=%d)", account.ID, BudgetRectifyBudgetTokens, BudgetRectifyMaxTokens) + budgetRetryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, rectifiedBody, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode) + if buildErr == nil { + budgetRetryResp, retryErr := s.httpUpstream.DoWithTLS(budgetRetryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) + if retryErr == nil { + resp = budgetRetryResp + break + } + if budgetRetryResp != nil && budgetRetryResp.Body != nil { + _ = budgetRetryResp.Body.Close() + } + logger.LegacyPrintf("service.gateway", "Account %d: budget rectifier retry failed: %v", account.ID, retryErr) + } else { + logger.LegacyPrintf("service.gateway", "Account %d: budget rectifier retry build failed: %v", account.ID, buildErr) + } + } + } + resp.Body = io.NopCloser(bytes.NewReader(respBody)) } } @@ -4226,7 +4340,11 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A return "" }(), }) - return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} + return nil, &UpstreamFailoverError{ + StatusCode: resp.StatusCode, + ResponseBody: respBody, + RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode), + } } return s.handleRetryExhaustedError(ctx, resp, c, account) } @@ -4256,7 +4374,11 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A return "" }(), }) - return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} + return nil, &UpstreamFailoverError{ + StatusCode: resp.StatusCode, + ResponseBody: respBody, + RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode), + } } if resp.StatusCode >= 400 { // 可选:对部分 400 触发 failover(默认关闭以保持语义) @@ -4308,6 +4430,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A } // 处理正常响应 + ctx = withClaudeMaxResponseRewriteContext(ctx, c, parsed) // 触发上游接受回调(提前释放串行锁,不等流完成) if parsed.OnUpstreamAccepted != nil { @@ -4491,7 +4614,11 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthrough( return "" }(), }) - return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} + return nil, &UpstreamFailoverError{ + StatusCode: resp.StatusCode, + ResponseBody: respBody, + RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode), + } } return s.handleRetryExhaustedError(ctx, resp, c, account) } @@ -4521,7 +4648,11 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthrough( return "" }(), }) - return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} + return nil, &UpstreamFailoverError{ + StatusCode: resp.StatusCode, + ResponseBody: respBody, + RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode), + } } if resp.StatusCode >= 400 { @@ -4574,7 +4705,7 @@ func (s *GatewayService) buildUpstreamRequestAnthropicAPIKeyPassthrough( if err != nil { return nil, err } - targetURL = validatedURL + "/v1/messages" + targetURL = validatedURL + "/v1/messages?beta=true" } req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body)) @@ -4954,7 +5085,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex if err != nil { return nil, err } - targetURL = validatedURL + "/v1/messages" + targetURL = validatedURL + "/v1/messages?beta=true" } } @@ -5023,6 +5154,11 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex applyClaudeOAuthHeaderDefaults(req, reqStream) } + // Build effective drop set: merge static defaults with dynamic beta policy filter rules + policyFilterSet := s.getBetaPolicyFilterSet(ctx, c, account) + effectiveDropSet := mergeDropSets(policyFilterSet) + effectiveDropWithClaudeCodeSet := mergeDropSets(policyFilterSet, claude.BetaClaudeCode) + // 处理 anthropic-beta header(OAuth 账号需要包含 oauth beta) if tokenType == "oauth" { if mimicClaudeCode { @@ -5036,17 +5172,22 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex // messages requests typically use only oauth + interleaved-thinking. // Also drop claude-code beta if a downstream client added it. requiredBetas := []string{claude.BetaOAuth, claude.BetaInterleavedThinking} - req.Header.Set("anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, droppedBetasWithClaudeCodeSet)) + req.Header.Set("anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, effectiveDropWithClaudeCodeSet)) } else { // Claude Code 客户端:尽量透传原始 header,仅补齐 oauth beta clientBetaHeader := req.Header.Get("anthropic-beta") - req.Header.Set("anthropic-beta", stripBetaTokensWithSet(s.getBetaHeader(modelID, clientBetaHeader), defaultDroppedBetasSet)) + req.Header.Set("anthropic-beta", stripBetaTokensWithSet(s.getBetaHeader(modelID, clientBetaHeader), effectiveDropSet)) } - } else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" { - // API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭) - if requestNeedsBetaFeatures(body) { - if beta := defaultAPIKeyBetaHeader(body); beta != "" { - req.Header.Set("anthropic-beta", beta) + } else { + // API-key accounts: apply beta policy filter to strip controlled tokens + if existingBeta := req.Header.Get("anthropic-beta"); existingBeta != "" { + req.Header.Set("anthropic-beta", stripBetaTokensWithSet(existingBeta, effectiveDropSet)) + } else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey { + // API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭) + if requestNeedsBetaFeatures(body) { + if beta := defaultAPIKeyBetaHeader(body); beta != "" { + req.Header.Set("anthropic-beta", beta) + } } } } @@ -5224,6 +5365,104 @@ func stripBetaTokensWithSet(header string, drop map[string]struct{}) string { return strings.Join(out, ",") } +// BetaBlockedError indicates a request was blocked by a beta policy rule. +type BetaBlockedError struct { + Message string +} + +func (e *BetaBlockedError) Error() string { return e.Message } + +// betaPolicyResult holds the evaluated result of beta policy rules for a single request. +type betaPolicyResult struct { + blockErr *BetaBlockedError // non-nil if a block rule matched + filterSet map[string]struct{} // tokens to filter (may be nil) +} + +// evaluateBetaPolicy loads settings once and evaluates all rules against the given request. +func (s *GatewayService) evaluateBetaPolicy(ctx context.Context, betaHeader string, account *Account) betaPolicyResult { + if s.settingService == nil { + return betaPolicyResult{} + } + settings, err := s.settingService.GetBetaPolicySettings(ctx) + if err != nil || settings == nil { + return betaPolicyResult{} + } + isOAuth := account.IsOAuth() + var result betaPolicyResult + for _, rule := range settings.Rules { + if !betaPolicyScopeMatches(rule.Scope, isOAuth) { + continue + } + switch rule.Action { + case BetaPolicyActionBlock: + if result.blockErr == nil && betaHeader != "" && containsBetaToken(betaHeader, rule.BetaToken) { + msg := rule.ErrorMessage + if msg == "" { + msg = "beta feature " + rule.BetaToken + " is not allowed" + } + result.blockErr = &BetaBlockedError{Message: msg} + } + case BetaPolicyActionFilter: + if result.filterSet == nil { + result.filterSet = make(map[string]struct{}) + } + result.filterSet[rule.BetaToken] = struct{}{} + } + } + return result +} + +// mergeDropSets merges the static defaultDroppedBetasSet with dynamic policy filter tokens. +// Returns defaultDroppedBetasSet directly when policySet is empty (zero allocation). +func mergeDropSets(policySet map[string]struct{}, extra ...string) map[string]struct{} { + if len(policySet) == 0 && len(extra) == 0 { + return defaultDroppedBetasSet + } + m := make(map[string]struct{}, len(defaultDroppedBetasSet)+len(policySet)+len(extra)) + for t := range defaultDroppedBetasSet { + m[t] = struct{}{} + } + for t := range policySet { + m[t] = struct{}{} + } + for _, t := range extra { + m[t] = struct{}{} + } + return m +} + +// betaPolicyFilterSetKey is the gin.Context key for caching the policy filter set within a request. +const betaPolicyFilterSetKey = "betaPolicyFilterSet" + +// getBetaPolicyFilterSet returns the beta policy filter set, using the gin context cache if available. +// In the /v1/messages path, Forward() evaluates the policy first and caches the result; +// buildUpstreamRequest reuses it (zero extra DB calls). In the count_tokens path, this +// evaluates on demand (one DB call). +func (s *GatewayService) getBetaPolicyFilterSet(ctx context.Context, c *gin.Context, account *Account) map[string]struct{} { + if c != nil { + if v, ok := c.Get(betaPolicyFilterSetKey); ok { + if fs, ok := v.(map[string]struct{}); ok { + return fs + } + } + } + return s.evaluateBetaPolicy(ctx, "", account).filterSet +} + +// betaPolicyScopeMatches checks whether a rule's scope matches the current account type. +func betaPolicyScopeMatches(scope string, isOAuth bool) bool { + switch scope { + case BetaPolicyScopeAll: + return true + case BetaPolicyScopeOAuth: + return isOAuth + case BetaPolicyScopeAPIKey: + return !isOAuth + default: + return true // unknown scope → match all (fail-open) + } +} + // droppedBetaSet returns claude.DroppedBetas as a set, with optional extra tokens. func droppedBetaSet(extra ...string) map[string]struct{} { m := make(map[string]struct{}, len(defaultDroppedBetasSet)+len(extra)) @@ -5236,6 +5475,19 @@ func droppedBetaSet(extra ...string) map[string]struct{} { return m } +// containsBetaToken checks if a comma-separated header value contains the given token. +func containsBetaToken(header, token string) bool { + if header == "" || token == "" { + return false + } + for _, p := range strings.Split(header, ",") { + if strings.TrimSpace(p) == token { + return true + } + } + return false +} + func buildBetaTokenSet(tokens []string) map[string]struct{} { m := make(map[string]struct{}, len(tokens)) for _, t := range tokens { @@ -5247,10 +5499,7 @@ func buildBetaTokenSet(tokens []string) map[string]struct{} { return m } -var ( - defaultDroppedBetasSet = buildBetaTokenSet(claude.DroppedBetas) - droppedBetasWithClaudeCodeSet = droppedBetaSet(claude.BetaClaudeCode) -) +var defaultDroppedBetasSet = buildBetaTokenSet(claude.DroppedBetas) // applyClaudeCodeMimicHeaders forces "Claude Code-like" request headers. // This mirrors opencode-anthropic-auth behavior: do not trust downstream @@ -5377,6 +5626,11 @@ func extractUpstreamErrorMessage(body []byte) string { return m } + // ChatGPT 内部 API 风格:{"detail":"..."} + if d := gjson.GetBytes(body, "detail").String(); strings.TrimSpace(d) != "" { + return d + } + // 兜底:尝试顶层 message return gjson.GetBytes(body, "message").String() } @@ -5751,6 +6005,22 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http intervalCh = intervalTicker.C } + // 下游 keepalive:防止代理/Cloudflare Tunnel 因连接空闲而断开 + keepaliveInterval := time.Duration(0) + if s.cfg != nil && s.cfg.Gateway.StreamKeepaliveInterval > 0 { + keepaliveInterval = time.Duration(s.cfg.Gateway.StreamKeepaliveInterval) * time.Second + } + var keepaliveTicker *time.Ticker + if keepaliveInterval > 0 { + keepaliveTicker = time.NewTicker(keepaliveInterval) + defer keepaliveTicker.Stop() + } + var keepaliveCh <-chan time.Time + if keepaliveTicker != nil { + keepaliveCh = keepaliveTicker.C + } + lastDataAt := time.Now() + // 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端) errorEventSent := false sendErrorEvent := func(reason string) { @@ -5764,6 +6034,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http needModelReplace := originalModel != mappedModel clientDisconnected := false // 客户端断开标志,断开后继续读取上游以获取完整usage + skipAccountTTLOverride := false pendingEventLines := make([]string, 0, 4) @@ -5824,17 +6095,25 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http if msg, ok := event["message"].(map[string]any); ok { if u, ok := msg["usage"].(map[string]any); ok { eventChanged = reconcileCachedTokens(u) || eventChanged + claudeMaxOutcome := applyClaudeMaxSimulationToUsageJSONMap(ctx, u, originalModel, account.ID) + if claudeMaxOutcome.Simulated { + skipAccountTTLOverride = true + } } } } if eventType == "message_delta" { if u, ok := event["usage"].(map[string]any); ok { eventChanged = reconcileCachedTokens(u) || eventChanged + claudeMaxOutcome := applyClaudeMaxSimulationToUsageJSONMap(ctx, u, originalModel, account.ID) + if claudeMaxOutcome.Simulated { + skipAccountTTLOverride = true + } } } // Cache TTL Override: 重写 SSE 事件中的 cache_creation 分类 - if account.IsCacheTTLOverrideEnabled() { + if account.IsCacheTTLOverrideEnabled() && !skipAccountTTLOverride { overrideTarget := account.GetCacheTTLOverrideTarget() if eventType == "message_start" { if msg, ok := event["message"].(map[string]any); ok { @@ -5940,6 +6219,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http break } flusher.Flush() + lastDataAt = time.Now() } if data != "" { if firstTokenMs == nil && data != "[DONE]" { @@ -5973,6 +6253,22 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http } sendErrorEvent("stream_timeout") return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout") + + case <-keepaliveCh: + if clientDisconnected { + continue + } + if time.Since(lastDataAt) < keepaliveInterval { + continue + } + // SSE ping 事件:Anthropic 原生格式,客户端会正确处理, + // 同时保持连接活跃防止 Cloudflare Tunnel 等代理断开 + if _, werr := fmt.Fprint(w, "event: ping\ndata: {\"type\": \"ping\"}\n\n"); werr != nil { + clientDisconnected = true + logger.LegacyPrintf("service.gateway", "Client disconnected during keepalive ping, continuing to drain upstream for billing") + continue + } + flusher.Flush() } } @@ -6244,8 +6540,13 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h } } + claudeMaxOutcome := applyClaudeMaxSimulationToUsage(ctx, &response.Usage, originalModel, account.ID) + if claudeMaxOutcome.Simulated { + body = rewriteClaudeUsageJSONBytes(body, response.Usage) + } + // Cache TTL Override: 重写 non-streaming 响应中的 cache_creation 分类 - if account.IsCacheTTLOverrideEnabled() { + if account.IsCacheTTLOverrideEnabled() && !claudeMaxOutcome.Simulated { overrideTarget := account.GetCacheTTLOverrideTarget() if applyCacheTTLOverride(&response.Usage, overrideTarget) { // 同步更新 body JSON 中的嵌套 cache_creation 对象 @@ -6292,68 +6593,26 @@ func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toMo } func (s *GatewayService) getUserGroupRateMultiplier(ctx context.Context, userID, groupID int64, groupDefaultMultiplier float64) float64 { - if s == nil || userID <= 0 || groupID <= 0 { + if s == nil { return groupDefaultMultiplier } - - key := fmt.Sprintf("%d:%d", userID, groupID) - if s.userGroupRateCache != nil { - if cached, ok := s.userGroupRateCache.Get(key); ok { - if multiplier, castOK := cached.(float64); castOK { - userGroupRateCacheHitTotal.Add(1) - return multiplier - } - } + resolver := s.userGroupRateResolver + if resolver == nil { + resolver = newUserGroupRateResolver( + s.userGroupRateRepo, + s.userGroupRateCache, + resolveUserGroupRateCacheTTL(s.cfg), + &s.userGroupRateSF, + "service.gateway", + ) } - if s.userGroupRateRepo == nil { - return groupDefaultMultiplier - } - userGroupRateCacheMissTotal.Add(1) - - value, err, shared := s.userGroupRateSF.Do(key, func() (any, error) { - if s.userGroupRateCache != nil { - if cached, ok := s.userGroupRateCache.Get(key); ok { - if multiplier, castOK := cached.(float64); castOK { - userGroupRateCacheHitTotal.Add(1) - return multiplier, nil - } - } - } - - userGroupRateCacheLoadTotal.Add(1) - userRate, repoErr := s.userGroupRateRepo.GetByUserAndGroup(ctx, userID, groupID) - if repoErr != nil { - return nil, repoErr - } - multiplier := groupDefaultMultiplier - if userRate != nil { - multiplier = *userRate - } - if s.userGroupRateCache != nil { - s.userGroupRateCache.Set(key, multiplier, resolveUserGroupRateCacheTTL(s.cfg)) - } - return multiplier, nil - }) - if shared { - userGroupRateCacheSFSharedTotal.Add(1) - } - if err != nil { - userGroupRateCacheFallbackTotal.Add(1) - logger.LegacyPrintf("service.gateway", "get user group rate failed, fallback to group default: user=%d group=%d err=%v", userID, groupID, err) - return groupDefaultMultiplier - } - - multiplier, ok := value.(float64) - if !ok { - userGroupRateCacheFallbackTotal.Add(1) - return groupDefaultMultiplier - } - return multiplier + return resolver.Resolve(ctx, userID, groupID, groupDefaultMultiplier) } // RecordUsageInput 记录使用量的输入参数 type RecordUsageInput struct { Result *ForwardResult + ParsedRequest *ParsedRequest APIKey *APIKey User *User Account *Account @@ -6370,6 +6629,89 @@ type APIKeyQuotaUpdater interface { UpdateRateLimitUsage(ctx context.Context, apiKeyID int64, cost float64) error } +// postUsageBillingParams 统一扣费所需的参数 +type postUsageBillingParams struct { + Cost *CostBreakdown + User *User + APIKey *APIKey + Account *Account + Subscription *UserSubscription + IsSubscriptionBill bool + AccountRateMultiplier float64 + APIKeyService APIKeyQuotaUpdater +} + +// postUsageBilling 统一处理使用量记录后的扣费逻辑: +// - 订阅/余额扣费 +// - API Key 配额更新 +// - API Key 限速用量更新 +// - 账号配额用量更新(账号口径:TotalCost × 账号计费倍率) +func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *billingDeps) { + cost := p.Cost + + // 1. 订阅 / 余额扣费 + if p.IsSubscriptionBill { + if cost.TotalCost > 0 { + if err := deps.userSubRepo.IncrementUsage(ctx, p.Subscription.ID, cost.TotalCost); err != nil { + slog.Error("increment subscription usage failed", "subscription_id", p.Subscription.ID, "error", err) + } + deps.billingCacheService.QueueUpdateSubscriptionUsage(p.User.ID, *p.APIKey.GroupID, cost.TotalCost) + } + } else { + if cost.ActualCost > 0 { + if err := deps.userRepo.DeductBalance(ctx, p.User.ID, cost.ActualCost); err != nil { + slog.Error("deduct balance failed", "user_id", p.User.ID, "error", err) + } + deps.billingCacheService.QueueDeductBalance(p.User.ID, cost.ActualCost) + } + } + + // 2. API Key 配额 + if cost.ActualCost > 0 && p.APIKey.Quota > 0 && p.APIKeyService != nil { + if err := p.APIKeyService.UpdateQuotaUsed(ctx, p.APIKey.ID, cost.ActualCost); err != nil { + slog.Error("update api key quota failed", "api_key_id", p.APIKey.ID, "error", err) + } + } + + // 3. API Key 限速用量 + if cost.ActualCost > 0 && p.APIKey.HasRateLimits() && p.APIKeyService != nil { + if err := p.APIKeyService.UpdateRateLimitUsage(ctx, p.APIKey.ID, cost.ActualCost); err != nil { + slog.Error("update api key rate limit usage failed", "api_key_id", p.APIKey.ID, "error", err) + } + deps.billingCacheService.QueueUpdateAPIKeyRateLimitUsage(p.APIKey.ID, cost.ActualCost) + } + + // 4. 账号配额用量(账号口径:TotalCost × 账号计费倍率) + if cost.TotalCost > 0 && p.Account.Type == AccountTypeAPIKey && p.Account.HasAnyQuotaLimit() { + accountCost := cost.TotalCost * p.AccountRateMultiplier + if err := deps.accountRepo.IncrementQuotaUsed(ctx, p.Account.ID, accountCost); err != nil { + slog.Error("increment account quota used failed", "account_id", p.Account.ID, "cost", accountCost, "error", err) + } + } + + // 5. 更新账号最近使用时间 + deps.deferredService.ScheduleLastUsedUpdate(p.Account.ID) +} + +// billingDeps 扣费逻辑依赖的服务(由各 gateway service 提供) +type billingDeps struct { + accountRepo AccountRepository + userRepo UserRepository + userSubRepo UserSubscriptionRepository + billingCacheService *BillingCacheService + deferredService *DeferredService +} + +func (s *GatewayService) billingDeps() *billingDeps { + return &billingDeps{ + accountRepo: s.accountRepo, + userRepo: s.userRepo, + userSubRepo: s.userSubRepo, + billingCacheService: s.billingCacheService, + deferredService: s.deferredService, + } +} + // RecordUsage 记录使用量并扣费(或更新订阅用量) func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInput) error { result := input.Result @@ -6387,9 +6729,19 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu result.Usage.InputTokens = 0 } + // Claude Max cache billing policy (group-level): + // - GatewayService 路径: Forward 已改写 usage(含 cache tokens)→ apply 见到 cache tokens 跳过 → simulatedClaudeMax=true(通过第二条件) + // - Antigravity 路径: Forward 中 hook 改写了客户端 SSE,但 ForwardResult.Usage 是原始值 → apply 实际执行模拟 → simulatedClaudeMax=true + var apiKeyGroup *Group + if apiKey != nil { + apiKeyGroup = apiKey.Group + } + claudeMaxOutcome := applyClaudeMaxCacheBillingPolicyToUsage(&result.Usage, input.ParsedRequest, apiKeyGroup, result.Model, account.ID) + simulatedClaudeMax := claudeMaxOutcome.Simulated || + (shouldApplyClaudeMaxBillingRulesForUsage(apiKeyGroup, result.Model, input.ParsedRequest) && hasCacheCreationTokens(result.Usage)) // Cache TTL Override: 确保计费时 token 分类与账号设置一致 cacheTTLOverridden := false - if account.IsCacheTTLOverrideEnabled() { + if account.IsCacheTTLOverrideEnabled() && !simulatedClaudeMax { applyCacheTTLOverride(&result.Usage, account.GetCacheTTLOverrideTarget()) cacheTTLOverridden = (result.Usage.CacheCreation5mTokens + result.Usage.CacheCreation1hTokens) > 0 } @@ -6533,45 +6885,21 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu shouldBill := inserted || err != nil - // 根据计费类型执行扣费 - if isSubscriptionBilling { - // 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率) - if shouldBill && cost.TotalCost > 0 { - if err := s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost); err != nil { - logger.LegacyPrintf("service.gateway", "Increment subscription usage failed: %v", err) - } - // 异步更新订阅缓存 - s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost) - } + if shouldBill { + postUsageBilling(ctx, &postUsageBillingParams{ + Cost: cost, + User: user, + APIKey: apiKey, + Account: account, + Subscription: subscription, + IsSubscriptionBill: isSubscriptionBilling, + AccountRateMultiplier: accountRateMultiplier, + APIKeyService: input.APIKeyService, + }, s.billingDeps()) } else { - // 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用) - if shouldBill && cost.ActualCost > 0 { - if err := s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost); err != nil { - logger.LegacyPrintf("service.gateway", "Deduct balance failed: %v", err) - } - // 异步更新余额缓存 - s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost) - } + s.deferredService.ScheduleLastUsedUpdate(account.ID) } - // 更新 API Key 配额(如果设置了配额限制) - if shouldBill && cost.ActualCost > 0 && apiKey.Quota > 0 && input.APIKeyService != nil { - if err := input.APIKeyService.UpdateQuotaUsed(ctx, apiKey.ID, cost.ActualCost); err != nil { - logger.LegacyPrintf("service.gateway", "Update API key quota failed: %v", err) - } - } - - // Update API Key rate limit usage - if shouldBill && cost.ActualCost > 0 && apiKey.HasRateLimits() && input.APIKeyService != nil { - if err := input.APIKeyService.UpdateRateLimitUsage(ctx, apiKey.ID, cost.ActualCost); err != nil { - logger.LegacyPrintf("service.gateway", "Update API key rate limit usage failed: %v", err) - } - s.billingCacheService.QueueUpdateAPIKeyRateLimitUsage(apiKey.ID, cost.ActualCost) - } - - // Schedule batch update for account last_used_at - s.deferredService.ScheduleLastUsedUpdate(account.ID) - return nil } @@ -6731,44 +7059,21 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * shouldBill := inserted || err != nil - // 根据计费类型执行扣费 - if isSubscriptionBilling { - // 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率) - if shouldBill && cost.TotalCost > 0 { - if err := s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost); err != nil { - logger.LegacyPrintf("service.gateway", "Increment subscription usage failed: %v", err) - } - // 异步更新订阅缓存 - s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost) - } + if shouldBill { + postUsageBilling(ctx, &postUsageBillingParams{ + Cost: cost, + User: user, + APIKey: apiKey, + Account: account, + Subscription: subscription, + IsSubscriptionBill: isSubscriptionBilling, + AccountRateMultiplier: accountRateMultiplier, + APIKeyService: input.APIKeyService, + }, s.billingDeps()) } else { - // 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用) - if shouldBill && cost.ActualCost > 0 { - if err := s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost); err != nil { - logger.LegacyPrintf("service.gateway", "Deduct balance failed: %v", err) - } - // 异步更新余额缓存 - s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost) - // API Key 独立配额扣费 - if input.APIKeyService != nil && apiKey.Quota > 0 { - if err := input.APIKeyService.UpdateQuotaUsed(ctx, apiKey.ID, cost.ActualCost); err != nil { - logger.LegacyPrintf("service.gateway", "Add API key quota used failed: %v", err) - } - } - } + s.deferredService.ScheduleLastUsedUpdate(account.ID) } - // Update API Key rate limit usage - if shouldBill && cost.ActualCost > 0 && apiKey.HasRateLimits() && input.APIKeyService != nil { - if err := input.APIKeyService.UpdateRateLimitUsage(ctx, apiKey.ID, cost.ActualCost); err != nil { - logger.LegacyPrintf("service.gateway", "Update API key rate limit usage failed: %v", err) - } - s.billingCacheService.QueueUpdateAPIKeyRateLimitUsage(apiKey.ID, cost.ActualCost) - } - - // Schedule batch update for account last_used_at - s.deferredService.ScheduleLastUsedUpdate(account.ID) - return nil } @@ -6781,7 +7086,14 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, } if account != nil && account.IsAnthropicAPIKeyPassthroughEnabled() { - return s.forwardCountTokensAnthropicAPIKeyPassthrough(ctx, c, account, parsed.Body) + passthroughBody := parsed.Body + if reqModel := parsed.Model; reqModel != "" { + if mappedModel := account.GetMappedModel(reqModel); mappedModel != reqModel { + passthroughBody = s.replaceModelInBody(passthroughBody, mappedModel) + logger.LegacyPrintf("service.gateway", "CountTokens passthrough model mapping: %s -> %s (account: %s)", reqModel, mappedModel, account.Name) + } + } + return s.forwardCountTokensAnthropicAPIKeyPassthrough(ctx, c, account, passthroughBody) } body := parsed.Body @@ -6871,7 +7183,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, } // 检测 thinking block 签名错误(400)并重试一次(过滤 thinking blocks) - if resp.StatusCode == 400 && s.isThinkingBlockSignatureError(respBody) { + if resp.StatusCode == 400 && s.isThinkingBlockSignatureError(respBody) && s.settingService.IsSignatureRectifierEnabled(ctx) { logger.LegacyPrintf("service.gateway", "Account %d: detected thinking block signature error on count_tokens, retrying with filtered thinking blocks", account.ID) filteredBody := FilterThinkingBlocksForRetry(body) @@ -7072,7 +7384,7 @@ func (s *GatewayService) buildCountTokensRequestAnthropicAPIKeyPassthrough( if err != nil { return nil, err } - targetURL = validatedURL + "/v1/messages/count_tokens" + targetURL = validatedURL + "/v1/messages/count_tokens?beta=true" } req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body)) @@ -7119,7 +7431,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con if err != nil { return nil, err } - targetURL = validatedURL + "/v1/messages/count_tokens" + targetURL = validatedURL + "/v1/messages/count_tokens?beta=true" } } @@ -7183,6 +7495,9 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con applyClaudeOAuthHeaderDefaults(req, false) } + // Build effective drop set for count_tokens: merge static defaults with dynamic beta policy filter rules + ctEffectiveDropSet := mergeDropSets(s.getBetaPolicyFilterSet(ctx, c, account)) + // OAuth 账号:处理 anthropic-beta header if tokenType == "oauth" { if mimicClaudeCode { @@ -7190,8 +7505,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con incomingBeta := req.Header.Get("anthropic-beta") requiredBetas := []string{claude.BetaClaudeCode, claude.BetaOAuth, claude.BetaInterleavedThinking, claude.BetaTokenCounting} - drop := droppedBetaSet() - req.Header.Set("anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, drop)) + req.Header.Set("anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, ctEffectiveDropSet)) } else { clientBetaHeader := req.Header.Get("anthropic-beta") if clientBetaHeader == "" { @@ -7201,14 +7515,19 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con if !strings.Contains(beta, claude.BetaTokenCounting) { beta = beta + "," + claude.BetaTokenCounting } - req.Header.Set("anthropic-beta", stripBetaTokensWithSet(beta, defaultDroppedBetasSet)) + req.Header.Set("anthropic-beta", stripBetaTokensWithSet(beta, ctEffectiveDropSet)) } } - } else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" { - // API-key:与 messages 同步的按需 beta 注入(默认关闭) - if requestNeedsBetaFeatures(body) { - if beta := defaultAPIKeyBetaHeader(body); beta != "" { - req.Header.Set("anthropic-beta", beta) + } else { + // API-key accounts: apply beta policy filter to strip controlled tokens + if existingBeta := req.Header.Get("anthropic-beta"); existingBeta != "" { + req.Header.Set("anthropic-beta", stripBetaTokensWithSet(existingBeta, ctEffectiveDropSet)) + } else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey { + // API-key:与 messages 同步的按需 beta 注入(默认关闭) + if requestNeedsBetaFeatures(body) { + if beta := defaultAPIKeyBetaHeader(body); beta != "" { + req.Header.Set("anthropic-beta", beta) + } } } } diff --git a/backend/internal/service/gemini_error_policy_test.go b/backend/internal/service/gemini_error_policy_test.go index 2ce8793a..4bd1ced7 100644 --- a/backend/internal/service/gemini_error_policy_test.go +++ b/backend/internal/service/gemini_error_policy_test.go @@ -122,6 +122,28 @@ func TestCheckErrorPolicy_GeminiAccounts(t *testing.T) { body: []byte(`overloaded service`), expected: ErrorPolicyTempUnscheduled, }, + { + name: "gemini_apikey_temp_unschedulable_401_second_hit_returns_none", + account: &Account{ + ID: 105, + Type: AccountTypeAPIKey, + Platform: PlatformGemini, + TempUnschedulableReason: `{"status_code":401,"until_unix":1735689600}`, + Credentials: map[string]any{ + "temp_unschedulable_enabled": true, + "temp_unschedulable_rules": []any{ + map[string]any{ + "error_code": float64(401), + "keywords": []any{"unauthorized"}, + "duration_minutes": float64(10), + }, + }, + }, + }, + statusCode: 401, + body: []byte(`unauthorized`), + expected: ErrorPolicyNone, + }, { name: "gemini_custom_codes_override_temp_unschedulable", account: &Account{ diff --git a/backend/internal/service/gemini_multiplatform_test.go b/backend/internal/service/gemini_multiplatform_test.go index 9476e984..b0b804eb 100644 --- a/backend/internal/service/gemini_multiplatform_test.go +++ b/backend/internal/service/gemini_multiplatform_test.go @@ -176,6 +176,14 @@ func (m *mockAccountRepoForGemini) BulkUpdate(ctx context.Context, ids []int64, return 0, nil } +func (m *mockAccountRepoForGemini) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error { + return nil +} + +func (m *mockAccountRepoForGemini) ResetQuotaUsed(ctx context.Context, id int64) error { + return nil +} + // Verify interface implementation var _ AccountRepository = (*mockAccountRepoForGemini)(nil) diff --git a/backend/internal/service/gemini_native_signature_cleaner_test.go b/backend/internal/service/gemini_native_signature_cleaner_test.go new file mode 100644 index 00000000..2e184919 --- /dev/null +++ b/backend/internal/service/gemini_native_signature_cleaner_test.go @@ -0,0 +1,75 @@ +package service + +import ( + "encoding/json" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" + "github.com/stretchr/testify/require" +) + +func TestCleanGeminiNativeThoughtSignatures_ReplacesNestedThoughtSignatures(t *testing.T) { + input := []byte(`{ + "contents": [ + { + "role": "user", + "parts": [{"text": "hello"}] + }, + { + "role": "model", + "parts": [ + {"text": "thinking", "thought": true, "thoughtSignature": "sig_1"}, + {"functionCall": {"name": "toolA", "args": {"k": "v"}}, "thoughtSignature": "sig_2"} + ] + } + ], + "cachedContent": { + "parts": [{"text": "cached", "thoughtSignature": "sig_3"}] + }, + "signature": "keep_me" + }`) + + cleaned := CleanGeminiNativeThoughtSignatures(input) + + var got map[string]any + require.NoError(t, json.Unmarshal(cleaned, &got)) + + require.NotContains(t, string(cleaned), `"thoughtSignature":"sig_1"`) + require.NotContains(t, string(cleaned), `"thoughtSignature":"sig_2"`) + require.NotContains(t, string(cleaned), `"thoughtSignature":"sig_3"`) + require.Contains(t, string(cleaned), `"thoughtSignature":"`+antigravity.DummyThoughtSignature+`"`) + require.Contains(t, string(cleaned), `"signature":"keep_me"`) +} + +func TestCleanGeminiNativeThoughtSignatures_InvalidJSONReturnsOriginal(t *testing.T) { + input := []byte(`{"contents":[invalid-json]}`) + + cleaned := CleanGeminiNativeThoughtSignatures(input) + + require.Equal(t, input, cleaned) +} + +func TestReplaceThoughtSignaturesRecursive_OnlyReplacesTargetField(t *testing.T) { + input := map[string]any{ + "thoughtSignature": "sig_root", + "signature": "keep_signature", + "nested": []any{ + map[string]any{ + "thoughtSignature": "sig_nested", + "signature": "keep_nested_signature", + }, + }, + } + + got, ok := replaceThoughtSignaturesRecursive(input).(map[string]any) + require.True(t, ok) + require.Equal(t, antigravity.DummyThoughtSignature, got["thoughtSignature"]) + require.Equal(t, "keep_signature", got["signature"]) + + nested, ok := got["nested"].([]any) + require.True(t, ok) + nestedMap, ok := nested[0].(map[string]any) + require.True(t, ok) + require.Equal(t, antigravity.DummyThoughtSignature, nestedMap["thoughtSignature"]) + require.Equal(t, "keep_nested_signature", nestedMap["signature"]) +} diff --git a/backend/internal/service/group.go b/backend/internal/service/group.go index 6990caca..c9851bd8 100644 --- a/backend/internal/service/group.go +++ b/backend/internal/service/group.go @@ -50,6 +50,9 @@ type Group struct { // MCP XML 协议注入开关(仅 antigravity 平台使用) MCPXMLInject bool + // Claude usage 模拟开关:将无写缓存 usage 模拟为 claude-max 风格 + SimulateClaudeMaxEnabled bool + // 支持的模型系列(仅 antigravity 平台使用) // 可选值: claude, gemini_text, gemini_image SupportedModelScopes []string @@ -57,6 +60,10 @@ type Group struct { // 分组排序 SortOrder int + // OpenAI Messages 调度配置(仅 openai 平台使用) + AllowMessagesDispatch bool + DefaultMappedModel string + CreatedAt time.Time UpdatedAt time.Time diff --git a/backend/internal/service/identity_service.go b/backend/internal/service/identity_service.go index f3130c91..f6a94d15 100644 --- a/backend/internal/service/identity_service.go +++ b/backend/internal/service/identity_service.go @@ -19,8 +19,10 @@ import ( // 预编译正则表达式(避免每次调用重新编译) var ( - // 匹配 user_id 格式: user_{64位hex}_account__session_{uuid} - userIDRegex = regexp.MustCompile(`^user_[a-f0-9]{64}_account__session_([a-f0-9-]{36})$`) + // 匹配 user_id 格式: + // 旧格式: user_{64位hex}_account__session_{uuid} (account 后无 UUID) + // 新格式: user_{64位hex}_account_{uuid}_session_{uuid} (account 后有 UUID) + userIDRegex = regexp.MustCompile(`^user_[a-f0-9]{64}_account_([a-f0-9-]*)_session_([a-f0-9-]{36})$`) // 匹配 User-Agent 版本号: xxx/x.y.z userAgentVersionRegex = regexp.MustCompile(`/(\d+)\.(\d+)\.(\d+)`) ) @@ -239,13 +241,16 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI return body, nil } - // 匹配格式: user_{64位hex}_account__session_{uuid} + // 匹配格式: + // 旧格式: user_{64位hex}_account__session_{uuid} + // 新格式: user_{64位hex}_account_{uuid}_session_{uuid} matches := userIDRegex.FindStringSubmatch(userID) if matches == nil { return body, nil } - sessionTail := matches[1] // 原始session UUID + // matches[1] = account UUID (可能为空), matches[2] = session UUID + sessionTail := matches[2] // 原始session UUID // 生成新的session hash: SHA256(accountID::sessionTail) -> UUID格式 seed := fmt.Sprintf("%d::%s", accountID, sessionTail) diff --git a/backend/internal/service/openai_account_scheduler.go b/backend/internal/service/openai_account_scheduler.go index 99013ce5..0fcf450b 100644 --- a/backend/internal/service/openai_account_scheduler.go +++ b/backend/internal/service/openai_account_scheduler.go @@ -319,7 +319,7 @@ func (s *defaultOpenAIAccountScheduler) selectBySessionHash( _ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash) return nil, nil } - if shouldClearStickySession(account, req.RequestedModel) || !account.IsOpenAI() { + if shouldClearStickySession(account, req.RequestedModel) || !account.IsOpenAI() || !account.IsSchedulable() { _ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash) return nil, nil } @@ -342,6 +342,7 @@ func (s *defaultOpenAIAccountScheduler) selectBySessionHash( } cfg := s.service.schedulingConfig() + // WaitPlan.MaxConcurrency 使用 Concurrency(非 EffectiveLoadFactor),因为 WaitPlan 控制的是 Redis 实际并发槽位等待。 if s.service.concurrencyService != nil { return &AccountSelectionResult{ Account: account, @@ -590,7 +591,7 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance( filtered = append(filtered, account) loadReq = append(loadReq, AccountWithConcurrency{ ID: account.ID, - MaxConcurrency: account.Concurrency, + MaxConcurrency: account.EffectiveLoadFactor(), }) } if len(filtered) == 0 { @@ -686,16 +687,20 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance( for i := 0; i < len(selectionOrder); i++ { candidate := selectionOrder[i] - result, acquireErr := s.service.tryAcquireAccountSlot(ctx, candidate.account.ID, candidate.account.Concurrency) + fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel) + if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) { + continue + } + result, acquireErr := s.service.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency) if acquireErr != nil { return nil, len(candidates), topK, loadSkew, acquireErr } if result != nil && result.Acquired { if req.SessionHash != "" { - _ = s.service.BindStickySession(ctx, req.GroupID, req.SessionHash, candidate.account.ID) + _ = s.service.BindStickySession(ctx, req.GroupID, req.SessionHash, fresh.ID) } return &AccountSelectionResult{ - Account: candidate.account, + Account: fresh, Acquired: true, ReleaseFunc: result.ReleaseFunc, }, len(candidates), topK, loadSkew, nil @@ -703,16 +708,24 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance( } cfg := s.service.schedulingConfig() - candidate := selectionOrder[0] - return &AccountSelectionResult{ - Account: candidate.account, - WaitPlan: &AccountWaitPlan{ - AccountID: candidate.account.ID, - MaxConcurrency: candidate.account.Concurrency, - Timeout: cfg.FallbackWaitTimeout, - MaxWaiting: cfg.FallbackMaxWaiting, - }, - }, len(candidates), topK, loadSkew, nil + // WaitPlan.MaxConcurrency 使用 Concurrency(非 EffectiveLoadFactor),因为 WaitPlan 控制的是 Redis 实际并发槽位等待。 + for _, candidate := range selectionOrder { + fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel) + if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) { + continue + } + return &AccountSelectionResult{ + Account: fresh, + WaitPlan: &AccountWaitPlan{ + AccountID: fresh.ID, + MaxConcurrency: fresh.Concurrency, + Timeout: cfg.FallbackWaitTimeout, + MaxWaiting: cfg.FallbackMaxWaiting, + }, + }, len(candidates), topK, loadSkew, nil + } + + return nil, len(candidates), topK, loadSkew, errors.New("no available accounts") } func (s *defaultOpenAIAccountScheduler) isAccountTransportCompatible(account *Account, requiredTransport OpenAIUpstreamTransport) bool { diff --git a/backend/internal/service/openai_account_scheduler_test.go b/backend/internal/service/openai_account_scheduler_test.go index 7f6f1b66..977c4ee8 100644 --- a/backend/internal/service/openai_account_scheduler_test.go +++ b/backend/internal/service/openai_account_scheduler_test.go @@ -12,6 +12,78 @@ import ( "github.com/stretchr/testify/require" ) +type openAISnapshotCacheStub struct { + SchedulerCache + snapshotAccounts []*Account + accountsByID map[int64]*Account +} + +func (s *openAISnapshotCacheStub) GetSnapshot(ctx context.Context, bucket SchedulerBucket) ([]*Account, bool, error) { + if len(s.snapshotAccounts) == 0 { + return nil, false, nil + } + out := make([]*Account, 0, len(s.snapshotAccounts)) + for _, account := range s.snapshotAccounts { + if account == nil { + continue + } + cloned := *account + out = append(out, &cloned) + } + return out, true, nil +} + +func (s *openAISnapshotCacheStub) GetAccount(ctx context.Context, accountID int64) (*Account, error) { + if s.accountsByID == nil { + return nil, nil + } + account := s.accountsByID[accountID] + if account == nil { + return nil, nil + } + cloned := *account + return &cloned, nil +} + +func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyRateLimitedAccountFallsBackToFreshCandidate(t *testing.T) { + ctx := context.Background() + groupID := int64(10101) + rateLimitedUntil := time.Now().Add(30 * time.Minute) + staleSticky := &Account{ID: 31001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0} + staleBackup := &Account{ID: 31002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5} + freshSticky := &Account{ID: 31001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0, RateLimitResetAt: &rateLimitedUntil} + freshBackup := &Account{ID: 31002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5} + cache := &stubGatewayCache{sessionBindings: map[string]int64{"openai:session_hash_rate_limited": 31001}} + snapshotCache := &openAISnapshotCacheStub{snapshotAccounts: []*Account{staleSticky, staleBackup}, accountsByID: map[int64]*Account{31001: freshSticky, 31002: freshBackup}} + snapshotService := &SchedulerSnapshotService{cache: snapshotCache} + svc := &OpenAIGatewayService{accountRepo: stubOpenAIAccountRepo{accounts: []Account{*freshSticky, *freshBackup}}, cache: cache, cfg: &config.Config{}, schedulerSnapshot: snapshotService, concurrencyService: NewConcurrencyService(stubConcurrencyCache{})} + + selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_rate_limited", "gpt-5.1", nil, OpenAIUpstreamTransportAny) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, int64(31002), selection.Account.ID) + require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer) +} + +func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_SkipsFreshlyRateLimitedSnapshotCandidate(t *testing.T) { + ctx := context.Background() + groupID := int64(10102) + rateLimitedUntil := time.Now().Add(30 * time.Minute) + stalePrimary := &Account{ID: 32001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0} + staleSecondary := &Account{ID: 32002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5} + freshPrimary := &Account{ID: 32001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0, RateLimitResetAt: &rateLimitedUntil} + freshSecondary := &Account{ID: 32002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5} + snapshotCache := &openAISnapshotCacheStub{snapshotAccounts: []*Account{stalePrimary, staleSecondary}, accountsByID: map[int64]*Account{32001: freshPrimary, 32002: freshSecondary}} + snapshotService := &SchedulerSnapshotService{cache: snapshotCache} + svc := &OpenAIGatewayService{accountRepo: stubOpenAIAccountRepo{accounts: []Account{*freshPrimary, *freshSecondary}}, cfg: &config.Config{}, schedulerSnapshot: snapshotService} + + account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "gpt-5.1", nil) + require.NoError(t, err) + require.NotNil(t, account) + require.Equal(t, int64(32002), account.ID) +} + func TestOpenAIGatewayService_SelectAccountWithScheduler_PreviousResponseSticky(t *testing.T) { ctx := context.Background() groupID := int64(9) diff --git a/backend/internal/service/openai_codex_transform.go b/backend/internal/service/openai_codex_transform.go index 16befb82..b0e4d44f 100644 --- a/backend/internal/service/openai_codex_transform.go +++ b/backend/internal/service/openai_codex_transform.go @@ -1,14 +1,18 @@ package service import ( - _ "embed" + "fmt" "strings" ) -//go:embed prompts/codex_cli_instructions.md -var codexCLIInstructions string - var codexModelMap = map[string]string{ + "gpt-5.4": "gpt-5.4", + "gpt-5.4-none": "gpt-5.4", + "gpt-5.4-low": "gpt-5.4", + "gpt-5.4-medium": "gpt-5.4", + "gpt-5.4-high": "gpt-5.4", + "gpt-5.4-xhigh": "gpt-5.4", + "gpt-5.4-chat-latest": "gpt-5.4", "gpt-5.3": "gpt-5.3-codex", "gpt-5.3-none": "gpt-5.3-codex", "gpt-5.3-low": "gpt-5.3-codex", @@ -70,7 +74,7 @@ type codexTransformResult struct { PromptCacheKey string } -func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool) codexTransformResult { +func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact bool) codexTransformResult { result := codexTransformResult{} // 工具续链需求会影响存储策略与 input 过滤逻辑。 needsToolContinuation := NeedsToolContinuation(reqBody) @@ -88,15 +92,26 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool) codexTran result.NormalizedModel = normalizedModel } - // OAuth 走 ChatGPT internal API 时,store 必须为 false;显式 true 也会强制覆盖。 - // 避免上游返回 "Store must be set to false"。 - if v, ok := reqBody["store"].(bool); !ok || v { - reqBody["store"] = false - result.Modified = true - } - if v, ok := reqBody["stream"].(bool); !ok || !v { - reqBody["stream"] = true - result.Modified = true + if isCompact { + if _, ok := reqBody["store"]; ok { + delete(reqBody, "store") + result.Modified = true + } + if _, ok := reqBody["stream"]; ok { + delete(reqBody, "stream") + result.Modified = true + } + } else { + // OAuth 走 ChatGPT internal API 时,store 必须为 false;显式 true 也会强制覆盖。 + // 避免上游返回 "Store must be set to false"。 + if v, ok := reqBody["store"].(bool); !ok || v { + reqBody["store"] = false + result.Modified = true + } + if v, ok := reqBody["stream"].(bool); !ok || !v { + reqBody["stream"] = true + result.Modified = true + } } // Strip parameters unsupported by codex models via the Responses API. @@ -132,6 +147,22 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool) codexTran input = filterCodexInput(input, needsToolContinuation) reqBody["input"] = input result.Modified = true + } else if inputStr, ok := reqBody["input"].(string); ok { + // ChatGPT codex endpoint requires input to be a list, not a string. + // Convert string input to the expected message array format. + trimmed := strings.TrimSpace(inputStr) + if trimmed != "" { + reqBody["input"] = []any{ + map[string]any{ + "type": "message", + "role": "user", + "content": inputStr, + }, + } + } else { + reqBody["input"] = []any{} + } + result.Modified = true } return result @@ -154,6 +185,9 @@ func normalizeCodexModel(model string) string { normalized := strings.ToLower(modelID) + if strings.Contains(normalized, "gpt-5.4") || strings.Contains(normalized, "gpt 5.4") { + return "gpt-5.4" + } if strings.Contains(normalized, "gpt-5.2-codex") || strings.Contains(normalized, "gpt 5.2 codex") { return "gpt-5.2-codex" } @@ -193,6 +227,29 @@ func normalizeCodexModel(model string) string { return "gpt-5.1" } +func SupportsVerbosity(model string) bool { + if !strings.HasPrefix(model, "gpt-") { + return true + } + + var major, minor int + n, _ := fmt.Sscanf(model, "gpt-%d.%d", &major, &minor) + + if major > 5 { + return true + } + if major < 5 { + return false + } + + // gpt-5 + if n == 1 { + return true + } + + return minor >= 3 +} + func getNormalizedCodexModel(modelID string) string { if modelID == "" { return "" @@ -209,72 +266,13 @@ func getNormalizedCodexModel(modelID string) string { return "" } -func getOpenCodeCodexHeader() string { - // 兼容保留:历史上这里会从 opencode 仓库拉取 codex_header.txt。 - // 现在我们与 Codex CLI 一致,直接使用仓库内置的 instructions,避免读写缓存与外网依赖。 - return getCodexCLIInstructions() -} - -func getCodexCLIInstructions() string { - return codexCLIInstructions -} - -func GetOpenCodeInstructions() string { - return getOpenCodeCodexHeader() -} - -// GetCodexCLIInstructions 返回内置的 Codex CLI 指令内容。 -func GetCodexCLIInstructions() string { - return getCodexCLIInstructions() -} - -// applyInstructions 处理 instructions 字段 -// isCodexCLI=true: 仅补充缺失的 instructions(使用内置 Codex CLI 指令) -// isCodexCLI=false: 优先使用内置 Codex CLI 指令覆盖 +// applyInstructions 处理 instructions 字段:仅在 instructions 为空时填充默认值。 func applyInstructions(reqBody map[string]any, isCodexCLI bool) bool { - if isCodexCLI { - return applyCodexCLIInstructions(reqBody) - } - return applyOpenCodeInstructions(reqBody) -} - -// applyCodexCLIInstructions 为 Codex CLI 请求补充缺失的 instructions -// 仅在 instructions 为空时添加内置 Codex CLI 指令(不依赖 opencode 缓存/回源) -func applyCodexCLIInstructions(reqBody map[string]any) bool { if !isInstructionsEmpty(reqBody) { - return false // 已有有效 instructions,不修改 + return false } - - instructions := strings.TrimSpace(getCodexCLIInstructions()) - if instructions != "" { - reqBody["instructions"] = instructions - return true - } - - return false -} - -// applyOpenCodeInstructions 为非 Codex CLI 请求应用内置 Codex CLI 指令(兼容历史函数名) -// 优先使用内置 Codex CLI 指令覆盖 -func applyOpenCodeInstructions(reqBody map[string]any) bool { - instructions := strings.TrimSpace(getOpenCodeCodexHeader()) - existingInstructions, _ := reqBody["instructions"].(string) - existingInstructions = strings.TrimSpace(existingInstructions) - - if instructions != "" { - if existingInstructions != instructions { - reqBody["instructions"] = instructions - return true - } - } else if existingInstructions == "" { - codexInstructions := strings.TrimSpace(getCodexCLIInstructions()) - if codexInstructions != "" { - reqBody["instructions"] = codexInstructions - return true - } - } - - return false + reqBody["instructions"] = "You are a helpful coding assistant." + return true } // isInstructionsEmpty 检查 instructions 字段是否为空 diff --git a/backend/internal/service/openai_codex_transform_test.go b/backend/internal/service/openai_codex_transform_test.go index 27093f6c..c8097aed 100644 --- a/backend/internal/service/openai_codex_transform_test.go +++ b/backend/internal/service/openai_codex_transform_test.go @@ -18,7 +18,7 @@ func TestApplyCodexOAuthTransform_ToolContinuationPreservesInput(t *testing.T) { "tool_choice": "auto", } - applyCodexOAuthTransform(reqBody, false) + applyCodexOAuthTransform(reqBody, false, false) // 未显式设置 store=true,默认为 false。 store, ok := reqBody["store"].(bool) @@ -53,7 +53,7 @@ func TestApplyCodexOAuthTransform_ExplicitStoreFalsePreserved(t *testing.T) { "tool_choice": "auto", } - applyCodexOAuthTransform(reqBody, false) + applyCodexOAuthTransform(reqBody, false, false) store, ok := reqBody["store"].(bool) require.True(t, ok) @@ -72,13 +72,29 @@ func TestApplyCodexOAuthTransform_ExplicitStoreTrueForcedFalse(t *testing.T) { "tool_choice": "auto", } - applyCodexOAuthTransform(reqBody, false) + applyCodexOAuthTransform(reqBody, false, false) store, ok := reqBody["store"].(bool) require.True(t, ok) require.False(t, store) } +func TestApplyCodexOAuthTransform_CompactForcesNonStreaming(t *testing.T) { + reqBody := map[string]any{ + "model": "gpt-5.1-codex", + "store": true, + "stream": true, + } + + result := applyCodexOAuthTransform(reqBody, true, true) + + _, hasStore := reqBody["store"] + require.False(t, hasStore) + _, hasStream := reqBody["stream"] + require.False(t, hasStream) + require.True(t, result.Modified) +} + func TestApplyCodexOAuthTransform_NonContinuationDefaultsStoreFalseAndStripsIDs(t *testing.T) { // 非续链场景:未设置 store 时默认 false,并移除 input 中的 id。 @@ -89,7 +105,7 @@ func TestApplyCodexOAuthTransform_NonContinuationDefaultsStoreFalseAndStripsIDs( }, } - applyCodexOAuthTransform(reqBody, false) + applyCodexOAuthTransform(reqBody, false, false) store, ok := reqBody["store"].(bool) require.True(t, ok) @@ -138,7 +154,7 @@ func TestApplyCodexOAuthTransform_NormalizeCodexTools_PreservesResponsesFunction }, } - applyCodexOAuthTransform(reqBody, false) + applyCodexOAuthTransform(reqBody, false, false) tools, ok := reqBody["tools"].([]any) require.True(t, ok) @@ -158,7 +174,7 @@ func TestApplyCodexOAuthTransform_EmptyInput(t *testing.T) { "input": []any{}, } - applyCodexOAuthTransform(reqBody, false) + applyCodexOAuthTransform(reqBody, false, false) input, ok := reqBody["input"].([]any) require.True(t, ok) @@ -167,6 +183,10 @@ func TestApplyCodexOAuthTransform_EmptyInput(t *testing.T) { func TestNormalizeCodexModel_Gpt53(t *testing.T) { cases := map[string]string{ + "gpt-5.4": "gpt-5.4", + "gpt-5.4-high": "gpt-5.4", + "gpt-5.4-chat-latest": "gpt-5.4", + "gpt 5.4": "gpt-5.4", "gpt-5.3": "gpt-5.3-codex", "gpt-5.3-codex": "gpt-5.3-codex", "gpt-5.3-codex-xhigh": "gpt-5.3-codex", @@ -189,7 +209,7 @@ func TestApplyCodexOAuthTransform_CodexCLI_PreservesExistingInstructions(t *test "instructions": "existing instructions", } - result := applyCodexOAuthTransform(reqBody, true) // isCodexCLI=true + result := applyCodexOAuthTransform(reqBody, true, false) // isCodexCLI=true instructions, ok := reqBody["instructions"].(string) require.True(t, ok) @@ -206,7 +226,7 @@ func TestApplyCodexOAuthTransform_CodexCLI_SuppliesDefaultWhenEmpty(t *testing.T // 没有 instructions 字段 } - result := applyCodexOAuthTransform(reqBody, true) // isCodexCLI=true + result := applyCodexOAuthTransform(reqBody, true, false) // isCodexCLI=true instructions, ok := reqBody["instructions"].(string) require.True(t, ok) @@ -214,20 +234,63 @@ func TestApplyCodexOAuthTransform_CodexCLI_SuppliesDefaultWhenEmpty(t *testing.T require.True(t, result.Modified) } -func TestApplyCodexOAuthTransform_NonCodexCLI_OverridesInstructions(t *testing.T) { - // 非 Codex CLI 场景:使用内置 Codex CLI 指令覆盖 +func TestApplyCodexOAuthTransform_NonCodexCLI_PreservesExistingInstructions(t *testing.T) { + // 非 Codex CLI 场景:已有 instructions 时保留客户端的值,不再覆盖 reqBody := map[string]any{ "model": "gpt-5.1", "instructions": "old instructions", } - result := applyCodexOAuthTransform(reqBody, false) // isCodexCLI=false + applyCodexOAuthTransform(reqBody, false, false) // isCodexCLI=false instructions, ok := reqBody["instructions"].(string) require.True(t, ok) - require.NotEqual(t, "old instructions", instructions) + require.Equal(t, "old instructions", instructions) +} + +func TestApplyCodexOAuthTransform_StringInputConvertedToArray(t *testing.T) { + reqBody := map[string]any{"model": "gpt-5.4", "input": "Hello, world!"} + result := applyCodexOAuthTransform(reqBody, false, false) require.True(t, result.Modified) + input, ok := reqBody["input"].([]any) + require.True(t, ok) + require.Len(t, input, 1) + msg, ok := input[0].(map[string]any) + require.True(t, ok) + require.Equal(t, "message", msg["type"]) + require.Equal(t, "user", msg["role"]) + require.Equal(t, "Hello, world!", msg["content"]) +} + +func TestApplyCodexOAuthTransform_EmptyStringInputBecomesEmptyArray(t *testing.T) { + reqBody := map[string]any{"model": "gpt-5.4", "input": ""} + result := applyCodexOAuthTransform(reqBody, false, false) + require.True(t, result.Modified) + input, ok := reqBody["input"].([]any) + require.True(t, ok) + require.Len(t, input, 0) +} + +func TestApplyCodexOAuthTransform_WhitespaceStringInputBecomesEmptyArray(t *testing.T) { + reqBody := map[string]any{"model": "gpt-5.4", "input": " "} + result := applyCodexOAuthTransform(reqBody, false, false) + require.True(t, result.Modified) + input, ok := reqBody["input"].([]any) + require.True(t, ok) + require.Len(t, input, 0) +} + +func TestApplyCodexOAuthTransform_StringInputWithToolsField(t *testing.T) { + reqBody := map[string]any{ + "model": "gpt-5.4", + "input": "Run the tests", + "tools": []any{map[string]any{"type": "function", "name": "bash"}}, + } + applyCodexOAuthTransform(reqBody, false, false) + input, ok := reqBody["input"].([]any) + require.True(t, ok) + require.Len(t, input, 1) } func TestIsInstructionsEmpty(t *testing.T) { diff --git a/backend/internal/service/openai_gateway_chat_completions.go b/backend/internal/service/openai_gateway_chat_completions.go new file mode 100644 index 00000000..f893eeb9 --- /dev/null +++ b/backend/internal/service/openai_gateway_chat_completions.go @@ -0,0 +1,512 @@ +package service + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" + "github.com/gin-gonic/gin" + "go.uber.org/zap" +) + +// ForwardAsChatCompletions accepts a Chat Completions request body, converts it +// to OpenAI Responses API format, forwards to the OpenAI upstream, and converts +// the response back to Chat Completions format. All account types (OAuth and API +// Key) go through the Responses API conversion path since the upstream only +// exposes the /v1/responses endpoint. +func (s *OpenAIGatewayService) ForwardAsChatCompletions( + ctx context.Context, + c *gin.Context, + account *Account, + body []byte, + promptCacheKey string, + defaultMappedModel string, +) (*OpenAIForwardResult, error) { + startTime := time.Now() + + // 1. Parse Chat Completions request + var chatReq apicompat.ChatCompletionsRequest + if err := json.Unmarshal(body, &chatReq); err != nil { + return nil, fmt.Errorf("parse chat completions request: %w", err) + } + originalModel := chatReq.Model + clientStream := chatReq.Stream + includeUsage := chatReq.StreamOptions != nil && chatReq.StreamOptions.IncludeUsage + + // 2. Convert to Responses and forward + // ChatCompletionsToResponses always sets Stream=true (upstream always streams). + responsesReq, err := apicompat.ChatCompletionsToResponses(&chatReq) + if err != nil { + return nil, fmt.Errorf("convert chat completions to responses: %w", err) + } + + // 3. Model mapping + mappedModel := account.GetMappedModel(originalModel) + if mappedModel == originalModel && defaultMappedModel != "" { + mappedModel = defaultMappedModel + } + responsesReq.Model = mappedModel + + logger.L().Debug("openai chat_completions: model mapping applied", + zap.Int64("account_id", account.ID), + zap.String("original_model", originalModel), + zap.String("mapped_model", mappedModel), + zap.Bool("stream", clientStream), + ) + + // 4. Marshal Responses request body, then apply OAuth codex transform + responsesBody, err := json.Marshal(responsesReq) + if err != nil { + return nil, fmt.Errorf("marshal responses request: %w", err) + } + + if account.Type == AccountTypeOAuth { + var reqBody map[string]any + if err := json.Unmarshal(responsesBody, &reqBody); err != nil { + return nil, fmt.Errorf("unmarshal for codex transform: %w", err) + } + codexResult := applyCodexOAuthTransform(reqBody, false, false) + if codexResult.PromptCacheKey != "" { + promptCacheKey = codexResult.PromptCacheKey + } else if promptCacheKey != "" { + reqBody["prompt_cache_key"] = promptCacheKey + } + responsesBody, err = json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("remarshal after codex transform: %w", err) + } + } + + // 5. Get access token + token, _, err := s.GetAccessToken(ctx, account) + if err != nil { + return nil, fmt.Errorf("get access token: %w", err) + } + + // 6. Build upstream request + upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, responsesBody, token, true, promptCacheKey, false) + if err != nil { + return nil, fmt.Errorf("build upstream request: %w", err) + } + + if promptCacheKey != "" { + upstreamReq.Header.Set("session_id", generateSessionUUID(promptCacheKey)) + } + + // 7. Send request + proxyURL := "" + if account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) + if err != nil { + safeErr := sanitizeUpstreamErrorMessage(err.Error()) + setOpsUpstreamError(c, 0, safeErr, "") + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: 0, + Kind: "request_error", + Message: safeErr, + }) + writeChatCompletionsError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed") + return nil, fmt.Errorf("upstream request failed: %s", safeErr) + } + defer func() { _ = resp.Body.Close() }() + + // 8. Handle error response with failover + if resp.StatusCode >= 400 { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) { + upstreamDetail := "" + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + upstreamDetail = truncateString(string(respBody), maxBytes) + } + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: "failover", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + if s.rateLimitService != nil { + s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + } + return nil, &UpstreamFailoverError{ + StatusCode: resp.StatusCode, + ResponseBody: respBody, + RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)), + } + } + return s.handleChatCompletionsErrorResponse(resp, c, account) + } + + // 9. Handle normal response + var result *OpenAIForwardResult + var handleErr error + if clientStream { + result, handleErr = s.handleChatStreamingResponse(resp, c, originalModel, mappedModel, includeUsage, startTime) + } else { + result, handleErr = s.handleChatBufferedStreamingResponse(resp, c, originalModel, mappedModel, startTime) + } + + // Propagate ServiceTier and ReasoningEffort to result for billing + if handleErr == nil && result != nil { + if responsesReq.ServiceTier != "" { + st := responsesReq.ServiceTier + result.ServiceTier = &st + } + if responsesReq.Reasoning != nil && responsesReq.Reasoning.Effort != "" { + re := responsesReq.Reasoning.Effort + result.ReasoningEffort = &re + } + } + + // Extract and save Codex usage snapshot from response headers (for OAuth accounts) + if handleErr == nil && account.Type == AccountTypeOAuth { + if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil { + s.updateCodexUsageSnapshot(ctx, account.ID, snapshot) + } + } + + return result, handleErr +} + +// handleChatCompletionsErrorResponse reads an upstream error and returns it in +// OpenAI Chat Completions error format. +func (s *OpenAIGatewayService) handleChatCompletionsErrorResponse( + resp *http.Response, + c *gin.Context, + account *Account, +) (*OpenAIForwardResult, error) { + return s.handleCompatErrorResponse(resp, c, account, writeChatCompletionsError) +} + +// handleChatBufferedStreamingResponse reads all Responses SSE events from the +// upstream, finds the terminal event, converts to a Chat Completions JSON +// response, and writes it to the client. +func (s *OpenAIGatewayService) handleChatBufferedStreamingResponse( + resp *http.Response, + c *gin.Context, + originalModel string, + mappedModel string, + startTime time.Time, +) (*OpenAIForwardResult, error) { + requestID := resp.Header.Get("x-request-id") + + scanner := bufio.NewScanner(resp.Body) + maxLineSize := defaultMaxLineSize + if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { + maxLineSize = s.cfg.Gateway.MaxLineSize + } + scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize) + + var finalResponse *apicompat.ResponsesResponse + var usage OpenAIUsage + + for scanner.Scan() { + line := scanner.Text() + if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" { + continue + } + payload := line[6:] + + var event apicompat.ResponsesStreamEvent + if err := json.Unmarshal([]byte(payload), &event); err != nil { + logger.L().Warn("openai chat_completions buffered: failed to parse event", + zap.Error(err), + zap.String("request_id", requestID), + ) + continue + } + + if (event.Type == "response.completed" || event.Type == "response.incomplete" || event.Type == "response.failed") && + event.Response != nil { + finalResponse = event.Response + if event.Response.Usage != nil { + usage = OpenAIUsage{ + InputTokens: event.Response.Usage.InputTokens, + OutputTokens: event.Response.Usage.OutputTokens, + } + if event.Response.Usage.InputTokensDetails != nil { + usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens + } + } + } + } + + if err := scanner.Err(); err != nil { + if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) { + logger.L().Warn("openai chat_completions buffered: read error", + zap.Error(err), + zap.String("request_id", requestID), + ) + } + } + + if finalResponse == nil { + writeChatCompletionsError(c, http.StatusBadGateway, "api_error", "Upstream stream ended without a terminal response event") + return nil, fmt.Errorf("upstream stream ended without terminal event") + } + + chatResp := apicompat.ResponsesToChatCompletions(finalResponse, originalModel) + + if s.responseHeaderFilter != nil { + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + } + c.JSON(http.StatusOK, chatResp) + + return &OpenAIForwardResult{ + RequestID: requestID, + Usage: usage, + Model: originalModel, + BillingModel: mappedModel, + Stream: false, + Duration: time.Since(startTime), + }, nil +} + +// handleChatStreamingResponse reads Responses SSE events from upstream, +// converts each to Chat Completions SSE chunks, and writes them to the client. +func (s *OpenAIGatewayService) handleChatStreamingResponse( + resp *http.Response, + c *gin.Context, + originalModel string, + mappedModel string, + includeUsage bool, + startTime time.Time, +) (*OpenAIForwardResult, error) { + requestID := resp.Header.Get("x-request-id") + + if s.responseHeaderFilter != nil { + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + } + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("X-Accel-Buffering", "no") + c.Writer.WriteHeader(http.StatusOK) + + state := apicompat.NewResponsesEventToChatState() + state.Model = originalModel + state.IncludeUsage = includeUsage + + var usage OpenAIUsage + var firstTokenMs *int + firstChunk := true + + scanner := bufio.NewScanner(resp.Body) + maxLineSize := defaultMaxLineSize + if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { + maxLineSize = s.cfg.Gateway.MaxLineSize + } + scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize) + + resultWithUsage := func() *OpenAIForwardResult { + return &OpenAIForwardResult{ + RequestID: requestID, + Usage: usage, + Model: originalModel, + BillingModel: mappedModel, + Stream: true, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + } + } + + processDataLine := func(payload string) bool { + if firstChunk { + firstChunk = false + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + + var event apicompat.ResponsesStreamEvent + if err := json.Unmarshal([]byte(payload), &event); err != nil { + logger.L().Warn("openai chat_completions stream: failed to parse event", + zap.Error(err), + zap.String("request_id", requestID), + ) + return false + } + + // Extract usage from completion events + if (event.Type == "response.completed" || event.Type == "response.incomplete" || event.Type == "response.failed") && + event.Response != nil && event.Response.Usage != nil { + usage = OpenAIUsage{ + InputTokens: event.Response.Usage.InputTokens, + OutputTokens: event.Response.Usage.OutputTokens, + } + if event.Response.Usage.InputTokensDetails != nil { + usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens + } + } + + chunks := apicompat.ResponsesEventToChatChunks(&event, state) + for _, chunk := range chunks { + sse, err := apicompat.ChatChunkToSSE(chunk) + if err != nil { + logger.L().Warn("openai chat_completions stream: failed to marshal chunk", + zap.Error(err), + zap.String("request_id", requestID), + ) + continue + } + if _, err := fmt.Fprint(c.Writer, sse); err != nil { + logger.L().Info("openai chat_completions stream: client disconnected", + zap.String("request_id", requestID), + ) + return true + } + } + if len(chunks) > 0 { + c.Writer.Flush() + } + return false + } + + finalizeStream := func() (*OpenAIForwardResult, error) { + if finalChunks := apicompat.FinalizeResponsesChatStream(state); len(finalChunks) > 0 { + for _, chunk := range finalChunks { + sse, err := apicompat.ChatChunkToSSE(chunk) + if err != nil { + continue + } + fmt.Fprint(c.Writer, sse) //nolint:errcheck + } + } + // Send [DONE] sentinel + fmt.Fprint(c.Writer, "data: [DONE]\n\n") //nolint:errcheck + c.Writer.Flush() + return resultWithUsage(), nil + } + + handleScanErr := func(err error) { + if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) { + logger.L().Warn("openai chat_completions stream: read error", + zap.Error(err), + zap.String("request_id", requestID), + ) + } + } + + // Determine keepalive interval + keepaliveInterval := time.Duration(0) + if s.cfg != nil && s.cfg.Gateway.StreamKeepaliveInterval > 0 { + keepaliveInterval = time.Duration(s.cfg.Gateway.StreamKeepaliveInterval) * time.Second + } + + // No keepalive: fast synchronous path + if keepaliveInterval <= 0 { + for scanner.Scan() { + line := scanner.Text() + if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" { + continue + } + if processDataLine(line[6:]) { + return resultWithUsage(), nil + } + } + handleScanErr(scanner.Err()) + return finalizeStream() + } + + // With keepalive: goroutine + channel + select + type scanEvent struct { + line string + err error + } + events := make(chan scanEvent, 16) + done := make(chan struct{}) + sendEvent := func(ev scanEvent) bool { + select { + case events <- ev: + return true + case <-done: + return false + } + } + go func() { + defer close(events) + for scanner.Scan() { + if !sendEvent(scanEvent{line: scanner.Text()}) { + return + } + } + if err := scanner.Err(); err != nil { + _ = sendEvent(scanEvent{err: err}) + } + }() + defer close(done) + + keepaliveTicker := time.NewTicker(keepaliveInterval) + defer keepaliveTicker.Stop() + lastDataAt := time.Now() + + for { + select { + case ev, ok := <-events: + if !ok { + return finalizeStream() + } + if ev.err != nil { + handleScanErr(ev.err) + return finalizeStream() + } + lastDataAt = time.Now() + line := ev.line + if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" { + continue + } + if processDataLine(line[6:]) { + return resultWithUsage(), nil + } + + case <-keepaliveTicker.C: + if time.Since(lastDataAt) < keepaliveInterval { + continue + } + // Send SSE comment as keepalive + if _, err := fmt.Fprint(c.Writer, ":\n\n"); err != nil { + logger.L().Info("openai chat_completions stream: client disconnected during keepalive", + zap.String("request_id", requestID), + ) + return resultWithUsage(), nil + } + c.Writer.Flush() + } + } +} + +// writeChatCompletionsError writes an error response in OpenAI Chat Completions format. +func writeChatCompletionsError(c *gin.Context, statusCode int, errType, message string) { + c.JSON(statusCode, gin.H{ + "error": gin.H{ + "type": errType, + "message": message, + }, + }) +} diff --git a/backend/internal/service/openai_gateway_messages.go b/backend/internal/service/openai_gateway_messages.go new file mode 100644 index 00000000..e4a3d9c0 --- /dev/null +++ b/backend/internal/service/openai_gateway_messages.go @@ -0,0 +1,541 @@ +package service + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" + "github.com/Wei-Shaw/sub2api/internal/pkg/claude" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" + "github.com/gin-gonic/gin" + "go.uber.org/zap" +) + +// ForwardAsAnthropic accepts an Anthropic Messages request body, converts it +// to OpenAI Responses API format, forwards to the OpenAI upstream, and converts +// the response back to Anthropic Messages format. This enables Claude Code +// clients to access OpenAI models through the standard /v1/messages endpoint. +func (s *OpenAIGatewayService) ForwardAsAnthropic( + ctx context.Context, + c *gin.Context, + account *Account, + body []byte, + promptCacheKey string, + defaultMappedModel string, +) (*OpenAIForwardResult, error) { + startTime := time.Now() + + // 1. Parse Anthropic request + var anthropicReq apicompat.AnthropicRequest + if err := json.Unmarshal(body, &anthropicReq); err != nil { + return nil, fmt.Errorf("parse anthropic request: %w", err) + } + originalModel := anthropicReq.Model + clientStream := anthropicReq.Stream // client's original stream preference + + // 2. Convert Anthropic → Responses + responsesReq, err := apicompat.AnthropicToResponses(&anthropicReq) + if err != nil { + return nil, fmt.Errorf("convert anthropic to responses: %w", err) + } + + // Upstream always uses streaming (upstream may not support sync mode). + // The client's original preference determines the response format. + responsesReq.Stream = true + isStream := true + + // 2b. Handle BetaFastMode → service_tier: "priority" + if containsBetaToken(c.GetHeader("anthropic-beta"), claude.BetaFastMode) { + responsesReq.ServiceTier = "priority" + } + + // 3. Model mapping + mappedModel := account.GetMappedModel(originalModel) + // 分组级降级:账号未映射时使用分组默认映射模型 + if mappedModel == originalModel && defaultMappedModel != "" { + mappedModel = defaultMappedModel + } + responsesReq.Model = mappedModel + + logger.L().Debug("openai messages: model mapping applied", + zap.Int64("account_id", account.ID), + zap.String("original_model", originalModel), + zap.String("mapped_model", mappedModel), + zap.Bool("stream", isStream), + ) + + // 4. Marshal Responses request body, then apply OAuth codex transform + responsesBody, err := json.Marshal(responsesReq) + if err != nil { + return nil, fmt.Errorf("marshal responses request: %w", err) + } + + if account.Type == AccountTypeOAuth { + var reqBody map[string]any + if err := json.Unmarshal(responsesBody, &reqBody); err != nil { + return nil, fmt.Errorf("unmarshal for codex transform: %w", err) + } + codexResult := applyCodexOAuthTransform(reqBody, false, false) + if codexResult.PromptCacheKey != "" { + promptCacheKey = codexResult.PromptCacheKey + } else if promptCacheKey != "" { + reqBody["prompt_cache_key"] = promptCacheKey + } + // OAuth codex transform forces stream=true upstream, so always use + // the streaming response handler regardless of what the client asked. + isStream = true + responsesBody, err = json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("remarshal after codex transform: %w", err) + } + } + + // 5. Get access token + token, _, err := s.GetAccessToken(ctx, account) + if err != nil { + return nil, fmt.Errorf("get access token: %w", err) + } + + // 6. Build upstream request + upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, responsesBody, token, isStream, promptCacheKey, false) + if err != nil { + return nil, fmt.Errorf("build upstream request: %w", err) + } + + // Override session_id with a deterministic UUID derived from the sticky + // session key (buildUpstreamRequest may have set it to the raw value). + if promptCacheKey != "" { + upstreamReq.Header.Set("session_id", generateSessionUUID(promptCacheKey)) + } + + // 7. Send request + proxyURL := "" + if account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) + if err != nil { + safeErr := sanitizeUpstreamErrorMessage(err.Error()) + setOpsUpstreamError(c, 0, safeErr, "") + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: 0, + Kind: "request_error", + Message: safeErr, + }) + writeAnthropicError(c, http.StatusBadGateway, "api_error", "Upstream request failed") + return nil, fmt.Errorf("upstream request failed: %s", safeErr) + } + defer func() { _ = resp.Body.Close() }() + + // 8. Handle error response with failover + if resp.StatusCode >= 400 { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) { + upstreamDetail := "" + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + upstreamDetail = truncateString(string(respBody), maxBytes) + } + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: "failover", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + if s.rateLimitService != nil { + s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + } + return nil, &UpstreamFailoverError{ + StatusCode: resp.StatusCode, + ResponseBody: respBody, + RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)), + } + } + // Non-failover error: return Anthropic-formatted error to client + return s.handleAnthropicErrorResponse(resp, c, account) + } + + // 9. Handle normal response + // Upstream is always streaming; choose response format based on client preference. + var result *OpenAIForwardResult + var handleErr error + if clientStream { + result, handleErr = s.handleAnthropicStreamingResponse(resp, c, originalModel, mappedModel, startTime) + } else { + // Client wants JSON: buffer the streaming response and assemble a JSON reply. + result, handleErr = s.handleAnthropicBufferedStreamingResponse(resp, c, originalModel, mappedModel, startTime) + } + + // Propagate ServiceTier and ReasoningEffort to result for billing + if handleErr == nil && result != nil { + if responsesReq.ServiceTier != "" { + st := responsesReq.ServiceTier + result.ServiceTier = &st + } + if responsesReq.Reasoning != nil && responsesReq.Reasoning.Effort != "" { + re := responsesReq.Reasoning.Effort + result.ReasoningEffort = &re + } + } + + // Extract and save Codex usage snapshot from response headers (for OAuth accounts) + if handleErr == nil && account.Type == AccountTypeOAuth { + if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil { + s.updateCodexUsageSnapshot(ctx, account.ID, snapshot) + } + } + + return result, handleErr +} + +// handleAnthropicErrorResponse reads an upstream error and returns it in +// Anthropic error format. +func (s *OpenAIGatewayService) handleAnthropicErrorResponse( + resp *http.Response, + c *gin.Context, + account *Account, +) (*OpenAIForwardResult, error) { + return s.handleCompatErrorResponse(resp, c, account, writeAnthropicError) +} + +// handleAnthropicBufferedStreamingResponse reads all Responses SSE events from +// the upstream streaming response, finds the terminal event (response.completed +// / response.incomplete / response.failed), converts the complete response to +// Anthropic Messages JSON format, and writes it to the client. +// This is used when the client requested stream=false but the upstream is always +// streaming. +func (s *OpenAIGatewayService) handleAnthropicBufferedStreamingResponse( + resp *http.Response, + c *gin.Context, + originalModel string, + mappedModel string, + startTime time.Time, +) (*OpenAIForwardResult, error) { + requestID := resp.Header.Get("x-request-id") + + scanner := bufio.NewScanner(resp.Body) + maxLineSize := defaultMaxLineSize + if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { + maxLineSize = s.cfg.Gateway.MaxLineSize + } + scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize) + + var finalResponse *apicompat.ResponsesResponse + var usage OpenAIUsage + + for scanner.Scan() { + line := scanner.Text() + + if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" { + continue + } + payload := line[6:] + + var event apicompat.ResponsesStreamEvent + if err := json.Unmarshal([]byte(payload), &event); err != nil { + logger.L().Warn("openai messages buffered: failed to parse event", + zap.Error(err), + zap.String("request_id", requestID), + ) + continue + } + + // Terminal events carry the complete ResponsesResponse with output + usage. + if (event.Type == "response.completed" || event.Type == "response.incomplete" || event.Type == "response.failed") && + event.Response != nil { + finalResponse = event.Response + if event.Response.Usage != nil { + usage = OpenAIUsage{ + InputTokens: event.Response.Usage.InputTokens, + OutputTokens: event.Response.Usage.OutputTokens, + } + if event.Response.Usage.InputTokensDetails != nil { + usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens + } + } + } + } + + if err := scanner.Err(); err != nil { + if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) { + logger.L().Warn("openai messages buffered: read error", + zap.Error(err), + zap.String("request_id", requestID), + ) + } + } + + if finalResponse == nil { + writeAnthropicError(c, http.StatusBadGateway, "api_error", "Upstream stream ended without a terminal response event") + return nil, fmt.Errorf("upstream stream ended without terminal event") + } + + anthropicResp := apicompat.ResponsesToAnthropic(finalResponse, originalModel) + + if s.responseHeaderFilter != nil { + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + } + c.JSON(http.StatusOK, anthropicResp) + + return &OpenAIForwardResult{ + RequestID: requestID, + Usage: usage, + Model: originalModel, + BillingModel: mappedModel, + Stream: false, + Duration: time.Since(startTime), + }, nil +} + +// handleAnthropicStreamingResponse reads Responses SSE events from upstream, +// converts each to Anthropic SSE events, and writes them to the client. +// When StreamKeepaliveInterval is configured, it uses a goroutine + channel +// pattern to send Anthropic ping events during periods of upstream silence, +// preventing proxy/client timeout disconnections. +func (s *OpenAIGatewayService) handleAnthropicStreamingResponse( + resp *http.Response, + c *gin.Context, + originalModel string, + mappedModel string, + startTime time.Time, +) (*OpenAIForwardResult, error) { + requestID := resp.Header.Get("x-request-id") + + if s.responseHeaderFilter != nil { + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + } + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("X-Accel-Buffering", "no") + c.Writer.WriteHeader(http.StatusOK) + + state := apicompat.NewResponsesEventToAnthropicState() + state.Model = originalModel + var usage OpenAIUsage + var firstTokenMs *int + firstChunk := true + + scanner := bufio.NewScanner(resp.Body) + maxLineSize := defaultMaxLineSize + if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { + maxLineSize = s.cfg.Gateway.MaxLineSize + } + scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize) + + // resultWithUsage builds the final result snapshot. + resultWithUsage := func() *OpenAIForwardResult { + return &OpenAIForwardResult{ + RequestID: requestID, + Usage: usage, + Model: originalModel, + BillingModel: mappedModel, + Stream: true, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + } + } + + // processDataLine handles a single "data: ..." SSE line from upstream. + // Returns (clientDisconnected bool). + processDataLine := func(payload string) bool { + if firstChunk { + firstChunk = false + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + + var event apicompat.ResponsesStreamEvent + if err := json.Unmarshal([]byte(payload), &event); err != nil { + logger.L().Warn("openai messages stream: failed to parse event", + zap.Error(err), + zap.String("request_id", requestID), + ) + return false + } + + // Extract usage from completion events + if (event.Type == "response.completed" || event.Type == "response.incomplete" || event.Type == "response.failed") && + event.Response != nil && event.Response.Usage != nil { + usage = OpenAIUsage{ + InputTokens: event.Response.Usage.InputTokens, + OutputTokens: event.Response.Usage.OutputTokens, + } + if event.Response.Usage.InputTokensDetails != nil { + usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens + } + } + + // Convert to Anthropic events + events := apicompat.ResponsesEventToAnthropicEvents(&event, state) + for _, evt := range events { + sse, err := apicompat.ResponsesAnthropicEventToSSE(evt) + if err != nil { + logger.L().Warn("openai messages stream: failed to marshal event", + zap.Error(err), + zap.String("request_id", requestID), + ) + continue + } + if _, err := fmt.Fprint(c.Writer, sse); err != nil { + logger.L().Info("openai messages stream: client disconnected", + zap.String("request_id", requestID), + ) + return true + } + } + if len(events) > 0 { + c.Writer.Flush() + } + return false + } + + // finalizeStream sends any remaining Anthropic events and returns the result. + finalizeStream := func() (*OpenAIForwardResult, error) { + if finalEvents := apicompat.FinalizeResponsesAnthropicStream(state); len(finalEvents) > 0 { + for _, evt := range finalEvents { + sse, err := apicompat.ResponsesAnthropicEventToSSE(evt) + if err != nil { + continue + } + fmt.Fprint(c.Writer, sse) //nolint:errcheck + } + c.Writer.Flush() + } + return resultWithUsage(), nil + } + + // handleScanErr logs scanner errors if meaningful. + handleScanErr := func(err error) { + if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) { + logger.L().Warn("openai messages stream: read error", + zap.Error(err), + zap.String("request_id", requestID), + ) + } + } + + // ── Determine keepalive interval ── + keepaliveInterval := time.Duration(0) + if s.cfg != nil && s.cfg.Gateway.StreamKeepaliveInterval > 0 { + keepaliveInterval = time.Duration(s.cfg.Gateway.StreamKeepaliveInterval) * time.Second + } + + // ── No keepalive: fast synchronous path (no goroutine overhead) ── + if keepaliveInterval <= 0 { + for scanner.Scan() { + line := scanner.Text() + if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" { + continue + } + if processDataLine(line[6:]) { + return resultWithUsage(), nil + } + } + handleScanErr(scanner.Err()) + return finalizeStream() + } + + // ── With keepalive: goroutine + channel + select ── + type scanEvent struct { + line string + err error + } + events := make(chan scanEvent, 16) + done := make(chan struct{}) + sendEvent := func(ev scanEvent) bool { + select { + case events <- ev: + return true + case <-done: + return false + } + } + go func() { + defer close(events) + for scanner.Scan() { + if !sendEvent(scanEvent{line: scanner.Text()}) { + return + } + } + if err := scanner.Err(); err != nil { + _ = sendEvent(scanEvent{err: err}) + } + }() + defer close(done) + + keepaliveTicker := time.NewTicker(keepaliveInterval) + defer keepaliveTicker.Stop() + lastDataAt := time.Now() + + for { + select { + case ev, ok := <-events: + if !ok { + // Upstream closed + return finalizeStream() + } + if ev.err != nil { + handleScanErr(ev.err) + return finalizeStream() + } + lastDataAt = time.Now() + line := ev.line + if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" { + continue + } + if processDataLine(line[6:]) { + return resultWithUsage(), nil + } + + case <-keepaliveTicker.C: + if time.Since(lastDataAt) < keepaliveInterval { + continue + } + // Send Anthropic-format ping event + if _, err := fmt.Fprint(c.Writer, "event: ping\ndata: {\"type\":\"ping\"}\n\n"); err != nil { + // Client disconnected + logger.L().Info("openai messages stream: client disconnected during keepalive", + zap.String("request_id", requestID), + ) + return resultWithUsage(), nil + } + c.Writer.Flush() + } + } +} + +// writeAnthropicError writes an error response in Anthropic Messages API format. +func writeAnthropicError(c *gin.Context, statusCode int, errType, message string) { + c.JSON(statusCode, gin.H{ + "type": "error", + "error": gin.H{ + "type": errType, + "message": message, + }, + }) +} diff --git a/backend/internal/service/openai_gateway_record_usage_test.go b/backend/internal/service/openai_gateway_record_usage_test.go new file mode 100644 index 00000000..9529462e --- /dev/null +++ b/backend/internal/service/openai_gateway_record_usage_test.go @@ -0,0 +1,558 @@ +package service + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +type openAIRecordUsageLogRepoStub struct { + UsageLogRepository + + inserted bool + err error + calls int + lastLog *UsageLog +} + +func (s *openAIRecordUsageLogRepoStub) Create(ctx context.Context, log *UsageLog) (bool, error) { + s.calls++ + s.lastLog = log + return s.inserted, s.err +} + +type openAIRecordUsageUserRepoStub struct { + UserRepository + + deductCalls int + deductErr error + lastAmount float64 +} + +func (s *openAIRecordUsageUserRepoStub) DeductBalance(ctx context.Context, id int64, amount float64) error { + s.deductCalls++ + s.lastAmount = amount + return s.deductErr +} + +type openAIRecordUsageSubRepoStub struct { + UserSubscriptionRepository + + incrementCalls int + incrementErr error +} + +func (s *openAIRecordUsageSubRepoStub) IncrementUsage(ctx context.Context, id int64, costUSD float64) error { + s.incrementCalls++ + return s.incrementErr +} + +type openAIRecordUsageAPIKeyQuotaStub struct { + quotaCalls int + rateLimitCalls int + err error + lastAmount float64 +} + +func (s *openAIRecordUsageAPIKeyQuotaStub) UpdateQuotaUsed(ctx context.Context, apiKeyID int64, cost float64) error { + s.quotaCalls++ + s.lastAmount = cost + return s.err +} + +func (s *openAIRecordUsageAPIKeyQuotaStub) UpdateRateLimitUsage(ctx context.Context, apiKeyID int64, cost float64) error { + s.rateLimitCalls++ + s.lastAmount = cost + return s.err +} + +type openAIUserGroupRateRepoStub struct { + UserGroupRateRepository + + rate *float64 + err error + calls int +} + +func (s *openAIUserGroupRateRepoStub) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) { + s.calls++ + if s.err != nil { + return nil, s.err + } + return s.rate, nil +} + +func i64p(v int64) *int64 { + return &v +} + +func newOpenAIRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo UserRepository, subRepo UserSubscriptionRepository, rateRepo UserGroupRateRepository) *OpenAIGatewayService { + cfg := &config.Config{} + cfg.Default.RateMultiplier = 1.1 + + return &OpenAIGatewayService{ + usageLogRepo: usageRepo, + userRepo: userRepo, + userSubRepo: subRepo, + cfg: cfg, + billingService: NewBillingService(cfg, nil), + billingCacheService: &BillingCacheService{}, + deferredService: &DeferredService{}, + userGroupRateResolver: newUserGroupRateResolver( + rateRepo, + nil, + resolveUserGroupRateCacheTTL(cfg), + nil, + "service.openai_gateway.test", + ), + } +} + +func expectedOpenAICost(t *testing.T, svc *OpenAIGatewayService, model string, usage OpenAIUsage, multiplier float64) *CostBreakdown { + t.Helper() + + cost, err := svc.billingService.CalculateCost(model, UsageTokens{ + InputTokens: max(usage.InputTokens-usage.CacheReadInputTokens, 0), + OutputTokens: usage.OutputTokens, + CacheCreationTokens: usage.CacheCreationInputTokens, + CacheReadTokens: usage.CacheReadInputTokens, + }, multiplier) + require.NoError(t, err) + return cost +} + +func max(a, b int) int { + if a > b { + return a + } + return b +} + +func TestOpenAIGatewayServiceRecordUsage_UsesUserSpecificGroupRate(t *testing.T) { + groupID := int64(11) + groupRate := 1.4 + userRate := 1.8 + usage := OpenAIUsage{InputTokens: 15, OutputTokens: 4, CacheReadInputTokens: 3} + + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + rateRepo := &openAIUserGroupRateRepoStub{rate: &userRate} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, rateRepo) + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_user_group_rate", + Usage: usage, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 1001, + GroupID: i64p(groupID), + Group: &Group{ + ID: groupID, + RateMultiplier: groupRate, + }, + }, + User: &User{ID: 2001}, + Account: &Account{ID: 3001}, + }) + + require.NoError(t, err) + require.Equal(t, 1, rateRepo.calls) + require.NotNil(t, usageRepo.lastLog) + require.Equal(t, userRate, usageRepo.lastLog.RateMultiplier) + require.Equal(t, 12, usageRepo.lastLog.InputTokens) + require.Equal(t, 3, usageRepo.lastLog.CacheReadTokens) + + expected := expectedOpenAICost(t, svc, "gpt-5.1", usage, userRate) + require.InDelta(t, expected.ActualCost, usageRepo.lastLog.ActualCost, 1e-12) + require.InDelta(t, expected.ActualCost, userRepo.lastAmount, 1e-12) + require.Equal(t, 1, userRepo.deductCalls) +} + +func TestOpenAIGatewayServiceRecordUsage_FallsBackToGroupDefaultRateOnResolverError(t *testing.T) { + groupID := int64(12) + groupRate := 1.6 + usage := OpenAIUsage{InputTokens: 10, OutputTokens: 5, CacheReadInputTokens: 2} + + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + rateRepo := &openAIUserGroupRateRepoStub{err: errors.New("db unavailable")} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, rateRepo) + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_group_default_on_error", + Usage: usage, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 1002, + GroupID: i64p(groupID), + Group: &Group{ + ID: groupID, + RateMultiplier: groupRate, + }, + }, + User: &User{ID: 2002}, + Account: &Account{ID: 3002}, + }) + + require.NoError(t, err) + require.Equal(t, 1, rateRepo.calls) + require.NotNil(t, usageRepo.lastLog) + require.Equal(t, groupRate, usageRepo.lastLog.RateMultiplier) + + expected := expectedOpenAICost(t, svc, "gpt-5.1", usage, groupRate) + require.InDelta(t, expected.ActualCost, userRepo.lastAmount, 1e-12) +} + +func TestOpenAIGatewayServiceRecordUsage_FallsBackToGroupDefaultRateWhenResolverMissing(t *testing.T) { + groupID := int64(13) + groupRate := 1.25 + usage := OpenAIUsage{InputTokens: 9, OutputTokens: 4, CacheReadInputTokens: 1} + + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil) + svc.userGroupRateResolver = nil + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_group_default_nil_resolver", + Usage: usage, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 1003, + GroupID: i64p(groupID), + Group: &Group{ + ID: groupID, + RateMultiplier: groupRate, + }, + }, + User: &User{ID: 2003}, + Account: &Account{ID: 3003}, + }) + + require.NoError(t, err) + require.NotNil(t, usageRepo.lastLog) + require.Equal(t, groupRate, usageRepo.lastLog.RateMultiplier) +} + +func TestOpenAIGatewayServiceRecordUsage_DuplicateUsageLogSkipsBilling(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{inserted: false} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil) + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_duplicate", + Usage: OpenAIUsage{ + InputTokens: 8, + OutputTokens: 4, + }, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 1004}, + User: &User{ID: 2004}, + Account: &Account{ID: 3004}, + }) + + require.NoError(t, err) + require.Equal(t, 1, usageRepo.calls) + require.Equal(t, 0, userRepo.deductCalls) + require.Equal(t, 0, subRepo.incrementCalls) +} + +func TestOpenAIGatewayServiceRecordUsage_UpdatesAPIKeyQuotaWhenConfigured(t *testing.T) { + usage := OpenAIUsage{InputTokens: 10, OutputTokens: 6, CacheReadInputTokens: 2} + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil) + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_quota_update", + Usage: usage, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 1005, + Quota: 100, + }, + User: &User{ID: 2005}, + Account: &Account{ID: 3005}, + APIKeyService: quotaSvc, + }) + + require.NoError(t, err) + require.Equal(t, 1, quotaSvc.quotaCalls) + require.Equal(t, 0, quotaSvc.rateLimitCalls) + expected := expectedOpenAICost(t, svc, "gpt-5.1", usage, 1.1) + require.InDelta(t, expected.ActualCost, quotaSvc.lastAmount, 1e-12) +} + +func TestOpenAIGatewayServiceRecordUsage_ClampsActualInputTokensToZero(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil) + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_clamp_actual_input", + Usage: OpenAIUsage{ + InputTokens: 2, + OutputTokens: 1, + CacheReadInputTokens: 5, + }, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 1006}, + User: &User{ID: 2006}, + Account: &Account{ID: 3006}, + }) + + require.NoError(t, err) + require.NotNil(t, usageRepo.lastLog) + require.Equal(t, 0, usageRepo.lastLog.InputTokens) +} + +func TestOpenAIGatewayServiceRecordUsage_Gpt54LongContextBillsWholeSession(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil) + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_gpt54_long_context", + Usage: OpenAIUsage{ + InputTokens: 300000, + OutputTokens: 2000, + }, + Model: "gpt-5.4-2026-03-05", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 1014}, + User: &User{ID: 2014}, + Account: &Account{ID: 3014}, + }) + + require.NoError(t, err) + require.NotNil(t, usageRepo.lastLog) + + expectedInput := 300000 * 2.5e-6 * 2.0 + expectedOutput := 2000 * 15e-6 * 1.5 + require.InDelta(t, expectedInput, usageRepo.lastLog.InputCost, 1e-10) + require.InDelta(t, expectedOutput, usageRepo.lastLog.OutputCost, 1e-10) + require.InDelta(t, expectedInput+expectedOutput, usageRepo.lastLog.TotalCost, 1e-10) + require.InDelta(t, (expectedInput+expectedOutput)*1.1, usageRepo.lastLog.ActualCost, 1e-10) + require.Equal(t, 1, userRepo.deductCalls) +} + +func TestOpenAIGatewayServiceRecordUsage_ServiceTierPriorityUsesFastPricing(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil) + serviceTier := "priority" + usage := OpenAIUsage{InputTokens: 100, OutputTokens: 50} + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_service_tier_priority", + ServiceTier: &serviceTier, + Usage: usage, + Model: "gpt-5.4", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 1015}, + User: &User{ID: 2015}, + Account: &Account{ID: 3015}, + }) + + require.NoError(t, err) + require.NotNil(t, usageRepo.lastLog) + require.NotNil(t, usageRepo.lastLog.ServiceTier) + require.Equal(t, serviceTier, *usageRepo.lastLog.ServiceTier) + + baseCost, calcErr := svc.billingService.CalculateCost("gpt-5.4", UsageTokens{InputTokens: 100, OutputTokens: 50}, 1.0) + require.NoError(t, calcErr) + require.InDelta(t, baseCost.TotalCost*2, usageRepo.lastLog.TotalCost, 1e-10) +} + +func TestOpenAIGatewayServiceRecordUsage_ServiceTierFlexHalvesCost(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil) + serviceTier := "flex" + usage := OpenAIUsage{InputTokens: 100, OutputTokens: 50, CacheReadInputTokens: 20} + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_service_tier_flex", + ServiceTier: &serviceTier, + Usage: usage, + Model: "gpt-5.4", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 1016}, + User: &User{ID: 2016}, + Account: &Account{ID: 3016}, + }) + + require.NoError(t, err) + require.NotNil(t, usageRepo.lastLog) + + baseCost, calcErr := svc.billingService.CalculateCost("gpt-5.4", UsageTokens{InputTokens: 80, OutputTokens: 50, CacheReadTokens: 20}, 1.0) + require.NoError(t, calcErr) + require.InDelta(t, baseCost.TotalCost*0.5, usageRepo.lastLog.TotalCost, 1e-10) +} + +func TestNormalizeOpenAIServiceTier(t *testing.T) { + t.Run("fast maps to priority", func(t *testing.T) { + got := normalizeOpenAIServiceTier(" fast ") + require.NotNil(t, got) + require.Equal(t, "priority", *got) + }) + + t.Run("default ignored", func(t *testing.T) { + require.Nil(t, normalizeOpenAIServiceTier("default")) + }) + + t.Run("invalid ignored", func(t *testing.T) { + require.Nil(t, normalizeOpenAIServiceTier("turbo")) + }) +} + +func TestExtractOpenAIServiceTier(t *testing.T) { + require.Equal(t, "priority", *extractOpenAIServiceTier(map[string]any{"service_tier": "fast"})) + require.Equal(t, "flex", *extractOpenAIServiceTier(map[string]any{"service_tier": "flex"})) + require.Nil(t, extractOpenAIServiceTier(map[string]any{"service_tier": 1})) + require.Nil(t, extractOpenAIServiceTier(nil)) +} + +func TestExtractOpenAIServiceTierFromBody(t *testing.T) { + require.Equal(t, "priority", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"fast"}`))) + require.Equal(t, "flex", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"flex"}`))) + require.Nil(t, extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"default"}`))) + require.Nil(t, extractOpenAIServiceTierFromBody(nil)) +} + +func TestOpenAIGatewayServiceRecordUsage_UsesBillingModelAndMetadataFields(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil) + serviceTier := "priority" + reasoning := "high" + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_billing_model_override", + BillingModel: "gpt-5.1-codex", + Model: "gpt-5.1", + ServiceTier: &serviceTier, + ReasoningEffort: &reasoning, + Usage: OpenAIUsage{ + InputTokens: 20, + OutputTokens: 10, + }, + Duration: 2 * time.Second, + FirstTokenMs: func() *int { v := 120; return &v }(), + }, + APIKey: &APIKey{ID: 10, GroupID: i64p(11), Group: &Group{ID: 11, RateMultiplier: 1.2}}, + User: &User{ID: 20}, + Account: &Account{ID: 30}, + UserAgent: "codex-cli/1.0", + IPAddress: "127.0.0.1", + }) + + require.NoError(t, err) + require.NotNil(t, usageRepo.lastLog) + require.Equal(t, "gpt-5.1-codex", usageRepo.lastLog.Model) + require.NotNil(t, usageRepo.lastLog.ServiceTier) + require.Equal(t, serviceTier, *usageRepo.lastLog.ServiceTier) + require.NotNil(t, usageRepo.lastLog.ReasoningEffort) + require.Equal(t, reasoning, *usageRepo.lastLog.ReasoningEffort) + require.NotNil(t, usageRepo.lastLog.UserAgent) + require.Equal(t, "codex-cli/1.0", *usageRepo.lastLog.UserAgent) + require.NotNil(t, usageRepo.lastLog.IPAddress) + require.Equal(t, "127.0.0.1", *usageRepo.lastLog.IPAddress) + require.NotNil(t, usageRepo.lastLog.GroupID) + require.Equal(t, int64(11), *usageRepo.lastLog.GroupID) + require.Equal(t, 1, userRepo.deductCalls) +} + +func TestOpenAIGatewayServiceRecordUsage_SubscriptionBillingSetsSubscriptionFields(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil) + subscription := &UserSubscription{ID: 99} + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_subscription_billing", + Usage: OpenAIUsage{InputTokens: 10, OutputTokens: 5}, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 100, GroupID: i64p(88), Group: &Group{ID: 88, SubscriptionType: SubscriptionTypeSubscription}}, + User: &User{ID: 200}, + Account: &Account{ID: 300}, + Subscription: subscription, + }) + + require.NoError(t, err) + require.NotNil(t, usageRepo.lastLog) + require.Equal(t, BillingTypeSubscription, usageRepo.lastLog.BillingType) + require.NotNil(t, usageRepo.lastLog.SubscriptionID) + require.Equal(t, subscription.ID, *usageRepo.lastLog.SubscriptionID) + require.Equal(t, 1, subRepo.incrementCalls) + require.Equal(t, 0, userRepo.deductCalls) +} + +func TestOpenAIGatewayServiceRecordUsage_SimpleModeSkipsBillingAfterPersist(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil) + svc.cfg.RunMode = config.RunModeSimple + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_simple_mode", + Usage: OpenAIUsage{InputTokens: 10, OutputTokens: 5}, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 1000}, + User: &User{ID: 2000}, + Account: &Account{ID: 3000}, + }) + + require.NoError(t, err) + require.Equal(t, 1, usageRepo.calls) + require.Equal(t, 0, userRepo.deductCalls) + require.Equal(t, 0, subRepo.incrementCalls) +} diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 8606708f..023e4ed4 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -25,6 +25,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" "github.com/gin-gonic/gin" + "github.com/google/uuid" "github.com/tidwall/gjson" "github.com/tidwall/sjson" "go.uber.org/zap" @@ -49,6 +50,10 @@ const ( openAIWSRetryBackoffInitialDefault = 120 * time.Millisecond openAIWSRetryBackoffMaxDefault = 2 * time.Second openAIWSRetryJitterRatioDefault = 0.2 + openAICompactSessionSeedKey = "openai_compact_session_seed" + codexCLIVersion = "0.104.0" + // Codex 限额快照仅用于后台展示/诊断,不需要每个成功请求都立即落库。 + openAICodexSnapshotPersistMinInterval = 30 * time.Second ) // OpenAI allowed headers whitelist (for non-passthrough). @@ -204,12 +209,21 @@ type OpenAIUsage struct { type OpenAIForwardResult struct { RequestID string Usage OpenAIUsage - Model string + Model string // 原始模型(用于响应和日志显示) + // BillingModel is the model used for cost calculation. + // When non-empty, CalculateCost uses this instead of Model. + // This is set by the Anthropic Messages conversion path where + // the mapped upstream model differs from the client-facing model. + BillingModel string + // ServiceTier records the OpenAI Responses API service tier, e.g. "priority" / "flex". + // Nil means the request did not specify a recognized tier. + ServiceTier *string // ReasoningEffort is extracted from request body (reasoning.effort) or derived from model suffix. // Stored for usage records display; nil means not provided / not applicable. ReasoningEffort *string Stream bool OpenAIWSMode bool + ResponseHeaders http.Header Duration time.Duration FirstTokenMs *int } @@ -243,37 +257,81 @@ type openAIWSRetryMetrics struct { nonRetryableFastFallback atomic.Int64 } +type accountWriteThrottle struct { + minInterval time.Duration + mu sync.Mutex + lastByID map[int64]time.Time +} + +func newAccountWriteThrottle(minInterval time.Duration) *accountWriteThrottle { + return &accountWriteThrottle{ + minInterval: minInterval, + lastByID: make(map[int64]time.Time), + } +} + +func (t *accountWriteThrottle) Allow(id int64, now time.Time) bool { + if t == nil || id <= 0 || t.minInterval <= 0 { + return true + } + + t.mu.Lock() + defer t.mu.Unlock() + + if last, ok := t.lastByID[id]; ok && now.Sub(last) < t.minInterval { + return false + } + t.lastByID[id] = now + + if len(t.lastByID) > 4096 { + cutoff := now.Add(-4 * t.minInterval) + for accountID, writtenAt := range t.lastByID { + if writtenAt.Before(cutoff) { + delete(t.lastByID, accountID) + } + } + } + + return true +} + +var defaultOpenAICodexSnapshotPersistThrottle = newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval) + // OpenAIGatewayService handles OpenAI API gateway operations type OpenAIGatewayService struct { - accountRepo AccountRepository - usageLogRepo UsageLogRepository - userRepo UserRepository - userSubRepo UserSubscriptionRepository - cache GatewayCache - cfg *config.Config - codexDetector CodexClientRestrictionDetector - schedulerSnapshot *SchedulerSnapshotService - concurrencyService *ConcurrencyService - billingService *BillingService - rateLimitService *RateLimitService - billingCacheService *BillingCacheService - httpUpstream HTTPUpstream - deferredService *DeferredService - openAITokenProvider *OpenAITokenProvider - toolCorrector *CodexToolCorrector - openaiWSResolver OpenAIWSProtocolResolver + accountRepo AccountRepository + usageLogRepo UsageLogRepository + userRepo UserRepository + userSubRepo UserSubscriptionRepository + cache GatewayCache + cfg *config.Config + codexDetector CodexClientRestrictionDetector + schedulerSnapshot *SchedulerSnapshotService + concurrencyService *ConcurrencyService + billingService *BillingService + rateLimitService *RateLimitService + billingCacheService *BillingCacheService + userGroupRateResolver *userGroupRateResolver + httpUpstream HTTPUpstream + deferredService *DeferredService + openAITokenProvider *OpenAITokenProvider + toolCorrector *CodexToolCorrector + openaiWSResolver OpenAIWSProtocolResolver - openaiWSPoolOnce sync.Once - openaiWSStateStoreOnce sync.Once - openaiSchedulerOnce sync.Once - openaiWSPool *openAIWSConnPool - openaiWSStateStore OpenAIWSStateStore - openaiScheduler OpenAIAccountScheduler - openaiAccountStats *openAIAccountRuntimeStats + openaiWSPoolOnce sync.Once + openaiWSStateStoreOnce sync.Once + openaiSchedulerOnce sync.Once + openaiWSPassthroughDialerOnce sync.Once + openaiWSPool *openAIWSConnPool + openaiWSStateStore OpenAIWSStateStore + openaiScheduler OpenAIAccountScheduler + openaiWSPassthroughDialer openAIWSClientDialer + openaiAccountStats *openAIAccountRuntimeStats openaiWSFallbackUntil sync.Map // key: int64(accountID), value: time.Time openaiWSRetryMetrics openAIWSRetryMetrics responseHeaderFilter *responseheaders.CompiledHeaderFilter + codexSnapshotThrottle *accountWriteThrottle } // NewOpenAIGatewayService creates a new OpenAIGatewayService @@ -282,6 +340,7 @@ func NewOpenAIGatewayService( usageLogRepo UsageLogRepository, userRepo UserRepository, userSubRepo UserSubscriptionRepository, + userGroupRateRepo UserGroupRateRepository, cache GatewayCache, cfg *config.Config, schedulerSnapshot *SchedulerSnapshotService, @@ -294,29 +353,54 @@ func NewOpenAIGatewayService( openAITokenProvider *OpenAITokenProvider, ) *OpenAIGatewayService { svc := &OpenAIGatewayService{ - accountRepo: accountRepo, - usageLogRepo: usageLogRepo, - userRepo: userRepo, - userSubRepo: userSubRepo, - cache: cache, - cfg: cfg, - codexDetector: NewOpenAICodexClientRestrictionDetector(cfg), - schedulerSnapshot: schedulerSnapshot, - concurrencyService: concurrencyService, - billingService: billingService, - rateLimitService: rateLimitService, - billingCacheService: billingCacheService, - httpUpstream: httpUpstream, - deferredService: deferredService, - openAITokenProvider: openAITokenProvider, - toolCorrector: NewCodexToolCorrector(), - openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), - responseHeaderFilter: compileResponseHeaderFilter(cfg), + accountRepo: accountRepo, + usageLogRepo: usageLogRepo, + userRepo: userRepo, + userSubRepo: userSubRepo, + cache: cache, + cfg: cfg, + codexDetector: NewOpenAICodexClientRestrictionDetector(cfg), + schedulerSnapshot: schedulerSnapshot, + concurrencyService: concurrencyService, + billingService: billingService, + rateLimitService: rateLimitService, + billingCacheService: billingCacheService, + userGroupRateResolver: newUserGroupRateResolver( + userGroupRateRepo, + nil, + resolveUserGroupRateCacheTTL(cfg), + nil, + "service.openai_gateway", + ), + httpUpstream: httpUpstream, + deferredService: deferredService, + openAITokenProvider: openAITokenProvider, + toolCorrector: NewCodexToolCorrector(), + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + responseHeaderFilter: compileResponseHeaderFilter(cfg), + codexSnapshotThrottle: newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval), } svc.logOpenAIWSModeBootstrap() return svc } +func (s *OpenAIGatewayService) getCodexSnapshotThrottle() *accountWriteThrottle { + if s != nil && s.codexSnapshotThrottle != nil { + return s.codexSnapshotThrottle + } + return defaultOpenAICodexSnapshotPersistThrottle +} + +func (s *OpenAIGatewayService) billingDeps() *billingDeps { + return &billingDeps{ + accountRepo: s.accountRepo, + userRepo: s.userRepo, + userSubRepo: s.userSubRepo, + billingCacheService: s.billingCacheService, + deferredService: s.deferredService, + } +} + // CloseOpenAIWSPool 关闭 OpenAI WebSocket 连接池的后台 worker 和空闲连接。 // 应在应用优雅关闭时调用。 func (s *OpenAIGatewayService) CloseOpenAIWSPool() { @@ -804,8 +888,10 @@ func logOpenAIInstructionsRequiredDebug( } userAgent := "" + originator := "" if c != nil { userAgent = strings.TrimSpace(c.GetHeader("User-Agent")) + originator = strings.TrimSpace(c.GetHeader("originator")) } fields := []zap.Field{ @@ -815,7 +901,7 @@ func logOpenAIInstructionsRequiredDebug( zap.Int("upstream_status_code", upstreamStatusCode), zap.String("upstream_error_message", msg), zap.String("request_user_agent", userAgent), - zap.Bool("codex_official_client_match", openai.IsCodexCLIRequest(userAgent)), + zap.Bool("codex_official_client_match", openai.IsCodexOfficialClientByHeaders(userAgent, originator)), } fields = appendCodexCLIOnlyRejectedRequestFields(fields, c, requestBody) @@ -876,6 +962,52 @@ func isOpenAIInstructionsRequiredError(upstreamStatusCode int, upstreamMsg strin return false } +func isOpenAITransientProcessingError(upstreamStatusCode int, upstreamMsg string, upstreamBody []byte) bool { + if upstreamStatusCode != http.StatusBadRequest { + return false + } + + match := func(text string) bool { + lower := strings.ToLower(strings.TrimSpace(text)) + if lower == "" { + return false + } + if strings.Contains(lower, "an error occurred while processing your request") { + return true + } + return strings.Contains(lower, "you can retry your request") && + strings.Contains(lower, "help.openai.com") && + strings.Contains(lower, "request id") + } + + if match(upstreamMsg) { + return true + } + if len(upstreamBody) == 0 { + return false + } + if match(gjson.GetBytes(upstreamBody, "error.message").String()) { + return true + } + return match(string(upstreamBody)) +} + +// ExtractSessionID extracts the raw session ID from headers or body without hashing. +// Used by ForwardAsAnthropic to pass as prompt_cache_key for upstream cache. +func (s *OpenAIGatewayService) ExtractSessionID(c *gin.Context, body []byte) string { + if c == nil { + return "" + } + sessionID := strings.TrimSpace(c.GetHeader("session_id")) + if sessionID == "" { + sessionID = strings.TrimSpace(c.GetHeader("conversation_id")) + } + if sessionID == "" && len(body) > 0 { + sessionID = strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String()) + } + return sessionID +} + // GenerateSessionHash generates a sticky-session hash for OpenAI requests. // // Priority: @@ -922,6 +1054,18 @@ func (s *OpenAIGatewayService) GenerateSessionHashWithFallback(c *gin.Context, b return currentHash } +func resolveOpenAIUpstreamOriginator(c *gin.Context, isOfficialClient bool) string { + if c != nil { + if originator := strings.TrimSpace(c.GetHeader("originator")); originator != "" { + return originator + } + } + if isOfficialClient { + return "codex_cli_rs" + } + return "opencode" +} + // BindStickySession sets session -> account binding with standard TTL. func (s *OpenAIGatewayService) BindStickySession(ctx context.Context, groupID *int64, sessionHash string, accountID int64) error { if sessionHash == "" || accountID <= 0 { @@ -966,7 +1110,7 @@ func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.C // 3. 按优先级 + LRU 选择最佳账号 // Select by priority + LRU - selected := s.selectBestAccount(accounts, requestedModel, excludedIDs) + selected := s.selectBestAccount(ctx, accounts, requestedModel, excludedIDs) if selected == nil { if requestedModel != "" { @@ -1039,7 +1183,7 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID // // selectBestAccount selects the best account from candidates (priority + LRU). // Returns nil if no available account. -func (s *OpenAIGatewayService) selectBestAccount(accounts []Account, requestedModel string, excludedIDs map[int64]struct{}) *Account { +func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, accounts []Account, requestedModel string, excludedIDs map[int64]struct{}) *Account { var selected *Account for i := range accounts { @@ -1051,27 +1195,20 @@ func (s *OpenAIGatewayService) selectBestAccount(accounts []Account, requestedMo continue } - // 调度器快照可能暂时过时,这里重新检查可调度性和平台 - // Scheduler snapshots can be temporarily stale; re-check schedulability and platform - if !acc.IsSchedulable() || !acc.IsOpenAI() { - continue - } - - // 检查模型支持 - // Check model support - if requestedModel != "" && !acc.IsModelSupported(requestedModel) { + fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, acc, requestedModel) + if fresh == nil { continue } // 选择优先级最高且最久未使用的账号 // Select highest priority and least recently used if selected == nil { - selected = acc + selected = fresh continue } - if s.isBetterAccount(acc, selected) { - selected = acc + if s.isBetterAccount(fresh, selected) { + selected = fresh } } @@ -1240,7 +1377,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex for _, acc := range candidates { accountLoads = append(accountLoads, AccountWithConcurrency{ ID: acc.ID, - MaxConcurrency: acc.Concurrency, + MaxConcurrency: acc.EffectiveLoadFactor(), }) } @@ -1249,13 +1386,17 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex ordered := append([]*Account(nil), candidates...) sortAccountsByPriorityAndLastUsed(ordered, false) for _, acc := range ordered { - result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency) + fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, acc, requestedModel) + if fresh == nil { + continue + } + result, err := s.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency) if err == nil && result.Acquired { if sessionHash != "" { - _ = s.setStickySessionAccountID(ctx, groupID, sessionHash, acc.ID, openaiStickySessionTTL) + _ = s.setStickySessionAccountID(ctx, groupID, sessionHash, fresh.ID, openaiStickySessionTTL) } return &AccountSelectionResult{ - Account: acc, + Account: fresh, Acquired: true, ReleaseFunc: result.ReleaseFunc, }, nil @@ -1299,13 +1440,17 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex shuffleWithinSortGroups(available) for _, item := range available { - result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency) + fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, item.account, requestedModel) + if fresh == nil { + continue + } + result, err := s.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency) if err == nil && result.Acquired { if sessionHash != "" { - _ = s.setStickySessionAccountID(ctx, groupID, sessionHash, item.account.ID, openaiStickySessionTTL) + _ = s.setStickySessionAccountID(ctx, groupID, sessionHash, fresh.ID, openaiStickySessionTTL) } return &AccountSelectionResult{ - Account: item.account, + Account: fresh, Acquired: true, ReleaseFunc: result.ReleaseFunc, }, nil @@ -1317,11 +1462,15 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex // ============ Layer 3: Fallback wait ============ sortAccountsByPriorityAndLastUsed(candidates, false) for _, acc := range candidates { + fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, acc, requestedModel) + if fresh == nil { + continue + } return &AccountSelectionResult{ - Account: acc, + Account: fresh, WaitPlan: &AccountWaitPlan{ - AccountID: acc.ID, - MaxConcurrency: acc.Concurrency, + AccountID: fresh.ID, + MaxConcurrency: fresh.Concurrency, Timeout: cfg.FallbackWaitTimeout, MaxWaiting: cfg.FallbackMaxWaiting, }, @@ -1358,11 +1507,44 @@ func (s *OpenAIGatewayService) tryAcquireAccountSlot(ctx context.Context, accoun return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency) } -func (s *OpenAIGatewayService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) { - if s.schedulerSnapshot != nil { - return s.schedulerSnapshot.GetAccount(ctx, accountID) +func (s *OpenAIGatewayService) resolveFreshSchedulableOpenAIAccount(ctx context.Context, account *Account, requestedModel string) *Account { + if account == nil { + return nil } - return s.accountRepo.GetByID(ctx, accountID) + + fresh := account + if s.schedulerSnapshot != nil { + current, err := s.getSchedulableAccount(ctx, account.ID) + if err != nil || current == nil { + return nil + } + fresh = current + } + + if !fresh.IsSchedulable() || !fresh.IsOpenAI() { + return nil + } + if requestedModel != "" && !fresh.IsModelSupported(requestedModel) { + return nil + } + return fresh +} + +func (s *OpenAIGatewayService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) { + var ( + account *Account + err error + ) + if s.schedulerSnapshot != nil { + account, err = s.schedulerSnapshot.GetAccount(ctx, accountID) + } else { + account, err = s.accountRepo.GetByID(ctx, accountID) + } + if err != nil || account == nil { + return account, err + } + syncOpenAICodexRateLimitFromExtra(ctx, s.accountRepo, account, time.Now()) + return account, nil } func (s *OpenAIGatewayService) schedulingConfig() config.GatewaySchedulingConfig { @@ -1417,6 +1599,13 @@ func (s *OpenAIGatewayService) shouldFailoverUpstreamError(statusCode int) bool } } +func (s *OpenAIGatewayService) shouldFailoverOpenAIUpstreamResponse(statusCode int, upstreamMsg string, upstreamBody []byte) bool { + if s.shouldFailoverUpstreamError(statusCode) { + return true + } + return isOpenAITransientProcessingError(statusCode, upstreamMsg, upstreamBody) +} + func (s *OpenAIGatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account) { body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) @@ -1443,7 +1632,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco reqModel, reqStream, promptCacheKey := extractOpenAIRequestMetaFromBody(body) originalModel := reqModel - isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI) + isCodexCLI := openai.IsCodexOfficialClientByHeaders(c.GetHeader("User-Agent"), c.GetHeader("originator")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI) wsDecision := s.getOpenAIWSProtocolResolver().Resolve(account) clientTransport := GetOpenAIClientTransport(c) // 仅允许 WS 入站请求走 WS 上游,避免出现 HTTP -> WS 协议混用。 @@ -1551,13 +1740,11 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco patchDisabled = true } - // 非透传模式下,保持历史行为:非 Codex CLI 请求在 instructions 为空时注入默认指令。 - if !isCodexCLI && isInstructionsEmpty(reqBody) { - if instructions := strings.TrimSpace(GetOpenCodeInstructions()); instructions != "" { - reqBody["instructions"] = instructions - bodyModified = true - markPatchSet("instructions", instructions) - } + // 非透传模式下,instructions 为空时注入默认指令。 + if isInstructionsEmpty(reqBody) { + reqBody["instructions"] = "You are a helpful coding assistant." + bodyModified = true + markPatchSet("instructions", "You are a helpful coding assistant.") } // 对所有请求执行模型映射(包含 Codex CLI)。 @@ -1580,6 +1767,14 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco bodyModified = true markPatchSet("model", normalizedModel) } + + // 移除 gpt-5.2-codex 以下的版本 verbosity 参数 + // 确保高版本模型向低版本模型映射不报错 + if !SupportsVerbosity(normalizedModel) { + if text, ok := reqBody["text"].(map[string]any); ok { + delete(text, "verbosity") + } + } } // 规范化 reasoning.effort 参数(minimal -> none),与上游允许值对齐。 @@ -1593,7 +1788,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco } if account.Type == AccountTypeOAuth { - codexResult := applyCodexOAuthTransform(reqBody, isCodexCLI) + codexResult := applyCodexOAuthTransform(reqBody, isCodexCLI, isOpenAIResponsesCompactPath(c)) if codexResult.Modified { bodyModified = true disablePatch() @@ -1917,13 +2112,13 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco // Handle error response if resp.StatusCode >= 400 { - if s.shouldFailoverUpstreamError(resp.StatusCode) { - respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - _ = resp.Body.Close() - resp.Body = io.NopCloser(bytes.NewReader(respBody)) + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(respBody)) - upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) - upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) { upstreamDetail := "" if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes @@ -1944,7 +2139,11 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco }) s.handleFailoverSideEffects(ctx, resp, account) - return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} + return nil, &UpstreamFailoverError{ + StatusCode: resp.StatusCode, + ResponseBody: respBody, + RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)), + } } return s.handleErrorResponse(ctx, resp, c, account, body) } @@ -1978,11 +2177,13 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco } reasoningEffort := extractOpenAIReasoningEffort(reqBody, originalModel) + serviceTier := extractOpenAIServiceTier(reqBody) return &OpenAIForwardResult{ RequestID: resp.Header.Get("x-request-id"), Usage: *usage, Model: originalModel, + ServiceTier: serviceTier, ReasoningEffort: reasoningEffort, Stream: reqStream, OpenAIWSMode: false, @@ -2025,14 +2226,14 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough( return nil, fmt.Errorf("openai passthrough rejected before upstream: %s", rejectReason) } - normalizedBody, normalized, err := normalizeOpenAIPassthroughOAuthBody(body) + normalizedBody, normalized, err := normalizeOpenAIPassthroughOAuthBody(body, isOpenAIResponsesCompactPath(c)) if err != nil { return nil, err } if normalized { body = normalizedBody - reqStream = true } + reqStream = gjson.GetBytes(body, "stream").Bool() } logger.LegacyPrintf("service.openai_gateway", @@ -2137,6 +2338,7 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough( RequestID: resp.Header.Get("x-request-id"), Usage: *usage, Model: reqModel, + ServiceTier: extractOpenAIServiceTierFromBody(body), ReasoningEffort: reasoningEffort, Stream: reqStream, OpenAIWSMode: false, @@ -2197,6 +2399,7 @@ func (s *OpenAIGatewayService) buildUpstreamRequestOpenAIPassthrough( targetURL = buildOpenAIResponsesURL(validatedURL) } } + targetURL = appendOpenAIResponsesRequestPathSuffix(targetURL, openAIResponsesRequestPathSuffix(c)) req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body)) if err != nil { @@ -2230,7 +2433,15 @@ func (s *OpenAIGatewayService) buildUpstreamRequestOpenAIPassthrough( if chatgptAccountID := account.GetChatGPTAccountID(); chatgptAccountID != "" { req.Header.Set("chatgpt-account-id", chatgptAccountID) } - if req.Header.Get("accept") == "" { + if isOpenAIResponsesCompactPath(c) { + req.Header.Set("accept", "application/json") + if req.Header.Get("version") == "" { + req.Header.Set("version", codexCLIVersion) + } + if req.Header.Get("session_id") == "" { + req.Header.Set("session_id", resolveOpenAICompactSessionID(c)) + } + } else if req.Header.Get("accept") == "" { req.Header.Set("accept", "text/event-stream") } if req.Header.Get("OpenAI-Beta") == "" { @@ -2577,6 +2788,7 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin. default: targetURL = openaiPlatformAPIURL } + targetURL = appendOpenAIResponsesRequestPathSuffix(targetURL, openAIResponsesRequestPathSuffix(c)) req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body)) if err != nil { @@ -2608,12 +2820,18 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin. } if account.Type == AccountTypeOAuth { req.Header.Set("OpenAI-Beta", "responses=experimental") - if isCodexCLI { - req.Header.Set("originator", "codex_cli_rs") + req.Header.Set("originator", resolveOpenAIUpstreamOriginator(c, isCodexCLI)) + if isOpenAIResponsesCompactPath(c) { + req.Header.Set("accept", "application/json") + if req.Header.Get("version") == "" { + req.Header.Set("version", codexCLIVersion) + } + if req.Header.Get("session_id") == "" { + req.Header.Set("session_id", resolveOpenAICompactSessionID(c)) + } } else { - req.Header.Set("originator", "opencode") + req.Header.Set("accept", "text/event-stream") } - req.Header.Set("accept", "text/event-stream") if promptCacheKey != "" { req.Header.Set("conversation_id", promptCacheKey) req.Header.Set("session_id", promptCacheKey) @@ -2741,7 +2959,11 @@ func (s *OpenAIGatewayService) handleErrorResponse( Detail: upstreamDetail, }) if shouldDisable { - return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: body} + return nil, &UpstreamFailoverError{ + StatusCode: resp.StatusCode, + ResponseBody: body, + RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode), + } } // Return appropriate error response @@ -2784,6 +3006,120 @@ func (s *OpenAIGatewayService) handleErrorResponse( return nil, fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg) } +// compatErrorWriter is the signature for format-specific error writers used by +// the compat paths (Chat Completions and Anthropic Messages). +type compatErrorWriter func(c *gin.Context, statusCode int, errType, message string) + +// handleCompatErrorResponse is the shared non-failover error handler for the +// Chat Completions and Anthropic Messages compat paths. It mirrors the logic of +// handleErrorResponse (passthrough rules, ShouldHandleErrorCode, rate-limit +// tracking, secondary failover) but delegates the final error write to the +// format-specific writer function. +func (s *OpenAIGatewayService) handleCompatErrorResponse( + resp *http.Response, + c *gin.Context, + account *Account, + writeError compatErrorWriter, +) (*OpenAIForwardResult, error) { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body)) + if upstreamMsg == "" { + upstreamMsg = fmt.Sprintf("Upstream error: %d", resp.StatusCode) + } + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + + upstreamDetail := "" + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + upstreamDetail = truncateString(string(body), maxBytes) + } + setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail) + + // Apply error passthrough rules + if status, errType, errMsg, matched := applyErrorPassthroughRule( + c, account.Platform, resp.StatusCode, body, + http.StatusBadGateway, "api_error", "Upstream request failed", + ); matched { + writeError(c, status, errType, errMsg) + if upstreamMsg == "" { + upstreamMsg = errMsg + } + if upstreamMsg == "" { + return nil, fmt.Errorf("upstream error: %d (passthrough rule matched)", resp.StatusCode) + } + return nil, fmt.Errorf("upstream error: %d (passthrough rule matched) message=%s", resp.StatusCode, upstreamMsg) + } + + // Check custom error codes — if the account does not handle this status, + // return a generic error without exposing upstream details. + if !account.ShouldHandleErrorCode(resp.StatusCode) { + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: "http_error", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + writeError(c, http.StatusInternalServerError, "api_error", "Upstream gateway error") + if upstreamMsg == "" { + return nil, fmt.Errorf("upstream error: %d (not in custom error codes)", resp.StatusCode) + } + return nil, fmt.Errorf("upstream error: %d (not in custom error codes) message=%s", resp.StatusCode, upstreamMsg) + } + + // Track rate limits and decide whether to trigger secondary failover. + shouldDisable := false + if s.rateLimitService != nil { + shouldDisable = s.rateLimitService.HandleUpstreamError( + c.Request.Context(), account, resp.StatusCode, resp.Header, body, + ) + } + kind := "http_error" + if shouldDisable { + kind = "failover" + } + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: kind, + Message: upstreamMsg, + Detail: upstreamDetail, + }) + if shouldDisable { + return nil, &UpstreamFailoverError{ + StatusCode: resp.StatusCode, + ResponseBody: body, + RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode), + } + } + + // Map status code to error type and write response + errType := "api_error" + switch { + case resp.StatusCode == 400: + errType = "invalid_request_error" + case resp.StatusCode == 404: + errType = "not_found_error" + case resp.StatusCode == 429: + errType = "rate_limit_error" + case resp.StatusCode >= 500: + errType = "api_error" + } + + writeError(c, resp.StatusCode, errType, upstreamMsg) + return nil, fmt.Errorf("upstream error: %d %s", resp.StatusCode, upstreamMsg) +} + // openaiStreamingResult streaming response result type openaiStreamingResult struct { usage *OpenAIUsage @@ -3249,6 +3585,14 @@ func (s *OpenAIGatewayService) handleOAuthSSEToJSON(resp *http.Response, c *gin. // Correct tool calls in final response body = s.correctToolCallsInResponseBody(body) } else { + terminalType, terminalPayload, terminalOK := extractOpenAISSETerminalEvent(bodyText) + if terminalOK && terminalType == "response.failed" { + msg := extractOpenAISSEErrorMessage(terminalPayload) + if msg == "" { + msg = "Upstream compact response failed" + } + return nil, s.writeOpenAINonStreamingProtocolError(resp, c, msg) + } usage = s.parseSSEUsageFromBody(bodyText) if originalModel != mappedModel { bodyText = s.replaceModelInSSEBody(bodyText, mappedModel, originalModel) @@ -3270,6 +3614,51 @@ func (s *OpenAIGatewayService) handleOAuthSSEToJSON(resp *http.Response, c *gin. return usage, nil } +func extractOpenAISSETerminalEvent(body string) (string, []byte, bool) { + lines := strings.Split(body, "\n") + for _, line := range lines { + data, ok := extractOpenAISSEDataLine(line) + if !ok || data == "" || data == "[DONE]" { + continue + } + eventType := strings.TrimSpace(gjson.Get(data, "type").String()) + switch eventType { + case "response.completed", "response.done", "response.failed": + return eventType, []byte(data), true + } + } + return "", nil, false +} + +func extractOpenAISSEErrorMessage(payload []byte) string { + if len(payload) == 0 { + return "" + } + for _, path := range []string{"response.error.message", "error.message", "message"} { + if msg := strings.TrimSpace(gjson.GetBytes(payload, path).String()); msg != "" { + return sanitizeUpstreamErrorMessage(msg) + } + } + return sanitizeUpstreamErrorMessage(strings.TrimSpace(extractUpstreamErrorMessage(payload))) +} + +func (s *OpenAIGatewayService) writeOpenAINonStreamingProtocolError(resp *http.Response, c *gin.Context, message string) error { + message = sanitizeUpstreamErrorMessage(strings.TrimSpace(message)) + if message == "" { + message = "Upstream returned an invalid non-streaming response" + } + setOpsUpstreamError(c, http.StatusBadGateway, message, "") + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + c.Writer.Header().Set("Content-Type", "application/json; charset=utf-8") + c.JSON(http.StatusBadGateway, gin.H{ + "error": gin.H{ + "type": "upstream_error", + "message": message, + }, + }) + return fmt.Errorf("non-streaming openai protocol error: %s", message) +} + func extractCodexFinalResponse(body string) ([]byte, bool) { lines := strings.Split(body, "\n") for _, line := range lines { @@ -3351,6 +3740,95 @@ func buildOpenAIResponsesURL(base string) string { return normalized + "/v1/responses" } +func IsOpenAIResponsesCompactPathForTest(c *gin.Context) bool { + return isOpenAIResponsesCompactPath(c) +} + +func OpenAICompactSessionSeedKeyForTest() string { + return openAICompactSessionSeedKey +} + +func NormalizeOpenAICompactRequestBodyForTest(body []byte) ([]byte, bool, error) { + return normalizeOpenAICompactRequestBody(body) +} + +func isOpenAIResponsesCompactPath(c *gin.Context) bool { + suffix := strings.TrimSpace(openAIResponsesRequestPathSuffix(c)) + return suffix == "/compact" || strings.HasPrefix(suffix, "/compact/") +} + +func normalizeOpenAICompactRequestBody(body []byte) ([]byte, bool, error) { + if len(body) == 0 { + return body, false, nil + } + + normalized := []byte(`{}`) + for _, field := range []string{"model", "input", "instructions", "previous_response_id"} { + value := gjson.GetBytes(body, field) + if !value.Exists() { + continue + } + next, err := sjson.SetRawBytes(normalized, field, []byte(value.Raw)) + if err != nil { + return body, false, fmt.Errorf("normalize compact body %s: %w", field, err) + } + normalized = next + } + + if bytes.Equal(bytes.TrimSpace(body), bytes.TrimSpace(normalized)) { + return body, false, nil + } + return normalized, true, nil +} + +func resolveOpenAICompactSessionID(c *gin.Context) string { + if c != nil { + if sessionID := strings.TrimSpace(c.GetHeader("session_id")); sessionID != "" { + return sessionID + } + if conversationID := strings.TrimSpace(c.GetHeader("conversation_id")); conversationID != "" { + return conversationID + } + if seed, ok := c.Get(openAICompactSessionSeedKey); ok { + if seedStr, ok := seed.(string); ok && strings.TrimSpace(seedStr) != "" { + return strings.TrimSpace(seedStr) + } + } + } + return uuid.NewString() +} + +func openAIResponsesRequestPathSuffix(c *gin.Context) string { + if c == nil || c.Request == nil || c.Request.URL == nil { + return "" + } + normalizedPath := strings.TrimRight(strings.TrimSpace(c.Request.URL.Path), "/") + if normalizedPath == "" { + return "" + } + idx := strings.LastIndex(normalizedPath, "/responses") + if idx < 0 { + return "" + } + suffix := normalizedPath[idx+len("/responses"):] + if suffix == "" || suffix == "/" { + return "" + } + if !strings.HasPrefix(suffix, "/") { + return "" + } + return suffix +} + +func appendOpenAIResponsesRequestPathSuffix(baseURL, suffix string) string { + trimmedBase := strings.TrimRight(strings.TrimSpace(baseURL), "/") + trimmedSuffix := strings.TrimSpace(suffix) + if trimmedBase == "" || trimmedSuffix == "" { + return trimmedBase + } + return trimmedBase + trimmedSuffix +} + func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel, toModel string) []byte { // 使用 gjson/sjson 精确替换 model 字段,避免全量 JSON 反序列化 if m := gjson.GetBytes(body, "model"); m.Exists() && m.Str == fromModel { @@ -3378,6 +3856,13 @@ type OpenAIRecordUsageInput struct { // RecordUsage records usage and deducts balance func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRecordUsageInput) error { result := input.Result + + // 跳过所有 token 均为零的用量记录——上游未返回 usage 时不应写入数据库 + if result.Usage.InputTokens == 0 && result.Usage.OutputTokens == 0 && + result.Usage.CacheCreationInputTokens == 0 && result.Usage.CacheReadInputTokens == 0 { + return nil + } + apiKey := input.APIKey user := input.User account := input.Account @@ -3401,10 +3886,22 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec // Get rate multiplier multiplier := s.cfg.Default.RateMultiplier if apiKey.GroupID != nil && apiKey.Group != nil { - multiplier = apiKey.Group.RateMultiplier + resolver := s.userGroupRateResolver + if resolver == nil { + resolver = newUserGroupRateResolver(nil, nil, resolveUserGroupRateCacheTTL(s.cfg), nil, "service.openai_gateway") + } + multiplier = resolver.Resolve(ctx, user.ID, *apiKey.GroupID, apiKey.Group.RateMultiplier) } - cost, err := s.billingService.CalculateCost(result.Model, tokens, multiplier) + billingModel := result.Model + if result.BillingModel != "" { + billingModel = result.BillingModel + } + serviceTier := "" + if result.ServiceTier != nil { + serviceTier = strings.TrimSpace(*result.ServiceTier) + } + cost, err := s.billingService.CalculateCostWithServiceTier(billingModel, tokens, multiplier, serviceTier) if err != nil { cost = &CostBreakdown{ActualCost: 0} } @@ -3424,7 +3921,8 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec APIKeyID: apiKey.ID, AccountID: account.ID, RequestID: result.RequestID, - Model: result.Model, + Model: billingModel, + ServiceTier: result.ServiceTier, ReasoningEffort: result.ReasoningEffort, InputTokens: actualInputTokens, OutputTokens: result.Usage.OutputTokens, @@ -3472,37 +3970,21 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec shouldBill := inserted || err != nil - // Deduct based on billing type - if isSubscriptionBilling { - if shouldBill && cost.TotalCost > 0 { - _ = s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost) - s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost) - } + if shouldBill { + postUsageBilling(ctx, &postUsageBillingParams{ + Cost: cost, + User: user, + APIKey: apiKey, + Account: account, + Subscription: subscription, + IsSubscriptionBill: isSubscriptionBilling, + AccountRateMultiplier: accountRateMultiplier, + APIKeyService: input.APIKeyService, + }, s.billingDeps()) } else { - if shouldBill && cost.ActualCost > 0 { - _ = s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost) - s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost) - } + s.deferredService.ScheduleLastUsedUpdate(account.ID) } - // Update API key quota if applicable (only for balance mode with quota set) - if shouldBill && cost.ActualCost > 0 && apiKey.Quota > 0 && input.APIKeyService != nil { - if err := input.APIKeyService.UpdateQuotaUsed(ctx, apiKey.ID, cost.ActualCost); err != nil { - logger.LegacyPrintf("service.openai_gateway", "Update API key quota failed: %v", err) - } - } - - // Update API Key rate limit usage - if shouldBill && cost.ActualCost > 0 && apiKey.HasRateLimits() && input.APIKeyService != nil { - if err := input.APIKeyService.UpdateRateLimitUsage(ctx, apiKey.ID, cost.ActualCost); err != nil { - logger.LegacyPrintf("service.openai_gateway", "Update API key rate limit usage failed: %v", err) - } - s.billingCacheService.QueueUpdateAPIKeyRateLimitUsage(apiKey.ID, cost.ActualCost) - } - - // Schedule batch update for account last_used_at - s.deferredService.ScheduleLastUsedUpdate(account.ID) - return nil } @@ -3663,6 +4145,69 @@ func buildCodexUsageExtraUpdates(snapshot *OpenAICodexUsageSnapshot, fallbackNow return updates } +func codexUsagePercentExhausted(value *float64) bool { + return value != nil && *value >= 100-1e-9 +} + +func codexRateLimitResetAtFromSnapshot(snapshot *OpenAICodexUsageSnapshot, fallbackNow time.Time) *time.Time { + if snapshot == nil { + return nil + } + normalized := snapshot.Normalize() + if normalized == nil { + return nil + } + baseTime := codexSnapshotBaseTime(snapshot, fallbackNow) + if codexUsagePercentExhausted(normalized.Used7dPercent) && normalized.Reset7dSeconds != nil { + resetAt := baseTime.Add(time.Duration(*normalized.Reset7dSeconds) * time.Second) + return &resetAt + } + if codexUsagePercentExhausted(normalized.Used5hPercent) && normalized.Reset5hSeconds != nil { + resetAt := baseTime.Add(time.Duration(*normalized.Reset5hSeconds) * time.Second) + return &resetAt + } + return nil +} + +func codexRateLimitResetAtFromExtra(extra map[string]any, now time.Time) *time.Time { + if len(extra) == 0 { + return nil + } + if progress := buildCodexUsageProgressFromExtra(extra, "7d", now); progress != nil && codexUsagePercentExhausted(&progress.Utilization) && progress.ResetsAt != nil && now.Before(*progress.ResetsAt) { + resetAt := progress.ResetsAt.UTC() + return &resetAt + } + if progress := buildCodexUsageProgressFromExtra(extra, "5h", now); progress != nil && codexUsagePercentExhausted(&progress.Utilization) && progress.ResetsAt != nil && now.Before(*progress.ResetsAt) { + resetAt := progress.ResetsAt.UTC() + return &resetAt + } + return nil +} + +func applyOpenAICodexRateLimitFromExtra(account *Account, now time.Time) (*time.Time, bool) { + if account == nil || !account.IsOpenAI() { + return nil, false + } + resetAt := codexRateLimitResetAtFromExtra(account.Extra, now) + if resetAt == nil { + return nil, false + } + if account.RateLimitResetAt != nil && now.Before(*account.RateLimitResetAt) && !account.RateLimitResetAt.Before(*resetAt) { + return account.RateLimitResetAt, false + } + account.RateLimitResetAt = resetAt + return resetAt, true +} + +func syncOpenAICodexRateLimitFromExtra(ctx context.Context, repo AccountRepository, account *Account, now time.Time) *time.Time { + resetAt, changed := applyOpenAICodexRateLimitFromExtra(account, now) + if !changed || resetAt == nil || repo == nil || account == nil || account.ID <= 0 { + return resetAt + } + _ = repo.SetRateLimited(ctx, account.ID, *resetAt) + return resetAt +} + // updateCodexUsageSnapshot saves the Codex usage snapshot to account's Extra field func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, accountID int64, snapshot *OpenAICodexUsageSnapshot) { if snapshot == nil { @@ -3672,19 +4217,38 @@ func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, acc return } - updates := buildCodexUsageExtraUpdates(snapshot, time.Now()) - if len(updates) == 0 { + now := time.Now() + updates := buildCodexUsageExtraUpdates(snapshot, now) + resetAt := codexRateLimitResetAtFromSnapshot(snapshot, now) + if len(updates) == 0 && resetAt == nil { + return + } + shouldPersistUpdates := len(updates) > 0 && s.getCodexSnapshotThrottle().Allow(accountID, now) + if !shouldPersistUpdates && resetAt == nil { return } - // Update account's Extra field asynchronously go func() { updateCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - _ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates) + if shouldPersistUpdates { + _ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates) + } + if resetAt != nil { + _ = s.accountRepo.SetRateLimited(updateCtx, accountID, *resetAt) + } }() } +func (s *OpenAIGatewayService) UpdateCodexUsageSnapshotFromHeaders(ctx context.Context, accountID int64, headers http.Header) { + if accountID <= 0 || headers == nil { + return + } + if snapshot := ParseCodexRateLimitHeaders(headers); snapshot != nil { + s.updateCodexUsageSnapshot(ctx, accountID, snapshot) + } +} + func getOpenAIReasoningEffortFromReqBody(reqBody map[string]any) (value string, present bool) { if reqBody == nil { return "", false @@ -3743,8 +4307,8 @@ func extractOpenAIRequestMetaFromBody(body []byte) (model string, stream bool, p } // normalizeOpenAIPassthroughOAuthBody 将透传 OAuth 请求体收敛为旧链路关键行为: -// 1) store=false 2) stream=true -func normalizeOpenAIPassthroughOAuthBody(body []byte) ([]byte, bool, error) { +// 1) store=false 2) 非 compact 保持 stream=true;compact 强制 stream=false +func normalizeOpenAIPassthroughOAuthBody(body []byte, compact bool) ([]byte, bool, error) { if len(body) == 0 { return body, false, nil } @@ -3752,22 +4316,40 @@ func normalizeOpenAIPassthroughOAuthBody(body []byte) ([]byte, bool, error) { normalized := body changed := false - if store := gjson.GetBytes(normalized, "store"); !store.Exists() || store.Type != gjson.False { - next, err := sjson.SetBytes(normalized, "store", false) - if err != nil { - return body, false, fmt.Errorf("normalize passthrough body store=false: %w", err) + if compact { + if store := gjson.GetBytes(normalized, "store"); store.Exists() { + next, err := sjson.DeleteBytes(normalized, "store") + if err != nil { + return body, false, fmt.Errorf("normalize passthrough body delete store: %w", err) + } + normalized = next + changed = true } - normalized = next - changed = true - } - - if stream := gjson.GetBytes(normalized, "stream"); !stream.Exists() || stream.Type != gjson.True { - next, err := sjson.SetBytes(normalized, "stream", true) - if err != nil { - return body, false, fmt.Errorf("normalize passthrough body stream=true: %w", err) + if stream := gjson.GetBytes(normalized, "stream"); stream.Exists() { + next, err := sjson.DeleteBytes(normalized, "stream") + if err != nil { + return body, false, fmt.Errorf("normalize passthrough body delete stream: %w", err) + } + normalized = next + changed = true + } + } else { + if store := gjson.GetBytes(normalized, "store"); !store.Exists() || store.Type != gjson.False { + next, err := sjson.SetBytes(normalized, "store", false) + if err != nil { + return body, false, fmt.Errorf("normalize passthrough body store=false: %w", err) + } + normalized = next + changed = true + } + if stream := gjson.GetBytes(normalized, "stream"); !stream.Exists() || stream.Type != gjson.True { + next, err := sjson.SetBytes(normalized, "stream", true) + if err != nil { + return body, false, fmt.Errorf("normalize passthrough body stream=true: %w", err) + } + normalized = next + changed = true } - normalized = next - changed = true } return normalized, changed, nil @@ -3812,6 +4394,40 @@ func extractOpenAIReasoningEffortFromBody(body []byte, requestedModel string) *s return &value } +func extractOpenAIServiceTier(reqBody map[string]any) *string { + if reqBody == nil { + return nil + } + raw, ok := reqBody["service_tier"].(string) + if !ok { + return nil + } + return normalizeOpenAIServiceTier(raw) +} + +func extractOpenAIServiceTierFromBody(body []byte) *string { + if len(body) == 0 { + return nil + } + return normalizeOpenAIServiceTier(gjson.GetBytes(body, "service_tier").String()) +} + +func normalizeOpenAIServiceTier(raw string) *string { + value := strings.ToLower(strings.TrimSpace(raw)) + if value == "" { + return nil + } + if value == "fast" { + value = "priority" + } + switch value { + case "priority", "flex": + return &value + default: + return nil + } +} + func getOpenAIRequestBodyMap(c *gin.Context, body []byte) (map[string]any, error) { if c != nil { if cached, ok := c.Get(OpenAIParsedRequestBodyKey); ok { diff --git a/backend/internal/service/openai_gateway_service_codex_cli_only_test.go b/backend/internal/service/openai_gateway_service_codex_cli_only_test.go index d7c95ada..fe58e92f 100644 --- a/backend/internal/service/openai_gateway_service_codex_cli_only_test.go +++ b/backend/internal/service/openai_gateway_service_codex_cli_only_test.go @@ -211,6 +211,26 @@ func TestLogOpenAIInstructionsRequiredDebug_NonTargetErrorSkipped(t *testing.T) require.False(t, logSink.ContainsMessage("OpenAI 上游返回 Instructions are required,已记录请求详情用于排查")) } +func TestIsOpenAITransientProcessingError(t *testing.T) { + require.True(t, isOpenAITransientProcessingError( + http.StatusBadRequest, + "An error occurred while processing your request.", + nil, + )) + + require.True(t, isOpenAITransientProcessingError( + http.StatusBadRequest, + "", + []byte(`{"error":{"message":"An error occurred while processing your request. You can retry your request, or contact us through our help center at help.openai.com if the error persists. Please include the request ID req_123 in your message."}}`), + )) + + require.False(t, isOpenAITransientProcessingError( + http.StatusBadRequest, + "Missing required parameter: 'instructions'", + []byte(`{"error":{"message":"Missing required parameter: 'instructions'"}}`), + )) +} + func TestOpenAIGatewayService_Forward_LogsInstructionsRequiredDetails(t *testing.T) { gin.SetMode(gin.TestMode) logSink, restore := captureStructuredLog(t) @@ -264,3 +284,51 @@ func TestOpenAIGatewayService_Forward_LogsInstructionsRequiredDetails(t *testing require.True(t, logSink.ContainsField("request_body_size")) require.False(t, logSink.ContainsField("request_body_preview")) } + +func TestOpenAIGatewayService_Forward_TransientProcessingErrorTriggersFailover(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0") + c.Request.Header.Set("Content-Type", "application/json") + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusBadRequest, + Header: http.Header{ + "Content-Type": []string{"application/json"}, + "x-request-id": []string{"rid-processing-400"}, + }, + Body: io.NopCloser(strings.NewReader(`{"error":{"message":"An error occurred while processing your request. You can retry your request, or contact us through our help center at help.openai.com if the error persists. Please include the request ID req_123 in your message.","type":"invalid_request_error"}}`)), + }, + } + svc := &OpenAIGatewayService{ + cfg: &config.Config{ + Gateway: config.GatewayConfig{ForceCodexCLI: false}, + }, + httpUpstream: upstream, + } + account := &Account{ + ID: 1001, + Name: "codex max套餐", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{"api_key": "sk-test"}, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + body := []byte(`{"model":"gpt-5.1-codex","stream":false,"input":[{"type":"text","text":"hello"}]}`) + + _, err := svc.Forward(context.Background(), c, account, body) + require.Error(t, err) + + var failoverErr *UpstreamFailoverError + require.ErrorAs(t, err, &failoverErr) + require.Equal(t, http.StatusBadRequest, failoverErr.StatusCode) + require.Contains(t, string(failoverErr.ResponseBody), "An error occurred while processing your request") + require.False(t, c.Writer.Written(), "service 层应返回 failover 错误给上层换号,而不是直接向客户端写响应") +} diff --git a/backend/internal/service/openai_gateway_service_test.go b/backend/internal/service/openai_gateway_service_test.go index 4f5f7f3c..43e2f39d 100644 --- a/backend/internal/service/openai_gateway_service_test.go +++ b/backend/internal/service/openai_gateway_service_test.go @@ -14,6 +14,7 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/cespare/xxhash/v2" "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" @@ -28,6 +29,22 @@ type stubOpenAIAccountRepo struct { accounts []Account } +type snapshotUpdateAccountRepo struct { + stubOpenAIAccountRepo + updateExtraCalls chan map[string]any +} + +func (r *snapshotUpdateAccountRepo) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error { + if r.updateExtraCalls != nil { + copied := make(map[string]any, len(updates)) + for k, v := range updates { + copied[k] = v + } + r.updateExtraCalls <- copied + } + return nil +} + func (r stubOpenAIAccountRepo) GetByID(ctx context.Context, id int64) (*Account, error) { for i := range r.accounts { if r.accounts[i].ID == id { @@ -1248,8 +1265,157 @@ func TestOpenAIValidateUpstreamBaseURLEnabledEnforcesAllowlist(t *testing.T) { } } +func TestOpenAIUpdateCodexUsageSnapshotFromHeaders(t *testing.T) { + repo := &snapshotUpdateAccountRepo{updateExtraCalls: make(chan map[string]any, 1)} + svc := &OpenAIGatewayService{accountRepo: repo} + headers := http.Header{} + headers.Set("x-codex-primary-used-percent", "12") + headers.Set("x-codex-secondary-used-percent", "34") + headers.Set("x-codex-primary-window-minutes", "300") + headers.Set("x-codex-secondary-window-minutes", "10080") + headers.Set("x-codex-primary-reset-after-seconds", "600") + headers.Set("x-codex-secondary-reset-after-seconds", "86400") + + svc.UpdateCodexUsageSnapshotFromHeaders(context.Background(), 123, headers) + + select { + case updates := <-repo.updateExtraCalls: + require.Equal(t, 12.0, updates["codex_5h_used_percent"]) + require.Equal(t, 34.0, updates["codex_7d_used_percent"]) + require.Equal(t, 600, updates["codex_5h_reset_after_seconds"]) + require.Equal(t, 86400, updates["codex_7d_reset_after_seconds"]) + case <-time.After(2 * time.Second): + t.Fatal("expected UpdateExtra to be called") + } +} + +func TestOpenAIResponsesRequestPathSuffix(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + tests := []struct { + name string + path string + want string + }{ + {name: "exact v1 responses", path: "/v1/responses", want: ""}, + {name: "compact v1 responses", path: "/v1/responses/compact", want: "/compact"}, + {name: "compact alias responses", path: "/responses/compact/", want: "/compact"}, + {name: "nested suffix", path: "/openai/v1/responses/compact/detail", want: "/compact/detail"}, + {name: "unrelated path", path: "/v1/chat/completions", want: ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c.Request = httptest.NewRequest(http.MethodPost, tt.path, nil) + require.Equal(t, tt.want, openAIResponsesRequestPathSuffix(c)) + }) + } +} + +func TestOpenAIBuildUpstreamRequestOpenAIPassthroughPreservesCompactPath(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses/compact", bytes.NewReader([]byte(`{"model":"gpt-5"}`))) + + svc := &OpenAIGatewayService{} + account := &Account{Type: AccountTypeOAuth} + + req, err := svc.buildUpstreamRequestOpenAIPassthrough(c.Request.Context(), c, account, []byte(`{"model":"gpt-5"}`), "token") + require.NoError(t, err) + require.Equal(t, chatgptCodexURL+"/compact", req.URL.String()) + require.Equal(t, "application/json", req.Header.Get("Accept")) + require.Equal(t, codexCLIVersion, req.Header.Get("Version")) + require.NotEmpty(t, req.Header.Get("Session_Id")) +} + +func TestOpenAIBuildUpstreamRequestCompactForcesJSONAcceptForOAuth(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses/compact", bytes.NewReader([]byte(`{"model":"gpt-5"}`))) + + svc := &OpenAIGatewayService{} + account := &Account{ + Type: AccountTypeOAuth, + Credentials: map[string]any{"chatgpt_account_id": "chatgpt-acc"}, + } + + req, err := svc.buildUpstreamRequest(c.Request.Context(), c, account, []byte(`{"model":"gpt-5"}`), "token", false, "", true) + require.NoError(t, err) + require.Equal(t, chatgptCodexURL+"/compact", req.URL.String()) + require.Equal(t, "application/json", req.Header.Get("Accept")) + require.Equal(t, codexCLIVersion, req.Header.Get("Version")) + require.NotEmpty(t, req.Header.Get("Session_Id")) +} + +func TestOpenAIBuildUpstreamRequestPreservesCompactPathForAPIKeyBaseURL(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/responses/compact", bytes.NewReader([]byte(`{"model":"gpt-5"}`))) + + svc := &OpenAIGatewayService{cfg: &config.Config{ + Security: config.SecurityConfig{ + URLAllowlist: config.URLAllowlistConfig{Enabled: false}, + }, + }} + account := &Account{ + Type: AccountTypeAPIKey, + Platform: PlatformOpenAI, + Credentials: map[string]any{"base_url": "https://example.com/v1"}, + } + + req, err := svc.buildUpstreamRequest(c.Request.Context(), c, account, []byte(`{"model":"gpt-5"}`), "token", false, "", false) + require.NoError(t, err) + require.Equal(t, "https://example.com/v1/responses/compact", req.URL.String()) +} + +func TestOpenAIBuildUpstreamRequestOAuthOfficialClientOriginatorCompatibility(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + userAgent string + originator string + wantOriginator string + }{ + {name: "desktop originator preserved", originator: "Codex Desktop", wantOriginator: "Codex Desktop"}, + {name: "vscode originator preserved", originator: "codex_vscode", wantOriginator: "codex_vscode"}, + {name: "official ua fallback to codex_cli_rs", userAgent: "Codex Desktop/1.2.3", wantOriginator: "codex_cli_rs"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader([]byte(`{"model":"gpt-5"}`))) + if tt.userAgent != "" { + c.Request.Header.Set("User-Agent", tt.userAgent) + } + if tt.originator != "" { + c.Request.Header.Set("originator", tt.originator) + } + + svc := &OpenAIGatewayService{} + account := &Account{ + Type: AccountTypeOAuth, + Credentials: map[string]any{"chatgpt_account_id": "chatgpt-acc"}, + } + + isCodexCLI := openai.IsCodexOfficialClientByHeaders(c.GetHeader("User-Agent"), c.GetHeader("originator")) + req, err := svc.buildUpstreamRequest(c.Request.Context(), c, account, []byte(`{"model":"gpt-5"}`), "token", false, "", isCodexCLI) + require.NoError(t, err) + require.Equal(t, tt.wantOriginator, req.Header.Get("originator")) + }) + } +} + // ==================== P1-08 修复:model 替换性能优化测试 ==================== +// ==================== P1-08 修复:model 替换性能优化测试 ============= func TestReplaceModelInSSELine(t *testing.T) { svc := &OpenAIGatewayService{} @@ -1576,3 +1742,27 @@ func TestHandleOAuthSSEToJSON_NoFinalResponseKeepsSSEBody(t *testing.T) { require.Contains(t, rec.Header().Get("Content-Type"), "text/event-stream") require.Contains(t, rec.Body.String(), `data: {"type":"response.in_progress"`) } + +func TestHandleOAuthSSEToJSON_ResponseFailedReturnsProtocolError(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + svc := &OpenAIGatewayService{cfg: &config.Config{}} + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + } + body := []byte(strings.Join([]string{ + `data: {"type":"response.failed","error":{"message":"upstream rejected request"}}`, + `data: [DONE]`, + }, "\n")) + + usage, err := svc.handleOAuthSSEToJSON(resp, c, body, "gpt-4o", "gpt-4o") + require.Nil(t, usage) + require.Error(t, err) + require.Equal(t, http.StatusBadGateway, rec.Code) + require.Contains(t, rec.Body.String(), "upstream rejected request") + require.Contains(t, rec.Header().Get("Content-Type"), "application/json") +} diff --git a/backend/internal/service/openai_oauth_passthrough_test.go b/backend/internal/service/openai_oauth_passthrough_test.go index 0840d3b1..6fbd2469 100644 --- a/backend/internal/service/openai_oauth_passthrough_test.go +++ b/backend/internal/service/openai_oauth_passthrough_test.go @@ -236,6 +236,60 @@ func TestOpenAIGatewayService_OAuthPassthrough_StreamKeepsToolNameAndBodyNormali require.NotContains(t, body, "\"name\":\"edit\"") } +func TestOpenAIGatewayService_OAuthPassthrough_CompactUsesJSONAndKeepsNonStreaming(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses/compact", bytes.NewReader(nil)) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0") + c.Request.Header.Set("Content-Type", "application/json") + + originalBody := []byte(`{"model":"gpt-5.1-codex","stream":true,"store":true,"instructions":"local-test-instructions","input":[{"type":"text","text":"compact me"}]}`) + + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid-compact"}}, + Body: io.NopCloser(strings.NewReader(`{"id":"cmp_123","usage":{"input_tokens":11,"output_tokens":22}}`)), + } + upstream := &httpUpstreamRecorder{resp: resp} + + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}}, + httpUpstream: upstream, + } + + account := &Account{ + ID: 123, + Name: "acc", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"}, + Extra: map[string]any{"openai_passthrough": true}, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + + result, err := svc.Forward(context.Background(), c, account, originalBody) + require.NoError(t, err) + require.NotNil(t, result) + require.False(t, result.Stream) + + require.False(t, gjson.GetBytes(upstream.lastBody, "store").Exists()) + require.False(t, gjson.GetBytes(upstream.lastBody, "stream").Exists()) + require.Equal(t, "gpt-5.1-codex", gjson.GetBytes(upstream.lastBody, "model").String()) + require.Equal(t, "compact me", gjson.GetBytes(upstream.lastBody, "input.0.text").String()) + require.Equal(t, "local-test-instructions", strings.TrimSpace(gjson.GetBytes(upstream.lastBody, "instructions").String())) + require.Equal(t, "application/json", upstream.lastReq.Header.Get("Accept")) + require.Equal(t, codexCLIVersion, upstream.lastReq.Header.Get("Version")) + require.NotEmpty(t, upstream.lastReq.Header.Get("Session_Id")) + require.Equal(t, "chatgpt.com", upstream.lastReq.Host) + require.Equal(t, "chatgpt-acc", upstream.lastReq.Header.Get("chatgpt-account-id")) + require.Contains(t, rec.Body.String(), `"id":"cmp_123"`) +} + func TestOpenAIGatewayService_OAuthPassthrough_CodexMissingInstructionsRejectedBeforeUpstream(t *testing.T) { gin.SetMode(gin.TestMode) logSink, restore := captureStructuredLog(t) @@ -617,7 +671,7 @@ func TestOpenAIGatewayService_OAuthPassthrough_StreamingSetsFirstTokenMs(t *test c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)) c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0") - originalBody := []byte(`{"model":"gpt-5.2","stream":true,"input":[{"type":"text","text":"hi"}]}`) + originalBody := []byte(`{"model":"gpt-5.2","stream":true,"service_tier":"fast","input":[{"type":"text","text":"hi"}]}`) upstreamSSE := strings.Join([]string{ `data: {"type":"response.output_text.delta","delta":"h"}`, @@ -657,6 +711,8 @@ func TestOpenAIGatewayService_OAuthPassthrough_StreamingSetsFirstTokenMs(t *test require.GreaterOrEqual(t, time.Since(start), time.Duration(0)) require.NotNil(t, result.FirstTokenMs) require.GreaterOrEqual(t, *result.FirstTokenMs, 0) + require.NotNil(t, result.ServiceTier) + require.Equal(t, "priority", *result.ServiceTier) } func TestOpenAIGatewayService_OAuthPassthrough_StreamClientDisconnectStillCollectsUsage(t *testing.T) { @@ -723,7 +779,7 @@ func TestOpenAIGatewayService_APIKeyPassthrough_PreservesBodyAndUsesResponsesEnd c.Request.Header.Set("User-Agent", "curl/8.0") c.Request.Header.Set("X-Test", "keep") - originalBody := []byte(`{"model":"gpt-5.2","stream":false,"max_output_tokens":128,"input":[{"type":"text","text":"hi"}]}`) + originalBody := []byte(`{"model":"gpt-5.2","stream":false,"service_tier":"flex","max_output_tokens":128,"input":[{"type":"text","text":"hi"}]}`) resp := &http.Response{ StatusCode: http.StatusOK, Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid"}}, @@ -749,8 +805,11 @@ func TestOpenAIGatewayService_APIKeyPassthrough_PreservesBodyAndUsesResponsesEnd RateMultiplier: f64p(1), } - _, err := svc.Forward(context.Background(), c, account, originalBody) + result, err := svc.Forward(context.Background(), c, account, originalBody) require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.ServiceTier) + require.Equal(t, "flex", *result.ServiceTier) require.NotNil(t, upstream.lastReq) require.Equal(t, originalBody, upstream.lastBody) require.Equal(t, "https://api.openai.com/v1/responses", upstream.lastReq.URL.String()) diff --git a/backend/internal/service/openai_oauth_service.go b/backend/internal/service/openai_oauth_service.go index 72f4bbb0..bd82e107 100644 --- a/backend/internal/service/openai_oauth_service.go +++ b/backend/internal/service/openai_oauth_service.go @@ -130,6 +130,7 @@ type OpenAITokenInfo struct { ChatGPTAccountID string `json:"chatgpt_account_id,omitempty"` ChatGPTUserID string `json:"chatgpt_user_id,omitempty"` OrganizationID string `json:"organization_id,omitempty"` + PlanType string `json:"plan_type,omitempty"` } // ExchangeCode exchanges authorization code for tokens @@ -202,6 +203,7 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch tokenInfo.ChatGPTAccountID = userInfo.ChatGPTAccountID tokenInfo.ChatGPTUserID = userInfo.ChatGPTUserID tokenInfo.OrganizationID = userInfo.OrganizationID + tokenInfo.PlanType = userInfo.PlanType } return tokenInfo, nil @@ -246,6 +248,7 @@ func (s *OpenAIOAuthService) RefreshTokenWithClientID(ctx context.Context, refre tokenInfo.ChatGPTAccountID = userInfo.ChatGPTAccountID tokenInfo.ChatGPTUserID = userInfo.ChatGPTUserID tokenInfo.OrganizationID = userInfo.OrganizationID + tokenInfo.PlanType = userInfo.PlanType } return tokenInfo, nil @@ -510,6 +513,9 @@ func (s *OpenAIOAuthService) BuildAccountCredentials(tokenInfo *OpenAITokenInfo) if tokenInfo.OrganizationID != "" { creds["organization_id"] = tokenInfo.OrganizationID } + if tokenInfo.PlanType != "" { + creds["plan_type"] = tokenInfo.PlanType + } if strings.TrimSpace(tokenInfo.ClientID) != "" { creds["client_id"] = strings.TrimSpace(tokenInfo.ClientID) } diff --git a/backend/internal/service/openai_sticky_compat.go b/backend/internal/service/openai_sticky_compat.go index e897debc..fe0f1309 100644 --- a/backend/internal/service/openai_sticky_compat.go +++ b/backend/internal/service/openai_sticky_compat.go @@ -29,6 +29,13 @@ func openAIStickyCompatStats() (legacyReadFallbackTotal, legacyReadFallbackHit, openAIStickyLegacyDualWriteTotal.Load() } +// DeriveSessionHashFromSeed computes the current-format sticky-session hash +// from an arbitrary seed string. +func DeriveSessionHashFromSeed(seed string) string { + currentHash, _ := deriveOpenAISessionHashes(seed) + return currentHash +} + func deriveOpenAISessionHashes(sessionID string) (currentHash string, legacyHash string) { normalized := strings.TrimSpace(sessionID) if normalized == "" { diff --git a/backend/internal/service/openai_ws_account_sticky_test.go b/backend/internal/service/openai_ws_account_sticky_test.go index 3fe08179..9a8803d3 100644 --- a/backend/internal/service/openai_ws_account_sticky_test.go +++ b/backend/internal/service/openai_ws_account_sticky_test.go @@ -48,6 +48,43 @@ func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_Hit(t *testing.T } } +func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_RateLimitedMiss(t *testing.T) { + ctx := context.Background() + groupID := int64(23) + rateLimitedUntil := time.Now().Add(30 * time.Minute) + account := Account{ + ID: 12, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + RateLimitResetAt: &rateLimitedUntil, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + }, + } + cache := &stubGatewayCache{} + store := NewOpenAIWSStateStore(cache) + cfg := newOpenAIWSV2TestConfig() + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}, + cache: cache, + cfg: cfg, + concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + openaiWSStateStore: store, + } + + require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_rl", account.ID, time.Hour)) + + selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_rl", "gpt-5.1", nil) + require.NoError(t, err) + require.Nil(t, selection, "限额中的账号不应继续命中 previous_response_id 粘连") + boundAccountID, getErr := store.GetResponseAccount(ctx, groupID, "resp_prev_rl") + require.NoError(t, getErr) + require.Zero(t, boundAccountID) +} + func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_Excluded(t *testing.T) { ctx := context.Background() groupID := int64(23) diff --git a/backend/internal/service/openai_ws_client.go b/backend/internal/service/openai_ws_client.go index 9f3c47b7..80b75530 100644 --- a/backend/internal/service/openai_ws_client.go +++ b/backend/internal/service/openai_ws_client.go @@ -11,6 +11,7 @@ import ( "sync/atomic" "time" + openaiwsv2 "github.com/Wei-Shaw/sub2api/internal/service/openai_ws_v2" coderws "github.com/coder/websocket" "github.com/coder/websocket/wsjson" ) @@ -234,6 +235,8 @@ type coderOpenAIWSClientConn struct { conn *coderws.Conn } +var _ openaiwsv2.FrameConn = (*coderOpenAIWSClientConn)(nil) + func (c *coderOpenAIWSClientConn) WriteJSON(ctx context.Context, value any) error { if c == nil || c.conn == nil { return errOpenAIWSConnClosed @@ -264,6 +267,30 @@ func (c *coderOpenAIWSClientConn) ReadMessage(ctx context.Context) ([]byte, erro } } +func (c *coderOpenAIWSClientConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) { + if c == nil || c.conn == nil { + return coderws.MessageText, nil, errOpenAIWSConnClosed + } + if ctx == nil { + ctx = context.Background() + } + msgType, payload, err := c.conn.Read(ctx) + if err != nil { + return coderws.MessageText, nil, err + } + return msgType, payload, nil +} + +func (c *coderOpenAIWSClientConn) WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error { + if c == nil || c.conn == nil { + return errOpenAIWSConnClosed + } + if ctx == nil { + ctx = context.Background() + } + return c.conn.Write(ctx, msgType, payload) +} + func (c *coderOpenAIWSClientConn) Ping(ctx context.Context) error { if c == nil || c.conn == nil { return errOpenAIWSConnClosed diff --git a/backend/internal/service/openai_ws_forwarder.go b/backend/internal/service/openai_ws_forwarder.go index 74ba472f..52bb8590 100644 --- a/backend/internal/service/openai_ws_forwarder.go +++ b/backend/internal/service/openai_ws_forwarder.go @@ -46,9 +46,10 @@ const ( openAIWSPayloadSizeEstimateMaxBytes = 64 * 1024 openAIWSPayloadSizeEstimateMaxItems = 16 - openAIWSEventFlushBatchSizeDefault = 4 - openAIWSEventFlushIntervalDefault = 25 * time.Millisecond - openAIWSPayloadLogSampleDefault = 0.2 + openAIWSEventFlushBatchSizeDefault = 4 + openAIWSEventFlushIntervalDefault = 25 * time.Millisecond + openAIWSPayloadLogSampleDefault = 0.2 + openAIWSPassthroughIdleTimeoutDefault = time.Hour openAIWSStoreDisabledConnModeStrict = "strict" openAIWSStoreDisabledConnModeAdaptive = "adaptive" @@ -863,7 +864,8 @@ func isOpenAIWSClientDisconnectError(err error) bool { strings.Contains(message, "unexpected eof") || strings.Contains(message, "use of closed network connection") || strings.Contains(message, "connection reset by peer") || - strings.Contains(message, "broken pipe") + strings.Contains(message, "broken pipe") || + strings.Contains(message, "an established connection was aborted") } func classifyOpenAIWSReadFallbackReason(err error) string { @@ -904,6 +906,18 @@ func (s *OpenAIGatewayService) getOpenAIWSConnPool() *openAIWSConnPool { return s.openaiWSPool } +func (s *OpenAIGatewayService) getOpenAIWSPassthroughDialer() openAIWSClientDialer { + if s == nil { + return nil + } + s.openaiWSPassthroughDialerOnce.Do(func() { + if s.openaiWSPassthroughDialer == nil { + s.openaiWSPassthroughDialer = newDefaultOpenAIWSClientDialer() + } + }) + return s.openaiWSPassthroughDialer +} + func (s *OpenAIGatewayService) SnapshotOpenAIWSPoolMetrics() OpenAIWSPoolMetricsSnapshot { pool := s.getOpenAIWSConnPool() if pool == nil { @@ -967,6 +981,13 @@ func (s *OpenAIGatewayService) openAIWSReadTimeout() time.Duration { return 15 * time.Minute } +func (s *OpenAIGatewayService) openAIWSPassthroughIdleTimeout() time.Duration { + if timeout := s.openAIWSReadTimeout(); timeout > 0 { + return timeout + } + return openAIWSPassthroughIdleTimeoutDefault +} + func (s *OpenAIGatewayService) openAIWSWriteTimeout() time.Duration { if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.WriteTimeoutSeconds > 0 { return time.Duration(s.cfg.Gateway.OpenAIWS.WriteTimeoutSeconds) * time.Second @@ -1120,11 +1141,7 @@ func (s *OpenAIGatewayService) buildOpenAIWSHeaders( if chatgptAccountID := account.GetChatGPTAccountID(); chatgptAccountID != "" { headers.Set("chatgpt-account-id", chatgptAccountID) } - if isCodexCLI { - headers.Set("originator", "codex_cli_rs") - } else { - headers.Set("originator", "opencode") - } + headers.Set("originator", resolveOpenAIUpstreamOriginator(c, isCodexCLI)) } betaValue := openAIWSBetaV2Value @@ -1836,6 +1853,10 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2( wsPath, account.ProxyID != nil && account.Proxy != nil, ) + var dialErr *openAIWSDialError + if errors.As(err, &dialErr) && dialErr != nil && dialErr.StatusCode == http.StatusTooManyRequests { + s.persistOpenAIWSRateLimitSignal(ctx, account, dialErr.ResponseHeaders, nil, "rate_limit_exceeded", "rate_limit_error", strings.TrimSpace(err.Error())) + } return nil, wrapOpenAIWSFallback(classifyOpenAIWSAcquireError(err), err) } defer lease.Release() @@ -2119,6 +2140,7 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2( if eventType == "error" { errCodeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(message) + s.persistOpenAIWSRateLimitSignal(ctx, account, lease.HandshakeHeaders(), message, errCodeRaw, errTypeRaw, errMsgRaw) errMsg := strings.TrimSpace(errMsgRaw) if errMsg == "" { errMsg = "Upstream websocket error" @@ -2285,9 +2307,11 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2( RequestID: responseID, Usage: *usage, Model: originalModel, + ServiceTier: extractOpenAIServiceTier(reqBody), ReasoningEffort: extractOpenAIReasoningEffort(reqBody, originalModel), Stream: reqStream, OpenAIWSMode: true, + ResponseHeaders: lease.HandshakeHeaders(), Duration: time.Since(startTime), FirstTokenMs: firstTokenMs, }, nil @@ -2322,7 +2346,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( wsDecision := s.getOpenAIWSProtocolResolver().Resolve(account) modeRouterV2Enabled := s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.ModeRouterV2Enabled - ingressMode := OpenAIWSIngressModeShared + ingressMode := OpenAIWSIngressModeCtxPool if modeRouterV2Enabled { ingressMode = account.ResolveOpenAIResponsesWebSocketV2Mode(s.cfg.Gateway.OpenAIWS.IngressModeDefault) if ingressMode == OpenAIWSIngressModeOff { @@ -2332,6 +2356,30 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( nil, ) } + switch ingressMode { + case OpenAIWSIngressModePassthrough: + if wsDecision.Transport != OpenAIUpstreamTransportResponsesWebsocketV2 { + return fmt.Errorf("websocket ingress requires ws_v2 transport, got=%s", wsDecision.Transport) + } + return s.proxyResponsesWebSocketV2Passthrough( + ctx, + c, + clientConn, + account, + token, + firstClientMessage, + hooks, + wsDecision, + ) + case OpenAIWSIngressModeCtxPool, OpenAIWSIngressModeShared, OpenAIWSIngressModeDedicated: + // continue + default: + return NewOpenAIWSClientCloseError( + coderws.StatusPolicyViolation, + "websocket mode only supports ctx_pool/passthrough", + nil, + ) + } } if wsDecision.Transport != OpenAIUpstreamTransportResponsesWebsocketV2 { return fmt.Errorf("websocket ingress requires ws_v2 transport, got=%s", wsDecision.Transport) @@ -2497,7 +2545,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( } } - isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI) + isCodexCLI := openai.IsCodexOfficialClientByHeaders(c.GetHeader("User-Agent"), c.GetHeader("originator")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI) wsHeaders, _ := s.buildOpenAIWSHeaders(c, account, token, wsDecision, isCodexCLI, turnState, strings.TrimSpace(c.GetHeader(openAIWSTurnMetadataHeader)), firstPayload.promptCacheKey) baseAcquireReq := openAIWSAcquireRequest{ Account: account, @@ -2597,6 +2645,10 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( wsPath, account.ProxyID != nil && account.Proxy != nil, ) + var dialErr *openAIWSDialError + if errors.As(acquireErr, &dialErr) && dialErr != nil && dialErr.StatusCode == http.StatusTooManyRequests { + s.persistOpenAIWSRateLimitSignal(ctx, account, dialErr.ResponseHeaders, nil, "rate_limit_exceeded", "rate_limit_error", strings.TrimSpace(acquireErr.Error())) + } if errors.Is(acquireErr, errOpenAIWSPreferredConnUnavailable) { return nil, NewOpenAIWSClientCloseError( coderws.StatusPolicyViolation, @@ -2735,6 +2787,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( } if eventType == "error" { errCodeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(upstreamMessage) + s.persistOpenAIWSRateLimitSignal(ctx, account, lease.HandshakeHeaders(), upstreamMessage, errCodeRaw, errTypeRaw, errMsgRaw) fallbackReason, _ := classifyOpenAIWSErrorEventFromRaw(errCodeRaw, errTypeRaw, errMsgRaw) errCode, errType, errMessage := summarizeOpenAIWSErrorEventFieldsFromRaw(errCodeRaw, errTypeRaw, errMsgRaw) recoverablePrevNotFound := fallbackReason == openAIWSIngressStagePreviousResponseNotFound && @@ -2871,9 +2924,11 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( RequestID: responseID, Usage: usage, Model: originalModel, + ServiceTier: extractOpenAIServiceTierFromBody(payload), ReasoningEffort: extractOpenAIReasoningEffortFromBody(payload, originalModel), Stream: reqStream, OpenAIWSMode: true, + ResponseHeaders: lease.HandshakeHeaders(), Duration: time.Since(turnStart), FirstTokenMs: firstTokenMs, }, nil @@ -3561,6 +3616,7 @@ func (s *OpenAIGatewayService) performOpenAIWSGeneratePrewarm( if eventType == "error" { errCodeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(message) + s.persistOpenAIWSRateLimitSignal(ctx, account, lease.HandshakeHeaders(), message, errCodeRaw, errTypeRaw, errMsgRaw) errMsg := strings.TrimSpace(errMsgRaw) if errMsg == "" { errMsg = "OpenAI websocket prewarm error" @@ -3755,7 +3811,7 @@ func (s *OpenAIGatewayService) SelectAccountByPreviousResponseID( if s.getOpenAIWSProtocolResolver().Resolve(account).Transport != OpenAIUpstreamTransportResponsesWebsocketV2 { return nil, nil } - if shouldClearStickySession(account, requestedModel) || !account.IsOpenAI() { + if shouldClearStickySession(account, requestedModel) || !account.IsOpenAI() || !account.IsSchedulable() { _ = store.DeleteResponseAccount(ctx, derefGroupID(groupID), responseID) return nil, nil } @@ -3824,6 +3880,36 @@ func classifyOpenAIWSAcquireError(err error) string { return "acquire_conn" } +func isOpenAIWSRateLimitError(codeRaw, errTypeRaw, msgRaw string) bool { + code := strings.ToLower(strings.TrimSpace(codeRaw)) + errType := strings.ToLower(strings.TrimSpace(errTypeRaw)) + msg := strings.ToLower(strings.TrimSpace(msgRaw)) + + if strings.Contains(errType, "rate_limit") || strings.Contains(errType, "usage_limit") { + return true + } + if strings.Contains(code, "rate_limit") || strings.Contains(code, "usage_limit") || strings.Contains(code, "insufficient_quota") { + return true + } + if strings.Contains(msg, "usage limit") && strings.Contains(msg, "reached") { + return true + } + if strings.Contains(msg, "rate limit") && (strings.Contains(msg, "reached") || strings.Contains(msg, "exceeded")) { + return true + } + return false +} + +func (s *OpenAIGatewayService) persistOpenAIWSRateLimitSignal(ctx context.Context, account *Account, headers http.Header, responseBody []byte, codeRaw, errTypeRaw, msgRaw string) { + if s == nil || s.rateLimitService == nil || account == nil || account.Platform != PlatformOpenAI { + return + } + if !isOpenAIWSRateLimitError(codeRaw, errTypeRaw, msgRaw) { + return + } + s.rateLimitService.HandleUpstreamError(ctx, account, http.StatusTooManyRequests, headers, responseBody) +} + func classifyOpenAIWSErrorEventFromRaw(codeRaw, errTypeRaw, msgRaw string) (string, bool) { code := strings.ToLower(strings.TrimSpace(codeRaw)) errType := strings.ToLower(strings.TrimSpace(errTypeRaw)) @@ -3839,6 +3925,9 @@ func classifyOpenAIWSErrorEventFromRaw(codeRaw, errTypeRaw, msgRaw string) (stri case "previous_response_not_found": return "previous_response_not_found", true } + if isOpenAIWSRateLimitError(codeRaw, errTypeRaw, msgRaw) { + return "upstream_rate_limited", false + } if strings.Contains(msg, "upgrade required") || strings.Contains(msg, "status 426") { return "upgrade_required", true } @@ -3884,9 +3973,7 @@ func openAIWSErrorHTTPStatusFromRaw(codeRaw, errTypeRaw string) int { case strings.Contains(errType, "permission"), strings.Contains(code, "forbidden"): return http.StatusForbidden - case strings.Contains(errType, "rate_limit"), - strings.Contains(code, "rate_limit"), - strings.Contains(code, "insufficient_quota"): + case isOpenAIWSRateLimitError(codeRaw, errTypeRaw, ""): return http.StatusTooManyRequests default: return http.StatusBadGateway diff --git a/backend/internal/service/openai_ws_forwarder_ingress_session_test.go b/backend/internal/service/openai_ws_forwarder_ingress_session_test.go index 5a3c12c3..c527f2eb 100644 --- a/backend/internal/service/openai_ws_forwarder_ingress_session_test.go +++ b/backend/internal/service/openai_ws_forwarder_ingress_session_test.go @@ -149,7 +149,7 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_KeepLeaseAcrossT require.True(t, <-turnWSModeCh, "首轮 turn 应标记为 WS 模式") require.True(t, <-turnWSModeCh, "第二轮 turn 应标记为 WS 模式") - require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done")) + _ = clientConn.Close(coderws.StatusNormalClosure, "done") select { case serverErr := <-serverErrCh: @@ -298,6 +298,142 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_DedicatedModeDoe require.Equal(t, 2, dialer.DialCount(), "dedicated 模式下跨客户端会话不应复用上游连接") } +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PassthroughModeRelaysByCaddyAdapter(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true + cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeCtxPool + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + upstreamConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_passthrough_turn_1","model":"gpt-5.1","usage":{"input_tokens":2,"output_tokens":3}}}`), + }, + } + captureDialer := &openAIWSCaptureDialer{conn: upstreamConn} + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPassthroughDialer: captureDialer, + } + + account := &Account{ + ID: 452, + Name: "openai-ingress-passthrough", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_mode": OpenAIWSIngressModePassthrough, + }, + } + + serverErrCh := make(chan error, 1) + resultCh := make(chan *OpenAIForwardResult, 1) + hooks := &OpenAIWSIngressHooks{ + AfterTurn: func(_ int, result *OpenAIForwardResult, turnErr error) { + if turnErr == nil && result != nil { + resultCh <- result + } + }, + } + + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + serverErrCh <- err + return + } + defer func() { + _ = conn.CloseNow() + }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + serverErrCh <- errors.New("unsupported websocket client message type") + return + } + + serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, hooks) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second) + err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.1","stream":false,"service_tier":"fast"}`)) + cancelWrite() + require.NoError(t, err) + + readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second) + _, event, readErr := clientConn.Read(readCtx) + cancelRead() + require.NoError(t, readErr) + require.Equal(t, "response.completed", gjson.GetBytes(event, "type").String()) + require.Equal(t, "resp_passthrough_turn_1", gjson.GetBytes(event, "response.id").String()) + _ = clientConn.Close(coderws.StatusNormalClosure, "done") + + select { + case serverErr := <-serverErrCh: + require.NoError(t, serverErr) + case <-time.After(5 * time.Second): + t.Fatal("等待 passthrough websocket 结束超时") + } + + select { + case result := <-resultCh: + require.Equal(t, "resp_passthrough_turn_1", result.RequestID) + require.True(t, result.OpenAIWSMode) + require.Equal(t, 2, result.Usage.InputTokens) + require.Equal(t, 3, result.Usage.OutputTokens) + require.NotNil(t, result.ServiceTier) + require.Equal(t, "priority", *result.ServiceTier) + case <-time.After(2 * time.Second): + t.Fatal("未收到 passthrough turn 结果回调") + } + + require.Equal(t, 1, captureDialer.DialCount(), "passthrough 模式应直接建立上游 websocket") + require.Len(t, upstreamConn.writes, 1, "passthrough 模式应透传首条 response.create") +} + func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_ModeOffReturnsPolicyViolation(t *testing.T) { gin.SetMode(gin.TestMode) @@ -2459,7 +2595,7 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_ClientDisconnect require.NoError(t, err) writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second) - err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"custom-original-model","stream":false}`)) + err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"custom-original-model","stream":false,"service_tier":"flex"}`)) cancelWrite() require.NoError(t, err) // 立即关闭客户端,模拟客户端在 relay 期间断连。 @@ -2477,6 +2613,8 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_ClientDisconnect require.Equal(t, "resp_ingress_disconnect", result.RequestID) require.Equal(t, 2, result.Usage.InputTokens) require.Equal(t, 1, result.Usage.OutputTokens) + require.NotNil(t, result.ServiceTier) + require.Equal(t, "flex", *result.ServiceTier) case <-time.After(2 * time.Second): t.Fatal("未收到断连后的 turn 结果回调") } diff --git a/backend/internal/service/openai_ws_forwarder_success_test.go b/backend/internal/service/openai_ws_forwarder_success_test.go index 592801f6..912fade9 100644 --- a/backend/internal/service/openai_ws_forwarder_success_test.go +++ b/backend/internal/service/openai_ws_forwarder_success_test.go @@ -15,6 +15,7 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/config" + coderws "github.com/coder/websocket" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" "github.com/stretchr/testify/require" @@ -457,6 +458,86 @@ func TestOpenAIGatewayService_Forward_WSv2_OAuthStoreFalseByDefault(t *testing.T require.Equal(t, "conv-oauth-1", captureDialer.lastHeaders.Get("conversation_id")) } +func TestOpenAIGatewayService_Forward_WSv2_OAuthOriginatorCompatibility(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + userAgent string + originator string + wantOriginator string + }{ + {name: "desktop originator preserved", originator: "Codex Desktop", wantOriginator: "Codex Desktop"}, + {name: "vscode originator preserved", originator: "codex_vscode", wantOriginator: "codex_vscode"}, + {name: "official ua fallback to codex_cli_rs", userAgent: "Codex Desktop/1.2.3", wantOriginator: "codex_cli_rs"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + if tt.userAgent != "" { + c.Request.Header.Set("User-Agent", tt.userAgent) + } + if tt.originator != "" { + c.Request.Header.Set("originator", tt.originator) + } + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.AllowStoreRecovery = false + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + + captureConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_oauth_originator","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + captureDialer := &openAIWSCaptureDialer{conn: captureConn} + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(captureDialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + account := &Account{ + ID: 129, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token-1", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, tt.wantOriginator, captureDialer.lastHeaders.Get("originator")) + }) + } +} + func TestOpenAIGatewayService_Forward_WSv2_HeaderSessionFallbackFromPromptCacheKey(t *testing.T) { gin.SetMode(gin.TestMode) @@ -1282,6 +1363,18 @@ func (c *openAIWSCaptureConn) ReadMessage(ctx context.Context) ([]byte, error) { return event, nil } +func (c *openAIWSCaptureConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) { + payload, err := c.ReadMessage(ctx) + if err != nil { + return coderws.MessageText, nil, err + } + return coderws.MessageText, payload, nil +} + +func (c *openAIWSCaptureConn) WriteFrame(ctx context.Context, _ coderws.MessageType, payload []byte) error { + return c.WriteJSON(ctx, json.RawMessage(payload)) +} + func (c *openAIWSCaptureConn) Ping(ctx context.Context) error { _ = ctx return nil diff --git a/backend/internal/service/openai_ws_pool.go b/backend/internal/service/openai_ws_pool.go index db6a96a7..5950e028 100644 --- a/backend/internal/service/openai_ws_pool.go +++ b/backend/internal/service/openai_ws_pool.go @@ -126,6 +126,13 @@ func (l *openAIWSConnLease) HandshakeHeader(name string) string { return l.conn.handshakeHeader(name) } +func (l *openAIWSConnLease) HandshakeHeaders() http.Header { + if l == nil || l.conn == nil { + return nil + } + return cloneHeader(l.conn.handshakeHeaders) +} + func (l *openAIWSConnLease) IsPrewarmed() bool { if l == nil || l.conn == nil { return false diff --git a/backend/internal/service/openai_ws_protocol_forward_test.go b/backend/internal/service/openai_ws_protocol_forward_test.go index df4d4871..7295b13d 100644 --- a/backend/internal/service/openai_ws_protocol_forward_test.go +++ b/backend/internal/service/openai_ws_protocol_forward_test.go @@ -391,6 +391,7 @@ func TestNewOpenAIGatewayService_InitializesOpenAIWSResolver(t *testing.T) { nil, nil, nil, + nil, cfg, nil, nil, diff --git a/backend/internal/service/openai_ws_protocol_resolver.go b/backend/internal/service/openai_ws_protocol_resolver.go index 368643be..7266759c 100644 --- a/backend/internal/service/openai_ws_protocol_resolver.go +++ b/backend/internal/service/openai_ws_protocol_resolver.go @@ -69,8 +69,11 @@ func (r *defaultOpenAIWSProtocolResolver) Resolve(account *Account) OpenAIWSProt switch mode { case OpenAIWSIngressModeOff: return openAIWSHTTPDecision("account_mode_off") - case OpenAIWSIngressModeShared, OpenAIWSIngressModeDedicated: + case OpenAIWSIngressModeCtxPool, OpenAIWSIngressModePassthrough: // continue + case OpenAIWSIngressModeShared, OpenAIWSIngressModeDedicated: + // 历史值兼容:按 ctx_pool 处理。 + mode = OpenAIWSIngressModeCtxPool default: return openAIWSHTTPDecision("account_mode_off") } diff --git a/backend/internal/service/openai_ws_protocol_resolver_test.go b/backend/internal/service/openai_ws_protocol_resolver_test.go index 5be76e28..4d5dc5f1 100644 --- a/backend/internal/service/openai_ws_protocol_resolver_test.go +++ b/backend/internal/service/openai_ws_protocol_resolver_test.go @@ -143,21 +143,21 @@ func TestOpenAIWSProtocolResolver_Resolve_ModeRouterV2(t *testing.T) { cfg.Gateway.OpenAIWS.APIKeyEnabled = true cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true - cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeShared + cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeCtxPool account := &Account{ Platform: PlatformOpenAI, Type: AccountTypeOAuth, Concurrency: 1, Extra: map[string]any{ - "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated, + "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeCtxPool, }, } - t.Run("dedicated mode routes to ws v2", func(t *testing.T) { + t.Run("ctx_pool mode routes to ws v2", func(t *testing.T) { decision := NewOpenAIWSProtocolResolver(cfg).Resolve(account) require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport) - require.Equal(t, "ws_v2_mode_dedicated", decision.Reason) + require.Equal(t, "ws_v2_mode_ctx_pool", decision.Reason) }) t.Run("off mode routes to http", func(t *testing.T) { @@ -174,7 +174,7 @@ func TestOpenAIWSProtocolResolver_Resolve_ModeRouterV2(t *testing.T) { require.Equal(t, "account_mode_off", decision.Reason) }) - t.Run("legacy boolean maps to shared in v2 router", func(t *testing.T) { + t.Run("legacy boolean maps to ctx_pool in v2 router", func(t *testing.T) { legacyAccount := &Account{ Platform: PlatformOpenAI, Type: AccountTypeAPIKey, @@ -185,7 +185,21 @@ func TestOpenAIWSProtocolResolver_Resolve_ModeRouterV2(t *testing.T) { } decision := NewOpenAIWSProtocolResolver(cfg).Resolve(legacyAccount) require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport) - require.Equal(t, "ws_v2_mode_shared", decision.Reason) + require.Equal(t, "ws_v2_mode_ctx_pool", decision.Reason) + }) + + t.Run("passthrough mode routes to ws v2", func(t *testing.T) { + passthroughAccount := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Extra: map[string]any{ + "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModePassthrough, + }, + } + decision := NewOpenAIWSProtocolResolver(cfg).Resolve(passthroughAccount) + require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport) + require.Equal(t, "ws_v2_mode_passthrough", decision.Reason) }) t.Run("non-positive concurrency is rejected in v2 router", func(t *testing.T) { @@ -193,7 +207,7 @@ func TestOpenAIWSProtocolResolver_Resolve_ModeRouterV2(t *testing.T) { Platform: PlatformOpenAI, Type: AccountTypeOAuth, Extra: map[string]any{ - "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeShared, + "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeCtxPool, }, } decision := NewOpenAIWSProtocolResolver(cfg).Resolve(invalidConcurrency) diff --git a/backend/internal/service/openai_ws_ratelimit_signal_test.go b/backend/internal/service/openai_ws_ratelimit_signal_test.go new file mode 100644 index 00000000..f5c79923 --- /dev/null +++ b/backend/internal/service/openai_ws_ratelimit_signal_test.go @@ -0,0 +1,511 @@ +package service + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "strconv" + "strings" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + coderws "github.com/coder/websocket" + "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" + "github.com/stretchr/testify/require" +) + +type openAIWSRateLimitSignalRepo struct { + stubOpenAIAccountRepo + rateLimitCalls []time.Time + updateExtra []map[string]any +} + +type openAICodexSnapshotAsyncRepo struct { + stubOpenAIAccountRepo + updateExtraCh chan map[string]any + rateLimitCh chan time.Time +} + +type openAICodexExtraListRepo struct { + stubOpenAIAccountRepo + rateLimitCh chan time.Time +} + +func (r *openAIWSRateLimitSignalRepo) SetRateLimited(_ context.Context, _ int64, resetAt time.Time) error { + r.rateLimitCalls = append(r.rateLimitCalls, resetAt) + return nil +} + +func (r *openAIWSRateLimitSignalRepo) UpdateExtra(_ context.Context, _ int64, updates map[string]any) error { + copied := make(map[string]any, len(updates)) + for k, v := range updates { + copied[k] = v + } + r.updateExtra = append(r.updateExtra, copied) + return nil +} + +func (r *openAICodexSnapshotAsyncRepo) SetRateLimited(_ context.Context, _ int64, resetAt time.Time) error { + if r.rateLimitCh != nil { + r.rateLimitCh <- resetAt + } + return nil +} + +func (r *openAICodexSnapshotAsyncRepo) UpdateExtra(_ context.Context, _ int64, updates map[string]any) error { + if r.updateExtraCh != nil { + copied := make(map[string]any, len(updates)) + for k, v := range updates { + copied[k] = v + } + r.updateExtraCh <- copied + } + return nil +} + +func (r *openAICodexExtraListRepo) SetRateLimited(_ context.Context, _ int64, resetAt time.Time) error { + if r.rateLimitCh != nil { + r.rateLimitCh <- resetAt + } + return nil +} + +func (r *openAICodexExtraListRepo) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) { + _ = platform + _ = accountType + _ = status + _ = search + _ = groupID + return r.accounts, &pagination.PaginationResult{Total: int64(len(r.accounts)), Page: params.Page, PageSize: params.PageSize}, nil +} + +func TestOpenAIGatewayService_Forward_WSv2ErrorEventUsageLimitPersistsRateLimit(t *testing.T) { + gin.SetMode(gin.TestMode) + + resetAt := time.Now().Add(2 * time.Hour).Unix() + upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("upgrade websocket failed: %v", err) + return + } + defer func() { _ = conn.Close() }() + + var req map[string]any + if err := conn.ReadJSON(&req); err != nil { + t.Errorf("read ws request failed: %v", err) + return + } + _ = conn.WriteJSON(map[string]any{ + "type": "error", + "error": map[string]any{ + "code": "rate_limit_exceeded", + "type": "usage_limit_reached", + "message": "The usage limit has been reached", + "resets_at": resetAt, + }, + }) + })) + defer wsServer.Close() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "unit-test-agent/1.0") + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_should_not_run"}`)), + }, + } + + cfg := newOpenAIWSV2TestConfig() + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + + account := Account{ + ID: 501, + Name: "openai-ws-rate-limit-event", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": wsServer.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + repo := &openAIWSRateLimitSignalRepo{stubOpenAIAccountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}} + rateSvc := &RateLimitService{accountRepo: repo} + svc := &OpenAIGatewayService{ + accountRepo: repo, + rateLimitService: rateSvc, + httpUpstream: upstream, + cache: &stubGatewayCache{}, + cfg: cfg, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + } + + body := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, &account, body) + require.Error(t, err) + require.Nil(t, result) + require.Equal(t, http.StatusTooManyRequests, rec.Code) + require.Nil(t, upstream.lastReq, "WS 限流 error event 不应回退到同账号 HTTP") + require.Len(t, repo.rateLimitCalls, 1) + require.WithinDuration(t, time.Unix(resetAt, 0), repo.rateLimitCalls[0], 2*time.Second) +} + +func TestOpenAIGatewayService_Forward_WSv2Handshake429PersistsRateLimit(t *testing.T) { + gin.SetMode(gin.TestMode) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("x-codex-primary-used-percent", "100") + w.Header().Set("x-codex-primary-reset-after-seconds", "7200") + w.Header().Set("x-codex-primary-window-minutes", "10080") + w.Header().Set("x-codex-secondary-used-percent", "3") + w.Header().Set("x-codex-secondary-reset-after-seconds", "1800") + w.Header().Set("x-codex-secondary-window-minutes", "300") + w.WriteHeader(http.StatusTooManyRequests) + _, _ = w.Write([]byte(`{"error":{"type":"rate_limit_exceeded","message":"rate limited"}}`)) + })) + defer server.Close() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "unit-test-agent/1.0") + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_should_not_run"}`)), + }, + } + + cfg := newOpenAIWSV2TestConfig() + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + + account := Account{ + ID: 502, + Name: "openai-ws-rate-limit-handshake", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": server.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + repo := &openAIWSRateLimitSignalRepo{stubOpenAIAccountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}} + rateSvc := &RateLimitService{accountRepo: repo} + svc := &OpenAIGatewayService{ + accountRepo: repo, + rateLimitService: rateSvc, + httpUpstream: upstream, + cache: &stubGatewayCache{}, + cfg: cfg, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + } + + body := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, &account, body) + require.Error(t, err) + require.Nil(t, result) + require.Equal(t, http.StatusTooManyRequests, rec.Code) + require.Nil(t, upstream.lastReq, "WS 握手 429 不应回退到同账号 HTTP") + require.Len(t, repo.rateLimitCalls, 1) + require.NotEmpty(t, repo.updateExtra, "握手 429 的 x-codex 头应立即落库") + require.Contains(t, repo.updateExtra[0], "codex_usage_updated_at") +} + +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_ErrorEventUsageLimitPersistsRateLimit(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := newOpenAIWSV2TestConfig() + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + resetAt := time.Now().Add(90 * time.Minute).Unix() + captureConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"error","error":{"code":"rate_limit_exceeded","type":"usage_limit_reached","message":"The usage limit has been reached","resets_at":PLACEHOLDER}}`), + }, + } + captureConn.events[0] = []byte(strings.ReplaceAll(string(captureConn.events[0]), "PLACEHOLDER", strconv.FormatInt(resetAt, 10))) + captureDialer := &openAIWSCaptureDialer{conn: captureConn} + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(captureDialer) + + account := Account{ + ID: 503, + Name: "openai-ingress-rate-limit", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + repo := &openAIWSRateLimitSignalRepo{stubOpenAIAccountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}} + rateSvc := &RateLimitService{accountRepo: repo} + svc := &OpenAIGatewayService{ + accountRepo: repo, + rateLimitService: rateSvc, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + cfg: cfg, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + + serverErrCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{CompressionMode: coderws.CompressionContextTakeover}) + if err != nil { + serverErrCh <- err + return + } + defer func() { _ = conn.CloseNow() }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + serverErrCh <- io.ErrUnexpectedEOF + return + } + + serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, &account, "sk-test", firstMessage, nil) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { _ = clientConn.CloseNow() }() + + writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second) + err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.1","stream":false}`)) + cancelWrite() + require.NoError(t, err) + + select { + case serverErr := <-serverErrCh: + require.Error(t, serverErr) + require.Len(t, repo.rateLimitCalls, 1) + require.WithinDuration(t, time.Unix(resetAt, 0), repo.rateLimitCalls[0], 2*time.Second) + case <-time.After(5 * time.Second): + t.Fatal("等待 ingress websocket 结束超时") + } +} + +func TestOpenAIGatewayService_UpdateCodexUsageSnapshot_ExhaustedSnapshotSetsRateLimit(t *testing.T) { + repo := &openAICodexSnapshotAsyncRepo{ + updateExtraCh: make(chan map[string]any, 1), + rateLimitCh: make(chan time.Time, 1), + } + svc := &OpenAIGatewayService{accountRepo: repo} + snapshot := &OpenAICodexUsageSnapshot{ + PrimaryUsedPercent: ptrFloat64WS(100), + PrimaryResetAfterSeconds: ptrIntWS(3600), + PrimaryWindowMinutes: ptrIntWS(10080), + SecondaryUsedPercent: ptrFloat64WS(12), + SecondaryResetAfterSeconds: ptrIntWS(1200), + SecondaryWindowMinutes: ptrIntWS(300), + } + before := time.Now() + svc.updateCodexUsageSnapshot(context.Background(), 601, snapshot) + + select { + case updates := <-repo.updateExtraCh: + require.Equal(t, 100.0, updates["codex_7d_used_percent"]) + case <-time.After(2 * time.Second): + t.Fatal("等待 codex 快照落库超时") + } + + select { + case resetAt := <-repo.rateLimitCh: + require.WithinDuration(t, before.Add(time.Hour), resetAt, 2*time.Second) + case <-time.After(2 * time.Second): + t.Fatal("等待 codex 100% 自动切换限流超时") + } +} + +func TestOpenAIGatewayService_UpdateCodexUsageSnapshot_NonExhaustedSnapshotDoesNotSetRateLimit(t *testing.T) { + repo := &openAICodexSnapshotAsyncRepo{ + updateExtraCh: make(chan map[string]any, 1), + rateLimitCh: make(chan time.Time, 1), + } + svc := &OpenAIGatewayService{accountRepo: repo} + snapshot := &OpenAICodexUsageSnapshot{ + PrimaryUsedPercent: ptrFloat64WS(94), + PrimaryResetAfterSeconds: ptrIntWS(3600), + PrimaryWindowMinutes: ptrIntWS(10080), + SecondaryUsedPercent: ptrFloat64WS(22), + SecondaryResetAfterSeconds: ptrIntWS(1200), + SecondaryWindowMinutes: ptrIntWS(300), + } + svc.updateCodexUsageSnapshot(context.Background(), 602, snapshot) + + select { + case <-repo.updateExtraCh: + case <-time.After(2 * time.Second): + t.Fatal("等待 codex 快照落库超时") + } + + select { + case resetAt := <-repo.rateLimitCh: + t.Fatalf("unexpected rate limit reset at: %v", resetAt) + case <-time.After(200 * time.Millisecond): + } +} + +func TestOpenAIGatewayService_UpdateCodexUsageSnapshot_ThrottlesExtraWrites(t *testing.T) { + repo := &openAICodexSnapshotAsyncRepo{ + updateExtraCh: make(chan map[string]any, 2), + rateLimitCh: make(chan time.Time, 2), + } + svc := &OpenAIGatewayService{ + accountRepo: repo, + codexSnapshotThrottle: newAccountWriteThrottle(time.Hour), + } + snapshot := &OpenAICodexUsageSnapshot{ + PrimaryUsedPercent: ptrFloat64WS(94), + PrimaryResetAfterSeconds: ptrIntWS(3600), + PrimaryWindowMinutes: ptrIntWS(10080), + SecondaryUsedPercent: ptrFloat64WS(22), + SecondaryResetAfterSeconds: ptrIntWS(1200), + SecondaryWindowMinutes: ptrIntWS(300), + } + + svc.updateCodexUsageSnapshot(context.Background(), 777, snapshot) + svc.updateCodexUsageSnapshot(context.Background(), 777, snapshot) + + select { + case <-repo.updateExtraCh: + case <-time.After(2 * time.Second): + t.Fatal("等待第一次 codex 快照落库超时") + } + + select { + case updates := <-repo.updateExtraCh: + t.Fatalf("unexpected second codex snapshot write: %v", updates) + case <-time.After(200 * time.Millisecond): + } +} + +func ptrFloat64WS(v float64) *float64 { return &v } +func ptrIntWS(v int) *int { return &v } + +func TestOpenAIGatewayService_GetSchedulableAccount_ExhaustedCodexExtraSetsRateLimit(t *testing.T) { + resetAt := time.Now().Add(6 * 24 * time.Hour) + account := Account{ + ID: 701, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Extra: map[string]any{ + "codex_7d_used_percent": 100.0, + "codex_7d_reset_at": resetAt.UTC().Format(time.RFC3339), + }, + } + repo := &openAICodexExtraListRepo{stubOpenAIAccountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}, rateLimitCh: make(chan time.Time, 1)} + svc := &OpenAIGatewayService{accountRepo: repo} + + fresh, err := svc.getSchedulableAccount(context.Background(), account.ID) + require.NoError(t, err) + require.NotNil(t, fresh) + require.NotNil(t, fresh.RateLimitResetAt) + require.WithinDuration(t, resetAt.UTC(), *fresh.RateLimitResetAt, time.Second) + select { + case persisted := <-repo.rateLimitCh: + require.WithinDuration(t, resetAt.UTC(), persisted, time.Second) + case <-time.After(2 * time.Second): + t.Fatal("等待旧快照补写限流状态超时") + } +} + +func TestAdminService_ListAccounts_ExhaustedCodexExtraReturnsRateLimitedAccount(t *testing.T) { + resetAt := time.Now().Add(4 * 24 * time.Hour) + repo := &openAICodexExtraListRepo{ + stubOpenAIAccountRepo: stubOpenAIAccountRepo{accounts: []Account{{ + ID: 702, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Extra: map[string]any{ + "codex_7d_used_percent": 100.0, + "codex_7d_reset_at": resetAt.UTC().Format(time.RFC3339), + }, + }}}, + rateLimitCh: make(chan time.Time, 1), + } + svc := &adminServiceImpl{accountRepo: repo} + + accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformOpenAI, AccountTypeOAuth, "", "", 0) + require.NoError(t, err) + require.Equal(t, int64(1), total) + require.Len(t, accounts, 1) + require.NotNil(t, accounts[0].RateLimitResetAt) + require.WithinDuration(t, resetAt.UTC(), *accounts[0].RateLimitResetAt, time.Second) + select { + case persisted := <-repo.rateLimitCh: + require.WithinDuration(t, resetAt.UTC(), persisted, time.Second) + case <-time.After(2 * time.Second): + t.Fatal("等待列表补写限流状态超时") + } +} + +func TestOpenAIWSErrorHTTPStatusFromRaw_UsageLimitReachedIs429(t *testing.T) { + require.Equal(t, http.StatusTooManyRequests, openAIWSErrorHTTPStatusFromRaw("", "usage_limit_reached")) + require.Equal(t, http.StatusTooManyRequests, openAIWSErrorHTTPStatusFromRaw("rate_limit_exceeded", "")) +} diff --git a/backend/internal/service/openai_ws_v2/caddy_adapter.go b/backend/internal/service/openai_ws_v2/caddy_adapter.go new file mode 100644 index 00000000..1fecc231 --- /dev/null +++ b/backend/internal/service/openai_ws_v2/caddy_adapter.go @@ -0,0 +1,24 @@ +package openai_ws_v2 + +import ( + "context" +) + +// runCaddyStyleRelay 采用 Caddy reverseproxy 的双向隧道思想: +// 连接建立后并发复制两个方向,任一方向退出触发收敛关闭。 +// +// Reference: +// - Project: caddyserver/caddy (Apache-2.0) +// - Commit: f283062d37c50627d53ca682ebae2ce219b35515 +// - Files: +// - modules/caddyhttp/reverseproxy/streaming.go +// - modules/caddyhttp/reverseproxy/reverseproxy.go +func runCaddyStyleRelay( + ctx context.Context, + clientConn FrameConn, + upstreamConn FrameConn, + firstClientMessage []byte, + options RelayOptions, +) (RelayResult, *RelayExit) { + return Relay(ctx, clientConn, upstreamConn, firstClientMessage, options) +} diff --git a/backend/internal/service/openai_ws_v2/entry.go b/backend/internal/service/openai_ws_v2/entry.go new file mode 100644 index 00000000..176298fe --- /dev/null +++ b/backend/internal/service/openai_ws_v2/entry.go @@ -0,0 +1,23 @@ +package openai_ws_v2 + +import "context" + +// EntryInput 是 passthrough v2 数据面的入口参数。 +type EntryInput struct { + Ctx context.Context + ClientConn FrameConn + UpstreamConn FrameConn + FirstClientMessage []byte + Options RelayOptions +} + +// RunEntry 是 openai_ws_v2 包对外的统一入口。 +func RunEntry(input EntryInput) (RelayResult, *RelayExit) { + return runCaddyStyleRelay( + input.Ctx, + input.ClientConn, + input.UpstreamConn, + input.FirstClientMessage, + input.Options, + ) +} diff --git a/backend/internal/service/openai_ws_v2/metrics.go b/backend/internal/service/openai_ws_v2/metrics.go new file mode 100644 index 00000000..3708befd --- /dev/null +++ b/backend/internal/service/openai_ws_v2/metrics.go @@ -0,0 +1,29 @@ +package openai_ws_v2 + +import ( + "sync/atomic" +) + +// MetricsSnapshot 是 OpenAI WS v2 passthrough 路径的轻量运行时指标快照。 +type MetricsSnapshot struct { + SemanticMutationTotal int64 `json:"semantic_mutation_total"` + UsageParseFailureTotal int64 `json:"usage_parse_failure_total"` +} + +var ( + // passthrough 路径默认不会做语义改写,该计数通常应保持为 0(保留用于未来防御性校验)。 + passthroughSemanticMutationTotal atomic.Int64 + passthroughUsageParseFailureTotal atomic.Int64 +) + +func recordUsageParseFailure() { + passthroughUsageParseFailureTotal.Add(1) +} + +// SnapshotMetrics 返回当前 passthrough 指标快照。 +func SnapshotMetrics() MetricsSnapshot { + return MetricsSnapshot{ + SemanticMutationTotal: passthroughSemanticMutationTotal.Load(), + UsageParseFailureTotal: passthroughUsageParseFailureTotal.Load(), + } +} diff --git a/backend/internal/service/openai_ws_v2/passthrough_relay.go b/backend/internal/service/openai_ws_v2/passthrough_relay.go new file mode 100644 index 00000000..af8ee195 --- /dev/null +++ b/backend/internal/service/openai_ws_v2/passthrough_relay.go @@ -0,0 +1,807 @@ +package openai_ws_v2 + +import ( + "context" + "errors" + "io" + "net" + "strconv" + "strings" + "sync/atomic" + "time" + + coderws "github.com/coder/websocket" + "github.com/tidwall/gjson" +) + +type FrameConn interface { + ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) + WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error + Close() error +} + +type Usage struct { + InputTokens int + OutputTokens int + CacheCreationInputTokens int + CacheReadInputTokens int +} + +type RelayResult struct { + RequestModel string + Usage Usage + RequestID string + TerminalEventType string + FirstTokenMs *int + Duration time.Duration + ClientToUpstreamFrames int64 + UpstreamToClientFrames int64 + DroppedDownstreamFrames int64 +} + +type RelayTurnResult struct { + RequestModel string + Usage Usage + RequestID string + TerminalEventType string + Duration time.Duration + FirstTokenMs *int +} + +type RelayExit struct { + Stage string + Err error + WroteDownstream bool +} + +type RelayOptions struct { + WriteTimeout time.Duration + IdleTimeout time.Duration + UpstreamDrainTimeout time.Duration + FirstMessageType coderws.MessageType + OnUsageParseFailure func(eventType string, usageRaw string) + OnTurnComplete func(turn RelayTurnResult) + OnTrace func(event RelayTraceEvent) + Now func() time.Time +} + +type RelayTraceEvent struct { + Stage string + Direction string + MessageType string + PayloadBytes int + Graceful bool + WroteDownstream bool + Error string +} + +type relayState struct { + usage Usage + requestModel string + lastResponseID string + terminalEventType string + firstTokenMs *int + turnTimingByID map[string]*relayTurnTiming +} + +type relayExitSignal struct { + stage string + err error + graceful bool + wroteDownstream bool +} + +type observedUpstreamEvent struct { + terminal bool + eventType string + responseID string + usage Usage + duration time.Duration + firstToken *int +} + +type relayTurnTiming struct { + startAt time.Time + firstTokenMs *int +} + +func Relay( + ctx context.Context, + clientConn FrameConn, + upstreamConn FrameConn, + firstClientMessage []byte, + options RelayOptions, +) (RelayResult, *RelayExit) { + result := RelayResult{RequestModel: strings.TrimSpace(gjson.GetBytes(firstClientMessage, "model").String())} + if clientConn == nil || upstreamConn == nil { + return result, &RelayExit{Stage: "relay_init", Err: errors.New("relay connection is nil")} + } + if ctx == nil { + ctx = context.Background() + } + + nowFn := options.Now + if nowFn == nil { + nowFn = time.Now + } + writeTimeout := options.WriteTimeout + if writeTimeout <= 0 { + writeTimeout = 2 * time.Minute + } + drainTimeout := options.UpstreamDrainTimeout + if drainTimeout <= 0 { + drainTimeout = 1200 * time.Millisecond + } + firstMessageType := options.FirstMessageType + if firstMessageType != coderws.MessageBinary { + firstMessageType = coderws.MessageText + } + startAt := nowFn() + state := &relayState{requestModel: result.RequestModel} + onTrace := options.OnTrace + + relayCtx, relayCancel := context.WithCancel(ctx) + defer relayCancel() + + lastActivity := atomic.Int64{} + lastActivity.Store(nowFn().UnixNano()) + markActivity := func() { + lastActivity.Store(nowFn().UnixNano()) + } + + writeUpstream := func(msgType coderws.MessageType, payload []byte) error { + writeCtx, cancel := context.WithTimeout(relayCtx, writeTimeout) + defer cancel() + return upstreamConn.WriteFrame(writeCtx, msgType, payload) + } + writeClient := func(msgType coderws.MessageType, payload []byte) error { + writeCtx, cancel := context.WithTimeout(relayCtx, writeTimeout) + defer cancel() + return clientConn.WriteFrame(writeCtx, msgType, payload) + } + + clientToUpstreamFrames := &atomic.Int64{} + upstreamToClientFrames := &atomic.Int64{} + droppedDownstreamFrames := &atomic.Int64{} + emitRelayTrace(onTrace, RelayTraceEvent{ + Stage: "relay_start", + PayloadBytes: len(firstClientMessage), + MessageType: relayMessageTypeString(firstMessageType), + }) + + if err := writeUpstream(firstMessageType, firstClientMessage); err != nil { + result.Duration = nowFn().Sub(startAt) + emitRelayTrace(onTrace, RelayTraceEvent{ + Stage: "write_first_message_failed", + Direction: "client_to_upstream", + MessageType: relayMessageTypeString(firstMessageType), + PayloadBytes: len(firstClientMessage), + Error: err.Error(), + }) + return result, &RelayExit{Stage: "write_upstream", Err: err} + } + clientToUpstreamFrames.Add(1) + emitRelayTrace(onTrace, RelayTraceEvent{ + Stage: "write_first_message_ok", + Direction: "client_to_upstream", + MessageType: relayMessageTypeString(firstMessageType), + PayloadBytes: len(firstClientMessage), + }) + markActivity() + + exitCh := make(chan relayExitSignal, 3) + dropDownstreamWrites := atomic.Bool{} + go runClientToUpstream(relayCtx, clientConn, writeUpstream, markActivity, clientToUpstreamFrames, onTrace, exitCh) + go runUpstreamToClient( + relayCtx, + upstreamConn, + writeClient, + startAt, + nowFn, + state, + options.OnUsageParseFailure, + options.OnTurnComplete, + &dropDownstreamWrites, + upstreamToClientFrames, + droppedDownstreamFrames, + markActivity, + onTrace, + exitCh, + ) + go runIdleWatchdog(relayCtx, nowFn, options.IdleTimeout, &lastActivity, onTrace, exitCh) + + firstExit := <-exitCh + emitRelayTrace(onTrace, RelayTraceEvent{ + Stage: "first_exit", + Direction: relayDirectionFromStage(firstExit.stage), + Graceful: firstExit.graceful, + WroteDownstream: firstExit.wroteDownstream, + Error: relayErrorString(firstExit.err), + }) + combinedWroteDownstream := firstExit.wroteDownstream + secondExit := relayExitSignal{graceful: true} + hasSecondExit := false + + // 客户端断开后尽力继续读取上游短窗口,捕获延迟 usage/terminal 事件用于计费。 + if firstExit.stage == "read_client" && firstExit.graceful { + dropDownstreamWrites.Store(true) + secondExit, hasSecondExit = waitRelayExit(exitCh, drainTimeout) + } else { + relayCancel() + _ = upstreamConn.Close() + secondExit, hasSecondExit = waitRelayExit(exitCh, 200*time.Millisecond) + } + if hasSecondExit { + combinedWroteDownstream = combinedWroteDownstream || secondExit.wroteDownstream + emitRelayTrace(onTrace, RelayTraceEvent{ + Stage: "second_exit", + Direction: relayDirectionFromStage(secondExit.stage), + Graceful: secondExit.graceful, + WroteDownstream: secondExit.wroteDownstream, + Error: relayErrorString(secondExit.err), + }) + } + + relayCancel() + _ = upstreamConn.Close() + + enrichResult(&result, state, nowFn().Sub(startAt)) + result.ClientToUpstreamFrames = clientToUpstreamFrames.Load() + result.UpstreamToClientFrames = upstreamToClientFrames.Load() + result.DroppedDownstreamFrames = droppedDownstreamFrames.Load() + if firstExit.stage == "read_client" && firstExit.graceful { + stage := "client_disconnected" + exitErr := firstExit.err + if hasSecondExit && !secondExit.graceful { + stage = secondExit.stage + exitErr = secondExit.err + } + if exitErr == nil { + exitErr = io.EOF + } + emitRelayTrace(onTrace, RelayTraceEvent{ + Stage: "relay_exit", + Direction: relayDirectionFromStage(stage), + Graceful: false, + WroteDownstream: combinedWroteDownstream, + Error: relayErrorString(exitErr), + }) + return result, &RelayExit{ + Stage: stage, + Err: exitErr, + WroteDownstream: combinedWroteDownstream, + } + } + if firstExit.graceful && (!hasSecondExit || secondExit.graceful) { + emitRelayTrace(onTrace, RelayTraceEvent{ + Stage: "relay_complete", + Graceful: true, + WroteDownstream: combinedWroteDownstream, + }) + _ = clientConn.Close() + return result, nil + } + if !firstExit.graceful { + emitRelayTrace(onTrace, RelayTraceEvent{ + Stage: "relay_exit", + Direction: relayDirectionFromStage(firstExit.stage), + Graceful: false, + WroteDownstream: combinedWroteDownstream, + Error: relayErrorString(firstExit.err), + }) + return result, &RelayExit{ + Stage: firstExit.stage, + Err: firstExit.err, + WroteDownstream: combinedWroteDownstream, + } + } + if hasSecondExit && !secondExit.graceful { + emitRelayTrace(onTrace, RelayTraceEvent{ + Stage: "relay_exit", + Direction: relayDirectionFromStage(secondExit.stage), + Graceful: false, + WroteDownstream: combinedWroteDownstream, + Error: relayErrorString(secondExit.err), + }) + return result, &RelayExit{ + Stage: secondExit.stage, + Err: secondExit.err, + WroteDownstream: combinedWroteDownstream, + } + } + emitRelayTrace(onTrace, RelayTraceEvent{ + Stage: "relay_complete", + Graceful: true, + WroteDownstream: combinedWroteDownstream, + }) + _ = clientConn.Close() + return result, nil +} + +func runClientToUpstream( + ctx context.Context, + clientConn FrameConn, + writeUpstream func(msgType coderws.MessageType, payload []byte) error, + markActivity func(), + forwardedFrames *atomic.Int64, + onTrace func(event RelayTraceEvent), + exitCh chan<- relayExitSignal, +) { + for { + msgType, payload, err := clientConn.ReadFrame(ctx) + if err != nil { + emitRelayTrace(onTrace, RelayTraceEvent{ + Stage: "read_client_failed", + Direction: "client_to_upstream", + Error: err.Error(), + Graceful: isDisconnectError(err), + }) + exitCh <- relayExitSignal{stage: "read_client", err: err, graceful: isDisconnectError(err)} + return + } + markActivity() + if err := writeUpstream(msgType, payload); err != nil { + emitRelayTrace(onTrace, RelayTraceEvent{ + Stage: "write_upstream_failed", + Direction: "client_to_upstream", + MessageType: relayMessageTypeString(msgType), + PayloadBytes: len(payload), + Error: err.Error(), + }) + exitCh <- relayExitSignal{stage: "write_upstream", err: err} + return + } + if forwardedFrames != nil { + forwardedFrames.Add(1) + } + markActivity() + } +} + +func runUpstreamToClient( + ctx context.Context, + upstreamConn FrameConn, + writeClient func(msgType coderws.MessageType, payload []byte) error, + startAt time.Time, + nowFn func() time.Time, + state *relayState, + onUsageParseFailure func(eventType string, usageRaw string), + onTurnComplete func(turn RelayTurnResult), + dropDownstreamWrites *atomic.Bool, + forwardedFrames *atomic.Int64, + droppedFrames *atomic.Int64, + markActivity func(), + onTrace func(event RelayTraceEvent), + exitCh chan<- relayExitSignal, +) { + wroteDownstream := false + for { + msgType, payload, err := upstreamConn.ReadFrame(ctx) + if err != nil { + emitRelayTrace(onTrace, RelayTraceEvent{ + Stage: "read_upstream_failed", + Direction: "upstream_to_client", + Error: err.Error(), + Graceful: isDisconnectError(err), + WroteDownstream: wroteDownstream, + }) + exitCh <- relayExitSignal{ + stage: "read_upstream", + err: err, + graceful: isDisconnectError(err), + wroteDownstream: wroteDownstream, + } + return + } + markActivity() + observedEvent := observedUpstreamEvent{} + switch msgType { + case coderws.MessageText: + observedEvent = observeUpstreamMessage(state, payload, startAt, nowFn, onUsageParseFailure) + case coderws.MessageBinary: + // binary frame 直接透传,不进入 JSON 观测路径(避免无效解析开销)。 + } + emitTurnComplete(onTurnComplete, state, observedEvent) + if dropDownstreamWrites != nil && dropDownstreamWrites.Load() { + if droppedFrames != nil { + droppedFrames.Add(1) + } + emitRelayTrace(onTrace, RelayTraceEvent{ + Stage: "drop_downstream_frame", + Direction: "upstream_to_client", + MessageType: relayMessageTypeString(msgType), + PayloadBytes: len(payload), + WroteDownstream: wroteDownstream, + }) + if observedEvent.terminal { + exitCh <- relayExitSignal{ + stage: "drain_terminal", + graceful: true, + wroteDownstream: wroteDownstream, + } + return + } + markActivity() + continue + } + if err := writeClient(msgType, payload); err != nil { + emitRelayTrace(onTrace, RelayTraceEvent{ + Stage: "write_client_failed", + Direction: "upstream_to_client", + MessageType: relayMessageTypeString(msgType), + PayloadBytes: len(payload), + WroteDownstream: wroteDownstream, + Error: err.Error(), + }) + exitCh <- relayExitSignal{stage: "write_client", err: err, wroteDownstream: wroteDownstream} + return + } + wroteDownstream = true + if forwardedFrames != nil { + forwardedFrames.Add(1) + } + markActivity() + } +} + +func runIdleWatchdog( + ctx context.Context, + nowFn func() time.Time, + idleTimeout time.Duration, + lastActivity *atomic.Int64, + onTrace func(event RelayTraceEvent), + exitCh chan<- relayExitSignal, +) { + if idleTimeout <= 0 { + return + } + checkInterval := minDuration(idleTimeout/4, 5*time.Second) + if checkInterval < time.Second { + checkInterval = time.Second + } + ticker := time.NewTicker(checkInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + last := time.Unix(0, lastActivity.Load()) + if nowFn().Sub(last) < idleTimeout { + continue + } + emitRelayTrace(onTrace, RelayTraceEvent{ + Stage: "idle_timeout_triggered", + Direction: "watchdog", + Error: context.DeadlineExceeded.Error(), + }) + exitCh <- relayExitSignal{stage: "idle_timeout", err: context.DeadlineExceeded} + return + } + } +} + +func emitRelayTrace(onTrace func(event RelayTraceEvent), event RelayTraceEvent) { + if onTrace == nil { + return + } + onTrace(event) +} + +func relayMessageTypeString(msgType coderws.MessageType) string { + switch msgType { + case coderws.MessageText: + return "text" + case coderws.MessageBinary: + return "binary" + default: + return "unknown(" + strconv.Itoa(int(msgType)) + ")" + } +} + +func relayDirectionFromStage(stage string) string { + switch stage { + case "read_client", "write_upstream": + return "client_to_upstream" + case "read_upstream", "write_client", "drain_terminal": + return "upstream_to_client" + case "idle_timeout": + return "watchdog" + default: + return "" + } +} + +func relayErrorString(err error) string { + if err == nil { + return "" + } + return err.Error() +} + +func observeUpstreamMessage( + state *relayState, + message []byte, + startAt time.Time, + nowFn func() time.Time, + onUsageParseFailure func(eventType string, usageRaw string), +) observedUpstreamEvent { + if state == nil || len(message) == 0 { + return observedUpstreamEvent{} + } + values := gjson.GetManyBytes(message, "type", "response.id", "response_id", "id") + eventType := strings.TrimSpace(values[0].String()) + if eventType == "" { + return observedUpstreamEvent{} + } + responseID := strings.TrimSpace(values[1].String()) + if responseID == "" { + responseID = strings.TrimSpace(values[2].String()) + } + // 仅 terminal 事件兜底读取顶层 id,避免把 event_id 当成 response_id 关联到 turn。 + if responseID == "" && isTerminalEvent(eventType) { + responseID = strings.TrimSpace(values[3].String()) + } + now := nowFn() + + if state.firstTokenMs == nil && isTokenEvent(eventType) { + ms := int(now.Sub(startAt).Milliseconds()) + if ms >= 0 { + state.firstTokenMs = &ms + } + } + parsedUsage := parseUsageAndAccumulate(state, message, eventType, onUsageParseFailure) + observed := observedUpstreamEvent{ + eventType: eventType, + responseID: responseID, + usage: parsedUsage, + } + if responseID != "" { + turnTiming := openAIWSRelayGetOrInitTurnTiming(state, responseID, now) + if turnTiming != nil && turnTiming.firstTokenMs == nil && isTokenEvent(eventType) { + ms := int(now.Sub(turnTiming.startAt).Milliseconds()) + if ms >= 0 { + turnTiming.firstTokenMs = &ms + } + } + } + if !isTerminalEvent(eventType) { + return observed + } + observed.terminal = true + state.terminalEventType = eventType + if responseID != "" { + state.lastResponseID = responseID + if turnTiming, ok := openAIWSRelayDeleteTurnTiming(state, responseID); ok { + duration := now.Sub(turnTiming.startAt) + if duration < 0 { + duration = 0 + } + observed.duration = duration + observed.firstToken = openAIWSRelayCloneIntPtr(turnTiming.firstTokenMs) + } + } + return observed +} + +func emitTurnComplete( + onTurnComplete func(turn RelayTurnResult), + state *relayState, + observed observedUpstreamEvent, +) { + if onTurnComplete == nil || !observed.terminal { + return + } + responseID := strings.TrimSpace(observed.responseID) + if responseID == "" { + return + } + requestModel := "" + if state != nil { + requestModel = state.requestModel + } + onTurnComplete(RelayTurnResult{ + RequestModel: requestModel, + Usage: observed.usage, + RequestID: responseID, + TerminalEventType: observed.eventType, + Duration: observed.duration, + FirstTokenMs: openAIWSRelayCloneIntPtr(observed.firstToken), + }) +} + +func openAIWSRelayGetOrInitTurnTiming(state *relayState, responseID string, now time.Time) *relayTurnTiming { + if state == nil { + return nil + } + if state.turnTimingByID == nil { + state.turnTimingByID = make(map[string]*relayTurnTiming, 8) + } + timing, ok := state.turnTimingByID[responseID] + if !ok || timing == nil || timing.startAt.IsZero() { + timing = &relayTurnTiming{startAt: now} + state.turnTimingByID[responseID] = timing + return timing + } + return timing +} + +func openAIWSRelayDeleteTurnTiming(state *relayState, responseID string) (relayTurnTiming, bool) { + if state == nil || state.turnTimingByID == nil { + return relayTurnTiming{}, false + } + timing, ok := state.turnTimingByID[responseID] + if !ok || timing == nil { + return relayTurnTiming{}, false + } + delete(state.turnTimingByID, responseID) + return *timing, true +} + +func openAIWSRelayCloneIntPtr(v *int) *int { + if v == nil { + return nil + } + cloned := *v + return &cloned +} + +func parseUsageAndAccumulate( + state *relayState, + message []byte, + eventType string, + onParseFailure func(eventType string, usageRaw string), +) Usage { + if state == nil || len(message) == 0 || !shouldParseUsage(eventType) { + return Usage{} + } + usageResult := gjson.GetBytes(message, "response.usage") + if !usageResult.Exists() { + return Usage{} + } + usageRaw := strings.TrimSpace(usageResult.Raw) + if usageRaw == "" || !strings.HasPrefix(usageRaw, "{") { + recordUsageParseFailure() + if onParseFailure != nil { + onParseFailure(eventType, usageRaw) + } + return Usage{} + } + + inputResult := gjson.GetBytes(message, "response.usage.input_tokens") + outputResult := gjson.GetBytes(message, "response.usage.output_tokens") + cachedResult := gjson.GetBytes(message, "response.usage.input_tokens_details.cached_tokens") + + inputTokens, inputOK := parseUsageIntField(inputResult, true) + outputTokens, outputOK := parseUsageIntField(outputResult, true) + cachedTokens, cachedOK := parseUsageIntField(cachedResult, false) + if !inputOK || !outputOK || !cachedOK { + recordUsageParseFailure() + if onParseFailure != nil { + onParseFailure(eventType, usageRaw) + } + // 解析失败时不做部分字段累加,避免计费 usage 出现“半有效”状态。 + return Usage{} + } + parsedUsage := Usage{ + InputTokens: inputTokens, + OutputTokens: outputTokens, + CacheReadInputTokens: cachedTokens, + } + + state.usage.InputTokens += parsedUsage.InputTokens + state.usage.OutputTokens += parsedUsage.OutputTokens + state.usage.CacheReadInputTokens += parsedUsage.CacheReadInputTokens + return parsedUsage +} + +func parseUsageIntField(value gjson.Result, required bool) (int, bool) { + if !value.Exists() { + return 0, !required + } + if value.Type != gjson.Number { + return 0, false + } + return int(value.Int()), true +} + +func enrichResult(result *RelayResult, state *relayState, duration time.Duration) { + if result == nil { + return + } + result.Duration = duration + if state == nil { + return + } + result.RequestModel = state.requestModel + result.Usage = state.usage + result.RequestID = state.lastResponseID + result.TerminalEventType = state.terminalEventType + result.FirstTokenMs = state.firstTokenMs +} + +func isDisconnectError(err error) bool { + if err == nil { + return false + } + if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled) { + return true + } + switch coderws.CloseStatus(err) { + case coderws.StatusNormalClosure, coderws.StatusGoingAway, coderws.StatusNoStatusRcvd, coderws.StatusAbnormalClosure: + return true + } + message := strings.ToLower(strings.TrimSpace(err.Error())) + if message == "" { + return false + } + return strings.Contains(message, "failed to read frame header: eof") || + strings.Contains(message, "unexpected eof") || + strings.Contains(message, "use of closed network connection") || + strings.Contains(message, "connection reset by peer") || + strings.Contains(message, "broken pipe") +} + +func isTerminalEvent(eventType string) bool { + switch eventType { + case "response.completed", "response.done", "response.failed", "response.incomplete", "response.cancelled", "response.canceled": + return true + default: + return false + } +} + +func shouldParseUsage(eventType string) bool { + switch eventType { + case "response.completed", "response.done", "response.failed": + return true + default: + return false + } +} + +func isTokenEvent(eventType string) bool { + if eventType == "" { + return false + } + switch eventType { + case "response.created", "response.in_progress", "response.output_item.added", "response.output_item.done": + return false + } + if strings.Contains(eventType, ".delta") { + return true + } + if strings.HasPrefix(eventType, "response.output_text") { + return true + } + if strings.HasPrefix(eventType, "response.output") { + return true + } + return eventType == "response.completed" || eventType == "response.done" +} + +func minDuration(a, b time.Duration) time.Duration { + if a <= 0 { + return b + } + if b <= 0 { + return a + } + if a < b { + return a + } + return b +} + +func waitRelayExit(exitCh <-chan relayExitSignal, timeout time.Duration) (relayExitSignal, bool) { + if timeout <= 0 { + timeout = 200 * time.Millisecond + } + select { + case sig := <-exitCh: + return sig, true + case <-time.After(timeout): + return relayExitSignal{}, false + } +} diff --git a/backend/internal/service/openai_ws_v2/passthrough_relay_internal_test.go b/backend/internal/service/openai_ws_v2/passthrough_relay_internal_test.go new file mode 100644 index 00000000..123e10ce --- /dev/null +++ b/backend/internal/service/openai_ws_v2/passthrough_relay_internal_test.go @@ -0,0 +1,432 @@ +package openai_ws_v2 + +import ( + "context" + "errors" + "io" + "net" + "sync/atomic" + "testing" + "time" + + coderws "github.com/coder/websocket" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +func TestRunEntry_DelegatesRelay(t *testing.T) { + t.Parallel() + + clientConn := newPassthroughTestFrameConn(nil, false) + upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{ + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.completed","response":{"id":"resp_entry","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + }, true) + + result, relayExit := RunEntry(EntryInput{ + Ctx: context.Background(), + ClientConn: clientConn, + UpstreamConn: upstreamConn, + FirstClientMessage: []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`), + }) + require.Nil(t, relayExit) + require.Equal(t, "resp_entry", result.RequestID) +} + +func TestRunClientToUpstream_ErrorPaths(t *testing.T) { + t.Parallel() + + t.Run("read client eof", func(t *testing.T) { + t.Parallel() + + exitCh := make(chan relayExitSignal, 1) + runClientToUpstream( + context.Background(), + newPassthroughTestFrameConn(nil, true), + func(_ coderws.MessageType, _ []byte) error { return nil }, + func() {}, + nil, + nil, + exitCh, + ) + sig := <-exitCh + require.Equal(t, "read_client", sig.stage) + require.True(t, sig.graceful) + }) + + t.Run("write upstream failed", func(t *testing.T) { + t.Parallel() + + exitCh := make(chan relayExitSignal, 1) + runClientToUpstream( + context.Background(), + newPassthroughTestFrameConn([]passthroughTestFrame{ + {msgType: coderws.MessageText, payload: []byte(`{"x":1}`)}, + }, true), + func(_ coderws.MessageType, _ []byte) error { return errors.New("boom") }, + func() {}, + nil, + nil, + exitCh, + ) + sig := <-exitCh + require.Equal(t, "write_upstream", sig.stage) + require.False(t, sig.graceful) + }) + + t.Run("forwarded counter and trace callback", func(t *testing.T) { + t.Parallel() + + exitCh := make(chan relayExitSignal, 1) + forwarded := &atomic.Int64{} + traces := make([]RelayTraceEvent, 0, 2) + runClientToUpstream( + context.Background(), + newPassthroughTestFrameConn([]passthroughTestFrame{ + {msgType: coderws.MessageText, payload: []byte(`{"x":1}`)}, + }, true), + func(_ coderws.MessageType, _ []byte) error { return nil }, + func() {}, + forwarded, + func(event RelayTraceEvent) { + traces = append(traces, event) + }, + exitCh, + ) + sig := <-exitCh + require.Equal(t, "read_client", sig.stage) + require.Equal(t, int64(1), forwarded.Load()) + require.NotEmpty(t, traces) + }) +} + +func TestRunUpstreamToClient_ErrorAndDropPaths(t *testing.T) { + t.Parallel() + + t.Run("read upstream eof", func(t *testing.T) { + t.Parallel() + + exitCh := make(chan relayExitSignal, 1) + drop := &atomic.Bool{} + drop.Store(false) + runUpstreamToClient( + context.Background(), + newPassthroughTestFrameConn(nil, true), + func(_ coderws.MessageType, _ []byte) error { return nil }, + time.Now(), + time.Now, + &relayState{}, + nil, + nil, + drop, + nil, + nil, + func() {}, + nil, + exitCh, + ) + sig := <-exitCh + require.Equal(t, "read_upstream", sig.stage) + require.True(t, sig.graceful) + }) + + t.Run("write client failed", func(t *testing.T) { + t.Parallel() + + exitCh := make(chan relayExitSignal, 1) + drop := &atomic.Bool{} + drop.Store(false) + runUpstreamToClient( + context.Background(), + newPassthroughTestFrameConn([]passthroughTestFrame{ + {msgType: coderws.MessageText, payload: []byte(`{"type":"response.output_text.delta","delta":"x"}`)}, + }, true), + func(_ coderws.MessageType, _ []byte) error { return errors.New("write failed") }, + time.Now(), + time.Now, + &relayState{}, + nil, + nil, + drop, + nil, + nil, + func() {}, + nil, + exitCh, + ) + sig := <-exitCh + require.Equal(t, "write_client", sig.stage) + }) + + t.Run("drop downstream and stop on terminal", func(t *testing.T) { + t.Parallel() + + exitCh := make(chan relayExitSignal, 1) + drop := &atomic.Bool{} + drop.Store(true) + dropped := &atomic.Int64{} + runUpstreamToClient( + context.Background(), + newPassthroughTestFrameConn([]passthroughTestFrame{ + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.completed","response":{"id":"resp_drop","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + }, true), + func(_ coderws.MessageType, _ []byte) error { return nil }, + time.Now(), + time.Now, + &relayState{}, + nil, + nil, + drop, + nil, + dropped, + func() {}, + nil, + exitCh, + ) + sig := <-exitCh + require.Equal(t, "drain_terminal", sig.stage) + require.True(t, sig.graceful) + require.Equal(t, int64(1), dropped.Load()) + }) +} + +func TestRunIdleWatchdog_NoTimeoutWhenDisabled(t *testing.T) { + t.Parallel() + + exitCh := make(chan relayExitSignal, 1) + lastActivity := &atomic.Int64{} + lastActivity.Store(time.Now().UnixNano()) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go runIdleWatchdog(ctx, time.Now, 0, lastActivity, nil, exitCh) + select { + case <-exitCh: + t.Fatal("unexpected idle timeout signal") + case <-time.After(200 * time.Millisecond): + } +} + +func TestHelperFunctionsCoverage(t *testing.T) { + t.Parallel() + + require.Equal(t, "text", relayMessageTypeString(coderws.MessageText)) + require.Equal(t, "binary", relayMessageTypeString(coderws.MessageBinary)) + require.Contains(t, relayMessageTypeString(coderws.MessageType(99)), "unknown(") + + require.Equal(t, "", relayErrorString(nil)) + require.Equal(t, "x", relayErrorString(errors.New("x"))) + + require.True(t, isDisconnectError(io.EOF)) + require.True(t, isDisconnectError(net.ErrClosed)) + require.True(t, isDisconnectError(context.Canceled)) + require.True(t, isDisconnectError(coderws.CloseError{Code: coderws.StatusGoingAway})) + require.True(t, isDisconnectError(errors.New("broken pipe"))) + require.False(t, isDisconnectError(errors.New("unrelated"))) + + require.True(t, isTokenEvent("response.output_text.delta")) + require.True(t, isTokenEvent("response.output_audio.delta")) + require.True(t, isTokenEvent("response.completed")) + require.False(t, isTokenEvent("")) + require.False(t, isTokenEvent("response.created")) + + require.Equal(t, 2*time.Second, minDuration(2*time.Second, 5*time.Second)) + require.Equal(t, 2*time.Second, minDuration(5*time.Second, 2*time.Second)) + require.Equal(t, 5*time.Second, minDuration(0, 5*time.Second)) + require.Equal(t, 2*time.Second, minDuration(2*time.Second, 0)) + + ch := make(chan relayExitSignal, 1) + ch <- relayExitSignal{stage: "ok"} + sig, ok := waitRelayExit(ch, 10*time.Millisecond) + require.True(t, ok) + require.Equal(t, "ok", sig.stage) + ch <- relayExitSignal{stage: "ok2"} + sig, ok = waitRelayExit(ch, 0) + require.True(t, ok) + require.Equal(t, "ok2", sig.stage) + _, ok = waitRelayExit(ch, 10*time.Millisecond) + require.False(t, ok) + + n, ok := parseUsageIntField(gjson.Get(`{"n":3}`, "n"), true) + require.True(t, ok) + require.Equal(t, 3, n) + _, ok = parseUsageIntField(gjson.Get(`{"n":"x"}`, "n"), true) + require.False(t, ok) + n, ok = parseUsageIntField(gjson.Result{}, false) + require.True(t, ok) + require.Equal(t, 0, n) + _, ok = parseUsageIntField(gjson.Result{}, true) + require.False(t, ok) +} + +func TestParseUsageAndEnrichCoverage(t *testing.T) { + t.Parallel() + + state := &relayState{} + parseUsageAndAccumulate(state, []byte(`{"type":"response.completed","response":{"usage":{"input_tokens":"bad"}}}`), "response.completed", nil) + require.Equal(t, 0, state.usage.InputTokens) + + parseUsageAndAccumulate( + state, + []byte(`{"type":"response.completed","response":{"usage":{"input_tokens":9,"output_tokens":"bad","input_tokens_details":{"cached_tokens":2}}}}`), + "response.completed", + nil, + ) + require.Equal(t, 0, state.usage.InputTokens, "部分字段解析失败时不应累加 usage") + require.Equal(t, 0, state.usage.OutputTokens) + require.Equal(t, 0, state.usage.CacheReadInputTokens) + + parseUsageAndAccumulate( + state, + []byte(`{"type":"response.completed","response":{"usage":{"input_tokens_details":{"cached_tokens":2}}}}`), + "response.completed", + nil, + ) + require.Equal(t, 0, state.usage.InputTokens, "必填 usage 字段缺失时不应累加 usage") + require.Equal(t, 0, state.usage.OutputTokens) + require.Equal(t, 0, state.usage.CacheReadInputTokens) + + parseUsageAndAccumulate(state, []byte(`{"type":"response.completed","response":{"usage":{"input_tokens":2,"output_tokens":1,"input_tokens_details":{"cached_tokens":1}}}}`), "response.completed", nil) + require.Equal(t, 2, state.usage.InputTokens) + require.Equal(t, 1, state.usage.OutputTokens) + require.Equal(t, 1, state.usage.CacheReadInputTokens) + + result := &RelayResult{} + enrichResult(result, state, 5*time.Millisecond) + require.Equal(t, state.usage.InputTokens, result.Usage.InputTokens) + require.Equal(t, 5*time.Millisecond, result.Duration) + parseUsageAndAccumulate(state, []byte(`{"type":"response.in_progress","response":{"usage":{"input_tokens":9}}}`), "response.in_progress", nil) + require.Equal(t, 2, state.usage.InputTokens) + enrichResult(nil, state, 0) +} + +func TestEmitTurnCompleteCoverage(t *testing.T) { + t.Parallel() + + // 非 terminal 事件不应触发。 + called := 0 + emitTurnComplete(func(turn RelayTurnResult) { + called++ + }, &relayState{requestModel: "gpt-5"}, observedUpstreamEvent{ + terminal: false, + eventType: "response.output_text.delta", + responseID: "resp_ignored", + usage: Usage{InputTokens: 1}, + }) + require.Equal(t, 0, called) + + // 缺少 response_id 时不应触发。 + emitTurnComplete(func(turn RelayTurnResult) { + called++ + }, &relayState{requestModel: "gpt-5"}, observedUpstreamEvent{ + terminal: true, + eventType: "response.completed", + }) + require.Equal(t, 0, called) + + // terminal 且 response_id 存在,应该触发;state=nil 时 model 为空串。 + var got RelayTurnResult + emitTurnComplete(func(turn RelayTurnResult) { + called++ + got = turn + }, nil, observedUpstreamEvent{ + terminal: true, + eventType: "response.completed", + responseID: "resp_emit", + usage: Usage{InputTokens: 2, OutputTokens: 3}, + }) + require.Equal(t, 1, called) + require.Equal(t, "resp_emit", got.RequestID) + require.Equal(t, "response.completed", got.TerminalEventType) + require.Equal(t, 2, got.Usage.InputTokens) + require.Equal(t, 3, got.Usage.OutputTokens) + require.Equal(t, "", got.RequestModel) +} + +func TestIsDisconnectErrorCoverage_CloseStatusesAndMessageBranches(t *testing.T) { + t.Parallel() + + require.True(t, isDisconnectError(coderws.CloseError{Code: coderws.StatusNormalClosure})) + require.True(t, isDisconnectError(coderws.CloseError{Code: coderws.StatusNoStatusRcvd})) + require.True(t, isDisconnectError(coderws.CloseError{Code: coderws.StatusAbnormalClosure})) + require.True(t, isDisconnectError(errors.New("connection reset by peer"))) + require.False(t, isDisconnectError(errors.New(" "))) +} + +func TestIsTokenEventCoverageBranches(t *testing.T) { + t.Parallel() + + require.False(t, isTokenEvent("response.in_progress")) + require.False(t, isTokenEvent("response.output_item.added")) + require.True(t, isTokenEvent("response.output_audio.delta")) + require.True(t, isTokenEvent("response.output")) + require.True(t, isTokenEvent("response.done")) +} + +func TestRelayTurnTimingHelpersCoverage(t *testing.T) { + t.Parallel() + + now := time.Unix(100, 0) + // nil state + require.Nil(t, openAIWSRelayGetOrInitTurnTiming(nil, "resp_nil", now)) + _, ok := openAIWSRelayDeleteTurnTiming(nil, "resp_nil") + require.False(t, ok) + + state := &relayState{} + timing := openAIWSRelayGetOrInitTurnTiming(state, "resp_a", now) + require.NotNil(t, timing) + require.Equal(t, now, timing.startAt) + + // 再次获取返回同一条 timing + timing2 := openAIWSRelayGetOrInitTurnTiming(state, "resp_a", now.Add(5*time.Second)) + require.NotNil(t, timing2) + require.Equal(t, now, timing2.startAt) + + // 删除存在键 + deleted, ok := openAIWSRelayDeleteTurnTiming(state, "resp_a") + require.True(t, ok) + require.Equal(t, now, deleted.startAt) + + // 删除不存在键 + _, ok = openAIWSRelayDeleteTurnTiming(state, "resp_a") + require.False(t, ok) +} + +func TestObserveUpstreamMessage_ResponseIDFallbackPolicy(t *testing.T) { + t.Parallel() + + state := &relayState{requestModel: "gpt-5"} + startAt := time.Unix(0, 0) + now := startAt + nowFn := func() time.Time { + now = now.Add(5 * time.Millisecond) + return now + } + + // 非 terminal:仅有顶层 id,不应把 event id 当成 response_id。 + observed := observeUpstreamMessage( + state, + []byte(`{"type":"response.output_text.delta","id":"evt_123","delta":"hi"}`), + startAt, + nowFn, + nil, + ) + require.False(t, observed.terminal) + require.Equal(t, "", observed.responseID) + + // terminal:允许兜底用顶层 id(用于兼容少数字段变体)。 + observed = observeUpstreamMessage( + state, + []byte(`{"type":"response.completed","id":"resp_fallback","response":{"usage":{"input_tokens":1,"output_tokens":1}}}`), + startAt, + nowFn, + nil, + ) + require.True(t, observed.terminal) + require.Equal(t, "resp_fallback", observed.responseID) +} diff --git a/backend/internal/service/openai_ws_v2/passthrough_relay_test.go b/backend/internal/service/openai_ws_v2/passthrough_relay_test.go new file mode 100644 index 00000000..ff9b7311 --- /dev/null +++ b/backend/internal/service/openai_ws_v2/passthrough_relay_test.go @@ -0,0 +1,752 @@ +package openai_ws_v2 + +import ( + "context" + "errors" + "io" + "sync" + "sync/atomic" + "testing" + "time" + + coderws "github.com/coder/websocket" + "github.com/stretchr/testify/require" +) + +type passthroughTestFrame struct { + msgType coderws.MessageType + payload []byte +} + +type passthroughTestFrameConn struct { + mu sync.Mutex + writes []passthroughTestFrame + readCh chan passthroughTestFrame + once sync.Once +} + +type delayedReadFrameConn struct { + base FrameConn + firstDelay time.Duration + once sync.Once +} + +type closeSpyFrameConn struct { + closeCalls atomic.Int32 +} + +func newPassthroughTestFrameConn(frames []passthroughTestFrame, autoClose bool) *passthroughTestFrameConn { + c := &passthroughTestFrameConn{ + readCh: make(chan passthroughTestFrame, len(frames)+1), + } + for _, frame := range frames { + copied := passthroughTestFrame{msgType: frame.msgType, payload: append([]byte(nil), frame.payload...)} + c.readCh <- copied + } + if autoClose { + close(c.readCh) + } + return c +} + +func (c *passthroughTestFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) { + if ctx == nil { + ctx = context.Background() + } + select { + case <-ctx.Done(): + return coderws.MessageText, nil, ctx.Err() + case frame, ok := <-c.readCh: + if !ok { + return coderws.MessageText, nil, io.EOF + } + return frame.msgType, append([]byte(nil), frame.payload...), nil + } +} + +func (c *passthroughTestFrameConn) WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error { + if ctx == nil { + ctx = context.Background() + } + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + c.mu.Lock() + defer c.mu.Unlock() + c.writes = append(c.writes, passthroughTestFrame{msgType: msgType, payload: append([]byte(nil), payload...)}) + return nil +} + +func (c *passthroughTestFrameConn) Close() error { + c.once.Do(func() { + defer func() { _ = recover() }() + close(c.readCh) + }) + return nil +} + +func (c *passthroughTestFrameConn) Writes() []passthroughTestFrame { + c.mu.Lock() + defer c.mu.Unlock() + out := make([]passthroughTestFrame, len(c.writes)) + copy(out, c.writes) + return out +} + +func (c *delayedReadFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) { + if c == nil || c.base == nil { + return coderws.MessageText, nil, io.EOF + } + c.once.Do(func() { + if c.firstDelay > 0 { + timer := time.NewTimer(c.firstDelay) + defer timer.Stop() + select { + case <-ctx.Done(): + case <-timer.C: + } + } + }) + return c.base.ReadFrame(ctx) +} + +func (c *delayedReadFrameConn) WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error { + if c == nil || c.base == nil { + return io.EOF + } + return c.base.WriteFrame(ctx, msgType, payload) +} + +func (c *delayedReadFrameConn) Close() error { + if c == nil || c.base == nil { + return nil + } + return c.base.Close() +} + +func (c *closeSpyFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) { + if ctx == nil { + ctx = context.Background() + } + <-ctx.Done() + return coderws.MessageText, nil, ctx.Err() +} + +func (c *closeSpyFrameConn) WriteFrame(ctx context.Context, _ coderws.MessageType, _ []byte) error { + if ctx == nil { + ctx = context.Background() + } + select { + case <-ctx.Done(): + return ctx.Err() + default: + return nil + } +} + +func (c *closeSpyFrameConn) Close() error { + if c != nil { + c.closeCalls.Add(1) + } + return nil +} + +func (c *closeSpyFrameConn) CloseCalls() int32 { + if c == nil { + return 0 + } + return c.closeCalls.Load() +} + +func TestRelay_BasicRelayAndUsage(t *testing.T) { + t.Parallel() + + clientConn := newPassthroughTestFrameConn(nil, false) + upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{ + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.completed","response":{"id":"resp_123","usage":{"input_tokens":7,"output_tokens":3,"input_tokens_details":{"cached_tokens":2}}}}`), + }, + }, true) + + firstPayload := []byte(`{"type":"response.create","model":"gpt-5.3-codex","input":[{"type":"input_text","text":"hello"}]}`) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{}) + require.Nil(t, relayExit) + require.Equal(t, "gpt-5.3-codex", result.RequestModel) + require.Equal(t, "resp_123", result.RequestID) + require.Equal(t, "response.completed", result.TerminalEventType) + require.Equal(t, 7, result.Usage.InputTokens) + require.Equal(t, 3, result.Usage.OutputTokens) + require.Equal(t, 2, result.Usage.CacheReadInputTokens) + require.NotNil(t, result.FirstTokenMs) + require.Equal(t, int64(1), result.ClientToUpstreamFrames) + require.Equal(t, int64(1), result.UpstreamToClientFrames) + require.Equal(t, int64(0), result.DroppedDownstreamFrames) + + upstreamWrites := upstreamConn.Writes() + require.Len(t, upstreamWrites, 1) + require.Equal(t, coderws.MessageText, upstreamWrites[0].msgType) + require.JSONEq(t, string(firstPayload), string(upstreamWrites[0].payload)) + + clientWrites := clientConn.Writes() + require.Len(t, clientWrites, 1) + require.Equal(t, coderws.MessageText, clientWrites[0].msgType) + require.JSONEq(t, `{"type":"response.completed","response":{"id":"resp_123","usage":{"input_tokens":7,"output_tokens":3,"input_tokens_details":{"cached_tokens":2}}}}`, string(clientWrites[0].payload)) +} + +func TestRelay_FunctionCallOutputBytesPreserved(t *testing.T) { + t.Parallel() + + clientConn := newPassthroughTestFrameConn(nil, false) + upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{ + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.completed","response":{"id":"resp_func","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + }, true) + + firstPayload := []byte(`{"type":"response.create","model":"gpt-5.3-codex","input":[{"type":"function_call_output","call_id":"call_abc123","output":"{\"ok\":true}"}]}`) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + _, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{}) + require.Nil(t, relayExit) + + upstreamWrites := upstreamConn.Writes() + require.Len(t, upstreamWrites, 1) + require.Equal(t, coderws.MessageText, upstreamWrites[0].msgType) + require.Equal(t, firstPayload, upstreamWrites[0].payload) +} + +func TestRelay_UpstreamDisconnect(t *testing.T) { + t.Parallel() + + // 上游立即关闭(EOF),客户端不发送额外帧 + clientConn := newPassthroughTestFrameConn(nil, false) + upstreamConn := newPassthroughTestFrameConn(nil, true) // 立即 close -> EOF + + firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{}) + // 上游 EOF 属于 disconnect,标记为 graceful + require.Nil(t, relayExit, "上游 EOF 应被视为 graceful disconnect") + require.Equal(t, "gpt-4o", result.RequestModel) +} + +func TestRelay_ClientDisconnect(t *testing.T) { + t.Parallel() + + // 客户端立即关闭(EOF),上游阻塞读取直到 context 取消 + clientConn := newPassthroughTestFrameConn(nil, true) // 立即 close -> EOF + upstreamConn := newPassthroughTestFrameConn(nil, false) + + firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{}) + require.NotNil(t, relayExit, "客户端 EOF 应返回可观测的中断状态") + require.Equal(t, "client_disconnected", relayExit.Stage) + require.Equal(t, "gpt-4o", result.RequestModel) +} + +func TestRelay_ClientDisconnect_DrainCapturesLateUsage(t *testing.T) { + t.Parallel() + + clientConn := newPassthroughTestFrameConn(nil, true) + upstreamBase := newPassthroughTestFrameConn([]passthroughTestFrame{ + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.completed","response":{"id":"resp_drain","usage":{"input_tokens":6,"output_tokens":4,"input_tokens_details":{"cached_tokens":1}}}}`), + }, + }, true) + upstreamConn := &delayedReadFrameConn{ + base: upstreamBase, + firstDelay: 80 * time.Millisecond, + } + + firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{ + UpstreamDrainTimeout: 400 * time.Millisecond, + }) + require.NotNil(t, relayExit) + require.Equal(t, "client_disconnected", relayExit.Stage) + require.Equal(t, "resp_drain", result.RequestID) + require.Equal(t, "response.completed", result.TerminalEventType) + require.Equal(t, 6, result.Usage.InputTokens) + require.Equal(t, 4, result.Usage.OutputTokens) + require.Equal(t, 1, result.Usage.CacheReadInputTokens) + require.Equal(t, int64(1), result.ClientToUpstreamFrames) + require.Equal(t, int64(0), result.UpstreamToClientFrames) + require.Equal(t, int64(1), result.DroppedDownstreamFrames) +} + +func TestRelay_IdleTimeout(t *testing.T) { + t.Parallel() + + // 客户端和上游都不发送帧,idle timeout 应触发 + clientConn := newPassthroughTestFrameConn(nil, false) + upstreamConn := newPassthroughTestFrameConn(nil, false) + + firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // 使用快进时间来加速 idle timeout + now := time.Now() + callCount := 0 + nowFn := func() time.Time { + callCount++ + // 前几次调用返回正常时间(初始化阶段),之后快进 + if callCount <= 5 { + return now + } + return now.Add(time.Hour) // 快进到超时 + } + + result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{ + IdleTimeout: 2 * time.Second, + Now: nowFn, + }) + require.NotNil(t, relayExit, "应因 idle timeout 退出") + require.Equal(t, "idle_timeout", relayExit.Stage) + require.Equal(t, "gpt-4o", result.RequestModel) +} + +func TestRelay_IdleTimeoutDoesNotCloseClientOnError(t *testing.T) { + t.Parallel() + + clientConn := &closeSpyFrameConn{} + upstreamConn := &closeSpyFrameConn{} + + firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + now := time.Now() + callCount := 0 + nowFn := func() time.Time { + callCount++ + if callCount <= 5 { + return now + } + return now.Add(time.Hour) + } + + _, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{ + IdleTimeout: 2 * time.Second, + Now: nowFn, + }) + require.NotNil(t, relayExit, "应因 idle timeout 退出") + require.Equal(t, "idle_timeout", relayExit.Stage) + require.Zero(t, clientConn.CloseCalls(), "错误路径不应提前关闭客户端连接,交给上层决定 close code") + require.GreaterOrEqual(t, upstreamConn.CloseCalls(), int32(1)) +} + +func TestRelay_NilConnections(t *testing.T) { + t.Parallel() + + firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`) + ctx := context.Background() + + t.Run("nil client conn", func(t *testing.T) { + upstreamConn := newPassthroughTestFrameConn(nil, true) + _, relayExit := Relay(ctx, nil, upstreamConn, firstPayload, RelayOptions{}) + require.NotNil(t, relayExit) + require.Equal(t, "relay_init", relayExit.Stage) + require.Contains(t, relayExit.Err.Error(), "nil") + }) + + t.Run("nil upstream conn", func(t *testing.T) { + clientConn := newPassthroughTestFrameConn(nil, true) + _, relayExit := Relay(ctx, clientConn, nil, firstPayload, RelayOptions{}) + require.NotNil(t, relayExit) + require.Equal(t, "relay_init", relayExit.Stage) + require.Contains(t, relayExit.Err.Error(), "nil") + }) +} + +func TestRelay_MultipleUpstreamMessages(t *testing.T) { + t.Parallel() + + // 上游发送多个事件(delta + completed),验证多帧中继和 usage 聚合 + clientConn := newPassthroughTestFrameConn(nil, false) + upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{ + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.output_text.delta","delta":"Hello"}`), + }, + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.output_text.delta","delta":" world"}`), + }, + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.completed","response":{"id":"resp_multi","usage":{"input_tokens":10,"output_tokens":5,"input_tokens_details":{"cached_tokens":3}}}}`), + }, + }, true) + + firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[{"type":"input_text","text":"hi"}]}`) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{}) + require.Nil(t, relayExit) + require.Equal(t, "resp_multi", result.RequestID) + require.Equal(t, "response.completed", result.TerminalEventType) + require.Equal(t, 10, result.Usage.InputTokens) + require.Equal(t, 5, result.Usage.OutputTokens) + require.Equal(t, 3, result.Usage.CacheReadInputTokens) + require.NotNil(t, result.FirstTokenMs) + + // 验证所有 3 个上游帧都转发给了客户端 + clientWrites := clientConn.Writes() + require.Len(t, clientWrites, 3) +} + +func TestRelay_OnTurnComplete_PerTerminalEvent(t *testing.T) { + t.Parallel() + + clientConn := newPassthroughTestFrameConn(nil, false) + upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{ + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.completed","response":{"id":"resp_turn_1","usage":{"input_tokens":2,"output_tokens":1}}}`), + }, + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.failed","response":{"id":"resp_turn_2","usage":{"input_tokens":3,"output_tokens":4}}}`), + }, + }, true) + + firstPayload := []byte(`{"type":"response.create","model":"gpt-5.3-codex","input":[]}`) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + turns := make([]RelayTurnResult, 0, 2) + result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{ + OnTurnComplete: func(turn RelayTurnResult) { + turns = append(turns, turn) + }, + }) + require.Nil(t, relayExit) + require.Len(t, turns, 2) + require.Equal(t, "resp_turn_1", turns[0].RequestID) + require.Equal(t, "response.completed", turns[0].TerminalEventType) + require.Equal(t, 2, turns[0].Usage.InputTokens) + require.Equal(t, 1, turns[0].Usage.OutputTokens) + require.Equal(t, "resp_turn_2", turns[1].RequestID) + require.Equal(t, "response.failed", turns[1].TerminalEventType) + require.Equal(t, 3, turns[1].Usage.InputTokens) + require.Equal(t, 4, turns[1].Usage.OutputTokens) + require.Equal(t, 5, result.Usage.InputTokens) + require.Equal(t, 5, result.Usage.OutputTokens) +} + +func TestRelay_OnTurnComplete_ProvidesTurnMetrics(t *testing.T) { + t.Parallel() + + clientConn := newPassthroughTestFrameConn(nil, false) + upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{ + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.output_text.delta","response_id":"resp_metric","delta":"hi"}`), + }, + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.completed","response":{"id":"resp_metric","usage":{"input_tokens":2,"output_tokens":1}}}`), + }, + }, true) + + firstPayload := []byte(`{"type":"response.create","model":"gpt-5.3-codex","input":[]}`) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + base := time.Unix(0, 0) + var nowTick atomic.Int64 + nowFn := func() time.Time { + step := nowTick.Add(1) + return base.Add(time.Duration(step) * 5 * time.Millisecond) + } + + var turn RelayTurnResult + result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{ + Now: nowFn, + OnTurnComplete: func(current RelayTurnResult) { + turn = current + }, + }) + require.Nil(t, relayExit) + require.Equal(t, "resp_metric", turn.RequestID) + require.Equal(t, "response.completed", turn.TerminalEventType) + require.NotNil(t, turn.FirstTokenMs) + require.GreaterOrEqual(t, *turn.FirstTokenMs, 0) + require.Greater(t, turn.Duration.Milliseconds(), int64(0)) + require.NotNil(t, result.FirstTokenMs) + require.Greater(t, result.Duration.Milliseconds(), int64(0)) +} + +func TestRelay_BinaryFramePassthrough(t *testing.T) { + t.Parallel() + + // 验证 binary frame 被透传但不进行 usage 解析 + binaryPayload := []byte{0x00, 0x01, 0x02, 0x03} + clientConn := newPassthroughTestFrameConn(nil, false) + upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{ + { + msgType: coderws.MessageBinary, + payload: binaryPayload, + }, + }, true) + + firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{}) + require.Nil(t, relayExit) + // binary frame 不解析 usage + require.Equal(t, 0, result.Usage.InputTokens) + + clientWrites := clientConn.Writes() + require.Len(t, clientWrites, 1) + require.Equal(t, coderws.MessageBinary, clientWrites[0].msgType) + require.Equal(t, binaryPayload, clientWrites[0].payload) +} + +func TestRelay_BinaryJSONFrameSkipsObservation(t *testing.T) { + t.Parallel() + + clientConn := newPassthroughTestFrameConn(nil, false) + upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{ + { + msgType: coderws.MessageBinary, + payload: []byte(`{"type":"response.completed","response":{"id":"resp_binary","usage":{"input_tokens":7,"output_tokens":3}}}`), + }, + }, true) + + firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{}) + require.Nil(t, relayExit) + require.Equal(t, 0, result.Usage.InputTokens) + require.Equal(t, "", result.RequestID) + require.Equal(t, "", result.TerminalEventType) + + clientWrites := clientConn.Writes() + require.Len(t, clientWrites, 1) + require.Equal(t, coderws.MessageBinary, clientWrites[0].msgType) +} + +func TestRelay_UpstreamErrorEventPassthroughRaw(t *testing.T) { + t.Parallel() + + clientConn := newPassthroughTestFrameConn(nil, false) + errorEvent := []byte(`{"type":"error","error":{"type":"invalid_request_error","message":"No tool call found"}}`) + upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{ + { + msgType: coderws.MessageText, + payload: errorEvent, + }, + }, true) + + firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + _, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{}) + require.Nil(t, relayExit) + + clientWrites := clientConn.Writes() + require.Len(t, clientWrites, 1) + require.Equal(t, coderws.MessageText, clientWrites[0].msgType) + require.Equal(t, errorEvent, clientWrites[0].payload) +} + +func TestRelay_PreservesFirstMessageType(t *testing.T) { + t.Parallel() + + clientConn := newPassthroughTestFrameConn(nil, false) + upstreamConn := newPassthroughTestFrameConn(nil, true) + + firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + _, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{ + FirstMessageType: coderws.MessageBinary, + }) + require.Nil(t, relayExit) + + upstreamWrites := upstreamConn.Writes() + require.Len(t, upstreamWrites, 1) + require.Equal(t, coderws.MessageBinary, upstreamWrites[0].msgType) + require.Equal(t, firstPayload, upstreamWrites[0].payload) +} + +func TestRelay_UsageParseFailureDoesNotBlockRelay(t *testing.T) { + baseline := SnapshotMetrics().UsageParseFailureTotal + + // 上游发送无效 JSON(非 usage 格式),不应影响透传 + clientConn := newPassthroughTestFrameConn(nil, false) + upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{ + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.completed","response":{"id":"resp_bad","usage":"not_an_object"}}`), + }, + }, true) + + firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{}) + require.Nil(t, relayExit) + // usage 解析失败,值为 0 但不影响透传 + require.Equal(t, 0, result.Usage.InputTokens) + require.Equal(t, "response.completed", result.TerminalEventType) + + // 帧仍然被转发 + clientWrites := clientConn.Writes() + require.Len(t, clientWrites, 1) + require.GreaterOrEqual(t, SnapshotMetrics().UsageParseFailureTotal, baseline+1) +} + +func TestRelay_WriteUpstreamFirstMessageFails(t *testing.T) { + t.Parallel() + + // 上游连接立即关闭,首包写入失败 + upstreamConn := newPassthroughTestFrameConn(nil, true) + _ = upstreamConn.Close() + + // 覆盖 WriteFrame 使其返回错误 + errConn := &errorOnWriteFrameConn{} + clientConn := newPassthroughTestFrameConn(nil, false) + + firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + _, relayExit := Relay(ctx, clientConn, errConn, firstPayload, RelayOptions{}) + require.NotNil(t, relayExit) + require.Equal(t, "write_upstream", relayExit.Stage) +} + +func TestRelay_ContextCanceled(t *testing.T) { + t.Parallel() + + clientConn := newPassthroughTestFrameConn(nil, false) + upstreamConn := newPassthroughTestFrameConn(nil, false) + + firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`) + + // 立即取消 context + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + _, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{}) + // context 取消导致写首包失败 + require.NotNil(t, relayExit) +} + +func TestRelay_TraceEvents_ContainsLifecycleStages(t *testing.T) { + t.Parallel() + + clientConn := newPassthroughTestFrameConn(nil, false) + upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{ + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.completed","response":{"id":"resp_trace","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + }, true) + + firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + stages := make([]string, 0, 8) + var stagesMu sync.Mutex + _, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{ + OnTrace: func(event RelayTraceEvent) { + stagesMu.Lock() + stages = append(stages, event.Stage) + stagesMu.Unlock() + }, + }) + require.Nil(t, relayExit) + stagesMu.Lock() + capturedStages := append([]string(nil), stages...) + stagesMu.Unlock() + require.Contains(t, capturedStages, "relay_start") + require.Contains(t, capturedStages, "write_first_message_ok") + require.Contains(t, capturedStages, "first_exit") + require.Contains(t, capturedStages, "relay_complete") +} + +func TestRelay_TraceEvents_IdleTimeout(t *testing.T) { + t.Parallel() + + clientConn := newPassthroughTestFrameConn(nil, false) + upstreamConn := newPassthroughTestFrameConn(nil, false) + + firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + now := time.Now() + callCount := 0 + nowFn := func() time.Time { + callCount++ + if callCount <= 5 { + return now + } + return now.Add(time.Hour) + } + + stages := make([]string, 0, 8) + var stagesMu sync.Mutex + _, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{ + IdleTimeout: 2 * time.Second, + Now: nowFn, + OnTrace: func(event RelayTraceEvent) { + stagesMu.Lock() + stages = append(stages, event.Stage) + stagesMu.Unlock() + }, + }) + require.NotNil(t, relayExit) + require.Equal(t, "idle_timeout", relayExit.Stage) + stagesMu.Lock() + capturedStages := append([]string(nil), stages...) + stagesMu.Unlock() + require.Contains(t, capturedStages, "idle_timeout_triggered") + require.Contains(t, capturedStages, "relay_exit") +} + +// errorOnWriteFrameConn 是一个写入总是失败的 FrameConn 实现,用于测试首包写入失败。 +type errorOnWriteFrameConn struct{} + +func (c *errorOnWriteFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) { + <-ctx.Done() + return coderws.MessageText, nil, ctx.Err() +} + +func (c *errorOnWriteFrameConn) WriteFrame(_ context.Context, _ coderws.MessageType, _ []byte) error { + return errors.New("write failed: connection refused") +} + +func (c *errorOnWriteFrameConn) Close() error { + return nil +} diff --git a/backend/internal/service/openai_ws_v2_passthrough_adapter.go b/backend/internal/service/openai_ws_v2_passthrough_adapter.go new file mode 100644 index 00000000..cda2e351 --- /dev/null +++ b/backend/internal/service/openai_ws_v2_passthrough_adapter.go @@ -0,0 +1,372 @@ +package service + +import ( + "context" + "errors" + "fmt" + "net/http" + "net/url" + "strings" + "sync/atomic" + + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/pkg/openai" + openaiwsv2 "github.com/Wei-Shaw/sub2api/internal/service/openai_ws_v2" + coderws "github.com/coder/websocket" + "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" +) + +type openAIWSClientFrameConn struct { + conn *coderws.Conn +} + +const openaiWSV2PassthroughModeFields = "ws_mode=passthrough ws_router=v2" + +var _ openaiwsv2.FrameConn = (*openAIWSClientFrameConn)(nil) + +func (c *openAIWSClientFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) { + if c == nil || c.conn == nil { + return coderws.MessageText, nil, errOpenAIWSConnClosed + } + if ctx == nil { + ctx = context.Background() + } + return c.conn.Read(ctx) +} + +func (c *openAIWSClientFrameConn) WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error { + if c == nil || c.conn == nil { + return errOpenAIWSConnClosed + } + if ctx == nil { + ctx = context.Background() + } + return c.conn.Write(ctx, msgType, payload) +} + +func (c *openAIWSClientFrameConn) Close() error { + if c == nil || c.conn == nil { + return nil + } + _ = c.conn.Close(coderws.StatusNormalClosure, "") + _ = c.conn.CloseNow() + return nil +} + +func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough( + ctx context.Context, + c *gin.Context, + clientConn *coderws.Conn, + account *Account, + token string, + firstClientMessage []byte, + hooks *OpenAIWSIngressHooks, + wsDecision OpenAIWSProtocolDecision, +) error { + if s == nil { + return errors.New("service is nil") + } + if clientConn == nil { + return errors.New("client websocket is nil") + } + if account == nil { + return errors.New("account is nil") + } + if strings.TrimSpace(token) == "" { + return errors.New("token is empty") + } + requestModel := strings.TrimSpace(gjson.GetBytes(firstClientMessage, "model").String()) + requestServiceTier := extractOpenAIServiceTierFromBody(firstClientMessage) + requestPreviousResponseID := strings.TrimSpace(gjson.GetBytes(firstClientMessage, "previous_response_id").String()) + logOpenAIWSV2Passthrough( + "relay_start account_id=%d model=%s previous_response_id=%s first_message_type=%s first_message_bytes=%d", + account.ID, + truncateOpenAIWSLogValue(requestModel, openAIWSLogValueMaxLen), + truncateOpenAIWSLogValue(requestPreviousResponseID, openAIWSIDValueMaxLen), + openaiwsv2RelayMessageTypeName(coderws.MessageText), + len(firstClientMessage), + ) + + wsURL, err := s.buildOpenAIResponsesWSURL(account) + if err != nil { + return fmt.Errorf("build ws url: %w", err) + } + wsHost := "-" + wsPath := "-" + if parsedURL, parseErr := url.Parse(wsURL); parseErr == nil && parsedURL != nil { + wsHost = normalizeOpenAIWSLogValue(parsedURL.Host) + wsPath = normalizeOpenAIWSLogValue(parsedURL.Path) + } + logOpenAIWSV2Passthrough( + "relay_dial_start account_id=%d ws_host=%s ws_path=%s proxy_enabled=%v", + account.ID, + wsHost, + wsPath, + account.ProxyID != nil && account.Proxy != nil, + ) + + isCodexCLI := false + if c != nil { + isCodexCLI = openai.IsCodexOfficialClientByHeaders(c.GetHeader("User-Agent"), c.GetHeader("originator")) + } + if s.cfg != nil && s.cfg.Gateway.ForceCodexCLI { + isCodexCLI = true + } + headers, _ := s.buildOpenAIWSHeaders(c, account, token, wsDecision, isCodexCLI, "", "", "") + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + dialer := s.getOpenAIWSPassthroughDialer() + if dialer == nil { + return errors.New("openai ws passthrough dialer is nil") + } + + dialCtx, cancelDial := context.WithTimeout(ctx, s.openAIWSDialTimeout()) + defer cancelDial() + upstreamConn, statusCode, handshakeHeaders, err := dialer.Dial(dialCtx, wsURL, headers, proxyURL) + if err != nil { + logOpenAIWSV2Passthrough( + "relay_dial_failed account_id=%d status_code=%d err=%s", + account.ID, + statusCode, + truncateOpenAIWSLogValue(err.Error(), openAIWSLogValueMaxLen), + ) + return s.mapOpenAIWSPassthroughDialError(err, statusCode, handshakeHeaders) + } + defer func() { + _ = upstreamConn.Close() + }() + logOpenAIWSV2Passthrough( + "relay_dial_ok account_id=%d status_code=%d upstream_request_id=%s", + account.ID, + statusCode, + openAIWSHeaderValueForLog(handshakeHeaders, "x-request-id"), + ) + + upstreamFrameConn, ok := upstreamConn.(openaiwsv2.FrameConn) + if !ok { + return errors.New("openai ws passthrough upstream connection does not support frame relay") + } + + completedTurns := atomic.Int32{} + relayResult, relayExit := openaiwsv2.RunEntry(openaiwsv2.EntryInput{ + Ctx: ctx, + ClientConn: &openAIWSClientFrameConn{conn: clientConn}, + UpstreamConn: upstreamFrameConn, + FirstClientMessage: firstClientMessage, + Options: openaiwsv2.RelayOptions{ + WriteTimeout: s.openAIWSWriteTimeout(), + IdleTimeout: s.openAIWSPassthroughIdleTimeout(), + FirstMessageType: coderws.MessageText, + OnUsageParseFailure: func(eventType string, usageRaw string) { + logOpenAIWSV2Passthrough( + "usage_parse_failed event_type=%s usage_raw=%s", + truncateOpenAIWSLogValue(eventType, openAIWSLogValueMaxLen), + truncateOpenAIWSLogValue(usageRaw, openAIWSLogValueMaxLen), + ) + }, + OnTurnComplete: func(turn openaiwsv2.RelayTurnResult) { + turnNo := int(completedTurns.Add(1)) + turnResult := &OpenAIForwardResult{ + RequestID: turn.RequestID, + Usage: OpenAIUsage{ + InputTokens: turn.Usage.InputTokens, + OutputTokens: turn.Usage.OutputTokens, + CacheCreationInputTokens: turn.Usage.CacheCreationInputTokens, + CacheReadInputTokens: turn.Usage.CacheReadInputTokens, + }, + Model: turn.RequestModel, + ServiceTier: requestServiceTier, + Stream: true, + OpenAIWSMode: true, + ResponseHeaders: cloneHeader(handshakeHeaders), + Duration: turn.Duration, + FirstTokenMs: turn.FirstTokenMs, + } + logOpenAIWSV2Passthrough( + "relay_turn_completed account_id=%d turn=%d request_id=%s terminal_event=%s duration_ms=%d first_token_ms=%d input_tokens=%d output_tokens=%d cache_read_tokens=%d", + account.ID, + turnNo, + truncateOpenAIWSLogValue(turnResult.RequestID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(turn.TerminalEventType, openAIWSLogValueMaxLen), + turnResult.Duration.Milliseconds(), + openAIWSFirstTokenMsForLog(turnResult.FirstTokenMs), + turnResult.Usage.InputTokens, + turnResult.Usage.OutputTokens, + turnResult.Usage.CacheReadInputTokens, + ) + if hooks != nil && hooks.AfterTurn != nil { + hooks.AfterTurn(turnNo, turnResult, nil) + } + }, + OnTrace: func(event openaiwsv2.RelayTraceEvent) { + logOpenAIWSV2Passthrough( + "relay_trace account_id=%d stage=%s direction=%s msg_type=%s bytes=%d graceful=%v wrote_downstream=%v err=%s", + account.ID, + truncateOpenAIWSLogValue(event.Stage, openAIWSLogValueMaxLen), + truncateOpenAIWSLogValue(event.Direction, openAIWSLogValueMaxLen), + truncateOpenAIWSLogValue(event.MessageType, openAIWSLogValueMaxLen), + event.PayloadBytes, + event.Graceful, + event.WroteDownstream, + truncateOpenAIWSLogValue(event.Error, openAIWSLogValueMaxLen), + ) + }, + }, + }) + + result := &OpenAIForwardResult{ + RequestID: relayResult.RequestID, + Usage: OpenAIUsage{ + InputTokens: relayResult.Usage.InputTokens, + OutputTokens: relayResult.Usage.OutputTokens, + CacheCreationInputTokens: relayResult.Usage.CacheCreationInputTokens, + CacheReadInputTokens: relayResult.Usage.CacheReadInputTokens, + }, + Model: relayResult.RequestModel, + ServiceTier: requestServiceTier, + Stream: true, + OpenAIWSMode: true, + ResponseHeaders: cloneHeader(handshakeHeaders), + Duration: relayResult.Duration, + FirstTokenMs: relayResult.FirstTokenMs, + } + + turnCount := int(completedTurns.Load()) + if relayExit == nil { + logOpenAIWSV2Passthrough( + "relay_completed account_id=%d request_id=%s terminal_event=%s duration_ms=%d c2u_frames=%d u2c_frames=%d dropped_frames=%d turns=%d", + account.ID, + truncateOpenAIWSLogValue(result.RequestID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(relayResult.TerminalEventType, openAIWSLogValueMaxLen), + result.Duration.Milliseconds(), + relayResult.ClientToUpstreamFrames, + relayResult.UpstreamToClientFrames, + relayResult.DroppedDownstreamFrames, + turnCount, + ) + // 正常路径按 terminal 事件逐 turn 已回调;仅在零 turn 场景兜底回调一次。 + if turnCount == 0 && hooks != nil && hooks.AfterTurn != nil { + hooks.AfterTurn(1, result, nil) + } + return nil + } + logOpenAIWSV2Passthrough( + "relay_failed account_id=%d stage=%s wrote_downstream=%v err=%s duration_ms=%d c2u_frames=%d u2c_frames=%d dropped_frames=%d turns=%d", + account.ID, + truncateOpenAIWSLogValue(relayExit.Stage, openAIWSLogValueMaxLen), + relayExit.WroteDownstream, + truncateOpenAIWSLogValue(relayErrorText(relayExit.Err), openAIWSLogValueMaxLen), + result.Duration.Milliseconds(), + relayResult.ClientToUpstreamFrames, + relayResult.UpstreamToClientFrames, + relayResult.DroppedDownstreamFrames, + turnCount, + ) + + relayErr := relayExit.Err + if relayExit.Stage == "idle_timeout" { + relayErr = NewOpenAIWSClientCloseError( + coderws.StatusPolicyViolation, + "client websocket idle timeout", + relayErr, + ) + } + turnErr := wrapOpenAIWSIngressTurnError( + relayExit.Stage, + relayErr, + relayExit.WroteDownstream, + ) + if hooks != nil && hooks.AfterTurn != nil { + hooks.AfterTurn(turnCount+1, nil, turnErr) + } + return turnErr +} + +func (s *OpenAIGatewayService) mapOpenAIWSPassthroughDialError( + err error, + statusCode int, + handshakeHeaders http.Header, +) error { + if err == nil { + return nil + } + wrappedErr := err + var dialErr *openAIWSDialError + if !errors.As(err, &dialErr) { + wrappedErr = &openAIWSDialError{ + StatusCode: statusCode, + ResponseHeaders: cloneHeader(handshakeHeaders), + Err: err, + } + } + + if errors.Is(err, context.Canceled) { + return err + } + if errors.Is(err, context.DeadlineExceeded) { + return NewOpenAIWSClientCloseError( + coderws.StatusTryAgainLater, + "upstream websocket connect timeout", + wrappedErr, + ) + } + if statusCode == http.StatusTooManyRequests { + return NewOpenAIWSClientCloseError( + coderws.StatusTryAgainLater, + "upstream websocket is busy, please retry later", + wrappedErr, + ) + } + if statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden { + return NewOpenAIWSClientCloseError( + coderws.StatusPolicyViolation, + "upstream websocket authentication failed", + wrappedErr, + ) + } + if statusCode >= http.StatusBadRequest && statusCode < http.StatusInternalServerError { + return NewOpenAIWSClientCloseError( + coderws.StatusPolicyViolation, + "upstream websocket handshake rejected", + wrappedErr, + ) + } + return fmt.Errorf("openai ws passthrough dial: %w", wrappedErr) +} + +func openaiwsv2RelayMessageTypeName(msgType coderws.MessageType) string { + switch msgType { + case coderws.MessageText: + return "text" + case coderws.MessageBinary: + return "binary" + default: + return fmt.Sprintf("unknown(%d)", msgType) + } +} + +func relayErrorText(err error) string { + if err == nil { + return "" + } + return err.Error() +} + +func openAIWSFirstTokenMsForLog(firstTokenMs *int) int { + if firstTokenMs == nil { + return -1 + } + return *firstTokenMs +} + +func logOpenAIWSV2Passthrough(format string, args ...any) { + logger.LegacyPrintf( + "service.openai_ws_v2", + "[OpenAI WS v2 passthrough] %s "+format, + append([]any{openaiWSV2PassthroughModeFields}, args...)..., + ) +} diff --git a/backend/internal/service/ops_alert_evaluator_service.go b/backend/internal/service/ops_alert_evaluator_service.go index 169a5e32..88883180 100644 --- a/backend/internal/service/ops_alert_evaluator_service.go +++ b/backend/internal/service/ops_alert_evaluator_service.go @@ -506,6 +506,48 @@ func (s *OpsAlertEvaluatorService) computeRuleMetric( return float64(countAccountsByCondition(availability.Accounts, func(acc *AccountAvailability) bool { return acc.HasError && acc.TempUnschedulableUntil == nil })), true + case "group_rate_limit_ratio": + if groupID == nil || *groupID <= 0 { + return 0, false + } + if s == nil || s.opsService == nil { + return 0, false + } + availability, err := s.opsService.GetAccountAvailability(ctx, platform, groupID) + if err != nil || availability == nil { + return 0, false + } + if availability.Group == nil || availability.Group.TotalAccounts <= 0 { + return 0, true + } + return (float64(availability.Group.RateLimitCount) / float64(availability.Group.TotalAccounts)) * 100, true + case "account_error_ratio": + if s == nil || s.opsService == nil { + return 0, false + } + availability, err := s.opsService.GetAccountAvailability(ctx, platform, groupID) + if err != nil || availability == nil { + return 0, false + } + total := int64(len(availability.Accounts)) + if total <= 0 { + return 0, true + } + errorCount := countAccountsByCondition(availability.Accounts, func(acc *AccountAvailability) bool { + return acc.HasError && acc.TempUnschedulableUntil == nil + }) + return (float64(errorCount) / float64(total)) * 100, true + case "overload_account_count": + if s == nil || s.opsService == nil { + return 0, false + } + availability, err := s.opsService.GetAccountAvailability(ctx, platform, groupID) + if err != nil || availability == nil { + return 0, false + } + return float64(countAccountsByCondition(availability.Accounts, func(acc *AccountAvailability) bool { + return acc.IsOverloaded + })), true } overview, err := s.opsRepo.GetDashboardOverview(ctx, &OpsDashboardFilter{ diff --git a/backend/internal/service/ops_concurrency.go b/backend/internal/service/ops_concurrency.go index 92b37e73..c03108c4 100644 --- a/backend/internal/service/ops_concurrency.go +++ b/backend/internal/service/ops_concurrency.go @@ -64,8 +64,9 @@ func (s *OpsService) getAccountsLoadMapBestEffort(ctx context.Context, accounts if acc.ID <= 0 { continue } - if prev, ok := unique[acc.ID]; !ok || acc.Concurrency > prev { - unique[acc.ID] = acc.Concurrency + lf := acc.EffectiveLoadFactor() + if prev, ok := unique[acc.ID]; !ok || lf > prev { + unique[acc.ID] = lf } } diff --git a/backend/internal/service/ops_dashboard.go b/backend/internal/service/ops_dashboard.go index 31822ba8..6f70c75c 100644 --- a/backend/internal/service/ops_dashboard.go +++ b/backend/internal/service/ops_dashboard.go @@ -31,6 +31,10 @@ func (s *OpsService) GetDashboardOverview(ctx context.Context, filter *OpsDashbo filter.QueryMode = s.resolveOpsQueryMode(ctx, filter.QueryMode) overview, err := s.opsRepo.GetDashboardOverview(ctx, filter) + if err != nil && shouldFallbackOpsPreagg(filter, err) { + rawFilter := cloneOpsFilterWithMode(filter, OpsQueryModeRaw) + overview, err = s.opsRepo.GetDashboardOverview(ctx, rawFilter) + } if err != nil { if errors.Is(err, ErrOpsPreaggregatedNotPopulated) { return nil, infraerrors.Conflict("OPS_PREAGG_NOT_READY", "Pre-aggregated ops metrics are not populated yet") diff --git a/backend/internal/service/ops_errors.go b/backend/internal/service/ops_errors.go index 76b5ce8b..01671c1e 100644 --- a/backend/internal/service/ops_errors.go +++ b/backend/internal/service/ops_errors.go @@ -22,7 +22,14 @@ func (s *OpsService) GetErrorTrend(ctx context.Context, filter *OpsDashboardFilt if filter.StartTime.After(filter.EndTime) { return nil, infraerrors.BadRequest("OPS_TIME_RANGE_INVALID", "start_time must be <= end_time") } - return s.opsRepo.GetErrorTrend(ctx, filter, bucketSeconds) + filter.QueryMode = s.resolveOpsQueryMode(ctx, filter.QueryMode) + + result, err := s.opsRepo.GetErrorTrend(ctx, filter, bucketSeconds) + if err != nil && shouldFallbackOpsPreagg(filter, err) { + rawFilter := cloneOpsFilterWithMode(filter, OpsQueryModeRaw) + return s.opsRepo.GetErrorTrend(ctx, rawFilter, bucketSeconds) + } + return result, err } func (s *OpsService) GetErrorDistribution(ctx context.Context, filter *OpsDashboardFilter) (*OpsErrorDistributionResponse, error) { @@ -41,5 +48,12 @@ func (s *OpsService) GetErrorDistribution(ctx context.Context, filter *OpsDashbo if filter.StartTime.After(filter.EndTime) { return nil, infraerrors.BadRequest("OPS_TIME_RANGE_INVALID", "start_time must be <= end_time") } - return s.opsRepo.GetErrorDistribution(ctx, filter) + filter.QueryMode = s.resolveOpsQueryMode(ctx, filter.QueryMode) + + result, err := s.opsRepo.GetErrorDistribution(ctx, filter) + if err != nil && shouldFallbackOpsPreagg(filter, err) { + rawFilter := cloneOpsFilterWithMode(filter, OpsQueryModeRaw) + return s.opsRepo.GetErrorDistribution(ctx, rawFilter) + } + return result, err } diff --git a/backend/internal/service/ops_histograms.go b/backend/internal/service/ops_histograms.go index 9f5b514f..c555dbfc 100644 --- a/backend/internal/service/ops_histograms.go +++ b/backend/internal/service/ops_histograms.go @@ -22,5 +22,12 @@ func (s *OpsService) GetLatencyHistogram(ctx context.Context, filter *OpsDashboa if filter.StartTime.After(filter.EndTime) { return nil, infraerrors.BadRequest("OPS_TIME_RANGE_INVALID", "start_time must be <= end_time") } - return s.opsRepo.GetLatencyHistogram(ctx, filter) + filter.QueryMode = s.resolveOpsQueryMode(ctx, filter.QueryMode) + + result, err := s.opsRepo.GetLatencyHistogram(ctx, filter) + if err != nil && shouldFallbackOpsPreagg(filter, err) { + rawFilter := cloneOpsFilterWithMode(filter, OpsQueryModeRaw) + return s.opsRepo.GetLatencyHistogram(ctx, rawFilter) + } + return result, err } diff --git a/backend/internal/service/ops_metrics_collector.go b/backend/internal/service/ops_metrics_collector.go index 30adaae0..6c337071 100644 --- a/backend/internal/service/ops_metrics_collector.go +++ b/backend/internal/service/ops_metrics_collector.go @@ -389,13 +389,9 @@ func (c *OpsMetricsCollector) collectConcurrencyQueueDepth(parentCtx context.Con if acc.ID <= 0 { continue } - maxConc := acc.Concurrency - if maxConc < 0 { - maxConc = 0 - } batch = append(batch, AccountWithConcurrency{ ID: acc.ID, - MaxConcurrency: maxConc, + MaxConcurrency: acc.EffectiveLoadFactor(), }) } if len(batch) == 0 { diff --git a/backend/internal/service/ops_port.go b/backend/internal/service/ops_port.go index f3633eae..0ce9d425 100644 --- a/backend/internal/service/ops_port.go +++ b/backend/internal/service/ops_port.go @@ -7,6 +7,7 @@ import ( type OpsRepository interface { InsertErrorLog(ctx context.Context, input *OpsInsertErrorLogInput) (int64, error) + BatchInsertErrorLogs(ctx context.Context, inputs []*OpsInsertErrorLogInput) (int64, error) ListErrorLogs(ctx context.Context, filter *OpsErrorLogFilter) (*OpsErrorLogList, error) GetErrorLogByID(ctx context.Context, id int64) (*OpsErrorLogDetail, error) ListRequestDetails(ctx context.Context, filter *OpsRequestDetailFilter) ([]*OpsRequestDetail, int64, error) diff --git a/backend/internal/service/ops_query_mode.go b/backend/internal/service/ops_query_mode.go index e6fa9c1e..fa97f358 100644 --- a/backend/internal/service/ops_query_mode.go +++ b/backend/internal/service/ops_query_mode.go @@ -38,3 +38,18 @@ func (m OpsQueryMode) IsValid() bool { return false } } + +func shouldFallbackOpsPreagg(filter *OpsDashboardFilter, err error) bool { + return filter != nil && + filter.QueryMode == OpsQueryModeAuto && + errors.Is(err, ErrOpsPreaggregatedNotPopulated) +} + +func cloneOpsFilterWithMode(filter *OpsDashboardFilter, mode OpsQueryMode) *OpsDashboardFilter { + if filter == nil { + return nil + } + cloned := *filter + cloned.QueryMode = mode + return &cloned +} diff --git a/backend/internal/service/ops_query_mode_test.go b/backend/internal/service/ops_query_mode_test.go new file mode 100644 index 00000000..26c4b730 --- /dev/null +++ b/backend/internal/service/ops_query_mode_test.go @@ -0,0 +1,66 @@ +//go:build unit + +package service + +import ( + "errors" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestShouldFallbackOpsPreagg(t *testing.T) { + preaggErr := ErrOpsPreaggregatedNotPopulated + otherErr := errors.New("some other error") + + autoFilter := &OpsDashboardFilter{QueryMode: OpsQueryModeAuto} + rawFilter := &OpsDashboardFilter{QueryMode: OpsQueryModeRaw} + preaggFilter := &OpsDashboardFilter{QueryMode: OpsQueryModePreagg} + + tests := []struct { + name string + filter *OpsDashboardFilter + err error + want bool + }{ + {"auto mode + preagg error => fallback", autoFilter, preaggErr, true}, + {"auto mode + other error => no fallback", autoFilter, otherErr, false}, + {"auto mode + nil error => no fallback", autoFilter, nil, false}, + {"raw mode + preagg error => no fallback", rawFilter, preaggErr, false}, + {"preagg mode + preagg error => no fallback", preaggFilter, preaggErr, false}, + {"nil filter => no fallback", nil, preaggErr, false}, + {"wrapped preagg error => fallback", autoFilter, errors.Join(preaggErr, otherErr), true}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := shouldFallbackOpsPreagg(tc.filter, tc.err) + require.Equal(t, tc.want, got) + }) + } +} + +func TestCloneOpsFilterWithMode(t *testing.T) { + t.Run("nil filter returns nil", func(t *testing.T) { + require.Nil(t, cloneOpsFilterWithMode(nil, OpsQueryModeRaw)) + }) + + t.Run("cloned filter has new mode", func(t *testing.T) { + groupID := int64(42) + original := &OpsDashboardFilter{ + StartTime: time.Now(), + EndTime: time.Now().Add(time.Hour), + Platform: "anthropic", + GroupID: &groupID, + QueryMode: OpsQueryModeAuto, + } + + cloned := cloneOpsFilterWithMode(original, OpsQueryModeRaw) + require.Equal(t, OpsQueryModeRaw, cloned.QueryMode) + require.Equal(t, OpsQueryModeAuto, original.QueryMode, "original should not be modified") + require.Equal(t, original.Platform, cloned.Platform) + require.Equal(t, original.StartTime, cloned.StartTime) + require.Equal(t, original.GroupID, cloned.GroupID) + }) +} diff --git a/backend/internal/service/ops_repo_mock_test.go b/backend/internal/service/ops_repo_mock_test.go index e250dea3..c8c66ec6 100644 --- a/backend/internal/service/ops_repo_mock_test.go +++ b/backend/internal/service/ops_repo_mock_test.go @@ -7,6 +7,8 @@ import ( // opsRepoMock is a test-only OpsRepository implementation with optional function hooks. type opsRepoMock struct { + InsertErrorLogFn func(ctx context.Context, input *OpsInsertErrorLogInput) (int64, error) + BatchInsertErrorLogsFn func(ctx context.Context, inputs []*OpsInsertErrorLogInput) (int64, error) BatchInsertSystemLogsFn func(ctx context.Context, inputs []*OpsInsertSystemLogInput) (int64, error) ListSystemLogsFn func(ctx context.Context, filter *OpsSystemLogFilter) (*OpsSystemLogList, error) DeleteSystemLogsFn func(ctx context.Context, filter *OpsSystemLogCleanupFilter) (int64, error) @@ -14,9 +16,19 @@ type opsRepoMock struct { } func (m *opsRepoMock) InsertErrorLog(ctx context.Context, input *OpsInsertErrorLogInput) (int64, error) { + if m.InsertErrorLogFn != nil { + return m.InsertErrorLogFn(ctx, input) + } return 0, nil } +func (m *opsRepoMock) BatchInsertErrorLogs(ctx context.Context, inputs []*OpsInsertErrorLogInput) (int64, error) { + if m.BatchInsertErrorLogsFn != nil { + return m.BatchInsertErrorLogsFn(ctx, inputs) + } + return int64(len(inputs)), nil +} + func (m *opsRepoMock) ListErrorLogs(ctx context.Context, filter *OpsErrorLogFilter) (*OpsErrorLogList, error) { return &OpsErrorLogList{Errors: []*OpsErrorLog{}, Page: 1, PageSize: 20}, nil } diff --git a/backend/internal/service/ops_service.go b/backend/internal/service/ops_service.go index 767d1704..29f0aa8b 100644 --- a/backend/internal/service/ops_service.go +++ b/backend/internal/service/ops_service.go @@ -121,14 +121,74 @@ func (s *OpsService) IsMonitoringEnabled(ctx context.Context) bool { } func (s *OpsService) RecordError(ctx context.Context, entry *OpsInsertErrorLogInput, rawRequestBody []byte) error { - if entry == nil { + prepared, ok, err := s.prepareErrorLogInput(ctx, entry, rawRequestBody) + if err != nil { + log.Printf("[Ops] RecordError prepare failed: %v", err) + return err + } + if !ok { return nil } + + if _, err := s.opsRepo.InsertErrorLog(ctx, prepared); err != nil { + // Never bubble up to gateway; best-effort logging. + log.Printf("[Ops] RecordError failed: %v", err) + return err + } + return nil +} + +func (s *OpsService) RecordErrorBatch(ctx context.Context, entries []*OpsInsertErrorLogInput) error { + if len(entries) == 0 { + return nil + } + prepared := make([]*OpsInsertErrorLogInput, 0, len(entries)) + for _, entry := range entries { + item, ok, err := s.prepareErrorLogInput(ctx, entry, nil) + if err != nil { + log.Printf("[Ops] RecordErrorBatch prepare failed: %v", err) + continue + } + if ok { + prepared = append(prepared, item) + } + } + if len(prepared) == 0 { + return nil + } + if len(prepared) == 1 { + _, err := s.opsRepo.InsertErrorLog(ctx, prepared[0]) + if err != nil { + log.Printf("[Ops] RecordErrorBatch single insert failed: %v", err) + } + return err + } + + if _, err := s.opsRepo.BatchInsertErrorLogs(ctx, prepared); err != nil { + log.Printf("[Ops] RecordErrorBatch failed, fallback to single inserts: %v", err) + var firstErr error + for _, entry := range prepared { + if _, insertErr := s.opsRepo.InsertErrorLog(ctx, entry); insertErr != nil { + log.Printf("[Ops] RecordErrorBatch fallback insert failed: %v", insertErr) + if firstErr == nil { + firstErr = insertErr + } + } + } + return firstErr + } + return nil +} + +func (s *OpsService) prepareErrorLogInput(ctx context.Context, entry *OpsInsertErrorLogInput, rawRequestBody []byte) (*OpsInsertErrorLogInput, bool, error) { + if entry == nil { + return nil, false, nil + } if !s.IsMonitoringEnabled(ctx) { - return nil + return nil, false, nil } if s.opsRepo == nil { - return nil + return nil, false, nil } // Ensure timestamps are always populated. @@ -185,85 +245,88 @@ func (s *OpsService) RecordError(ctx context.Context, entry *OpsInsertErrorLogIn } } - // Sanitize + serialize upstream error events list. - if len(entry.UpstreamErrors) > 0 { - const maxEvents = 32 - events := entry.UpstreamErrors - if len(events) > maxEvents { - events = events[len(events)-maxEvents:] + if err := sanitizeOpsUpstreamErrors(entry); err != nil { + return nil, false, err + } + + return entry, true, nil +} + +func sanitizeOpsUpstreamErrors(entry *OpsInsertErrorLogInput) error { + if entry == nil || len(entry.UpstreamErrors) == 0 { + return nil + } + + const maxEvents = 32 + events := entry.UpstreamErrors + if len(events) > maxEvents { + events = events[len(events)-maxEvents:] + } + + sanitized := make([]*OpsUpstreamErrorEvent, 0, len(events)) + for _, ev := range events { + if ev == nil { + continue + } + out := *ev + + out.Platform = strings.TrimSpace(out.Platform) + out.UpstreamRequestID = truncateString(strings.TrimSpace(out.UpstreamRequestID), 128) + out.Kind = truncateString(strings.TrimSpace(out.Kind), 64) + + if out.AccountID < 0 { + out.AccountID = 0 + } + if out.UpstreamStatusCode < 0 { + out.UpstreamStatusCode = 0 + } + if out.AtUnixMs < 0 { + out.AtUnixMs = 0 } - sanitized := make([]*OpsUpstreamErrorEvent, 0, len(events)) - for _, ev := range events { - if ev == nil { - continue - } - out := *ev + msg := sanitizeUpstreamErrorMessage(strings.TrimSpace(out.Message)) + msg = truncateString(msg, 2048) + out.Message = msg - out.Platform = strings.TrimSpace(out.Platform) - out.UpstreamRequestID = truncateString(strings.TrimSpace(out.UpstreamRequestID), 128) - out.Kind = truncateString(strings.TrimSpace(out.Kind), 64) + detail := strings.TrimSpace(out.Detail) + if detail != "" { + // Keep upstream detail small; request bodies are not stored here, only upstream error payloads. + sanitizedDetail, _ := sanitizeErrorBodyForStorage(detail, opsMaxStoredErrorBodyBytes) + out.Detail = sanitizedDetail + } else { + out.Detail = "" + } - if out.AccountID < 0 { - out.AccountID = 0 - } - if out.UpstreamStatusCode < 0 { - out.UpstreamStatusCode = 0 - } - if out.AtUnixMs < 0 { - out.AtUnixMs = 0 - } - - msg := sanitizeUpstreamErrorMessage(strings.TrimSpace(out.Message)) - msg = truncateString(msg, 2048) - out.Message = msg - - detail := strings.TrimSpace(out.Detail) - if detail != "" { - // Keep upstream detail small; request bodies are not stored here, only upstream error payloads. - sanitizedDetail, _ := sanitizeErrorBodyForStorage(detail, opsMaxStoredErrorBodyBytes) - out.Detail = sanitizedDetail - } else { - out.Detail = "" - } - - out.UpstreamRequestBody = strings.TrimSpace(out.UpstreamRequestBody) - if out.UpstreamRequestBody != "" { - // Reuse the same sanitization/trimming strategy as request body storage. - // Keep it small so it is safe to persist in ops_error_logs JSON. - sanitized, truncated, _ := sanitizeAndTrimRequestBody([]byte(out.UpstreamRequestBody), 10*1024) - if sanitized != "" { - out.UpstreamRequestBody = sanitized - if truncated { - out.Kind = strings.TrimSpace(out.Kind) - if out.Kind == "" { - out.Kind = "upstream" - } - out.Kind = out.Kind + ":request_body_truncated" + out.UpstreamRequestBody = strings.TrimSpace(out.UpstreamRequestBody) + if out.UpstreamRequestBody != "" { + // Reuse the same sanitization/trimming strategy as request body storage. + // Keep it small so it is safe to persist in ops_error_logs JSON. + sanitizedBody, truncated, _ := sanitizeAndTrimRequestBody([]byte(out.UpstreamRequestBody), 10*1024) + if sanitizedBody != "" { + out.UpstreamRequestBody = sanitizedBody + if truncated { + out.Kind = strings.TrimSpace(out.Kind) + if out.Kind == "" { + out.Kind = "upstream" } - } else { - out.UpstreamRequestBody = "" + out.Kind = out.Kind + ":request_body_truncated" } + } else { + out.UpstreamRequestBody = "" } - - // Drop fully-empty events (can happen if only status code was known). - if out.UpstreamStatusCode == 0 && out.Message == "" && out.Detail == "" { - continue - } - - evCopy := out - sanitized = append(sanitized, &evCopy) } - entry.UpstreamErrorsJSON = marshalOpsUpstreamErrors(sanitized) - entry.UpstreamErrors = nil + // Drop fully-empty events (can happen if only status code was known). + if out.UpstreamStatusCode == 0 && out.Message == "" && out.Detail == "" { + continue + } + + evCopy := out + sanitized = append(sanitized, &evCopy) } - if _, err := s.opsRepo.InsertErrorLog(ctx, entry); err != nil { - // Never bubble up to gateway; best-effort logging. - log.Printf("[Ops] RecordError failed: %v", err) - return err - } + entry.UpstreamErrorsJSON = marshalOpsUpstreamErrors(sanitized) + entry.UpstreamErrors = nil return nil } diff --git a/backend/internal/service/ops_service_batch_test.go b/backend/internal/service/ops_service_batch_test.go new file mode 100644 index 00000000..f3a14d7f --- /dev/null +++ b/backend/internal/service/ops_service_batch_test.go @@ -0,0 +1,103 @@ +package service + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestOpsServiceRecordErrorBatch_SanitizesAndBatches(t *testing.T) { + t.Parallel() + + var captured []*OpsInsertErrorLogInput + repo := &opsRepoMock{ + BatchInsertErrorLogsFn: func(ctx context.Context, inputs []*OpsInsertErrorLogInput) (int64, error) { + captured = append(captured, inputs...) + return int64(len(inputs)), nil + }, + } + svc := NewOpsService(repo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + + msg := " upstream failed: https://example.com?access_token=secret-value " + detail := `{"authorization":"Bearer secret-token"}` + entries := []*OpsInsertErrorLogInput{ + { + ErrorBody: `{"error":"bad","access_token":"secret"}`, + UpstreamStatusCode: intPtr(-10), + UpstreamErrorMessage: strPtr(msg), + UpstreamErrorDetail: strPtr(detail), + UpstreamErrors: []*OpsUpstreamErrorEvent{ + { + AccountID: -2, + UpstreamStatusCode: 429, + Message: " token leaked ", + Detail: `{"refresh_token":"secret"}`, + UpstreamRequestBody: `{"api_key":"secret","messages":[{"role":"user","content":"hello"}]}`, + }, + }, + }, + { + ErrorPhase: "upstream", + ErrorType: "upstream_error", + CreatedAt: time.Now().UTC(), + }, + } + + require.NoError(t, svc.RecordErrorBatch(context.Background(), entries)) + require.Len(t, captured, 2) + + first := captured[0] + require.Equal(t, "internal", first.ErrorPhase) + require.Equal(t, "api_error", first.ErrorType) + require.Nil(t, first.UpstreamStatusCode) + require.NotNil(t, first.UpstreamErrorMessage) + require.NotContains(t, *first.UpstreamErrorMessage, "secret-value") + require.Contains(t, *first.UpstreamErrorMessage, "access_token=***") + require.NotNil(t, first.UpstreamErrorDetail) + require.NotContains(t, *first.UpstreamErrorDetail, "secret-token") + require.NotContains(t, first.ErrorBody, "secret") + require.Nil(t, first.UpstreamErrors) + require.NotNil(t, first.UpstreamErrorsJSON) + require.NotContains(t, *first.UpstreamErrorsJSON, "secret") + require.Contains(t, *first.UpstreamErrorsJSON, "[REDACTED]") + + second := captured[1] + require.Equal(t, "upstream", second.ErrorPhase) + require.Equal(t, "upstream_error", second.ErrorType) + require.False(t, second.CreatedAt.IsZero()) +} + +func TestOpsServiceRecordErrorBatch_FallsBackToSingleInsert(t *testing.T) { + t.Parallel() + + var ( + batchCalls int + singleCalls int + ) + repo := &opsRepoMock{ + BatchInsertErrorLogsFn: func(ctx context.Context, inputs []*OpsInsertErrorLogInput) (int64, error) { + batchCalls++ + return 0, errors.New("batch failed") + }, + InsertErrorLogFn: func(ctx context.Context, input *OpsInsertErrorLogInput) (int64, error) { + singleCalls++ + return int64(singleCalls), nil + }, + } + svc := NewOpsService(repo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + + err := svc.RecordErrorBatch(context.Background(), []*OpsInsertErrorLogInput{ + {ErrorMessage: "first"}, + {ErrorMessage: "second"}, + }) + require.NoError(t, err) + require.Equal(t, 1, batchCalls) + require.Equal(t, 2, singleCalls) +} + +func strPtr(v string) *string { + return &v +} diff --git a/backend/internal/service/ops_trends.go b/backend/internal/service/ops_trends.go index ec55c6ce..22db72ef 100644 --- a/backend/internal/service/ops_trends.go +++ b/backend/internal/service/ops_trends.go @@ -22,5 +22,13 @@ func (s *OpsService) GetThroughputTrend(ctx context.Context, filter *OpsDashboar if filter.StartTime.After(filter.EndTime) { return nil, infraerrors.BadRequest("OPS_TIME_RANGE_INVALID", "start_time must be <= end_time") } - return s.opsRepo.GetThroughputTrend(ctx, filter, bucketSeconds) + + filter.QueryMode = s.resolveOpsQueryMode(ctx, filter.QueryMode) + + result, err := s.opsRepo.GetThroughputTrend(ctx, filter, bucketSeconds) + if err != nil && shouldFallbackOpsPreagg(filter, err) { + rawFilter := cloneOpsFilterWithMode(filter, OpsQueryModeRaw) + return s.opsRepo.GetThroughputTrend(ctx, rawFilter, bucketSeconds) + } + return result, err } diff --git a/backend/internal/service/pricing_service.go b/backend/internal/service/pricing_service.go index 41e8b5eb..7ed4e7e4 100644 --- a/backend/internal/service/pricing_service.go +++ b/backend/internal/service/pricing_service.go @@ -21,18 +21,36 @@ import ( ) var ( - openAIModelDatePattern = regexp.MustCompile(`-\d{8}$`) - openAIModelBasePattern = regexp.MustCompile(`^(gpt-\d+(?:\.\d+)?)(?:-|$)`) + openAIModelDatePattern = regexp.MustCompile(`-\d{8}$`) + openAIModelBasePattern = regexp.MustCompile(`^(gpt-\d+(?:\.\d+)?)(?:-|$)`) + openAIGPT54FallbackPricing = &LiteLLMModelPricing{ + InputCostPerToken: 2.5e-06, // $2.5 per MTok + OutputCostPerToken: 1.5e-05, // $15 per MTok + CacheReadInputTokenCost: 2.5e-07, // $0.25 per MTok + LongContextInputTokenThreshold: 272000, + LongContextInputCostMultiplier: 2.0, + LongContextOutputCostMultiplier: 1.5, + LiteLLMProvider: "openai", + Mode: "chat", + SupportsPromptCaching: true, + } ) // LiteLLMModelPricing LiteLLM价格数据结构 // 只保留我们需要的字段,使用指针来处理可能缺失的值 type LiteLLMModelPricing struct { InputCostPerToken float64 `json:"input_cost_per_token"` + InputCostPerTokenPriority float64 `json:"input_cost_per_token_priority"` OutputCostPerToken float64 `json:"output_cost_per_token"` + OutputCostPerTokenPriority float64 `json:"output_cost_per_token_priority"` CacheCreationInputTokenCost float64 `json:"cache_creation_input_token_cost"` CacheCreationInputTokenCostAbove1hr float64 `json:"cache_creation_input_token_cost_above_1hr"` CacheReadInputTokenCost float64 `json:"cache_read_input_token_cost"` + CacheReadInputTokenCostPriority float64 `json:"cache_read_input_token_cost_priority"` + LongContextInputTokenThreshold int `json:"long_context_input_token_threshold,omitempty"` + LongContextInputCostMultiplier float64 `json:"long_context_input_cost_multiplier,omitempty"` + LongContextOutputCostMultiplier float64 `json:"long_context_output_cost_multiplier,omitempty"` + SupportsServiceTier bool `json:"supports_service_tier"` LiteLLMProvider string `json:"litellm_provider"` Mode string `json:"mode"` SupportsPromptCaching bool `json:"supports_prompt_caching"` @@ -48,10 +66,14 @@ type PricingRemoteClient interface { // LiteLLMRawEntry 用于解析原始JSON数据 type LiteLLMRawEntry struct { InputCostPerToken *float64 `json:"input_cost_per_token"` + InputCostPerTokenPriority *float64 `json:"input_cost_per_token_priority"` OutputCostPerToken *float64 `json:"output_cost_per_token"` + OutputCostPerTokenPriority *float64 `json:"output_cost_per_token_priority"` CacheCreationInputTokenCost *float64 `json:"cache_creation_input_token_cost"` CacheCreationInputTokenCostAbove1hr *float64 `json:"cache_creation_input_token_cost_above_1hr"` CacheReadInputTokenCost *float64 `json:"cache_read_input_token_cost"` + CacheReadInputTokenCostPriority *float64 `json:"cache_read_input_token_cost_priority"` + SupportsServiceTier bool `json:"supports_service_tier"` LiteLLMProvider string `json:"litellm_provider"` Mode string `json:"mode"` SupportsPromptCaching bool `json:"supports_prompt_caching"` @@ -310,14 +332,21 @@ func (s *PricingService) parsePricingData(body []byte) (map[string]*LiteLLMModel LiteLLMProvider: entry.LiteLLMProvider, Mode: entry.Mode, SupportsPromptCaching: entry.SupportsPromptCaching, + SupportsServiceTier: entry.SupportsServiceTier, } if entry.InputCostPerToken != nil { pricing.InputCostPerToken = *entry.InputCostPerToken } + if entry.InputCostPerTokenPriority != nil { + pricing.InputCostPerTokenPriority = *entry.InputCostPerTokenPriority + } if entry.OutputCostPerToken != nil { pricing.OutputCostPerToken = *entry.OutputCostPerToken } + if entry.OutputCostPerTokenPriority != nil { + pricing.OutputCostPerTokenPriority = *entry.OutputCostPerTokenPriority + } if entry.CacheCreationInputTokenCost != nil { pricing.CacheCreationInputTokenCost = *entry.CacheCreationInputTokenCost } @@ -327,6 +356,9 @@ func (s *PricingService) parsePricingData(body []byte) (map[string]*LiteLLMModel if entry.CacheReadInputTokenCost != nil { pricing.CacheReadInputTokenCost = *entry.CacheReadInputTokenCost } + if entry.CacheReadInputTokenCostPriority != nil { + pricing.CacheReadInputTokenCostPriority = *entry.CacheReadInputTokenCostPriority + } if entry.OutputCostPerImage != nil { pricing.OutputCostPerImage = *entry.OutputCostPerImage } @@ -660,7 +692,8 @@ func (s *PricingService) matchByModelFamily(model string) *LiteLLMModelPricing { // 2. gpt-5.2-codex -> gpt-5.2(去掉后缀如 -codex, -mini, -max 等) // 3. gpt-5.2-20251222 -> gpt-5.2(去掉日期版本号) // 4. gpt-5.3-codex -> gpt-5.2-codex -// 5. 最终回退到 DefaultTestModel (gpt-5.1-codex) +// 5. gpt-5.4* -> 业务静态兜底价 +// 6. 最终回退到 DefaultTestModel (gpt-5.1-codex) func (s *PricingService) matchOpenAIModel(model string) *LiteLLMModelPricing { if strings.HasPrefix(model, "gpt-5.3-codex-spark") { if pricing, ok := s.pricingData["gpt-5.1-codex"]; ok { @@ -690,6 +723,12 @@ func (s *PricingService) matchOpenAIModel(model string) *LiteLLMModelPricing { } } + if strings.HasPrefix(model, "gpt-5.4") { + logger.With(zap.String("component", "service.pricing")). + Info(fmt.Sprintf("[Pricing] OpenAI fallback matched %s -> %s", model, "gpt-5.4(static)")) + return openAIGPT54FallbackPricing + } + // 最终回退到 DefaultTestModel defaultModel := strings.ToLower(openai.DefaultTestModel) if pricing, ok := s.pricingData[defaultModel]; ok { diff --git a/backend/internal/service/pricing_service_test.go b/backend/internal/service/pricing_service_test.go index 127ff342..775024fd 100644 --- a/backend/internal/service/pricing_service_test.go +++ b/backend/internal/service/pricing_service_test.go @@ -1,11 +1,40 @@ package service import ( + "encoding/json" "testing" "github.com/stretchr/testify/require" ) +func TestParsePricingData_ParsesPriorityAndServiceTierFields(t *testing.T) { + svc := &PricingService{} + body := []byte(`{ + "gpt-5.4": { + "input_cost_per_token": 0.0000025, + "input_cost_per_token_priority": 0.000005, + "output_cost_per_token": 0.000015, + "output_cost_per_token_priority": 0.00003, + "cache_creation_input_token_cost": 0.0000025, + "cache_read_input_token_cost": 0.00000025, + "cache_read_input_token_cost_priority": 0.0000005, + "supports_service_tier": true, + "supports_prompt_caching": true, + "litellm_provider": "openai", + "mode": "chat" + } + }`) + + data, err := svc.parsePricingData(body) + require.NoError(t, err) + pricing := data["gpt-5.4"] + require.NotNil(t, pricing) + require.InDelta(t, 5e-6, pricing.InputCostPerTokenPriority, 1e-12) + require.InDelta(t, 3e-5, pricing.OutputCostPerTokenPriority, 1e-12) + require.InDelta(t, 5e-7, pricing.CacheReadInputTokenCostPriority, 1e-12) + require.True(t, pricing.SupportsServiceTier) +} + func TestGetModelPricing_Gpt53CodexSparkUsesGpt51CodexPricing(t *testing.T) { sparkPricing := &LiteLLMModelPricing{InputCostPerToken: 1} gpt53Pricing := &LiteLLMModelPricing{InputCostPerToken: 9} @@ -51,3 +80,81 @@ func TestGetModelPricing_OpenAIFallbackMatchedLoggedAsInfo(t *testing.T) { require.True(t, logSink.ContainsMessageAtLevel("[Pricing] OpenAI fallback matched gpt-5.3-codex -> gpt-5.2-codex", "info")) require.False(t, logSink.ContainsMessageAtLevel("[Pricing] OpenAI fallback matched gpt-5.3-codex -> gpt-5.2-codex", "warn")) } + +func TestGetModelPricing_Gpt54UsesStaticFallbackWhenRemoteMissing(t *testing.T) { + svc := &PricingService{ + pricingData: map[string]*LiteLLMModelPricing{ + "gpt-5.1-codex": &LiteLLMModelPricing{InputCostPerToken: 1.25e-6}, + }, + } + + got := svc.GetModelPricing("gpt-5.4") + require.NotNil(t, got) + require.InDelta(t, 2.5e-6, got.InputCostPerToken, 1e-12) + require.InDelta(t, 1.5e-5, got.OutputCostPerToken, 1e-12) + require.InDelta(t, 2.5e-7, got.CacheReadInputTokenCost, 1e-12) + require.Equal(t, 272000, got.LongContextInputTokenThreshold) + require.InDelta(t, 2.0, got.LongContextInputCostMultiplier, 1e-12) + require.InDelta(t, 1.5, got.LongContextOutputCostMultiplier, 1e-12) +} + +func TestParsePricingData_PreservesPriorityAndServiceTierFields(t *testing.T) { + raw := map[string]any{ + "gpt-5.4": map[string]any{ + "input_cost_per_token": 2.5e-6, + "input_cost_per_token_priority": 5e-6, + "output_cost_per_token": 15e-6, + "output_cost_per_token_priority": 30e-6, + "cache_read_input_token_cost": 0.25e-6, + "cache_read_input_token_cost_priority": 0.5e-6, + "supports_service_tier": true, + "supports_prompt_caching": true, + "litellm_provider": "openai", + "mode": "chat", + }, + } + body, err := json.Marshal(raw) + require.NoError(t, err) + + svc := &PricingService{} + pricingMap, err := svc.parsePricingData(body) + require.NoError(t, err) + + pricing := pricingMap["gpt-5.4"] + require.NotNil(t, pricing) + require.InDelta(t, 2.5e-6, pricing.InputCostPerToken, 1e-12) + require.InDelta(t, 5e-6, pricing.InputCostPerTokenPriority, 1e-12) + require.InDelta(t, 15e-6, pricing.OutputCostPerToken, 1e-12) + require.InDelta(t, 30e-6, pricing.OutputCostPerTokenPriority, 1e-12) + require.InDelta(t, 0.25e-6, pricing.CacheReadInputTokenCost, 1e-12) + require.InDelta(t, 0.5e-6, pricing.CacheReadInputTokenCostPriority, 1e-12) + require.True(t, pricing.SupportsServiceTier) +} + +func TestParsePricingData_PreservesServiceTierPriorityFields(t *testing.T) { + svc := &PricingService{} + pricingData, err := svc.parsePricingData([]byte(`{ + "gpt-5.4": { + "input_cost_per_token": 0.0000025, + "input_cost_per_token_priority": 0.000005, + "output_cost_per_token": 0.000015, + "output_cost_per_token_priority": 0.00003, + "cache_read_input_token_cost": 0.00000025, + "cache_read_input_token_cost_priority": 0.0000005, + "supports_service_tier": true, + "litellm_provider": "openai", + "mode": "chat" + } + }`)) + require.NoError(t, err) + + pricing := pricingData["gpt-5.4"] + require.NotNil(t, pricing) + require.InDelta(t, 0.0000025, pricing.InputCostPerToken, 1e-12) + require.InDelta(t, 0.000005, pricing.InputCostPerTokenPriority, 1e-12) + require.InDelta(t, 0.000015, pricing.OutputCostPerToken, 1e-12) + require.InDelta(t, 0.00003, pricing.OutputCostPerTokenPriority, 1e-12) + require.InDelta(t, 0.00000025, pricing.CacheReadInputTokenCost, 1e-12) + require.InDelta(t, 0.0000005, pricing.CacheReadInputTokenCostPriority, 1e-12) + require.True(t, pricing.SupportsServiceTier) +} diff --git a/backend/internal/service/prompts/codex_cli_instructions.md b/backend/internal/service/prompts/codex_cli_instructions.md deleted file mode 100644 index 4886c7ef..00000000 --- a/backend/internal/service/prompts/codex_cli_instructions.md +++ /dev/null @@ -1,275 +0,0 @@ -You are a coding agent running in the Codex CLI, a terminal-based coding assistant. Codex CLI is an open source project led by OpenAI. You are expected to be precise, safe, and helpful. - -Your capabilities: - -- Receive user prompts and other context provided by the harness, such as files in the workspace. -- Communicate with the user by streaming thinking & responses, and by making & updating plans. -- Emit function calls to run terminal commands and apply patches. Depending on how this specific run is configured, you can request that these function calls be escalated to the user for approval before running. More on this in the "Sandbox and approvals" section. - -Within this context, Codex refers to the open-source agentic coding interface (not the old Codex language model built by OpenAI). - -# How you work - -## Personality - -Your default personality and tone is concise, direct, and friendly. You communicate efficiently, always keeping the user clearly informed about ongoing actions without unnecessary detail. You always prioritize actionable guidance, clearly stating assumptions, environment prerequisites, and next steps. Unless explicitly asked, you avoid excessively verbose explanations about your work. - -# AGENTS.md spec -- Repos often contain AGENTS.md files. These files can appear anywhere within the repository. -- These files are a way for humans to give you (the agent) instructions or tips for working within the container. -- Some examples might be: coding conventions, info about how code is organized, or instructions for how to run or test code. -- Instructions in AGENTS.md files: - - The scope of an AGENTS.md file is the entire directory tree rooted at the folder that contains it. - - For every file you touch in the final patch, you must obey instructions in any AGENTS.md file whose scope includes that file. - - Instructions about code style, structure, naming, etc. apply only to code within the AGENTS.md file's scope, unless the file states otherwise. - - More-deeply-nested AGENTS.md files take precedence in the case of conflicting instructions. - - Direct system/developer/user instructions (as part of a prompt) take precedence over AGENTS.md instructions. -- The contents of the AGENTS.md file at the root of the repo and any directories from the CWD up to the root are included with the developer message and don't need to be re-read. When working in a subdirectory of CWD, or a directory outside the CWD, check for any AGENTS.md files that may be applicable. - -## Responsiveness - -### Preamble messages - -Before making tool calls, send a brief preamble to the user explaining what you’re about to do. When sending preamble messages, follow these principles and examples: - -- **Logically group related actions**: if you’re about to run several related commands, describe them together in one preamble rather than sending a separate note for each. -- **Keep it concise**: be no more than 1-2 sentences, focused on immediate, tangible next steps. (8–12 words for quick updates). -- **Build on prior context**: if this is not your first tool call, use the preamble message to connect the dots with what’s been done so far and create a sense of momentum and clarity for the user to understand your next actions. -- **Keep your tone light, friendly and curious**: add small touches of personality in preambles feel collaborative and engaging. -- **Exception**: Avoid adding a preamble for every trivial read (e.g., `cat` a single file) unless it’s part of a larger grouped action. - -**Examples:** - -- “I’ve explored the repo; now checking the API route definitions.” -- “Next, I’ll patch the config and update the related tests.” -- “I’m about to scaffold the CLI commands and helper functions.” -- “Ok cool, so I’ve wrapped my head around the repo. Now digging into the API routes.” -- “Config’s looking tidy. Next up is patching helpers to keep things in sync.” -- “Finished poking at the DB gateway. I will now chase down error handling.” -- “Alright, build pipeline order is interesting. Checking how it reports failures.” -- “Spotted a clever caching util; now hunting where it gets used.” - -## Planning - -You have access to an `update_plan` tool which tracks steps and progress and renders them to the user. Using the tool helps demonstrate that you've understood the task and convey how you're approaching it. Plans can help to make complex, ambiguous, or multi-phase work clearer and more collaborative for the user. A good plan should break the task into meaningful, logically ordered steps that are easy to verify as you go. - -Note that plans are not for padding out simple work with filler steps or stating the obvious. The content of your plan should not involve doing anything that you aren't capable of doing (i.e. don't try to test things that you can't test). Do not use plans for simple or single-step queries that you can just do or answer immediately. - -Do not repeat the full contents of the plan after an `update_plan` call — the harness already displays it. Instead, summarize the change made and highlight any important context or next step. - -Before running a command, consider whether or not you have completed the previous step, and make sure to mark it as completed before moving on to the next step. It may be the case that you complete all steps in your plan after a single pass of implementation. If this is the case, you can simply mark all the planned steps as completed. Sometimes, you may need to change plans in the middle of a task: call `update_plan` with the updated plan and make sure to provide an `explanation` of the rationale when doing so. - -Use a plan when: - -- The task is non-trivial and will require multiple actions over a long time horizon. -- There are logical phases or dependencies where sequencing matters. -- The work has ambiguity that benefits from outlining high-level goals. -- You want intermediate checkpoints for feedback and validation. -- When the user asked you to do more than one thing in a single prompt -- The user has asked you to use the plan tool (aka "TODOs") -- You generate additional steps while working, and plan to do them before yielding to the user - -### Examples - -**High-quality plans** - -Example 1: - -1. Add CLI entry with file args -2. Parse Markdown via CommonMark library -3. Apply semantic HTML template -4. Handle code blocks, images, links -5. Add error handling for invalid files - -Example 2: - -1. Define CSS variables for colors -2. Add toggle with localStorage state -3. Refactor components to use variables -4. Verify all views for readability -5. Add smooth theme-change transition - -Example 3: - -1. Set up Node.js + WebSocket server -2. Add join/leave broadcast events -3. Implement messaging with timestamps -4. Add usernames + mention highlighting -5. Persist messages in lightweight DB -6. Add typing indicators + unread count - -**Low-quality plans** - -Example 1: - -1. Create CLI tool -2. Add Markdown parser -3. Convert to HTML - -Example 2: - -1. Add dark mode toggle -2. Save preference -3. Make styles look good - -Example 3: - -1. Create single-file HTML game -2. Run quick sanity check -3. Summarize usage instructions - -If you need to write a plan, only write high quality plans, not low quality ones. - -## Task execution - -You are a coding agent. Please keep going until the query is completely resolved, before ending your turn and yielding back to the user. Only terminate your turn when you are sure that the problem is solved. Autonomously resolve the query to the best of your ability, using the tools available to you, before coming back to the user. Do NOT guess or make up an answer. - -You MUST adhere to the following criteria when solving queries: - -- Working on the repo(s) in the current environment is allowed, even if they are proprietary. -- Analyzing code for vulnerabilities is allowed. -- Showing user code and tool call details is allowed. -- Use the `apply_patch` tool to edit files (NEVER try `applypatch` or `apply-patch`, only `apply_patch`): {"command":["apply_patch","*** Begin Patch\\n*** Update File: path/to/file.py\\n@@ def example():\\n- pass\\n+ return 123\\n*** End Patch"]} - -If completing the user's task requires writing or modifying files, your code and final answer should follow these coding guidelines, though user instructions (i.e. AGENTS.md) may override these guidelines: - -- Fix the problem at the root cause rather than applying surface-level patches, when possible. -- Avoid unneeded complexity in your solution. -- Do not attempt to fix unrelated bugs or broken tests. It is not your responsibility to fix them. (You may mention them to the user in your final message though.) -- Update documentation as necessary. -- Keep changes consistent with the style of the existing codebase. Changes should be minimal and focused on the task. -- Use `git log` and `git blame` to search the history of the codebase if additional context is required. -- NEVER add copyright or license headers unless specifically requested. -- Do not waste tokens by re-reading files after calling `apply_patch` on them. The tool call will fail if it didn't work. The same goes for making folders, deleting folders, etc. -- Do not `git commit` your changes or create new git branches unless explicitly requested. -- Do not add inline comments within code unless explicitly requested. -- Do not use one-letter variable names unless explicitly requested. -- NEVER output inline citations like "【F:README.md†L5-L14】" in your outputs. The CLI is not able to render these so they will just be broken in the UI. Instead, if you output valid filepaths, users will be able to click on them to open the files in their editor. - -## Validating your work - -If the codebase has tests or the ability to build or run, consider using them to verify that your work is complete. - -When testing, your philosophy should be to start as specific as possible to the code you changed so that you can catch issues efficiently, then make your way to broader tests as you build confidence. If there's no test for the code you changed, and if the adjacent patterns in the codebases show that there's a logical place for you to add a test, you may do so. However, do not add tests to codebases with no tests. - -Similarly, once you're confident in correctness, you can suggest or use formatting commands to ensure that your code is well formatted. If there are issues you can iterate up to 3 times to get formatting right, but if you still can't manage it's better to save the user time and present them a correct solution where you call out the formatting in your final message. If the codebase does not have a formatter configured, do not add one. - -For all of testing, running, building, and formatting, do not attempt to fix unrelated bugs. It is not your responsibility to fix them. (You may mention them to the user in your final message though.) - -Be mindful of whether to run validation commands proactively. In the absence of behavioral guidance: - -- When running in non-interactive approval modes like **never** or **on-failure**, proactively run tests, lint and do whatever you need to ensure you've completed the task. -- When working in interactive approval modes like **untrusted**, or **on-request**, hold off on running tests or lint commands until the user is ready for you to finalize your output, because these commands take time to run and slow down iteration. Instead suggest what you want to do next, and let the user confirm first. -- When working on test-related tasks, such as adding tests, fixing tests, or reproducing a bug to verify behavior, you may proactively run tests regardless of approval mode. Use your judgement to decide whether this is a test-related task. - -## Ambition vs. precision - -For tasks that have no prior context (i.e. the user is starting something brand new), you should feel free to be ambitious and demonstrate creativity with your implementation. - -If you're operating in an existing codebase, you should make sure you do exactly what the user asks with surgical precision. Treat the surrounding codebase with respect, and don't overstep (i.e. changing filenames or variables unnecessarily). You should balance being sufficiently ambitious and proactive when completing tasks of this nature. - -You should use judicious initiative to decide on the right level of detail and complexity to deliver based on the user's needs. This means showing good judgment that you're capable of doing the right extras without gold-plating. This might be demonstrated by high-value, creative touches when scope of the task is vague; while being surgical and targeted when scope is tightly specified. - -## Sharing progress updates - -For especially longer tasks that you work on (i.e. requiring many tool calls, or a plan with multiple steps), you should provide progress updates back to the user at reasonable intervals. These updates should be structured as a concise sentence or two (no more than 8-10 words long) recapping progress so far in plain language: this update demonstrates your understanding of what needs to be done, progress so far (i.e. files explores, subtasks complete), and where you're going next. - -Before doing large chunks of work that may incur latency as experienced by the user (i.e. writing a new file), you should send a concise message to the user with an update indicating what you're about to do to ensure they know what you're spending time on. Don't start editing or writing large files before informing the user what you are doing and why. - -The messages you send before tool calls should describe what is immediately about to be done next in very concise language. If there was previous work done, this preamble message should also include a note about the work done so far to bring the user along. - -## Presenting your work and final message - -Your final message should read naturally, like an update from a concise teammate. For casual conversation, brainstorming tasks, or quick questions from the user, respond in a friendly, conversational tone. You should ask questions, suggest ideas, and adapt to the user’s style. If you've finished a large amount of work, when describing what you've done to the user, you should follow the final answer formatting guidelines to communicate substantive changes. You don't need to add structured formatting for one-word answers, greetings, or purely conversational exchanges. - -You can skip heavy formatting for single, simple actions or confirmations. In these cases, respond in plain sentences with any relevant next step or quick option. Reserve multi-section structured responses for results that need grouping or explanation. - -The user is working on the same computer as you, and has access to your work. As such there's no need to show the full contents of large files you have already written unless the user explicitly asks for them. Similarly, if you've created or modified files using `apply_patch`, there's no need to tell users to "save the file" or "copy the code into a file"—just reference the file path. - -If there's something that you think you could help with as a logical next step, concisely ask the user if they want you to do so. Good examples of this are running tests, committing changes, or building out the next logical component. If there’s something that you couldn't do (even with approval) but that the user might want to do (such as verifying changes by running the app), include those instructions succinctly. - -Brevity is very important as a default. You should be very concise (i.e. no more than 10 lines), but can relax this requirement for tasks where additional detail and comprehensiveness is important for the user's understanding. - -### Final answer structure and style guidelines - -You are producing plain text that will later be styled by the CLI. Follow these rules exactly. Formatting should make results easy to scan, but not feel mechanical. Use judgment to decide how much structure adds value. - -**Section Headers** - -- Use only when they improve clarity — they are not mandatory for every answer. -- Choose descriptive names that fit the content -- Keep headers short (1–3 words) and in `**Title Case**`. Always start headers with `**` and end with `**` -- Leave no blank line before the first bullet under a header. -- Section headers should only be used where they genuinely improve scanability; avoid fragmenting the answer. - -**Bullets** - -- Use `-` followed by a space for every bullet. -- Merge related points when possible; avoid a bullet for every trivial detail. -- Keep bullets to one line unless breaking for clarity is unavoidable. -- Group into short lists (4–6 bullets) ordered by importance. -- Use consistent keyword phrasing and formatting across sections. - -**Monospace** - -- Wrap all commands, file paths, env vars, and code identifiers in backticks (`` `...` ``). -- Apply to inline examples and to bullet keywords if the keyword itself is a literal file/command. -- Never mix monospace and bold markers; choose one based on whether it’s a keyword (`**`) or inline code/path (`` ` ``). - -**File References** -When referencing files in your response, make sure to include the relevant start line and always follow the below rules: - * Use inline code to make file paths clickable. - * Each reference should have a stand alone path. Even if it's the same file. - * Accepted: absolute, workspace‑relative, a/ or b/ diff prefixes, or bare filename/suffix. - * Line/column (1‑based, optional): :line[:column] or #Lline[Ccolumn] (column defaults to 1). - * Do not use URIs like file://, vscode://, or https://. - * Do not provide range of lines - * Examples: src/app.ts, src/app.ts:42, b/server/index.js#L10, C:\repo\project\main.rs:12:5 - -**Structure** - -- Place related bullets together; don’t mix unrelated concepts in the same section. -- Order sections from general → specific → supporting info. -- For subsections (e.g., “Binaries” under “Rust Workspace”), introduce with a bolded keyword bullet, then list items under it. -- Match structure to complexity: - - Multi-part or detailed results → use clear headers and grouped bullets. - - Simple results → minimal headers, possibly just a short list or paragraph. - -**Tone** - -- Keep the voice collaborative and natural, like a coding partner handing off work. -- Be concise and factual — no filler or conversational commentary and avoid unnecessary repetition -- Use present tense and active voice (e.g., “Runs tests” not “This will run tests”). -- Keep descriptions self-contained; don’t refer to “above” or “below”. -- Use parallel structure in lists for consistency. - -**Don’t** - -- Don’t use literal words “bold” or “monospace” in the content. -- Don’t nest bullets or create deep hierarchies. -- Don’t output ANSI escape codes directly — the CLI renderer applies them. -- Don’t cram unrelated keywords into a single bullet; split for clarity. -- Don’t let keyword lists run long — wrap or reformat for scanability. - -Generally, ensure your final answers adapt their shape and depth to the request. For example, answers to code explanations should have a precise, structured explanation with code references that answer the question directly. For tasks with a simple implementation, lead with the outcome and supplement only with what’s needed for clarity. Larger changes can be presented as a logical walkthrough of your approach, grouping related steps, explaining rationale where it adds value, and highlighting next actions to accelerate the user. Your answers should provide the right level of detail while being easily scannable. - -For casual greetings, acknowledgements, or other one-off conversational messages that are not delivering substantive information or structured results, respond naturally without section headers or bullet formatting. - -# Tool Guidelines - -## Shell commands - -When using the shell, you must adhere to the following guidelines: - -- When searching for text or files, prefer using `rg` or `rg --files` respectively because `rg` is much faster than alternatives like `grep`. (If the `rg` command is not found, then use alternatives.) -- Do not use python scripts to attempt to output larger chunks of a file. - -## `update_plan` - -A tool named `update_plan` is available to you. You can use it to keep an up‑to‑date, step‑by‑step plan for the task. - -To create a new plan, call `update_plan` with a short list of 1‑sentence steps (no more than 5-7 words each) with a `status` for each step (`pending`, `in_progress`, or `completed`). - -When steps have been completed, use `update_plan` to mark each finished step as `completed` and the next step you are working on as `in_progress`. There should always be exactly one `in_progress` step until everything is done. You can mark multiple items as complete in a single `update_plan` call. - -If all steps are complete, ensure you call `update_plan` to mark all steps as `completed`. diff --git a/backend/internal/service/ratelimit_service.go b/backend/internal/service/ratelimit_service.go index 96e30db2..5df2d639 100644 --- a/backend/internal/service/ratelimit_service.go +++ b/backend/internal/service/ratelimit_service.go @@ -28,6 +28,17 @@ type RateLimitService struct { usageCache map[int64]*geminiUsageCacheEntry } +// SuccessfulTestRecoveryResult 表示测试成功后恢复了哪些运行时状态。 +type SuccessfulTestRecoveryResult struct { + ClearedError bool + ClearedRateLimit bool +} + +// AccountRecoveryOptions 控制账号恢复时的附加行为。 +type AccountRecoveryOptions struct { + InvalidateToken bool +} + type geminiUsageCacheEntry struct { windowStart time.Time cachedAt time.Time @@ -87,6 +98,9 @@ func (s *RateLimitService) CheckErrorPolicy(ctx context.Context, account *Accoun slog.Info("account_error_code_skipped", "account_id", account.ID, "status_code", statusCode) return ErrorPolicySkipped } + if account.IsPoolMode() { + return ErrorPolicySkipped + } if s.tryTempUnschedulable(ctx, account, statusCode, responseBody) { return ErrorPolicyTempUnscheduled } @@ -96,9 +110,16 @@ func (s *RateLimitService) CheckErrorPolicy(ctx context.Context, account *Accoun // HandleUpstreamError 处理上游错误响应,标记账号状态 // 返回是否应该停止该账号的调度 func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, responseBody []byte) (shouldDisable bool) { + customErrorCodesEnabled := account.IsCustomErrorCodesEnabled() + + // 池模式默认不标记本地账号状态;仅当用户显式配置自定义错误码时按本地策略处理。 + if account.IsPoolMode() && !customErrorCodesEnabled { + slog.Info("pool_mode_error_skipped", "account_id", account.ID, "status_code", statusCode) + return false + } + // apikey 类型账号:检查自定义错误码配置 // 如果启用且错误码不在列表中,则不处理(不停止调度、不标记限流/过载) - customErrorCodesEnabled := account.IsCustomErrorCodesEnabled() if !account.ShouldHandleErrorCode(statusCode) { slog.Info("account_error_code_skipped", "account_id", account.ID, "status_code", statusCode) return false @@ -615,6 +636,7 @@ func (s *RateLimitService) handleCustomErrorCode(ctx context.Context, account *A func (s *RateLimitService) handle429(ctx context.Context, account *Account, headers http.Header, responseBody []byte) { // 1. OpenAI 平台:优先尝试解析 x-codex-* 响应头(用于 rate_limit_exceeded) if account.Platform == PlatformOpenAI { + s.persistOpenAICodexSnapshot(ctx, account, headers) if resetAt := s.calculateOpenAI429ResetTime(headers); resetAt != nil { if err := s.accountRepo.SetRateLimited(ctx, account.ID, *resetAt); err != nil { slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err) @@ -878,6 +900,23 @@ func pickSooner(a, b *time.Time) *time.Time { } } +func (s *RateLimitService) persistOpenAICodexSnapshot(ctx context.Context, account *Account, headers http.Header) { + if s == nil || s.accountRepo == nil || account == nil || headers == nil { + return + } + snapshot := ParseCodexRateLimitHeaders(headers) + if snapshot == nil { + return + } + updates := buildCodexUsageExtraUpdates(snapshot, time.Now()) + if len(updates) == 0 { + return + } + if err := s.accountRepo.UpdateExtra(ctx, account.ID, updates); err != nil { + slog.Warn("openai_codex_snapshot_persist_failed", "account_id", account.ID, "error", err) + } +} + // parseOpenAIRateLimitResetTime 解析 OpenAI 格式的 429 响应,返回重置时间的 Unix 时间戳 // OpenAI 的 usage_limit_reached 错误格式: // @@ -970,12 +1009,27 @@ func (s *RateLimitService) UpdateSessionWindow(ctx context.Context, account *Acc windowStart = &start windowEnd = &end slog.Info("account_session_window_initialized", "account_id", account.ID, "window_start", start, "window_end", end, "status", status) + // 窗口重置时清除旧的 utilization,避免残留上个窗口的数据 + _ = s.accountRepo.UpdateExtra(ctx, account.ID, map[string]any{ + "session_window_utilization": nil, + }) } if err := s.accountRepo.UpdateSessionWindow(ctx, account.ID, windowStart, windowEnd, status); err != nil { slog.Warn("session_window_update_failed", "account_id", account.ID, "error", err) } + // 存储真实的 utilization 值(0-1 小数),供 estimateSetupTokenUsage 使用 + if utilStr := headers.Get("anthropic-ratelimit-unified-5h-utilization"); utilStr != "" { + if util, err := strconv.ParseFloat(utilStr, 64); err == nil { + if err := s.accountRepo.UpdateExtra(ctx, account.ID, map[string]any{ + "session_window_utilization": util, + }); err != nil { + slog.Warn("session_window_utilization_update_failed", "account_id", account.ID, "error", err) + } + } + } + // 如果状态为allowed且之前有限流,说明窗口已重置,清除限流状态 if status == "allowed" && account.IsRateLimited() { if err := s.ClearRateLimit(ctx, account.ID); err != nil { @@ -1007,6 +1061,42 @@ func (s *RateLimitService) ClearRateLimit(ctx context.Context, accountID int64) return nil } +// RecoverAccountState 按需恢复账号的可恢复运行时状态。 +func (s *RateLimitService) RecoverAccountState(ctx context.Context, accountID int64, options AccountRecoveryOptions) (*SuccessfulTestRecoveryResult, error) { + account, err := s.accountRepo.GetByID(ctx, accountID) + if err != nil { + return nil, err + } + + result := &SuccessfulTestRecoveryResult{} + if account.Status == StatusError { + if err := s.accountRepo.ClearError(ctx, accountID); err != nil { + return nil, err + } + result.ClearedError = true + if options.InvalidateToken && s.tokenCacheInvalidator != nil && account.IsOAuth() { + if invalidateErr := s.tokenCacheInvalidator.InvalidateToken(ctx, account); invalidateErr != nil { + slog.Warn("recover_account_state_invalidate_token_failed", "account_id", accountID, "error", invalidateErr) + } + } + } + + if hasRecoverableRuntimeState(account) { + if err := s.ClearRateLimit(ctx, accountID); err != nil { + return nil, err + } + result.ClearedRateLimit = true + } + + return result, nil +} + +// RecoverAccountAfterSuccessfulTest 将一次成功测试视为正常请求, +// 按需恢复 error / rate-limit / overload / temp-unsched / model-rate-limit 等运行时状态。 +func (s *RateLimitService) RecoverAccountAfterSuccessfulTest(ctx context.Context, accountID int64) (*SuccessfulTestRecoveryResult, error) { + return s.RecoverAccountState(ctx, accountID, AccountRecoveryOptions{}) +} + func (s *RateLimitService) ClearTempUnschedulable(ctx context.Context, accountID int64) error { if err := s.accountRepo.ClearTempUnschedulable(ctx, accountID); err != nil { return err @@ -1023,6 +1113,36 @@ func (s *RateLimitService) ClearTempUnschedulable(ctx context.Context, accountID return nil } +func hasRecoverableRuntimeState(account *Account) bool { + if account == nil { + return false + } + if account.RateLimitedAt != nil || account.RateLimitResetAt != nil || account.OverloadUntil != nil || account.TempUnschedulableUntil != nil { + return true + } + if len(account.Extra) == 0 { + return false + } + return hasNonEmptyMapValue(account.Extra, "model_rate_limits") || hasNonEmptyMapValue(account.Extra, "antigravity_quota_scopes") +} + +func hasNonEmptyMapValue(extra map[string]any, key string) bool { + raw, ok := extra[key] + if !ok || raw == nil { + return false + } + switch typed := raw.(type) { + case map[string]any: + return len(typed) > 0 + case map[string]string: + return len(typed) > 0 + case []any: + return len(typed) > 0 + default: + return true + } +} + func (s *RateLimitService) GetTempUnschedStatus(ctx context.Context, accountID int64) (*TempUnschedState, error) { now := time.Now().Unix() if s.tempUnschedCache != nil { @@ -1091,6 +1211,22 @@ func (s *RateLimitService) tryTempUnschedulable(ctx context.Context, account *Ac if !account.IsTempUnschedulableEnabled() { return false } + // 401 首次命中可临时不可调度(给 token 刷新窗口); + // 若历史上已因 401 进入过临时不可调度,则本次应升级为 error(返回 false 交由默认错误逻辑处理)。 + if statusCode == http.StatusUnauthorized { + reason := account.TempUnschedulableReason + // 缓存可能没有 reason,从 DB 回退读取 + if reason == "" { + if dbAcc, err := s.accountRepo.GetByID(ctx, account.ID); err == nil && dbAcc != nil { + reason = dbAcc.TempUnschedulableReason + } + } + if wasTempUnschedByStatusCode(reason, statusCode) { + slog.Info("401_escalated_to_error", "account_id", account.ID, + "reason", "previous temp-unschedulable was also 401") + return false + } + } rules := account.GetTempUnschedulableRules() if len(rules) == 0 { return false @@ -1122,6 +1258,22 @@ func (s *RateLimitService) tryTempUnschedulable(ctx context.Context, account *Ac return false } +func wasTempUnschedByStatusCode(reason string, statusCode int) bool { + if statusCode <= 0 { + return false + } + reason = strings.TrimSpace(reason) + if reason == "" { + return false + } + + var state TempUnschedState + if err := json.Unmarshal([]byte(reason), &state); err != nil { + return false + } + return state.StatusCode == statusCode +} + func matchTempUnschedKeyword(bodyLower string, keywords []string) string { if bodyLower == "" { return "" diff --git a/backend/internal/service/ratelimit_service_401_db_fallback_test.go b/backend/internal/service/ratelimit_service_401_db_fallback_test.go new file mode 100644 index 00000000..e1611425 --- /dev/null +++ b/backend/internal/service/ratelimit_service_401_db_fallback_test.go @@ -0,0 +1,119 @@ +//go:build unit + +package service + +import ( + "context" + "net/http" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +// dbFallbackRepoStub extends errorPolicyRepoStub with a configurable DB account +// returned by GetByID, simulating cache miss + DB fallback. +type dbFallbackRepoStub struct { + errorPolicyRepoStub + dbAccount *Account // returned by GetByID when non-nil +} + +func (r *dbFallbackRepoStub) GetByID(ctx context.Context, id int64) (*Account, error) { + if r.dbAccount != nil && r.dbAccount.ID == id { + return r.dbAccount, nil + } + return nil, nil // not found, no error +} + +func TestCheckErrorPolicy_401_DBFallback_Escalates(t *testing.T) { + // Scenario: cache account has empty TempUnschedulableReason (cache miss), + // but DB account has a previous 401 record → should escalate to ErrorPolicyNone. + repo := &dbFallbackRepoStub{ + dbAccount: &Account{ + ID: 20, + TempUnschedulableReason: `{"status_code":401,"until_unix":1735689600}`, + }, + } + svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + + account := &Account{ + ID: 20, + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + TempUnschedulableReason: "", // cache miss — reason is empty + Credentials: map[string]any{ + "temp_unschedulable_enabled": true, + "temp_unschedulable_rules": []any{ + map[string]any{ + "error_code": float64(401), + "keywords": []any{"unauthorized"}, + "duration_minutes": float64(10), + }, + }, + }, + } + + result := svc.CheckErrorPolicy(context.Background(), account, http.StatusUnauthorized, []byte(`unauthorized`)) + require.Equal(t, ErrorPolicyNone, result, "401 with DB fallback showing previous 401 should escalate to ErrorPolicyNone") +} + +func TestCheckErrorPolicy_401_DBFallback_NoDBRecord_FirstHit(t *testing.T) { + // Scenario: cache account has empty TempUnschedulableReason, + // DB also has no previous 401 record → should NOT escalate (first hit → temp unscheduled). + repo := &dbFallbackRepoStub{ + dbAccount: &Account{ + ID: 21, + TempUnschedulableReason: "", // DB also empty + }, + } + svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + + account := &Account{ + ID: 21, + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + TempUnschedulableReason: "", + Credentials: map[string]any{ + "temp_unschedulable_enabled": true, + "temp_unschedulable_rules": []any{ + map[string]any{ + "error_code": float64(401), + "keywords": []any{"unauthorized"}, + "duration_minutes": float64(10), + }, + }, + }, + } + + result := svc.CheckErrorPolicy(context.Background(), account, http.StatusUnauthorized, []byte(`unauthorized`)) + require.Equal(t, ErrorPolicyTempUnscheduled, result, "401 first hit with no DB record should temp-unschedule") +} + +func TestCheckErrorPolicy_401_DBFallback_DBError_FirstHit(t *testing.T) { + // Scenario: cache account has empty TempUnschedulableReason, + // DB lookup returns nil (not found) → should treat as first hit → temp unscheduled. + repo := &dbFallbackRepoStub{ + dbAccount: nil, // GetByID returns nil, nil + } + svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + + account := &Account{ + ID: 22, + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + TempUnschedulableReason: "", + Credentials: map[string]any{ + "temp_unschedulable_enabled": true, + "temp_unschedulable_rules": []any{ + map[string]any{ + "error_code": float64(401), + "keywords": []any{"unauthorized"}, + "duration_minutes": float64(10), + }, + }, + }, + } + + result := svc.CheckErrorPolicy(context.Background(), account, http.StatusUnauthorized, []byte(`unauthorized`)) + require.Equal(t, ErrorPolicyTempUnscheduled, result, "401 first hit with DB not found should temp-unschedule") +} diff --git a/backend/internal/service/ratelimit_service_clear_test.go b/backend/internal/service/ratelimit_service_clear_test.go index f48151ed..1d7a02fc 100644 --- a/backend/internal/service/ratelimit_service_clear_test.go +++ b/backend/internal/service/ratelimit_service_clear_test.go @@ -6,6 +6,7 @@ import ( "context" "errors" "testing" + "time" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/stretchr/testify/require" @@ -13,16 +14,34 @@ import ( type rateLimitClearRepoStub struct { mockAccountRepoForGemini + getByIDAccount *Account + getByIDErr error + getByIDCalls int + clearErrorCalls int clearRateLimitCalls int clearAntigravityCalls int clearModelRateLimitCalls int clearTempUnschedCalls int + clearErrorErr error clearRateLimitErr error clearAntigravityErr error clearModelRateLimitErr error clearTempUnschedulableErr error } +func (r *rateLimitClearRepoStub) GetByID(ctx context.Context, id int64) (*Account, error) { + r.getByIDCalls++ + if r.getByIDErr != nil { + return nil, r.getByIDErr + } + return r.getByIDAccount, nil +} + +func (r *rateLimitClearRepoStub) ClearError(ctx context.Context, id int64) error { + r.clearErrorCalls++ + return r.clearErrorErr +} + func (r *rateLimitClearRepoStub) ClearRateLimit(ctx context.Context, id int64) error { r.clearRateLimitCalls++ return r.clearRateLimitErr @@ -48,6 +67,11 @@ type tempUnschedCacheRecorder struct { deleteErr error } +type recoverTokenInvalidatorStub struct { + accounts []*Account + err error +} + func (c *tempUnschedCacheRecorder) SetTempUnsched(ctx context.Context, accountID int64, state *TempUnschedState) error { return nil } @@ -61,6 +85,11 @@ func (c *tempUnschedCacheRecorder) DeleteTempUnsched(ctx context.Context, accoun return c.deleteErr } +func (s *recoverTokenInvalidatorStub) InvalidateToken(ctx context.Context, account *Account) error { + s.accounts = append(s.accounts, account) + return s.err +} + func TestRateLimitService_ClearRateLimit_AlsoClearsTempUnschedulable(t *testing.T) { repo := &rateLimitClearRepoStub{} cache := &tempUnschedCacheRecorder{} @@ -170,3 +199,108 @@ func TestRateLimitService_ClearRateLimit_WithoutTempUnschedCache(t *testing.T) { require.Equal(t, 1, repo.clearModelRateLimitCalls) require.Equal(t, 1, repo.clearTempUnschedCalls) } + +func TestRateLimitService_RecoverAccountAfterSuccessfulTest_ClearsErrorAndRateLimitRelatedState(t *testing.T) { + now := time.Now() + repo := &rateLimitClearRepoStub{ + getByIDAccount: &Account{ + ID: 42, + Status: StatusError, + RateLimitedAt: &now, + TempUnschedulableUntil: &now, + Extra: map[string]any{ + "model_rate_limits": map[string]any{ + "claude-sonnet-4-5": map[string]any{ + "rate_limit_reset_at": now.Format(time.RFC3339), + }, + }, + "antigravity_quota_scopes": map[string]any{"gemini": true}, + }, + }, + } + cache := &tempUnschedCacheRecorder{} + svc := NewRateLimitService(repo, nil, &config.Config{}, nil, cache) + + result, err := svc.RecoverAccountAfterSuccessfulTest(context.Background(), 42) + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, result.ClearedError) + require.True(t, result.ClearedRateLimit) + + require.Equal(t, 1, repo.getByIDCalls) + require.Equal(t, 1, repo.clearErrorCalls) + require.Equal(t, 1, repo.clearRateLimitCalls) + require.Equal(t, 1, repo.clearAntigravityCalls) + require.Equal(t, 1, repo.clearModelRateLimitCalls) + require.Equal(t, 1, repo.clearTempUnschedCalls) + require.Equal(t, []int64{42}, cache.deletedIDs) +} + +func TestRateLimitService_RecoverAccountAfterSuccessfulTest_NoRecoverableStateIsNoop(t *testing.T) { + repo := &rateLimitClearRepoStub{ + getByIDAccount: &Account{ + ID: 7, + Status: StatusActive, + Schedulable: true, + Extra: map[string]any{}, + }, + } + cache := &tempUnschedCacheRecorder{} + svc := NewRateLimitService(repo, nil, &config.Config{}, nil, cache) + + result, err := svc.RecoverAccountAfterSuccessfulTest(context.Background(), 7) + require.NoError(t, err) + require.NotNil(t, result) + require.False(t, result.ClearedError) + require.False(t, result.ClearedRateLimit) + + require.Equal(t, 1, repo.getByIDCalls) + require.Equal(t, 0, repo.clearErrorCalls) + require.Equal(t, 0, repo.clearRateLimitCalls) + require.Equal(t, 0, repo.clearAntigravityCalls) + require.Equal(t, 0, repo.clearModelRateLimitCalls) + require.Equal(t, 0, repo.clearTempUnschedCalls) + require.Empty(t, cache.deletedIDs) +} + +func TestRateLimitService_RecoverAccountAfterSuccessfulTest_ClearErrorFailed(t *testing.T) { + repo := &rateLimitClearRepoStub{ + getByIDAccount: &Account{ + ID: 9, + Status: StatusError, + }, + clearErrorErr: errors.New("clear error failed"), + } + svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + + result, err := svc.RecoverAccountAfterSuccessfulTest(context.Background(), 9) + require.Error(t, err) + require.Nil(t, result) + require.Equal(t, 1, repo.getByIDCalls) + require.Equal(t, 1, repo.clearErrorCalls) + require.Equal(t, 0, repo.clearRateLimitCalls) +} + +func TestRateLimitService_RecoverAccountState_InvalidatesOAuthTokenOnErrorRecovery(t *testing.T) { + repo := &rateLimitClearRepoStub{ + getByIDAccount: &Account{ + ID: 21, + Type: AccountTypeOAuth, + Status: StatusError, + }, + } + invalidator := &recoverTokenInvalidatorStub{} + svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + svc.SetTokenCacheInvalidator(invalidator) + + result, err := svc.RecoverAccountState(context.Background(), 21, AccountRecoveryOptions{ + InvalidateToken: true, + }) + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, result.ClearedError) + require.False(t, result.ClearedRateLimit) + require.Equal(t, 1, repo.clearErrorCalls) + require.Len(t, invalidator.accounts, 1) + require.Equal(t, int64(21), invalidator.accounts[0].ID) +} diff --git a/backend/internal/service/ratelimit_service_openai_test.go b/backend/internal/service/ratelimit_service_openai_test.go index 00902068..89c754c8 100644 --- a/backend/internal/service/ratelimit_service_openai_test.go +++ b/backend/internal/service/ratelimit_service_openai_test.go @@ -1,6 +1,9 @@ +//go:build unit + package service import ( + "context" "net/http" "testing" "time" @@ -141,6 +144,51 @@ func TestCalculateOpenAI429ResetTime_ReversedWindowOrder(t *testing.T) { } } +type openAI429SnapshotRepo struct { + mockAccountRepoForGemini + rateLimitedID int64 + updatedExtra map[string]any +} + +func (r *openAI429SnapshotRepo) SetRateLimited(_ context.Context, id int64, _ time.Time) error { + r.rateLimitedID = id + return nil +} + +func (r *openAI429SnapshotRepo) UpdateExtra(_ context.Context, _ int64, updates map[string]any) error { + r.updatedExtra = updates + return nil +} + +func TestHandle429_OpenAIPersistsCodexSnapshotImmediately(t *testing.T) { + repo := &openAI429SnapshotRepo{} + svc := NewRateLimitService(repo, nil, nil, nil, nil) + account := &Account{ID: 123, Platform: PlatformOpenAI, Type: AccountTypeOAuth} + + headers := http.Header{} + headers.Set("x-codex-primary-used-percent", "100") + headers.Set("x-codex-primary-reset-after-seconds", "604800") + headers.Set("x-codex-primary-window-minutes", "10080") + headers.Set("x-codex-secondary-used-percent", "100") + headers.Set("x-codex-secondary-reset-after-seconds", "18000") + headers.Set("x-codex-secondary-window-minutes", "300") + + svc.handle429(context.Background(), account, headers, nil) + + if repo.rateLimitedID != account.ID { + t.Fatalf("rateLimitedID = %d, want %d", repo.rateLimitedID, account.ID) + } + if len(repo.updatedExtra) == 0 { + t.Fatal("expected codex snapshot to be persisted on 429") + } + if got := repo.updatedExtra["codex_5h_used_percent"]; got != 100.0 { + t.Fatalf("codex_5h_used_percent = %v, want 100", got) + } + if got := repo.updatedExtra["codex_7d_used_percent"]; got != 100.0 { + t.Fatalf("codex_7d_used_percent = %v, want 100", got) + } +} + func TestNormalizedCodexLimits(t *testing.T) { // Test the Normalize() method directly pUsed := 100.0 diff --git a/backend/internal/service/registration_email_policy.go b/backend/internal/service/registration_email_policy.go new file mode 100644 index 00000000..875668c7 --- /dev/null +++ b/backend/internal/service/registration_email_policy.go @@ -0,0 +1,123 @@ +package service + +import ( + "encoding/json" + "fmt" + "regexp" + "strings" +) + +var registrationEmailDomainPattern = regexp.MustCompile( + `^[a-z0-9](?:[a-z0-9-]{0,61}[a-z0-9])?(?:\.[a-z0-9](?:[a-z0-9-]{0,61}[a-z0-9])?)+$`, +) + +// RegistrationEmailSuffix extracts normalized suffix in "@domain" form. +func RegistrationEmailSuffix(email string) string { + _, domain, ok := splitEmailForPolicy(email) + if !ok { + return "" + } + return "@" + domain +} + +// IsRegistrationEmailSuffixAllowed checks whether an email is allowed by suffix whitelist. +// Empty whitelist means allow all. +func IsRegistrationEmailSuffixAllowed(email string, whitelist []string) bool { + if len(whitelist) == 0 { + return true + } + suffix := RegistrationEmailSuffix(email) + if suffix == "" { + return false + } + for _, allowed := range whitelist { + if suffix == allowed { + return true + } + } + return false +} + +// NormalizeRegistrationEmailSuffixWhitelist normalizes and validates suffix whitelist items. +func NormalizeRegistrationEmailSuffixWhitelist(raw []string) ([]string, error) { + return normalizeRegistrationEmailSuffixWhitelist(raw, true) +} + +// ParseRegistrationEmailSuffixWhitelist parses persisted JSON into normalized suffixes. +// Invalid entries are ignored to keep old misconfigurations from breaking runtime reads. +func ParseRegistrationEmailSuffixWhitelist(raw string) []string { + raw = strings.TrimSpace(raw) + if raw == "" { + return []string{} + } + var items []string + if err := json.Unmarshal([]byte(raw), &items); err != nil { + return []string{} + } + normalized, _ := normalizeRegistrationEmailSuffixWhitelist(items, false) + if len(normalized) == 0 { + return []string{} + } + return normalized +} + +func normalizeRegistrationEmailSuffixWhitelist(raw []string, strict bool) ([]string, error) { + if len(raw) == 0 { + return nil, nil + } + + seen := make(map[string]struct{}, len(raw)) + out := make([]string, 0, len(raw)) + for _, item := range raw { + normalized, err := normalizeRegistrationEmailSuffix(item) + if err != nil { + if strict { + return nil, err + } + continue + } + if normalized == "" { + continue + } + if _, ok := seen[normalized]; ok { + continue + } + seen[normalized] = struct{}{} + out = append(out, normalized) + } + + if len(out) == 0 { + return nil, nil + } + return out, nil +} + +func normalizeRegistrationEmailSuffix(raw string) (string, error) { + value := strings.ToLower(strings.TrimSpace(raw)) + if value == "" { + return "", nil + } + + domain := value + if strings.Contains(value, "@") { + if !strings.HasPrefix(value, "@") || strings.Count(value, "@") != 1 { + return "", fmt.Errorf("invalid email suffix: %q", raw) + } + domain = strings.TrimPrefix(value, "@") + } + + if domain == "" || strings.Contains(domain, "@") || !registrationEmailDomainPattern.MatchString(domain) { + return "", fmt.Errorf("invalid email suffix: %q", raw) + } + + return "@" + domain, nil +} + +func splitEmailForPolicy(raw string) (local string, domain string, ok bool) { + email := strings.ToLower(strings.TrimSpace(raw)) + local, domain, found := strings.Cut(email, "@") + if !found || local == "" || domain == "" || strings.Contains(domain, "@") { + return "", "", false + } + return local, domain, true +} diff --git a/backend/internal/service/registration_email_policy_test.go b/backend/internal/service/registration_email_policy_test.go new file mode 100644 index 00000000..f0c46642 --- /dev/null +++ b/backend/internal/service/registration_email_policy_test.go @@ -0,0 +1,31 @@ +//go:build unit + +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestNormalizeRegistrationEmailSuffixWhitelist(t *testing.T) { + got, err := NormalizeRegistrationEmailSuffixWhitelist([]string{"example.com", "@EXAMPLE.COM", " @foo.bar "}) + require.NoError(t, err) + require.Equal(t, []string{"@example.com", "@foo.bar"}, got) +} + +func TestNormalizeRegistrationEmailSuffixWhitelist_Invalid(t *testing.T) { + _, err := NormalizeRegistrationEmailSuffixWhitelist([]string{"@invalid_domain"}) + require.Error(t, err) +} + +func TestParseRegistrationEmailSuffixWhitelist(t *testing.T) { + got := ParseRegistrationEmailSuffixWhitelist(`["example.com","@foo.bar","@invalid_domain"]`) + require.Equal(t, []string{"@example.com", "@foo.bar"}, got) +} + +func TestIsRegistrationEmailSuffixAllowed(t *testing.T) { + require.True(t, IsRegistrationEmailSuffixAllowed("user@example.com", []string{"@example.com"})) + require.False(t, IsRegistrationEmailSuffixAllowed("user@sub.example.com", []string{"@example.com"})) + require.True(t, IsRegistrationEmailSuffixAllowed("user@any.com", []string{})) +} diff --git a/backend/internal/service/scheduled_test_port.go b/backend/internal/service/scheduled_test_port.go new file mode 100644 index 00000000..1c0fdf21 --- /dev/null +++ b/backend/internal/service/scheduled_test_port.go @@ -0,0 +1,52 @@ +package service + +import ( + "context" + "time" +) + +// ScheduledTestPlan represents a scheduled test plan domain model. +type ScheduledTestPlan struct { + ID int64 `json:"id"` + AccountID int64 `json:"account_id"` + ModelID string `json:"model_id"` + CronExpression string `json:"cron_expression"` + Enabled bool `json:"enabled"` + MaxResults int `json:"max_results"` + AutoRecover bool `json:"auto_recover"` + LastRunAt *time.Time `json:"last_run_at"` + NextRunAt *time.Time `json:"next_run_at"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +// ScheduledTestResult represents a single test execution result. +type ScheduledTestResult struct { + ID int64 `json:"id"` + PlanID int64 `json:"plan_id"` + Status string `json:"status"` + ResponseText string `json:"response_text"` + ErrorMessage string `json:"error_message"` + LatencyMs int64 `json:"latency_ms"` + StartedAt time.Time `json:"started_at"` + FinishedAt time.Time `json:"finished_at"` + CreatedAt time.Time `json:"created_at"` +} + +// ScheduledTestPlanRepository defines the data access interface for test plans. +type ScheduledTestPlanRepository interface { + Create(ctx context.Context, plan *ScheduledTestPlan) (*ScheduledTestPlan, error) + GetByID(ctx context.Context, id int64) (*ScheduledTestPlan, error) + ListByAccountID(ctx context.Context, accountID int64) ([]*ScheduledTestPlan, error) + ListDue(ctx context.Context, now time.Time) ([]*ScheduledTestPlan, error) + Update(ctx context.Context, plan *ScheduledTestPlan) (*ScheduledTestPlan, error) + Delete(ctx context.Context, id int64) error + UpdateAfterRun(ctx context.Context, id int64, lastRunAt time.Time, nextRunAt time.Time) error +} + +// ScheduledTestResultRepository defines the data access interface for test results. +type ScheduledTestResultRepository interface { + Create(ctx context.Context, result *ScheduledTestResult) (*ScheduledTestResult, error) + ListByPlanID(ctx context.Context, planID int64, limit int) ([]*ScheduledTestResult, error) + PruneOldResults(ctx context.Context, planID int64, keepCount int) error +} diff --git a/backend/internal/service/scheduled_test_runner_service.go b/backend/internal/service/scheduled_test_runner_service.go new file mode 100644 index 00000000..f4d35f69 --- /dev/null +++ b/backend/internal/service/scheduled_test_runner_service.go @@ -0,0 +1,170 @@ +package service + +import ( + "context" + "sync" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/robfig/cron/v3" +) + +const scheduledTestDefaultMaxWorkers = 10 + +// ScheduledTestRunnerService periodically scans due test plans and executes them. +type ScheduledTestRunnerService struct { + planRepo ScheduledTestPlanRepository + scheduledSvc *ScheduledTestService + accountTestSvc *AccountTestService + rateLimitSvc *RateLimitService + cfg *config.Config + + cron *cron.Cron + startOnce sync.Once + stopOnce sync.Once +} + +// NewScheduledTestRunnerService creates a new runner. +func NewScheduledTestRunnerService( + planRepo ScheduledTestPlanRepository, + scheduledSvc *ScheduledTestService, + accountTestSvc *AccountTestService, + rateLimitSvc *RateLimitService, + cfg *config.Config, +) *ScheduledTestRunnerService { + return &ScheduledTestRunnerService{ + planRepo: planRepo, + scheduledSvc: scheduledSvc, + accountTestSvc: accountTestSvc, + rateLimitSvc: rateLimitSvc, + cfg: cfg, + } +} + +// Start begins the cron ticker (every minute). +func (s *ScheduledTestRunnerService) Start() { + if s == nil { + return + } + s.startOnce.Do(func() { + loc := time.Local + if s.cfg != nil { + if parsed, err := time.LoadLocation(s.cfg.Timezone); err == nil && parsed != nil { + loc = parsed + } + } + + c := cron.New(cron.WithParser(scheduledTestCronParser), cron.WithLocation(loc)) + _, err := c.AddFunc("* * * * *", func() { s.runScheduled() }) + if err != nil { + logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] not started (invalid schedule): %v", err) + return + } + s.cron = c + s.cron.Start() + logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] started (tick=every minute)") + }) +} + +// Stop gracefully shuts down the cron scheduler. +func (s *ScheduledTestRunnerService) Stop() { + if s == nil { + return + } + s.stopOnce.Do(func() { + if s.cron != nil { + ctx := s.cron.Stop() + select { + case <-ctx.Done(): + case <-time.After(3 * time.Second): + logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] cron stop timed out") + } + } + }) +} + +func (s *ScheduledTestRunnerService) runScheduled() { + // Delay 10s so execution lands at ~:10 of each minute instead of :00. + time.Sleep(10 * time.Second) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) + defer cancel() + + now := time.Now() + plans, err := s.planRepo.ListDue(ctx, now) + if err != nil { + logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] ListDue error: %v", err) + return + } + if len(plans) == 0 { + return + } + + logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] found %d due plans", len(plans)) + + sem := make(chan struct{}, scheduledTestDefaultMaxWorkers) + var wg sync.WaitGroup + + for _, plan := range plans { + sem <- struct{}{} + wg.Add(1) + go func(p *ScheduledTestPlan) { + defer wg.Done() + defer func() { <-sem }() + s.runOnePlan(ctx, p) + }(plan) + } + + wg.Wait() +} + +func (s *ScheduledTestRunnerService) runOnePlan(ctx context.Context, plan *ScheduledTestPlan) { + result, err := s.accountTestSvc.RunTestBackground(ctx, plan.AccountID, plan.ModelID) + if err != nil { + logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] plan=%d RunTestBackground error: %v", plan.ID, err) + return + } + + if err := s.scheduledSvc.SaveResult(ctx, plan.ID, plan.MaxResults, result); err != nil { + logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] plan=%d SaveResult error: %v", plan.ID, err) + } + + // Auto-recover account if test succeeded and auto_recover is enabled. + if result.Status == "success" && plan.AutoRecover { + s.tryRecoverAccount(ctx, plan.AccountID, plan.ID) + } + + nextRun, err := computeNextRun(plan.CronExpression, time.Now()) + if err != nil { + logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] plan=%d computeNextRun error: %v", plan.ID, err) + return + } + + if err := s.planRepo.UpdateAfterRun(ctx, plan.ID, time.Now(), nextRun); err != nil { + logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] plan=%d UpdateAfterRun error: %v", plan.ID, err) + } +} + +// tryRecoverAccount attempts to recover an account from recoverable runtime state. +func (s *ScheduledTestRunnerService) tryRecoverAccount(ctx context.Context, accountID int64, planID int64) { + if s.rateLimitSvc == nil { + return + } + + recovery, err := s.rateLimitSvc.RecoverAccountAfterSuccessfulTest(ctx, accountID) + if err != nil { + logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] plan=%d auto-recover failed: %v", planID, err) + return + } + if recovery == nil { + return + } + + if recovery.ClearedError { + logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] plan=%d auto-recover: account=%d recovered from error status", planID, accountID) + } + if recovery.ClearedRateLimit { + logger.LegacyPrintf("service.scheduled_test_runner", "[ScheduledTestRunner] plan=%d auto-recover: account=%d cleared rate-limit/runtime state", planID, accountID) + } +} diff --git a/backend/internal/service/scheduled_test_service.go b/backend/internal/service/scheduled_test_service.go new file mode 100644 index 00000000..c9bb3b6a --- /dev/null +++ b/backend/internal/service/scheduled_test_service.go @@ -0,0 +1,94 @@ +package service + +import ( + "context" + "fmt" + "time" + + "github.com/robfig/cron/v3" +) + +var scheduledTestCronParser = cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow) + +// ScheduledTestService provides CRUD operations for scheduled test plans and results. +type ScheduledTestService struct { + planRepo ScheduledTestPlanRepository + resultRepo ScheduledTestResultRepository +} + +// NewScheduledTestService creates a new ScheduledTestService. +func NewScheduledTestService( + planRepo ScheduledTestPlanRepository, + resultRepo ScheduledTestResultRepository, +) *ScheduledTestService { + return &ScheduledTestService{ + planRepo: planRepo, + resultRepo: resultRepo, + } +} + +// CreatePlan validates the cron expression, computes next_run_at, and persists the plan. +func (s *ScheduledTestService) CreatePlan(ctx context.Context, plan *ScheduledTestPlan) (*ScheduledTestPlan, error) { + nextRun, err := computeNextRun(plan.CronExpression, time.Now()) + if err != nil { + return nil, fmt.Errorf("invalid cron expression: %w", err) + } + plan.NextRunAt = &nextRun + + if plan.MaxResults <= 0 { + plan.MaxResults = 50 + } + + return s.planRepo.Create(ctx, plan) +} + +// GetPlan retrieves a plan by ID. +func (s *ScheduledTestService) GetPlan(ctx context.Context, id int64) (*ScheduledTestPlan, error) { + return s.planRepo.GetByID(ctx, id) +} + +// ListPlansByAccount returns all plans for a given account. +func (s *ScheduledTestService) ListPlansByAccount(ctx context.Context, accountID int64) ([]*ScheduledTestPlan, error) { + return s.planRepo.ListByAccountID(ctx, accountID) +} + +// UpdatePlan validates cron and updates the plan. +func (s *ScheduledTestService) UpdatePlan(ctx context.Context, plan *ScheduledTestPlan) (*ScheduledTestPlan, error) { + nextRun, err := computeNextRun(plan.CronExpression, time.Now()) + if err != nil { + return nil, fmt.Errorf("invalid cron expression: %w", err) + } + plan.NextRunAt = &nextRun + + return s.planRepo.Update(ctx, plan) +} + +// DeletePlan removes a plan and its results (via CASCADE). +func (s *ScheduledTestService) DeletePlan(ctx context.Context, id int64) error { + return s.planRepo.Delete(ctx, id) +} + +// ListResults returns the most recent results for a plan. +func (s *ScheduledTestService) ListResults(ctx context.Context, planID int64, limit int) ([]*ScheduledTestResult, error) { + if limit <= 0 { + limit = 50 + } + return s.resultRepo.ListByPlanID(ctx, planID, limit) +} + +// SaveResult inserts a result and prunes old entries beyond maxResults. +func (s *ScheduledTestService) SaveResult(ctx context.Context, planID int64, maxResults int, result *ScheduledTestResult) error { + result.PlanID = planID + if _, err := s.resultRepo.Create(ctx, result); err != nil { + return err + } + return s.resultRepo.PruneOldResults(ctx, planID, maxResults) +} + +func computeNextRun(cronExpr string, from time.Time) (time.Time, error) { + sched, err := scheduledTestCronParser.Parse(cronExpr) + if err != nil { + return time.Time{}, err + } + return sched.Next(from), nil +} diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index a94cfdfb..b77867de 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -108,6 +108,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings keys := []string{ SettingKeyRegistrationEnabled, SettingKeyEmailVerifyEnabled, + SettingKeyRegistrationEmailSuffixWhitelist, SettingKeyPromoCodeEnabled, SettingKeyPasswordResetEnabled, SettingKeyInvitationCodeEnabled, @@ -144,29 +145,33 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings // Password reset requires email verification to be enabled emailVerifyEnabled := settings[SettingKeyEmailVerifyEnabled] == "true" passwordResetEnabled := emailVerifyEnabled && settings[SettingKeyPasswordResetEnabled] == "true" + registrationEmailSuffixWhitelist := ParseRegistrationEmailSuffixWhitelist( + settings[SettingKeyRegistrationEmailSuffixWhitelist], + ) return &PublicSettings{ - RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true", - EmailVerifyEnabled: emailVerifyEnabled, - PromoCodeEnabled: settings[SettingKeyPromoCodeEnabled] != "false", // 默认启用 - PasswordResetEnabled: passwordResetEnabled, - InvitationCodeEnabled: settings[SettingKeyInvitationCodeEnabled] == "true", - TotpEnabled: settings[SettingKeyTotpEnabled] == "true", - TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true", - TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey], - SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"), - SiteLogo: settings[SettingKeySiteLogo], - SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"), - APIBaseURL: settings[SettingKeyAPIBaseURL], - ContactInfo: settings[SettingKeyContactInfo], - DocURL: settings[SettingKeyDocURL], - HomeContent: settings[SettingKeyHomeContent], - HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true", - PurchaseSubscriptionEnabled: settings[SettingKeyPurchaseSubscriptionEnabled] == "true", - PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]), - SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true", - CustomMenuItems: settings[SettingKeyCustomMenuItems], - LinuxDoOAuthEnabled: linuxDoEnabled, + RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true", + EmailVerifyEnabled: emailVerifyEnabled, + RegistrationEmailSuffixWhitelist: registrationEmailSuffixWhitelist, + PromoCodeEnabled: settings[SettingKeyPromoCodeEnabled] != "false", // 默认启用 + PasswordResetEnabled: passwordResetEnabled, + InvitationCodeEnabled: settings[SettingKeyInvitationCodeEnabled] == "true", + TotpEnabled: settings[SettingKeyTotpEnabled] == "true", + TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true", + TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey], + SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"), + SiteLogo: settings[SettingKeySiteLogo], + SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"), + APIBaseURL: settings[SettingKeyAPIBaseURL], + ContactInfo: settings[SettingKeyContactInfo], + DocURL: settings[SettingKeyDocURL], + HomeContent: settings[SettingKeyHomeContent], + HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true", + PurchaseSubscriptionEnabled: settings[SettingKeyPurchaseSubscriptionEnabled] == "true", + PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]), + SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true", + CustomMenuItems: settings[SettingKeyCustomMenuItems], + LinuxDoOAuthEnabled: linuxDoEnabled, }, nil } @@ -196,51 +201,53 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any // Return a struct that matches the frontend's expected format return &struct { - RegistrationEnabled bool `json:"registration_enabled"` - EmailVerifyEnabled bool `json:"email_verify_enabled"` - PromoCodeEnabled bool `json:"promo_code_enabled"` - PasswordResetEnabled bool `json:"password_reset_enabled"` - InvitationCodeEnabled bool `json:"invitation_code_enabled"` - TotpEnabled bool `json:"totp_enabled"` - TurnstileEnabled bool `json:"turnstile_enabled"` - TurnstileSiteKey string `json:"turnstile_site_key,omitempty"` - SiteName string `json:"site_name"` - SiteLogo string `json:"site_logo,omitempty"` - SiteSubtitle string `json:"site_subtitle,omitempty"` - APIBaseURL string `json:"api_base_url,omitempty"` - ContactInfo string `json:"contact_info,omitempty"` - DocURL string `json:"doc_url,omitempty"` - HomeContent string `json:"home_content,omitempty"` - HideCcsImportButton bool `json:"hide_ccs_import_button"` - PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"` - PurchaseSubscriptionURL string `json:"purchase_subscription_url,omitempty"` - SoraClientEnabled bool `json:"sora_client_enabled"` - CustomMenuItems json.RawMessage `json:"custom_menu_items"` - LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` - Version string `json:"version,omitempty"` + RegistrationEnabled bool `json:"registration_enabled"` + EmailVerifyEnabled bool `json:"email_verify_enabled"` + RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"` + PromoCodeEnabled bool `json:"promo_code_enabled"` + PasswordResetEnabled bool `json:"password_reset_enabled"` + InvitationCodeEnabled bool `json:"invitation_code_enabled"` + TotpEnabled bool `json:"totp_enabled"` + TurnstileEnabled bool `json:"turnstile_enabled"` + TurnstileSiteKey string `json:"turnstile_site_key,omitempty"` + SiteName string `json:"site_name"` + SiteLogo string `json:"site_logo,omitempty"` + SiteSubtitle string `json:"site_subtitle,omitempty"` + APIBaseURL string `json:"api_base_url,omitempty"` + ContactInfo string `json:"contact_info,omitempty"` + DocURL string `json:"doc_url,omitempty"` + HomeContent string `json:"home_content,omitempty"` + HideCcsImportButton bool `json:"hide_ccs_import_button"` + PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"` + PurchaseSubscriptionURL string `json:"purchase_subscription_url,omitempty"` + SoraClientEnabled bool `json:"sora_client_enabled"` + CustomMenuItems json.RawMessage `json:"custom_menu_items"` + LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` + Version string `json:"version,omitempty"` }{ - RegistrationEnabled: settings.RegistrationEnabled, - EmailVerifyEnabled: settings.EmailVerifyEnabled, - PromoCodeEnabled: settings.PromoCodeEnabled, - PasswordResetEnabled: settings.PasswordResetEnabled, - InvitationCodeEnabled: settings.InvitationCodeEnabled, - TotpEnabled: settings.TotpEnabled, - TurnstileEnabled: settings.TurnstileEnabled, - TurnstileSiteKey: settings.TurnstileSiteKey, - SiteName: settings.SiteName, - SiteLogo: settings.SiteLogo, - SiteSubtitle: settings.SiteSubtitle, - APIBaseURL: settings.APIBaseURL, - ContactInfo: settings.ContactInfo, - DocURL: settings.DocURL, - HomeContent: settings.HomeContent, - HideCcsImportButton: settings.HideCcsImportButton, - PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled, - PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL, - SoraClientEnabled: settings.SoraClientEnabled, - CustomMenuItems: filterUserVisibleMenuItems(settings.CustomMenuItems), - LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, - Version: s.version, + RegistrationEnabled: settings.RegistrationEnabled, + EmailVerifyEnabled: settings.EmailVerifyEnabled, + RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist, + PromoCodeEnabled: settings.PromoCodeEnabled, + PasswordResetEnabled: settings.PasswordResetEnabled, + InvitationCodeEnabled: settings.InvitationCodeEnabled, + TotpEnabled: settings.TotpEnabled, + TurnstileEnabled: settings.TurnstileEnabled, + TurnstileSiteKey: settings.TurnstileSiteKey, + SiteName: settings.SiteName, + SiteLogo: settings.SiteLogo, + SiteSubtitle: settings.SiteSubtitle, + APIBaseURL: settings.APIBaseURL, + ContactInfo: settings.ContactInfo, + DocURL: settings.DocURL, + HomeContent: settings.HomeContent, + HideCcsImportButton: settings.HideCcsImportButton, + PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled, + PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL, + SoraClientEnabled: settings.SoraClientEnabled, + CustomMenuItems: filterUserVisibleMenuItems(settings.CustomMenuItems), + LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, + Version: s.version, }, nil } @@ -356,12 +363,25 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet if err := s.validateDefaultSubscriptionGroups(ctx, settings.DefaultSubscriptions); err != nil { return err } + normalizedWhitelist, err := NormalizeRegistrationEmailSuffixWhitelist(settings.RegistrationEmailSuffixWhitelist) + if err != nil { + return infraerrors.BadRequest("INVALID_REGISTRATION_EMAIL_SUFFIX_WHITELIST", err.Error()) + } + if normalizedWhitelist == nil { + normalizedWhitelist = []string{} + } + settings.RegistrationEmailSuffixWhitelist = normalizedWhitelist updates := make(map[string]string) // 注册设置 updates[SettingKeyRegistrationEnabled] = strconv.FormatBool(settings.RegistrationEnabled) updates[SettingKeyEmailVerifyEnabled] = strconv.FormatBool(settings.EmailVerifyEnabled) + registrationEmailSuffixWhitelistJSON, err := json.Marshal(settings.RegistrationEmailSuffixWhitelist) + if err != nil { + return fmt.Errorf("marshal registration email suffix whitelist: %w", err) + } + updates[SettingKeyRegistrationEmailSuffixWhitelist] = string(registrationEmailSuffixWhitelistJSON) updates[SettingKeyPromoCodeEnabled] = strconv.FormatBool(settings.PromoCodeEnabled) updates[SettingKeyPasswordResetEnabled] = strconv.FormatBool(settings.PasswordResetEnabled) updates[SettingKeyInvitationCodeEnabled] = strconv.FormatBool(settings.InvitationCodeEnabled) @@ -514,6 +534,15 @@ func (s *SettingService) IsEmailVerifyEnabled(ctx context.Context) bool { return value == "true" } +// GetRegistrationEmailSuffixWhitelist returns normalized registration email suffix whitelist. +func (s *SettingService) GetRegistrationEmailSuffixWhitelist(ctx context.Context) []string { + value, err := s.settingRepo.GetValue(ctx, SettingKeyRegistrationEmailSuffixWhitelist) + if err != nil { + return []string{} + } + return ParseRegistrationEmailSuffixWhitelist(value) +} + // IsPromoCodeEnabled 检查是否启用优惠码功能 func (s *SettingService) IsPromoCodeEnabled(ctx context.Context) bool { value, err := s.settingRepo.GetValue(ctx, SettingKeyPromoCodeEnabled) @@ -617,20 +646,21 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { // 初始化默认设置 defaults := map[string]string{ - SettingKeyRegistrationEnabled: "true", - SettingKeyEmailVerifyEnabled: "false", - SettingKeyPromoCodeEnabled: "true", // 默认启用优惠码功能 - SettingKeySiteName: "Sub2API", - SettingKeySiteLogo: "", - SettingKeyPurchaseSubscriptionEnabled: "false", - SettingKeyPurchaseSubscriptionURL: "", - SettingKeySoraClientEnabled: "false", - SettingKeyCustomMenuItems: "[]", - SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency), - SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64), - SettingKeyDefaultSubscriptions: "[]", - SettingKeySMTPPort: "587", - SettingKeySMTPUseTLS: "false", + SettingKeyRegistrationEnabled: "true", + SettingKeyEmailVerifyEnabled: "false", + SettingKeyRegistrationEmailSuffixWhitelist: "[]", + SettingKeyPromoCodeEnabled: "true", // 默认启用优惠码功能 + SettingKeySiteName: "Sub2API", + SettingKeySiteLogo: "", + SettingKeyPurchaseSubscriptionEnabled: "false", + SettingKeyPurchaseSubscriptionURL: "", + SettingKeySoraClientEnabled: "false", + SettingKeyCustomMenuItems: "[]", + SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency), + SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64), + SettingKeyDefaultSubscriptions: "[]", + SettingKeySMTPPort: "587", + SettingKeySMTPUseTLS: "false", // Model fallback defaults SettingKeyEnableModelFallback: "false", SettingKeyFallbackModelAnthropic: "claude-3-5-sonnet-20241022", @@ -661,33 +691,34 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { func (s *SettingService) parseSettings(settings map[string]string) *SystemSettings { emailVerifyEnabled := settings[SettingKeyEmailVerifyEnabled] == "true" result := &SystemSettings{ - RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true", - EmailVerifyEnabled: emailVerifyEnabled, - PromoCodeEnabled: settings[SettingKeyPromoCodeEnabled] != "false", // 默认启用 - PasswordResetEnabled: emailVerifyEnabled && settings[SettingKeyPasswordResetEnabled] == "true", - InvitationCodeEnabled: settings[SettingKeyInvitationCodeEnabled] == "true", - TotpEnabled: settings[SettingKeyTotpEnabled] == "true", - SMTPHost: settings[SettingKeySMTPHost], - SMTPUsername: settings[SettingKeySMTPUsername], - SMTPFrom: settings[SettingKeySMTPFrom], - SMTPFromName: settings[SettingKeySMTPFromName], - SMTPUseTLS: settings[SettingKeySMTPUseTLS] == "true", - SMTPPasswordConfigured: settings[SettingKeySMTPPassword] != "", - TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true", - TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey], - TurnstileSecretKeyConfigured: settings[SettingKeyTurnstileSecretKey] != "", - SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"), - SiteLogo: settings[SettingKeySiteLogo], - SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"), - APIBaseURL: settings[SettingKeyAPIBaseURL], - ContactInfo: settings[SettingKeyContactInfo], - DocURL: settings[SettingKeyDocURL], - HomeContent: settings[SettingKeyHomeContent], - HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true", - PurchaseSubscriptionEnabled: settings[SettingKeyPurchaseSubscriptionEnabled] == "true", - PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]), - SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true", - CustomMenuItems: settings[SettingKeyCustomMenuItems], + RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true", + EmailVerifyEnabled: emailVerifyEnabled, + RegistrationEmailSuffixWhitelist: ParseRegistrationEmailSuffixWhitelist(settings[SettingKeyRegistrationEmailSuffixWhitelist]), + PromoCodeEnabled: settings[SettingKeyPromoCodeEnabled] != "false", // 默认启用 + PasswordResetEnabled: emailVerifyEnabled && settings[SettingKeyPasswordResetEnabled] == "true", + InvitationCodeEnabled: settings[SettingKeyInvitationCodeEnabled] == "true", + TotpEnabled: settings[SettingKeyTotpEnabled] == "true", + SMTPHost: settings[SettingKeySMTPHost], + SMTPUsername: settings[SettingKeySMTPUsername], + SMTPFrom: settings[SettingKeySMTPFrom], + SMTPFromName: settings[SettingKeySMTPFromName], + SMTPUseTLS: settings[SettingKeySMTPUseTLS] == "true", + SMTPPasswordConfigured: settings[SettingKeySMTPPassword] != "", + TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true", + TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey], + TurnstileSecretKeyConfigured: settings[SettingKeyTurnstileSecretKey] != "", + SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"), + SiteLogo: settings[SettingKeySiteLogo], + SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"), + APIBaseURL: settings[SettingKeyAPIBaseURL], + ContactInfo: settings[SettingKeyContactInfo], + DocURL: settings[SettingKeyDocURL], + HomeContent: settings[SettingKeyHomeContent], + HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true", + PurchaseSubscriptionEnabled: settings[SettingKeyPurchaseSubscriptionEnabled] == "true", + PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]), + SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true", + CustomMenuItems: settings[SettingKeyCustomMenuItems], } // 解析整数类型 @@ -1163,6 +1194,113 @@ func (s *SettingService) GetMinClaudeCodeVersion(ctx context.Context) string { return ver } +// GetRectifierSettings 获取请求整流器配置 +func (s *SettingService) GetRectifierSettings(ctx context.Context) (*RectifierSettings, error) { + value, err := s.settingRepo.GetValue(ctx, SettingKeyRectifierSettings) + if err != nil { + if errors.Is(err, ErrSettingNotFound) { + return DefaultRectifierSettings(), nil + } + return nil, fmt.Errorf("get rectifier settings: %w", err) + } + if value == "" { + return DefaultRectifierSettings(), nil + } + + var settings RectifierSettings + if err := json.Unmarshal([]byte(value), &settings); err != nil { + return DefaultRectifierSettings(), nil + } + + return &settings, nil +} + +// SetRectifierSettings 设置请求整流器配置 +func (s *SettingService) SetRectifierSettings(ctx context.Context, settings *RectifierSettings) error { + if settings == nil { + return fmt.Errorf("settings cannot be nil") + } + + data, err := json.Marshal(settings) + if err != nil { + return fmt.Errorf("marshal rectifier settings: %w", err) + } + + return s.settingRepo.Set(ctx, SettingKeyRectifierSettings, string(data)) +} + +// IsSignatureRectifierEnabled 判断签名整流是否启用(总开关 && 签名子开关) +func (s *SettingService) IsSignatureRectifierEnabled(ctx context.Context) bool { + settings, err := s.GetRectifierSettings(ctx) + if err != nil { + return true // fail-open: 查询失败时默认启用 + } + return settings.Enabled && settings.ThinkingSignatureEnabled +} + +// IsBudgetRectifierEnabled 判断 Budget 整流是否启用(总开关 && Budget 子开关) +func (s *SettingService) IsBudgetRectifierEnabled(ctx context.Context) bool { + settings, err := s.GetRectifierSettings(ctx) + if err != nil { + return true // fail-open: 查询失败时默认启用 + } + return settings.Enabled && settings.ThinkingBudgetEnabled +} + +// GetBetaPolicySettings 获取 Beta 策略配置 +func (s *SettingService) GetBetaPolicySettings(ctx context.Context) (*BetaPolicySettings, error) { + value, err := s.settingRepo.GetValue(ctx, SettingKeyBetaPolicySettings) + if err != nil { + if errors.Is(err, ErrSettingNotFound) { + return DefaultBetaPolicySettings(), nil + } + return nil, fmt.Errorf("get beta policy settings: %w", err) + } + if value == "" { + return DefaultBetaPolicySettings(), nil + } + + var settings BetaPolicySettings + if err := json.Unmarshal([]byte(value), &settings); err != nil { + return DefaultBetaPolicySettings(), nil + } + + return &settings, nil +} + +// SetBetaPolicySettings 设置 Beta 策略配置 +func (s *SettingService) SetBetaPolicySettings(ctx context.Context, settings *BetaPolicySettings) error { + if settings == nil { + return fmt.Errorf("settings cannot be nil") + } + + validActions := map[string]bool{ + BetaPolicyActionPass: true, BetaPolicyActionFilter: true, BetaPolicyActionBlock: true, + } + validScopes := map[string]bool{ + BetaPolicyScopeAll: true, BetaPolicyScopeOAuth: true, BetaPolicyScopeAPIKey: true, + } + + for i, rule := range settings.Rules { + if rule.BetaToken == "" { + return fmt.Errorf("rule[%d]: beta_token cannot be empty", i) + } + if !validActions[rule.Action] { + return fmt.Errorf("rule[%d]: invalid action %q", i, rule.Action) + } + if !validScopes[rule.Scope] { + return fmt.Errorf("rule[%d]: invalid scope %q", i, rule.Scope) + } + } + + data, err := json.Marshal(settings) + if err != nil { + return fmt.Errorf("marshal beta policy settings: %w", err) + } + + return s.settingRepo.Set(ctx, SettingKeyBetaPolicySettings, string(data)) +} + // SetStreamTimeoutSettings 设置流超时处理配置 func (s *SettingService) SetStreamTimeoutSettings(ctx context.Context, settings *StreamTimeoutSettings) error { if settings == nil { diff --git a/backend/internal/service/setting_service_public_test.go b/backend/internal/service/setting_service_public_test.go new file mode 100644 index 00000000..b511cd29 --- /dev/null +++ b/backend/internal/service/setting_service_public_test.go @@ -0,0 +1,64 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +type settingPublicRepoStub struct { + values map[string]string +} + +func (s *settingPublicRepoStub) Get(ctx context.Context, key string) (*Setting, error) { + panic("unexpected Get call") +} + +func (s *settingPublicRepoStub) GetValue(ctx context.Context, key string) (string, error) { + panic("unexpected GetValue call") +} + +func (s *settingPublicRepoStub) Set(ctx context.Context, key, value string) error { + panic("unexpected Set call") +} + +func (s *settingPublicRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) { + out := make(map[string]string, len(keys)) + for _, key := range keys { + if value, ok := s.values[key]; ok { + out[key] = value + } + } + return out, nil +} + +func (s *settingPublicRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error { + panic("unexpected SetMultiple call") +} + +func (s *settingPublicRepoStub) GetAll(ctx context.Context) (map[string]string, error) { + panic("unexpected GetAll call") +} + +func (s *settingPublicRepoStub) Delete(ctx context.Context, key string) error { + panic("unexpected Delete call") +} + +func TestSettingService_GetPublicSettings_ExposesRegistrationEmailSuffixWhitelist(t *testing.T) { + repo := &settingPublicRepoStub{ + values: map[string]string{ + SettingKeyRegistrationEnabled: "true", + SettingKeyEmailVerifyEnabled: "true", + SettingKeyRegistrationEmailSuffixWhitelist: `["@EXAMPLE.com"," @foo.bar ","@invalid_domain",""]`, + }, + } + svc := NewSettingService(repo, &config.Config{}) + + settings, err := svc.GetPublicSettings(context.Background()) + require.NoError(t, err) + require.Equal(t, []string{"@example.com", "@foo.bar"}, settings.RegistrationEmailSuffixWhitelist) +} diff --git a/backend/internal/service/setting_service_update_test.go b/backend/internal/service/setting_service_update_test.go index ec64511f..1de08611 100644 --- a/backend/internal/service/setting_service_update_test.go +++ b/backend/internal/service/setting_service_update_test.go @@ -172,6 +172,28 @@ func TestSettingService_UpdateSettings_DefaultSubscriptions_RejectsDuplicateGrou require.Nil(t, repo.updates) } +func TestSettingService_UpdateSettings_RegistrationEmailSuffixWhitelist_Normalized(t *testing.T) { + repo := &settingUpdateRepoStub{} + svc := NewSettingService(repo, &config.Config{}) + + err := svc.UpdateSettings(context.Background(), &SystemSettings{ + RegistrationEmailSuffixWhitelist: []string{"example.com", "@EXAMPLE.com", " @foo.bar "}, + }) + require.NoError(t, err) + require.Equal(t, `["@example.com","@foo.bar"]`, repo.updates[SettingKeyRegistrationEmailSuffixWhitelist]) +} + +func TestSettingService_UpdateSettings_RegistrationEmailSuffixWhitelist_Invalid(t *testing.T) { + repo := &settingUpdateRepoStub{} + svc := NewSettingService(repo, &config.Config{}) + + err := svc.UpdateSettings(context.Background(), &SystemSettings{ + RegistrationEmailSuffixWhitelist: []string{"@invalid_domain"}, + }) + require.Error(t, err) + require.Equal(t, "INVALID_REGISTRATION_EMAIL_SUFFIX_WHITELIST", infraerrors.Reason(err)) +} + func TestParseDefaultSubscriptions_NormalizesValues(t *testing.T) { got := parseDefaultSubscriptions(`[{"group_id":11,"validity_days":30},{"group_id":11,"validity_days":60},{"group_id":0,"validity_days":10},{"group_id":12,"validity_days":99999}]`) require.Equal(t, []DefaultSubscriptionSetting{ diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go index ebb7693a..8734e28a 100644 --- a/backend/internal/service/settings_view.go +++ b/backend/internal/service/settings_view.go @@ -1,12 +1,13 @@ package service type SystemSettings struct { - RegistrationEnabled bool - EmailVerifyEnabled bool - PromoCodeEnabled bool - PasswordResetEnabled bool - InvitationCodeEnabled bool - TotpEnabled bool // TOTP 双因素认证 + RegistrationEnabled bool + EmailVerifyEnabled bool + RegistrationEmailSuffixWhitelist []string + PromoCodeEnabled bool + PasswordResetEnabled bool + InvitationCodeEnabled bool + TotpEnabled bool // TOTP 双因素认证 SMTPHost string SMTPPort int @@ -76,22 +77,23 @@ type DefaultSubscriptionSetting struct { } type PublicSettings struct { - RegistrationEnabled bool - EmailVerifyEnabled bool - PromoCodeEnabled bool - PasswordResetEnabled bool - InvitationCodeEnabled bool - TotpEnabled bool // TOTP 双因素认证 - TurnstileEnabled bool - TurnstileSiteKey string - SiteName string - SiteLogo string - SiteSubtitle string - APIBaseURL string - ContactInfo string - DocURL string - HomeContent string - HideCcsImportButton bool + RegistrationEnabled bool + EmailVerifyEnabled bool + RegistrationEmailSuffixWhitelist []string + PromoCodeEnabled bool + PasswordResetEnabled bool + InvitationCodeEnabled bool + TotpEnabled bool // TOTP 双因素认证 + TurnstileEnabled bool + TurnstileSiteKey string + SiteName string + SiteLogo string + SiteSubtitle string + APIBaseURL string + ContactInfo string + DocURL string + HomeContent string + HideCcsImportButton bool PurchaseSubscriptionEnabled bool PurchaseSubscriptionURL string @@ -173,3 +175,61 @@ func DefaultStreamTimeoutSettings() *StreamTimeoutSettings { ThresholdWindowMinutes: 10, } } + +// RectifierSettings 请求整流器配置 +type RectifierSettings struct { + Enabled bool `json:"enabled"` // 总开关 + ThinkingSignatureEnabled bool `json:"thinking_signature_enabled"` // Thinking 签名整流 + ThinkingBudgetEnabled bool `json:"thinking_budget_enabled"` // Thinking Budget 整流 +} + +// DefaultRectifierSettings 返回默认的整流器配置(全部启用) +func DefaultRectifierSettings() *RectifierSettings { + return &RectifierSettings{ + Enabled: true, + ThinkingSignatureEnabled: true, + ThinkingBudgetEnabled: true, + } +} + +// Beta Policy 策略常量 +const ( + BetaPolicyActionPass = "pass" // 透传,不做任何处理 + BetaPolicyActionFilter = "filter" // 过滤,从 beta header 中移除该 token + BetaPolicyActionBlock = "block" // 拦截,直接返回错误 + + BetaPolicyScopeAll = "all" // 所有账号类型 + BetaPolicyScopeOAuth = "oauth" // 仅 OAuth 账号 + BetaPolicyScopeAPIKey = "apikey" // 仅 API Key 账号 +) + +// BetaPolicyRule 单条 Beta 策略规则 +type BetaPolicyRule struct { + BetaToken string `json:"beta_token"` // beta token 值 + Action string `json:"action"` // "pass" | "filter" | "block" + Scope string `json:"scope"` // "all" | "oauth" | "apikey" + ErrorMessage string `json:"error_message,omitempty"` // 自定义错误消息 (action=block 时生效) +} + +// BetaPolicySettings Beta 策略配置 +type BetaPolicySettings struct { + Rules []BetaPolicyRule `json:"rules"` +} + +// DefaultBetaPolicySettings 返回默认的 Beta 策略配置 +func DefaultBetaPolicySettings() *BetaPolicySettings { + return &BetaPolicySettings{ + Rules: []BetaPolicyRule{ + { + BetaToken: "fast-mode-2026-02-01", + Action: BetaPolicyActionFilter, + Scope: BetaPolicyScopeAll, + }, + { + BetaToken: "context-1m-2025-08-07", + Action: BetaPolicyActionFilter, + Scope: BetaPolicyScopeAll, + }, + }, + } +} diff --git a/backend/internal/service/subscription_calculate_progress_test.go b/backend/internal/service/subscription_calculate_progress_test.go index 22018bcd..53e5c568 100644 --- a/backend/internal/service/subscription_calculate_progress_test.go +++ b/backend/internal/service/subscription_calculate_progress_test.go @@ -34,7 +34,7 @@ func TestCalculateProgress_BasicFields(t *testing.T) { assert.Equal(t, int64(100), progress.ID) assert.Equal(t, "Premium", progress.GroupName) assert.Equal(t, sub.ExpiresAt, progress.ExpiresAt) - assert.Equal(t, 29, progress.ExpiresInDays) // 约 30 天 + assert.True(t, progress.ExpiresInDays == 29 || progress.ExpiresInDays == 30, "ExpiresInDays should be 29 or 30, got %d", progress.ExpiresInDays) assert.Nil(t, progress.Daily, "无日限额时 Daily 应为 nil") assert.Nil(t, progress.Weekly, "无周限额时 Weekly 应为 nil") assert.Nil(t, progress.Monthly, "无月限额时 Monthly 应为 nil") diff --git a/backend/internal/service/subscription_reset_quota_test.go b/backend/internal/service/subscription_reset_quota_test.go new file mode 100644 index 00000000..36aa177f --- /dev/null +++ b/backend/internal/service/subscription_reset_quota_test.go @@ -0,0 +1,166 @@ +//go:build unit + +package service + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +// resetQuotaUserSubRepoStub 支持 GetByID、ResetDailyUsage、ResetWeeklyUsage, +// 其余方法继承 userSubRepoNoop(panic)。 +type resetQuotaUserSubRepoStub struct { + userSubRepoNoop + + sub *UserSubscription + + resetDailyCalled bool + resetWeeklyCalled bool + resetDailyErr error + resetWeeklyErr error +} + +func (r *resetQuotaUserSubRepoStub) GetByID(_ context.Context, id int64) (*UserSubscription, error) { + if r.sub == nil || r.sub.ID != id { + return nil, ErrSubscriptionNotFound + } + cp := *r.sub + return &cp, nil +} + +func (r *resetQuotaUserSubRepoStub) ResetDailyUsage(_ context.Context, _ int64, windowStart time.Time) error { + r.resetDailyCalled = true + if r.resetDailyErr == nil && r.sub != nil { + r.sub.DailyUsageUSD = 0 + r.sub.DailyWindowStart = &windowStart + } + return r.resetDailyErr +} + +func (r *resetQuotaUserSubRepoStub) ResetWeeklyUsage(_ context.Context, _ int64, _ time.Time) error { + r.resetWeeklyCalled = true + return r.resetWeeklyErr +} + +func newResetQuotaSvc(stub *resetQuotaUserSubRepoStub) *SubscriptionService { + return NewSubscriptionService(groupRepoNoop{}, stub, nil, nil, nil) +} + +func TestAdminResetQuota_ResetBoth(t *testing.T) { + stub := &resetQuotaUserSubRepoStub{ + sub: &UserSubscription{ID: 1, UserID: 10, GroupID: 20}, + } + svc := newResetQuotaSvc(stub) + + result, err := svc.AdminResetQuota(context.Background(), 1, true, true) + + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, stub.resetDailyCalled, "应调用 ResetDailyUsage") + require.True(t, stub.resetWeeklyCalled, "应调用 ResetWeeklyUsage") +} + +func TestAdminResetQuota_ResetDailyOnly(t *testing.T) { + stub := &resetQuotaUserSubRepoStub{ + sub: &UserSubscription{ID: 2, UserID: 10, GroupID: 20}, + } + svc := newResetQuotaSvc(stub) + + result, err := svc.AdminResetQuota(context.Background(), 2, true, false) + + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, stub.resetDailyCalled, "应调用 ResetDailyUsage") + require.False(t, stub.resetWeeklyCalled, "不应调用 ResetWeeklyUsage") +} + +func TestAdminResetQuota_ResetWeeklyOnly(t *testing.T) { + stub := &resetQuotaUserSubRepoStub{ + sub: &UserSubscription{ID: 3, UserID: 10, GroupID: 20}, + } + svc := newResetQuotaSvc(stub) + + result, err := svc.AdminResetQuota(context.Background(), 3, false, true) + + require.NoError(t, err) + require.NotNil(t, result) + require.False(t, stub.resetDailyCalled, "不应调用 ResetDailyUsage") + require.True(t, stub.resetWeeklyCalled, "应调用 ResetWeeklyUsage") +} + +func TestAdminResetQuota_BothFalseReturnsError(t *testing.T) { + stub := &resetQuotaUserSubRepoStub{ + sub: &UserSubscription{ID: 7, UserID: 10, GroupID: 20}, + } + svc := newResetQuotaSvc(stub) + + _, err := svc.AdminResetQuota(context.Background(), 7, false, false) + + require.ErrorIs(t, err, ErrInvalidInput) + require.False(t, stub.resetDailyCalled) + require.False(t, stub.resetWeeklyCalled) +} + +func TestAdminResetQuota_SubscriptionNotFound(t *testing.T) { + stub := &resetQuotaUserSubRepoStub{sub: nil} + svc := newResetQuotaSvc(stub) + + _, err := svc.AdminResetQuota(context.Background(), 999, true, true) + + require.ErrorIs(t, err, ErrSubscriptionNotFound) + require.False(t, stub.resetDailyCalled) + require.False(t, stub.resetWeeklyCalled) +} + +func TestAdminResetQuota_ResetDailyUsageError(t *testing.T) { + dbErr := errors.New("db error") + stub := &resetQuotaUserSubRepoStub{ + sub: &UserSubscription{ID: 4, UserID: 10, GroupID: 20}, + resetDailyErr: dbErr, + } + svc := newResetQuotaSvc(stub) + + _, err := svc.AdminResetQuota(context.Background(), 4, true, true) + + require.ErrorIs(t, err, dbErr) + require.True(t, stub.resetDailyCalled) + require.False(t, stub.resetWeeklyCalled, "daily 失败后不应继续调用 weekly") +} + +func TestAdminResetQuota_ResetWeeklyUsageError(t *testing.T) { + dbErr := errors.New("db error") + stub := &resetQuotaUserSubRepoStub{ + sub: &UserSubscription{ID: 5, UserID: 10, GroupID: 20}, + resetWeeklyErr: dbErr, + } + svc := newResetQuotaSvc(stub) + + _, err := svc.AdminResetQuota(context.Background(), 5, false, true) + + require.ErrorIs(t, err, dbErr) + require.True(t, stub.resetWeeklyCalled) +} + +func TestAdminResetQuota_ReturnsRefreshedSub(t *testing.T) { + stub := &resetQuotaUserSubRepoStub{ + sub: &UserSubscription{ + ID: 6, + UserID: 10, + GroupID: 20, + DailyUsageUSD: 99.9, + }, + } + + svc := newResetQuotaSvc(stub) + result, err := svc.AdminResetQuota(context.Background(), 6, true, false) + + require.NoError(t, err) + // ResetDailyUsage stub 会将 sub.DailyUsageUSD 归零, + // 服务应返回第二次 GetByID 的刷新值而非初始的 99.9 + require.Equal(t, float64(0), result.DailyUsageUSD, "返回的订阅应反映已归零的用量") + require.True(t, stub.resetDailyCalled) +} diff --git a/backend/internal/service/subscription_service.go b/backend/internal/service/subscription_service.go index 57e04266..55f029fa 100644 --- a/backend/internal/service/subscription_service.go +++ b/backend/internal/service/subscription_service.go @@ -31,6 +31,7 @@ var ( ErrSubscriptionAlreadyExists = infraerrors.Conflict("SUBSCRIPTION_ALREADY_EXISTS", "subscription already exists for this user and group") ErrSubscriptionAssignConflict = infraerrors.Conflict("SUBSCRIPTION_ASSIGN_CONFLICT", "subscription exists but request conflicts with existing assignment semantics") ErrGroupNotSubscriptionType = infraerrors.BadRequest("GROUP_NOT_SUBSCRIPTION_TYPE", "group is not a subscription type") + ErrInvalidInput = infraerrors.BadRequest("INVALID_INPUT", "at least one of resetDaily or resetWeekly must be true") ErrDailyLimitExceeded = infraerrors.TooManyRequests("DAILY_LIMIT_EXCEEDED", "daily usage limit exceeded") ErrWeeklyLimitExceeded = infraerrors.TooManyRequests("WEEKLY_LIMIT_EXCEEDED", "weekly usage limit exceeded") ErrMonthlyLimitExceeded = infraerrors.TooManyRequests("MONTHLY_LIMIT_EXCEEDED", "monthly usage limit exceeded") @@ -695,6 +696,36 @@ func (s *SubscriptionService) CheckAndActivateWindow(ctx context.Context, sub *U return s.userSubRepo.ActivateWindows(ctx, sub.ID, windowStart) } +// AdminResetQuota manually resets the daily and/or weekly usage windows. +// Uses startOfDay(now) as the new window start, matching automatic resets. +func (s *SubscriptionService) AdminResetQuota(ctx context.Context, subscriptionID int64, resetDaily, resetWeekly bool) (*UserSubscription, error) { + if !resetDaily && !resetWeekly { + return nil, ErrInvalidInput + } + sub, err := s.userSubRepo.GetByID(ctx, subscriptionID) + if err != nil { + return nil, err + } + windowStart := startOfDay(time.Now()) + if resetDaily { + if err := s.userSubRepo.ResetDailyUsage(ctx, sub.ID, windowStart); err != nil { + return nil, err + } + } + if resetWeekly { + if err := s.userSubRepo.ResetWeeklyUsage(ctx, sub.ID, windowStart); err != nil { + return nil, err + } + } + // Invalidate caches, same as CheckAndResetWindows + s.InvalidateSubCache(sub.UserID, sub.GroupID) + if s.billingCacheService != nil { + _ = s.billingCacheService.InvalidateSubscription(ctx, sub.UserID, sub.GroupID) + } + // Return the refreshed subscription from DB + return s.userSubRepo.GetByID(ctx, subscriptionID) +} + // CheckAndResetWindows 检查并重置过期的窗口 func (s *SubscriptionService) CheckAndResetWindows(ctx context.Context, sub *UserSubscription) error { // 使用当天零点作为新窗口起始时间 diff --git a/backend/internal/service/usage_log.go b/backend/internal/service/usage_log.go index c1a95541..a7464956 100644 --- a/backend/internal/service/usage_log.go +++ b/backend/internal/service/usage_log.go @@ -98,6 +98,8 @@ type UsageLog struct { AccountID int64 RequestID string Model string + // ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex". + ServiceTier *string // ReasoningEffort is the request's reasoning effort level (OpenAI Responses API), // e.g. "low" / "medium" / "high" / "xhigh". Nil means not provided / not applicable. ReasoningEffort *string diff --git a/backend/internal/service/user_group_rate.go b/backend/internal/service/user_group_rate.go index 9eb5f067..9908546e 100644 --- a/backend/internal/service/user_group_rate.go +++ b/backend/internal/service/user_group_rate.go @@ -2,6 +2,13 @@ package service import "context" +// UserGroupRateEntry 分组下用户专属倍率条目 +type UserGroupRateEntry struct { + UserID int64 `json:"user_id"` + UserEmail string `json:"user_email"` + RateMultiplier float64 `json:"rate_multiplier"` +} + // UserGroupRateRepository 用户专属分组倍率仓储接口 // 允许管理员为特定用户设置分组的专属计费倍率,覆盖分组默认倍率 type UserGroupRateRepository interface { @@ -13,6 +20,9 @@ type UserGroupRateRepository interface { // 如果未设置专属倍率,返回 nil GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) + // GetByGroupID 获取指定分组下所有用户的专属倍率 + GetByGroupID(ctx context.Context, groupID int64) ([]UserGroupRateEntry, error) + // SyncUserGroupRates 同步用户的分组专属倍率 // rates: map[groupID]*rateMultiplier,nil 表示删除该分组的专属倍率 SyncUserGroupRates(ctx context.Context, userID int64, rates map[int64]*float64) error diff --git a/backend/internal/service/user_group_rate_resolver.go b/backend/internal/service/user_group_rate_resolver.go new file mode 100644 index 00000000..7f0ffb0f --- /dev/null +++ b/backend/internal/service/user_group_rate_resolver.go @@ -0,0 +1,103 @@ +package service + +import ( + "context" + "fmt" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + gocache "github.com/patrickmn/go-cache" + "golang.org/x/sync/singleflight" +) + +type userGroupRateResolver struct { + repo UserGroupRateRepository + cache *gocache.Cache + cacheTTL time.Duration + sf *singleflight.Group + logComponent string +} + +func newUserGroupRateResolver(repo UserGroupRateRepository, cache *gocache.Cache, cacheTTL time.Duration, sf *singleflight.Group, logComponent string) *userGroupRateResolver { + if cacheTTL <= 0 { + cacheTTL = defaultUserGroupRateCacheTTL + } + if cache == nil { + cache = gocache.New(cacheTTL, time.Minute) + } + if logComponent == "" { + logComponent = "service.gateway" + } + if sf == nil { + sf = &singleflight.Group{} + } + + return &userGroupRateResolver{ + repo: repo, + cache: cache, + cacheTTL: cacheTTL, + sf: sf, + logComponent: logComponent, + } +} + +func (r *userGroupRateResolver) Resolve(ctx context.Context, userID, groupID int64, groupDefaultMultiplier float64) float64 { + if r == nil || userID <= 0 || groupID <= 0 { + return groupDefaultMultiplier + } + + key := fmt.Sprintf("%d:%d", userID, groupID) + if r.cache != nil { + if cached, ok := r.cache.Get(key); ok { + if multiplier, castOK := cached.(float64); castOK { + userGroupRateCacheHitTotal.Add(1) + return multiplier + } + } + } + if r.repo == nil { + return groupDefaultMultiplier + } + userGroupRateCacheMissTotal.Add(1) + + value, err, shared := r.sf.Do(key, func() (any, error) { + if r.cache != nil { + if cached, ok := r.cache.Get(key); ok { + if multiplier, castOK := cached.(float64); castOK { + userGroupRateCacheHitTotal.Add(1) + return multiplier, nil + } + } + } + + userGroupRateCacheLoadTotal.Add(1) + userRate, repoErr := r.repo.GetByUserAndGroup(ctx, userID, groupID) + if repoErr != nil { + return nil, repoErr + } + + multiplier := groupDefaultMultiplier + if userRate != nil { + multiplier = *userRate + } + if r.cache != nil { + r.cache.Set(key, multiplier, r.cacheTTL) + } + return multiplier, nil + }) + if shared { + userGroupRateCacheSFSharedTotal.Add(1) + } + if err != nil { + userGroupRateCacheFallbackTotal.Add(1) + logger.LegacyPrintf(r.logComponent, "get user group rate failed, fallback to group default: user=%d group=%d err=%v", userID, groupID, err) + return groupDefaultMultiplier + } + + multiplier, ok := value.(float64) + if !ok { + userGroupRateCacheFallbackTotal.Add(1) + return groupDefaultMultiplier + } + return multiplier +} diff --git a/backend/internal/service/user_group_rate_resolver_test.go b/backend/internal/service/user_group_rate_resolver_test.go new file mode 100644 index 00000000..064ef7ba --- /dev/null +++ b/backend/internal/service/user_group_rate_resolver_test.go @@ -0,0 +1,83 @@ +package service + +import ( + "context" + "testing" + "time" + + gocache "github.com/patrickmn/go-cache" + "github.com/stretchr/testify/require" +) + +type userGroupRateResolverRepoStub struct { + UserGroupRateRepository + + rate *float64 + err error + calls int +} + +func (s *userGroupRateResolverRepoStub) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) { + s.calls++ + if s.err != nil { + return nil, s.err + } + return s.rate, nil +} + +func TestNewUserGroupRateResolver_Defaults(t *testing.T) { + resolver := newUserGroupRateResolver(nil, nil, 0, nil, "") + + require.NotNil(t, resolver) + require.NotNil(t, resolver.cache) + require.Equal(t, defaultUserGroupRateCacheTTL, resolver.cacheTTL) + require.NotNil(t, resolver.sf) + require.Equal(t, "service.gateway", resolver.logComponent) +} + +func TestUserGroupRateResolverResolve_FallbackForNilResolverAndInvalidIDs(t *testing.T) { + var nilResolver *userGroupRateResolver + require.Equal(t, 1.4, nilResolver.Resolve(context.Background(), 101, 202, 1.4)) + + resolver := newUserGroupRateResolver(nil, nil, time.Second, nil, "service.test") + require.Equal(t, 1.4, resolver.Resolve(context.Background(), 0, 202, 1.4)) + require.Equal(t, 1.4, resolver.Resolve(context.Background(), 101, 0, 1.4)) +} + +func TestUserGroupRateResolverResolve_InvalidCacheEntryLoadsRepoAndCaches(t *testing.T) { + resetGatewayHotpathStatsForTest() + + rate := 1.7 + repo := &userGroupRateResolverRepoStub{rate: &rate} + cache := gocache.New(time.Minute, time.Minute) + cache.Set("101:202", "bad-cache", time.Minute) + resolver := newUserGroupRateResolver(repo, cache, time.Minute, nil, "service.test") + + got := resolver.Resolve(context.Background(), 101, 202, 1.2) + require.Equal(t, rate, got) + require.Equal(t, 1, repo.calls) + + cached, ok := cache.Get("101:202") + require.True(t, ok) + require.Equal(t, rate, cached) + + hit, miss, load, _, fallback := GatewayUserGroupRateCacheStats() + require.Equal(t, int64(0), hit) + require.Equal(t, int64(1), miss) + require.Equal(t, int64(1), load) + require.Equal(t, int64(0), fallback) +} + +func TestGatewayServiceGetUserGroupRateMultiplier_FallbacksAndUsesExistingResolver(t *testing.T) { + var nilSvc *GatewayService + require.Equal(t, 1.3, nilSvc.getUserGroupRateMultiplier(context.Background(), 101, 202, 1.3)) + + rate := 1.9 + repo := &userGroupRateResolverRepoStub{rate: &rate} + resolver := newUserGroupRateResolver(repo, nil, time.Minute, nil, "service.gateway") + svc := &GatewayService{userGroupRateResolver: resolver} + + got := svc.getUserGroupRateMultiplier(context.Background(), 101, 202, 1.2) + require.Equal(t, rate, got) + require.Equal(t, 1, repo.calls) +} diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go index b5553935..49ba3645 100644 --- a/backend/internal/service/user_service.go +++ b/backend/internal/service/user_service.go @@ -22,6 +22,10 @@ type UserListFilters struct { Role string // User role filter Search string // Search in email, username Attributes map[int64]string // Custom attribute filters: attributeID -> value + // IncludeSubscriptions controls whether ListWithFilters should load active subscriptions. + // For large datasets this can be expensive; admin list pages should enable it on demand. + // nil means not specified (default: load subscriptions for backward compatibility). + IncludeSubscriptions *bool } type UserRepository interface { diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index 920ab1cc..7457b77e 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -105,6 +105,9 @@ func ProvideDeferredService(accountRepo AccountRepository, timingWheel *TimingWh // ProvideConcurrencyService creates ConcurrencyService and starts slot cleanup worker. func ProvideConcurrencyService(cache ConcurrencyCache, accountRepo AccountRepository, cfg *config.Config) *ConcurrencyService { svc := NewConcurrencyService(cache) + if err := svc.CleanupStaleProcessSlots(context.Background()); err != nil { + logger.LegacyPrintf("service.concurrency", "Warning: startup cleanup stale process slots failed: %v", err) + } if cfg != nil { svc.StartSlotCleanupWorker(accountRepo, cfg.Gateway.Scheduling.SlotCleanupInterval) } @@ -274,6 +277,27 @@ func ProvideIdempotencyCleanupService(repo IdempotencyRepository, cfg *config.Co return svc } +// ProvideScheduledTestService creates ScheduledTestService. +func ProvideScheduledTestService( + planRepo ScheduledTestPlanRepository, + resultRepo ScheduledTestResultRepository, +) *ScheduledTestService { + return NewScheduledTestService(planRepo, resultRepo) +} + +// ProvideScheduledTestRunnerService creates and starts ScheduledTestRunnerService. +func ProvideScheduledTestRunnerService( + planRepo ScheduledTestPlanRepository, + scheduledSvc *ScheduledTestService, + accountTestSvc *AccountTestService, + rateLimitSvc *RateLimitService, + cfg *config.Config, +) *ScheduledTestRunnerService { + svc := NewScheduledTestRunnerService(planRepo, scheduledSvc, accountTestSvc, rateLimitSvc, cfg) + svc.Start() + return svc +} + // ProvideOpsScheduledReportService creates and starts OpsScheduledReportService. func ProvideOpsScheduledReportService( opsService *OpsService, @@ -380,4 +404,6 @@ var ProviderSet = wire.NewSet( ProvideIdempotencyCoordinator, ProvideSystemOperationLockService, ProvideIdempotencyCleanupService, + ProvideScheduledTestService, + ProvideScheduledTestRunnerService, ) diff --git a/backend/internal/setup/setup.go b/backend/internal/setup/setup.go index 83c32db3..de3b765a 100644 --- a/backend/internal/setup/setup.go +++ b/backend/internal/setup/setup.go @@ -12,6 +12,7 @@ import ( "strings" "time" + "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/repository" "github.com/Wei-Shaw/sub2api/internal/service" @@ -23,10 +24,19 @@ import ( // Config paths const ( - ConfigFileName = "config.yaml" - InstallLockFile = ".installed" + ConfigFileName = "config.yaml" + InstallLockFile = ".installed" + defaultUserConcurrency = 5 + simpleModeAdminConcurrency = 30 ) +func setupDefaultAdminConcurrency() int { + if strings.EqualFold(strings.TrimSpace(os.Getenv("RUN_MODE")), config.RunModeSimple) { + return simpleModeAdminConcurrency + } + return defaultUserConcurrency +} + // GetDataDir returns the data directory for storing config and lock files. // Priority: DATA_DIR env > /app/data (if exists and writable) > current directory func GetDataDir() string { @@ -390,7 +400,7 @@ func createAdminUser(cfg *SetupConfig) (bool, string, error) { Role: service.RoleAdmin, Status: service.StatusActive, Balance: 0, - Concurrency: 5, + Concurrency: setupDefaultAdminConcurrency(), CreatedAt: time.Now(), UpdatedAt: time.Now(), } @@ -462,7 +472,7 @@ func writeConfigFile(cfg *SetupConfig) error { APIKeyPrefix string `yaml:"api_key_prefix"` RateMultiplier float64 `yaml:"rate_multiplier"` }{ - UserConcurrency: 5, + UserConcurrency: defaultUserConcurrency, UserBalance: 0, APIKeyPrefix: "sk-", RateMultiplier: 1.0, diff --git a/backend/internal/setup/setup_test.go b/backend/internal/setup/setup_test.go index 69655e92..a01dd00c 100644 --- a/backend/internal/setup/setup_test.go +++ b/backend/internal/setup/setup_test.go @@ -1,6 +1,10 @@ package setup -import "testing" +import ( + "os" + "strings" + "testing" +) func TestDecideAdminBootstrap(t *testing.T) { t.Parallel() @@ -49,3 +53,37 @@ func TestDecideAdminBootstrap(t *testing.T) { }) } } + +func TestSetupDefaultAdminConcurrency(t *testing.T) { + t.Run("simple mode admin uses higher concurrency", func(t *testing.T) { + t.Setenv("RUN_MODE", "simple") + if got := setupDefaultAdminConcurrency(); got != simpleModeAdminConcurrency { + t.Fatalf("setupDefaultAdminConcurrency()=%d, want %d", got, simpleModeAdminConcurrency) + } + }) + + t.Run("standard mode keeps existing default", func(t *testing.T) { + t.Setenv("RUN_MODE", "standard") + if got := setupDefaultAdminConcurrency(); got != defaultUserConcurrency { + t.Fatalf("setupDefaultAdminConcurrency()=%d, want %d", got, defaultUserConcurrency) + } + }) +} + +func TestWriteConfigFileKeepsDefaultUserConcurrency(t *testing.T) { + t.Setenv("RUN_MODE", "simple") + t.Setenv("DATA_DIR", t.TempDir()) + + if err := writeConfigFile(&SetupConfig{}); err != nil { + t.Fatalf("writeConfigFile() error = %v", err) + } + + data, err := os.ReadFile(GetConfigFilePath()) + if err != nil { + t.Fatalf("ReadFile() error = %v", err) + } + + if !strings.Contains(string(data), "user_concurrency: 5") { + t.Fatalf("config missing default user concurrency, got:\n%s", string(data)) + } +} diff --git a/backend/internal/testutil/stubs.go b/backend/internal/testutil/stubs.go index 217a5f56..bc572e11 100644 --- a/backend/internal/testutil/stubs.go +++ b/backend/internal/testutil/stubs.go @@ -76,6 +76,9 @@ func (c StubConcurrencyCache) GetAccountConcurrencyBatch(_ context.Context, acco func (c StubConcurrencyCache) CleanupExpiredAccountSlots(_ context.Context, _ int64) error { return nil } +func (c StubConcurrencyCache) CleanupStaleProcessSlots(_ context.Context, _ string) error { + return nil +} // ============================================================ // StubGatewayCache — service.GatewayCache 的空实现 diff --git a/backend/internal/web/embed_on.go b/backend/internal/web/embed_on.go index f7ba5c9e..41ce4d48 100644 --- a/backend/internal/web/embed_on.go +++ b/backend/internal/web/embed_on.go @@ -83,14 +83,7 @@ func (s *FrontendServer) Middleware() gin.HandlerFunc { path := c.Request.URL.Path // Skip API routes - if strings.HasPrefix(path, "/api/") || - strings.HasPrefix(path, "/v1/") || - strings.HasPrefix(path, "/v1beta/") || - strings.HasPrefix(path, "/sora/") || - strings.HasPrefix(path, "/antigravity/") || - strings.HasPrefix(path, "/setup/") || - path == "/health" || - path == "/responses" { + if shouldBypassEmbeddedFrontend(path) { c.Next() return } @@ -207,14 +200,7 @@ func ServeEmbeddedFrontend() gin.HandlerFunc { return func(c *gin.Context) { path := c.Request.URL.Path - if strings.HasPrefix(path, "/api/") || - strings.HasPrefix(path, "/v1/") || - strings.HasPrefix(path, "/v1beta/") || - strings.HasPrefix(path, "/sora/") || - strings.HasPrefix(path, "/antigravity/") || - strings.HasPrefix(path, "/setup/") || - path == "/health" || - path == "/responses" { + if shouldBypassEmbeddedFrontend(path) { c.Next() return } @@ -235,6 +221,19 @@ func ServeEmbeddedFrontend() gin.HandlerFunc { } } +func shouldBypassEmbeddedFrontend(path string) bool { + trimmed := strings.TrimSpace(path) + return strings.HasPrefix(trimmed, "/api/") || + strings.HasPrefix(trimmed, "/v1/") || + strings.HasPrefix(trimmed, "/v1beta/") || + strings.HasPrefix(trimmed, "/sora/") || + strings.HasPrefix(trimmed, "/antigravity/") || + strings.HasPrefix(trimmed, "/setup/") || + trimmed == "/health" || + trimmed == "/responses" || + strings.HasPrefix(trimmed, "/responses/") +} + func serveIndexHTML(c *gin.Context, fsys fs.FS) { file, err := fsys.Open("index.html") if err != nil { diff --git a/backend/internal/web/embed_test.go b/backend/internal/web/embed_test.go index e2cbcf15..f270b624 100644 --- a/backend/internal/web/embed_test.go +++ b/backend/internal/web/embed_test.go @@ -367,6 +367,7 @@ func TestFrontendServer_Middleware(t *testing.T) { "/setup/init", "/health", "/responses", + "/responses/compact", } for _, path := range apiPaths { @@ -388,6 +389,32 @@ func TestFrontendServer_Middleware(t *testing.T) { } }) + t.Run("skips_responses_compact_post_routes", func(t *testing.T) { + provider := &mockSettingsProvider{ + settings: map[string]string{"test": "value"}, + } + + server, err := NewFrontendServer(provider) + require.NoError(t, err) + + router := gin.New() + router.Use(server.Middleware()) + nextCalled := false + router.POST("/responses/compact", func(c *gin.Context) { + nextCalled = true + c.String(http.StatusOK, `{"ok":true}`) + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/responses/compact", strings.NewReader(`{"model":"gpt-5"}`)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(w, req) + + assert.True(t, nextCalled, "next handler should be called for compact API route") + assert.Equal(t, http.StatusOK, w.Code) + assert.JSONEq(t, `{"ok":true}`, w.Body.String()) + }) + t.Run("serves_index_for_spa_routes", func(t *testing.T) { provider := &mockSettingsProvider{ settings: map[string]string{"test": "value"}, @@ -543,6 +570,7 @@ func TestServeEmbeddedFrontend(t *testing.T) { "/setup/init", "/health", "/responses", + "/responses/compact", } for _, path := range apiPaths { diff --git a/backend/migrations/056_add_sonnet46_to_model_mapping.sql b/backend/migrations/056_add_sonnet46_to_model_mapping.sql new file mode 100644 index 00000000..aa7657d7 --- /dev/null +++ b/backend/migrations/056_add_sonnet46_to_model_mapping.sql @@ -0,0 +1,42 @@ +-- Add claude-sonnet-4-6 to model_mapping for all Antigravity accounts +-- +-- Background: +-- Antigravity now supports claude-sonnet-4-6 +-- +-- Strategy: +-- Directly overwrite the entire model_mapping with updated mappings +-- This ensures consistency with DefaultAntigravityModelMapping in constants.go + +UPDATE accounts +SET credentials = jsonb_set( + credentials, + '{model_mapping}', + '{ + "claude-opus-4-6-thinking": "claude-opus-4-6-thinking", + "claude-opus-4-6": "claude-opus-4-6-thinking", + "claude-opus-4-5-thinking": "claude-opus-4-6-thinking", + "claude-opus-4-5-20251101": "claude-opus-4-6-thinking", + "claude-sonnet-4-6": "claude-sonnet-4-6", + "claude-sonnet-4-5": "claude-sonnet-4-5", + "claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking", + "claude-sonnet-4-5-20250929": "claude-sonnet-4-5", + "claude-haiku-4-5": "claude-sonnet-4-5", + "claude-haiku-4-5-20251001": "claude-sonnet-4-5", + "gemini-2.5-flash": "gemini-2.5-flash", + "gemini-2.5-flash-lite": "gemini-2.5-flash-lite", + "gemini-2.5-flash-thinking": "gemini-2.5-flash-thinking", + "gemini-2.5-pro": "gemini-2.5-pro", + "gemini-3-flash": "gemini-3-flash", + "gemini-3-pro-high": "gemini-3-pro-high", + "gemini-3-pro-low": "gemini-3-pro-low", + "gemini-3-pro-image": "gemini-3-pro-image", + "gemini-3-flash-preview": "gemini-3-flash", + "gemini-3-pro-preview": "gemini-3-pro-high", + "gemini-3-pro-image-preview": "gemini-3-pro-image", + "gpt-oss-120b-medium": "gpt-oss-120b-medium", + "tab_flash_lite_preview": "tab_flash_lite_preview" + }'::jsonb +) +WHERE platform = 'antigravity' + AND deleted_at IS NULL + AND credentials->'model_mapping' IS NOT NULL; diff --git a/backend/migrations/057_add_gemini31_pro_to_model_mapping.sql b/backend/migrations/057_add_gemini31_pro_to_model_mapping.sql new file mode 100644 index 00000000..6305e717 --- /dev/null +++ b/backend/migrations/057_add_gemini31_pro_to_model_mapping.sql @@ -0,0 +1,45 @@ +-- Add gemini-3.1-pro-high, gemini-3.1-pro-low, gemini-3.1-pro-preview to model_mapping +-- +-- Background: +-- Antigravity now supports gemini-3.1-pro-high and gemini-3.1-pro-low +-- +-- Strategy: +-- Directly overwrite the entire model_mapping with updated mappings +-- This ensures consistency with DefaultAntigravityModelMapping in constants.go + +UPDATE accounts +SET credentials = jsonb_set( + credentials, + '{model_mapping}', + '{ + "claude-opus-4-6-thinking": "claude-opus-4-6-thinking", + "claude-opus-4-6": "claude-opus-4-6-thinking", + "claude-opus-4-5-thinking": "claude-opus-4-6-thinking", + "claude-opus-4-5-20251101": "claude-opus-4-6-thinking", + "claude-sonnet-4-6": "claude-sonnet-4-6", + "claude-sonnet-4-5": "claude-sonnet-4-5", + "claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking", + "claude-sonnet-4-5-20250929": "claude-sonnet-4-5", + "claude-haiku-4-5": "claude-sonnet-4-5", + "claude-haiku-4-5-20251001": "claude-sonnet-4-5", + "gemini-2.5-flash": "gemini-2.5-flash", + "gemini-2.5-flash-lite": "gemini-2.5-flash-lite", + "gemini-2.5-flash-thinking": "gemini-2.5-flash-thinking", + "gemini-2.5-pro": "gemini-2.5-pro", + "gemini-3-flash": "gemini-3-flash", + "gemini-3-pro-high": "gemini-3-pro-high", + "gemini-3-pro-low": "gemini-3-pro-low", + "gemini-3-pro-image": "gemini-3-pro-image", + "gemini-3-flash-preview": "gemini-3-flash", + "gemini-3-pro-preview": "gemini-3-pro-high", + "gemini-3-pro-image-preview": "gemini-3-pro-image", + "gemini-3.1-pro-high": "gemini-3.1-pro-high", + "gemini-3.1-pro-low": "gemini-3.1-pro-low", + "gemini-3.1-pro-preview": "gemini-3.1-pro-high", + "gpt-oss-120b-medium": "gpt-oss-120b-medium", + "tab_flash_lite_preview": "tab_flash_lite_preview" + }'::jsonb +) +WHERE platform = 'antigravity' + AND deleted_at IS NULL + AND credentials->'model_mapping' IS NOT NULL; diff --git a/backend/migrations/060_add_group_simulate_claude_max.sql b/backend/migrations/060_add_group_simulate_claude_max.sql new file mode 100644 index 00000000..55662dfd --- /dev/null +++ b/backend/migrations/060_add_group_simulate_claude_max.sql @@ -0,0 +1,3 @@ +ALTER TABLE groups + ADD COLUMN IF NOT EXISTS simulate_claude_max_enabled BOOLEAN NOT NULL DEFAULT FALSE; + diff --git a/backend/migrations/065_add_search_trgm_indexes.sql b/backend/migrations/065_add_search_trgm_indexes.sql new file mode 100644 index 00000000..f5efb5da --- /dev/null +++ b/backend/migrations/065_add_search_trgm_indexes.sql @@ -0,0 +1,33 @@ +-- Improve admin fuzzy-search performance on large datasets. +-- Best effort: +-- 1) try enabling pg_trgm +-- 2) only create trigram indexes when extension is available +DO $$ +BEGIN + BEGIN + CREATE EXTENSION IF NOT EXISTS pg_trgm; + EXCEPTION + WHEN OTHERS THEN + RAISE NOTICE 'pg_trgm extension not created: %', SQLERRM; + END; + + IF EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'pg_trgm') THEN + EXECUTE 'CREATE INDEX IF NOT EXISTS idx_users_email_trgm + ON users USING gin (email gin_trgm_ops)'; + EXECUTE 'CREATE INDEX IF NOT EXISTS idx_users_username_trgm + ON users USING gin (username gin_trgm_ops)'; + EXECUTE 'CREATE INDEX IF NOT EXISTS idx_users_notes_trgm + ON users USING gin (notes gin_trgm_ops)'; + + EXECUTE 'CREATE INDEX IF NOT EXISTS idx_accounts_name_trgm + ON accounts USING gin (name gin_trgm_ops)'; + + EXECUTE 'CREATE INDEX IF NOT EXISTS idx_api_keys_key_trgm + ON api_keys USING gin ("key" gin_trgm_ops)'; + EXECUTE 'CREATE INDEX IF NOT EXISTS idx_api_keys_name_trgm + ON api_keys USING gin (name gin_trgm_ops)'; + ELSE + RAISE NOTICE 'skip trigram indexes because pg_trgm is unavailable'; + END IF; +END +$$; diff --git a/backend/migrations/066_add_scheduled_test_tables.sql b/backend/migrations/066_add_scheduled_test_tables.sql new file mode 100644 index 00000000..a9f839c0 --- /dev/null +++ b/backend/migrations/066_add_scheduled_test_tables.sql @@ -0,0 +1,30 @@ +-- 066_add_scheduled_test_tables.sql +-- Scheduled account test plans and results + +CREATE TABLE IF NOT EXISTS scheduled_test_plans ( + id BIGSERIAL PRIMARY KEY, + account_id BIGINT NOT NULL REFERENCES accounts(id) ON DELETE CASCADE, + model_id VARCHAR(100) NOT NULL DEFAULT '', + cron_expression VARCHAR(100) NOT NULL DEFAULT '*/30 * * * *', + enabled BOOLEAN NOT NULL DEFAULT true, + max_results INT NOT NULL DEFAULT 50, + last_run_at TIMESTAMPTZ, + next_run_at TIMESTAMPTZ, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); +CREATE INDEX IF NOT EXISTS idx_stp_account_id ON scheduled_test_plans(account_id); +CREATE INDEX IF NOT EXISTS idx_stp_enabled_next_run ON scheduled_test_plans(enabled, next_run_at) WHERE enabled = true; + +CREATE TABLE IF NOT EXISTS scheduled_test_results ( + id BIGSERIAL PRIMARY KEY, + plan_id BIGINT NOT NULL REFERENCES scheduled_test_plans(id) ON DELETE CASCADE, + status VARCHAR(20) NOT NULL DEFAULT 'success', + response_text TEXT NOT NULL DEFAULT '', + error_message TEXT NOT NULL DEFAULT '', + latency_ms BIGINT NOT NULL DEFAULT 0, + started_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + finished_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); +CREATE INDEX IF NOT EXISTS idx_str_plan_created ON scheduled_test_results(plan_id, created_at DESC); diff --git a/backend/migrations/067_add_account_load_factor.sql b/backend/migrations/067_add_account_load_factor.sql new file mode 100644 index 00000000..6805e8c2 --- /dev/null +++ b/backend/migrations/067_add_account_load_factor.sql @@ -0,0 +1 @@ +ALTER TABLE accounts ADD COLUMN IF NOT EXISTS load_factor INTEGER; diff --git a/backend/migrations/068_add_announcement_notify_mode.sql b/backend/migrations/068_add_announcement_notify_mode.sql new file mode 100644 index 00000000..28deb983 --- /dev/null +++ b/backend/migrations/068_add_announcement_notify_mode.sql @@ -0,0 +1 @@ +ALTER TABLE announcements ADD COLUMN IF NOT EXISTS notify_mode VARCHAR(20) NOT NULL DEFAULT 'silent'; diff --git a/backend/migrations/069_add_group_messages_dispatch.sql b/backend/migrations/069_add_group_messages_dispatch.sql new file mode 100644 index 00000000..7b9d5f5d --- /dev/null +++ b/backend/migrations/069_add_group_messages_dispatch.sql @@ -0,0 +1,2 @@ +ALTER TABLE groups ADD COLUMN allow_messages_dispatch BOOLEAN NOT NULL DEFAULT false; +ALTER TABLE groups ADD COLUMN default_mapped_model VARCHAR(100) NOT NULL DEFAULT ''; diff --git a/backend/migrations/070_add_scheduled_test_auto_recover.sql b/backend/migrations/070_add_scheduled_test_auto_recover.sql new file mode 100644 index 00000000..5f0c6789 --- /dev/null +++ b/backend/migrations/070_add_scheduled_test_auto_recover.sql @@ -0,0 +1,4 @@ +-- 070: Add auto_recover column to scheduled_test_plans +-- When enabled, automatically recovers account from error/rate-limited state on successful test + +ALTER TABLE scheduled_test_plans ADD COLUMN IF NOT EXISTS auto_recover BOOLEAN NOT NULL DEFAULT false; diff --git a/backend/migrations/070_add_usage_log_service_tier.sql b/backend/migrations/070_add_usage_log_service_tier.sql new file mode 100644 index 00000000..085ec0d6 --- /dev/null +++ b/backend/migrations/070_add_usage_log_service_tier.sql @@ -0,0 +1,5 @@ +ALTER TABLE usage_logs + ADD COLUMN IF NOT EXISTS service_tier VARCHAR(16); + +CREATE INDEX IF NOT EXISTS idx_usage_logs_service_tier_created_at + ON usage_logs (service_tier, created_at); diff --git a/backend/migrations/071_add_gemini25_flash_image_to_model_mapping.sql b/backend/migrations/071_add_gemini25_flash_image_to_model_mapping.sql new file mode 100644 index 00000000..f3cb3d37 --- /dev/null +++ b/backend/migrations/071_add_gemini25_flash_image_to_model_mapping.sql @@ -0,0 +1,51 @@ +-- Add gemini-2.5-flash-image aliases to Antigravity model_mapping +-- +-- Background: +-- Gemini native image generation now relies on gemini-2.5-flash-image, and +-- existing Antigravity accounts with persisted model_mapping need this alias in +-- order to participate in mixed scheduling from gemini groups. +-- +-- Strategy: +-- Overwrite the stored model_mapping so it matches DefaultAntigravityModelMapping +-- in constants.go, including legacy gemini-3-pro-image aliases. + +UPDATE accounts +SET credentials = jsonb_set( + credentials, + '{model_mapping}', + '{ + "claude-opus-4-6-thinking": "claude-opus-4-6-thinking", + "claude-opus-4-6": "claude-opus-4-6-thinking", + "claude-opus-4-5-thinking": "claude-opus-4-6-thinking", + "claude-opus-4-5-20251101": "claude-opus-4-6-thinking", + "claude-sonnet-4-6": "claude-sonnet-4-6", + "claude-sonnet-4-5": "claude-sonnet-4-5", + "claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking", + "claude-sonnet-4-5-20250929": "claude-sonnet-4-5", + "claude-haiku-4-5": "claude-sonnet-4-5", + "claude-haiku-4-5-20251001": "claude-sonnet-4-5", + "gemini-2.5-flash": "gemini-2.5-flash", + "gemini-2.5-flash-image": "gemini-2.5-flash-image", + "gemini-2.5-flash-image-preview": "gemini-2.5-flash-image", + "gemini-2.5-flash-lite": "gemini-2.5-flash-lite", + "gemini-2.5-flash-thinking": "gemini-2.5-flash-thinking", + "gemini-2.5-pro": "gemini-2.5-pro", + "gemini-3-flash": "gemini-3-flash", + "gemini-3-pro-high": "gemini-3-pro-high", + "gemini-3-pro-low": "gemini-3-pro-low", + "gemini-3-flash-preview": "gemini-3-flash", + "gemini-3-pro-preview": "gemini-3-pro-high", + "gemini-3.1-pro-high": "gemini-3.1-pro-high", + "gemini-3.1-pro-low": "gemini-3.1-pro-low", + "gemini-3.1-pro-preview": "gemini-3.1-pro-high", + "gemini-3.1-flash-image": "gemini-3.1-flash-image", + "gemini-3.1-flash-image-preview": "gemini-3.1-flash-image", + "gemini-3-pro-image": "gemini-3.1-flash-image", + "gemini-3-pro-image-preview": "gemini-3.1-flash-image", + "gpt-oss-120b-medium": "gpt-oss-120b-medium", + "tab_flash_lite_preview": "tab_flash_lite_preview" + }'::jsonb +) +WHERE platform = 'antigravity' + AND deleted_at IS NULL + AND credentials->'model_mapping' IS NOT NULL; diff --git a/backend/resources/model-pricing/model_prices_and_context_window.json b/backend/resources/model-pricing/model_prices_and_context_window.json index 650e128e..72860bf9 100644 --- a/backend/resources/model-pricing/model_prices_and_context_window.json +++ b/backend/resources/model-pricing/model_prices_and_context_window.json @@ -5140,6 +5140,39 @@ "supports_vision": true, "supports_web_search": true }, + "gpt-5.4": { + "cache_read_input_token_cost": 2.5e-07, + "input_cost_per_token": 2.5e-06, + "litellm_provider": "openai", + "max_input_tokens": 1050000, + "max_output_tokens": 128000, + "max_tokens": 128000, + "mode": "chat", + "output_cost_per_token": 1.5e-05, + "supported_endpoints": [ + "/v1/chat/completions", + "/v1/responses" + ], + "supported_modalities": [ + "text", + "image" + ], + "supported_output_modalities": [ + "text", + "image" + ], + "supports_function_calling": true, + "supports_native_streaming": true, + "supports_parallel_function_calling": true, + "supports_pdf_input": true, + "supports_prompt_caching": true, + "supports_reasoning": true, + "supports_response_schema": true, + "supports_service_tier": true, + "supports_system_messages": true, + "supports_tool_choice": true, + "supports_vision": true + }, "gpt-5.3-codex": { "cache_read_input_token_cost": 1.75e-07, "cache_read_input_token_cost_priority": 3.5e-07, diff --git a/deploy/Dockerfile b/deploy/Dockerfile index b3320300..ffe815e5 100644 --- a/deploy/Dockerfile +++ b/deploy/Dockerfile @@ -7,7 +7,7 @@ # ============================================================================= ARG NODE_IMAGE=node:24-alpine -ARG GOLANG_IMAGE=golang:1.25.5-alpine +ARG GOLANG_IMAGE=golang:1.26.1-alpine ARG ALPINE_IMAGE=alpine:3.20 ARG GOPROXY=https://goproxy.cn,direct ARG GOSUMDB=sum.golang.google.cn diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index e2eb3130..2058ced1 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -209,8 +209,9 @@ gateway: openai_ws: # 新版 WS mode 路由(默认关闭)。关闭时保持当前 legacy 实现行为。 mode_router_v2_enabled: false - # ingress 默认模式:off|shared|dedicated(仅 mode_router_v2_enabled=true 生效) - ingress_mode_default: shared + # ingress 默认模式:off|ctx_pool|passthrough(仅 mode_router_v2_enabled=true 生效) + # 兼容旧值:shared/dedicated 会按 ctx_pool 处理。 + ingress_mode_default: ctx_pool # 全局总开关,默认 true;关闭时所有请求保持原有 HTTP/SSE 路由 enabled: true # 按账号类型细分开关 diff --git a/deploy/docker-compose.yml b/deploy/docker-compose.yml index e5c97bf8..8715d75d 100644 --- a/deploy/docker-compose.yml +++ b/deploy/docker-compose.yml @@ -47,13 +47,15 @@ services: # ======================================================================= # Database Configuration (PostgreSQL) + # Default: uses local postgres container + # External DB: set DATABASE_HOST and DATABASE_SSLMODE in .env # ======================================================================= - - DATABASE_HOST=postgres - - DATABASE_PORT=5432 + - DATABASE_HOST=${DATABASE_HOST:-postgres} + - DATABASE_PORT=${DATABASE_PORT:-5432} - DATABASE_USER=${POSTGRES_USER:-sub2api} - DATABASE_PASSWORD=${POSTGRES_PASSWORD:?POSTGRES_PASSWORD is required} - DATABASE_DBNAME=${POSTGRES_DB:-sub2api} - - DATABASE_SSLMODE=disable + - DATABASE_SSLMODE=${DATABASE_SSLMODE:-disable} - DATABASE_MAX_OPEN_CONNS=${DATABASE_MAX_OPEN_CONNS:-50} - DATABASE_MAX_IDLE_CONNS=${DATABASE_MAX_IDLE_CONNS:-10} - DATABASE_CONN_MAX_LIFETIME_MINUTES=${DATABASE_CONN_MAX_LIFETIME_MINUTES:-30} @@ -139,8 +141,6 @@ services: # Examples: http://host:port, socks5://host:port - UPDATE_PROXY_URL=${UPDATE_PROXY_URL:-} depends_on: - postgres: - condition: service_healthy redis: condition: service_healthy networks: diff --git a/deploy/docker-deploy.sh b/deploy/docker-deploy.sh index 1e4ce81f..a07f4f41 100644 --- a/deploy/docker-deploy.sh +++ b/deploy/docker-deploy.sh @@ -8,7 +8,7 @@ # - Creates necessary data directories # # After running this script, you can start services with: -# docker-compose -f docker-compose.local.yml up -d +# docker-compose up -d # ============================================================================= set -e @@ -65,7 +65,7 @@ main() { fi # Check if deployment already exists - if [ -f "docker-compose.local.yml" ] && [ -f ".env" ]; then + if [ -f "docker-compose.yml" ] && [ -f ".env" ]; then print_warning "Deployment files already exist in current directory." read -p "Overwrite existing files? (y/N): " -r echo @@ -75,17 +75,17 @@ main() { fi fi - # Download docker-compose.local.yml - print_info "Downloading docker-compose.local.yml..." + # Download docker-compose.local.yml and save as docker-compose.yml + print_info "Downloading docker-compose.yml..." if command_exists curl; then - curl -sSL "${GITHUB_RAW_URL}/docker-compose.local.yml" -o docker-compose.local.yml + curl -sSL "${GITHUB_RAW_URL}/docker-compose.local.yml" -o docker-compose.yml elif command_exists wget; then - wget -q "${GITHUB_RAW_URL}/docker-compose.local.yml" -O docker-compose.local.yml + wget -q "${GITHUB_RAW_URL}/docker-compose.local.yml" -O docker-compose.yml else print_error "Neither curl nor wget is installed. Please install one of them." exit 1 fi - print_success "Downloaded docker-compose.local.yml" + print_success "Downloaded docker-compose.yml" # Download .env.example print_info "Downloading .env.example..." @@ -144,7 +144,7 @@ main() { print_warning "Please keep them secure and do not share publicly!" echo "" echo "Directory structure:" - echo " docker-compose.local.yml - Docker Compose configuration" + echo " docker-compose.yml - Docker Compose configuration" echo " .env - Environment variables (generated secrets)" echo " .env.example - Example template (for reference)" echo " data/ - Application data (will be created on first run)" @@ -154,10 +154,10 @@ main() { echo "Next steps:" echo " 1. (Optional) Edit .env to customize configuration" echo " 2. Start services:" - echo " docker-compose -f docker-compose.local.yml up -d" + echo " docker-compose up -d" echo "" echo " 3. View logs:" - echo " docker-compose -f docker-compose.local.yml logs -f sub2api" + echo " docker-compose logs -f sub2api" echo "" echo " 4. Access Web UI:" echo " http://localhost:8080" diff --git a/docs/ADMIN_PAYMENT_INTEGRATION_API.md b/docs/ADMIN_PAYMENT_INTEGRATION_API.md index 4cc21594..f674f86c 100644 --- a/docs/ADMIN_PAYMENT_INTEGRATION_API.md +++ b/docs/ADMIN_PAYMENT_INTEGRATION_API.md @@ -99,16 +99,17 @@ curl -X POST "${BASE}/api/v1/admin/users/123/balance" \ }' ``` -### 4) 购买页 URL Query 透传(iframe / 新窗口一致) -当 Sub2API 打开 `purchase_subscription_url` 时,会统一追加: +### 4) 购买页 / 自定义页面 URL Query 透传(iframe / 新窗口一致) +当 Sub2API 打开 `purchase_subscription_url` 或用户侧自定义页面 iframe URL 时,会统一追加: - `user_id` - `token` - `theme`(`light` / `dark`) +- `lang`(例如 `zh` / `en`,用于向嵌入页传递当前界面语言) - `ui_mode`(固定 `embedded`) 示例: ```text -https://pay.example.com/pay?user_id=123&token=&theme=light&ui_mode=embedded +https://pay.example.com/pay?user_id=123&token=&theme=light&lang=zh&ui_mode=embedded ``` ### 5) 失败处理建议 @@ -218,16 +219,17 @@ curl -X POST "${BASE}/api/v1/admin/users/123/balance" \ }' ``` -### 4) Purchase URL query forwarding (iframe and new tab) -When Sub2API opens `purchase_subscription_url`, it appends: +### 4) Purchase / Custom Page URL query forwarding (iframe and new tab) +When Sub2API opens `purchase_subscription_url` or a user-facing custom page iframe URL, it appends: - `user_id` - `token` - `theme` (`light` / `dark`) +- `lang` (for example `zh` / `en`, used to pass the current UI language to the embedded page) - `ui_mode` (fixed: `embedded`) Example: ```text -https://pay.example.com/pay?user_id=123&token=&theme=light&ui_mode=embedded +https://pay.example.com/pay?user_id=123&token=&theme=light&lang=zh&ui_mode=embedded ``` ### 5) Failure handling recommendations diff --git a/frontend/public/wechat-qr.jpg b/frontend/public/wechat-qr.jpg new file mode 100644 index 00000000..659068d8 Binary files /dev/null and b/frontend/public/wechat-qr.jpg differ diff --git a/frontend/src/App.vue b/frontend/src/App.vue index b831c9ff..4fc6a7c8 100644 --- a/frontend/src/App.vue +++ b/frontend/src/App.vue @@ -1,9 +1,10 @@ diff --git a/frontend/src/components/account/BulkEditAccountModal.vue b/frontend/src/components/account/BulkEditAccountModal.vue index 1c83e658..b7639359 100644 --- a/frontend/src/components/account/BulkEditAccountModal.vue +++ b/frontend/src/components/account/BulkEditAccountModal.vue @@ -5,7 +5,7 @@ width="wide" @close="handleClose" > -
+

@@ -469,7 +469,7 @@

-
+
+
+
+ + +
+ +

{{ t('admin.accounts.loadFactorHint') }}

+
+ +
+
+
+ +

+ {{ t('admin.accounts.poolModeHint') }} +

+
+ +
+
+

+ + {{ t('admin.accounts.poolModeInfo') }} +

+
+
+ + +

+ {{ + t('admin.accounts.poolModeRetryCountHint', { + default: DEFAULT_POOL_MODE_RETRY_COUNT, + max: MAX_POOL_MODE_RETRY_COUNT + }) + }} +

+
+
+
@@ -1227,6 +1279,160 @@
+ +
+
+

{{ t('admin.accounts.quotaLimit') }}

+

+ {{ t('admin.accounts.quotaLimitHint') }} +

+
+ +
+ + +
+ + +
+

+ {{ t('admin.accounts.openai.modelRestrictionDisabledByPassthrough') }} +

+
+ + +
+
@@ -1749,10 +1955,24 @@
-
+
- + +
+
+ + +

{{ t('admin.accounts.loadFactorHint') }}

+
+
+ + +

{{ t('admin.accounts.loadFactorHint') }}

@@ -1807,7 +2027,7 @@
- +

- {{ t('admin.accounts.openai.wsModeConcurrencyHint') }} + {{ t(openAIWSModeConcurrencyHintKey) }}

@@ -2337,14 +2557,16 @@ import Icon from '@/components/icons/Icon.vue' import ProxySelector from '@/components/common/ProxySelector.vue' import GroupSelector from '@/components/common/GroupSelector.vue' import ModelWhitelistSelector from '@/components/account/ModelWhitelistSelector.vue' +import QuotaLimitCard from '@/components/account/QuotaLimitCard.vue' import { applyInterceptWarmup } from '@/components/account/credentialsBuilder' import { formatDateTimeLocalInput, parseDateTimeLocalInput } from '@/utils/format' import { createStableObjectKeyResolver } from '@/utils/stableObjectKey' import { - OPENAI_WS_MODE_DEDICATED, + // OPENAI_WS_MODE_CTX_POOL, OPENAI_WS_MODE_OFF, - OPENAI_WS_MODE_SHARED, + OPENAI_WS_MODE_PASSTHROUGH, isOpenAIWSModeEnabled, + resolveOpenAIWSModeConcurrencyHintKey, type OpenAIWSMode } from '@/utils/openaiWsMode' import OAuthAuthorizationFlow from './OAuthAuthorizationFlow.vue' @@ -2459,9 +2681,16 @@ const accountCategory = ref<'oauth-based' | 'apikey'>('oauth-based') // UI selec const addMethod = ref('oauth') // For oauth-based: 'oauth' or 'setup-token' const apiKeyBaseUrl = ref('https://api.anthropic.com') const apiKeyValue = ref('') +const editQuotaLimit = ref(null) +const editQuotaDailyLimit = ref(null) +const editQuotaWeeklyLimit = ref(null) const modelMappings = ref([]) const modelRestrictionMode = ref<'whitelist' | 'mapping'>('whitelist') const allowedModels = ref([]) +const DEFAULT_POOL_MODE_RETRY_COUNT = 3 +const MAX_POOL_MODE_RETRY_COUNT = 10 +const poolModeEnabled = ref(false) +const poolModeRetryCount = ref(DEFAULT_POOL_MODE_RETRY_COUNT) const customErrorCodesEnabled = ref(false) const selectedErrorCodes = ref([]) const customErrorCodeInput = ref(null) @@ -2541,8 +2770,9 @@ const geminiSelectedTier = computed(() => { const openAIWSModeOptions = computed(() => [ { value: OPENAI_WS_MODE_OFF, label: t('admin.accounts.openai.wsModeOff') }, - { value: OPENAI_WS_MODE_SHARED, label: t('admin.accounts.openai.wsModeShared') }, - { value: OPENAI_WS_MODE_DEDICATED, label: t('admin.accounts.openai.wsModeDedicated') } + // TODO: ctx_pool 选项暂时隐藏,待测试完成后恢复 + // { value: OPENAI_WS_MODE_CTX_POOL, label: t('admin.accounts.openai.wsModeCtxPool') }, + { value: OPENAI_WS_MODE_PASSTHROUGH, label: t('admin.accounts.openai.wsModePassthrough') } ]) const openaiResponsesWebSocketV2Mode = computed({ @@ -2561,6 +2791,10 @@ const openaiResponsesWebSocketV2Mode = computed({ } }) +const openAIWSModeConcurrencyHintKey = computed(() => + resolveOpenAIWSModeConcurrencyHintKey(openaiResponsesWebSocketV2Mode.value) +) + const isOpenAIModelRestrictionDisabled = computed(() => form.platform === 'openai' && openaiPassthroughEnabled.value ) @@ -2627,6 +2861,7 @@ const form = reactive({ credentials: {} as Record, proxy_id: null as number | null, concurrency: 10, + load_factor: null as number | null, priority: 1, rate_multiplier: 1, group_ids: [] as number[], @@ -3106,6 +3341,7 @@ const resetForm = () => { form.credentials = {} form.proxy_id = null form.concurrency = 10 + form.load_factor = null form.priority = 1 form.rate_multiplier = 1 form.group_ids = [] @@ -3114,6 +3350,9 @@ const resetForm = () => { addMethod.value = 'oauth' apiKeyBaseUrl.value = 'https://api.anthropic.com' apiKeyValue.value = '' + editQuotaLimit.value = null + editQuotaDailyLimit.value = null + editQuotaWeeklyLimit.value = null modelMappings.value = [] modelRestrictionMode.value = 'whitelist' allowedModels.value = [...claudeModels] // Default fill related models @@ -3123,6 +3362,8 @@ const resetForm = () => { fetchAntigravityDefaultMappings().then(mappings => { antigravityModelMappings.value = [...mappings] }) + poolModeEnabled.value = false + poolModeRetryCount.value = DEFAULT_POOL_MODE_RETRY_COUNT customErrorCodesEnabled.value = false selectedErrorCodes.value = [] customErrorCodeInput.value = null @@ -3180,10 +3421,13 @@ const buildOpenAIExtra = (base?: Record): Record = { ...(base || {}) } - extra.openai_oauth_responses_websockets_v2_mode = openaiOAuthResponsesWebSocketV2Mode.value - extra.openai_apikey_responses_websockets_v2_mode = openaiAPIKeyResponsesWebSocketV2Mode.value - extra.openai_oauth_responses_websockets_v2_enabled = isOpenAIWSModeEnabled(openaiOAuthResponsesWebSocketV2Mode.value) - extra.openai_apikey_responses_websockets_v2_enabled = isOpenAIWSModeEnabled(openaiAPIKeyResponsesWebSocketV2Mode.value) + if (accountCategory.value === 'oauth-based') { + extra.openai_oauth_responses_websockets_v2_mode = openaiOAuthResponsesWebSocketV2Mode.value + extra.openai_oauth_responses_websockets_v2_enabled = isOpenAIWSModeEnabled(openaiOAuthResponsesWebSocketV2Mode.value) + } else if (accountCategory.value === 'apikey') { + extra.openai_apikey_responses_websockets_v2_mode = openaiAPIKeyResponsesWebSocketV2Mode.value + extra.openai_apikey_responses_websockets_v2_enabled = isOpenAIWSModeEnabled(openaiAPIKeyResponsesWebSocketV2Mode.value) + } // 清理兼容旧键,统一改用分类型开关。 delete extra.responses_websockets_v2_enabled delete extra.openai_ws_enabled @@ -3272,6 +3516,20 @@ const handleMixedChannelCancel = () => { clearMixedChannelDialog() } +const normalizePoolModeRetryCount = (value: number) => { + if (!Number.isFinite(value)) { + return DEFAULT_POOL_MODE_RETRY_COUNT + } + const normalized = Math.trunc(value) + if (normalized < 0) { + return 0 + } + if (normalized > MAX_POOL_MODE_RETRY_COUNT) { + return MAX_POOL_MODE_RETRY_COUNT + } + return normalized +} + const handleSubmit = async () => { // For OAuth-based type, handle OAuth flow (goes to step 2) if (isOAuthFlow.value) { @@ -3371,6 +3629,12 @@ const handleSubmit = async () => { } } + // Add pool mode if enabled + if (poolModeEnabled.value) { + credentials.pool_mode = true + credentials.pool_mode_retry_count = normalizePoolModeRetryCount(poolModeRetryCount.value) + } + // Add custom error codes if enabled if (customErrorCodesEnabled.value) { credentials.custom_error_codes_enabled = true @@ -3474,6 +3738,7 @@ const handleImportAccessToken = async (accessTokenInput: string) => { extra: soraExtra, proxy_id: form.proxy_id, concurrency: form.concurrency, + load_factor: form.load_factor ?? undefined, priority: form.priority, rate_multiplier: form.rate_multiplier, group_ids: form.group_ids, @@ -3524,15 +3789,33 @@ const createAccountAndFinish = async ( if (!applyTempUnschedConfig(credentials)) { return } + // Inject quota limits for apikey accounts + let finalExtra = extra + if (type === 'apikey') { + const quotaExtra: Record = { ...(extra || {}) } + if (editQuotaLimit.value != null && editQuotaLimit.value > 0) { + quotaExtra.quota_limit = editQuotaLimit.value + } + if (editQuotaDailyLimit.value != null && editQuotaDailyLimit.value > 0) { + quotaExtra.quota_daily_limit = editQuotaDailyLimit.value + } + if (editQuotaWeeklyLimit.value != null && editQuotaWeeklyLimit.value > 0) { + quotaExtra.quota_weekly_limit = editQuotaWeeklyLimit.value + } + if (Object.keys(quotaExtra).length > 0) { + finalExtra = quotaExtra + } + } await doCreateAccount({ name: form.name, notes: form.notes, platform, type, credentials, - extra, + extra: finalExtra, proxy_id: form.proxy_id, concurrency: form.concurrency, + load_factor: form.load_factor ?? undefined, priority: form.priority, rate_multiplier: form.rate_multiplier, group_ids: form.group_ids, @@ -3571,6 +3854,14 @@ const handleOpenAIExchange = async (authCode: string) => { const shouldCreateOpenAI = form.platform === 'openai' const shouldCreateSora = form.platform === 'sora' + // Add model mapping for OpenAI OAuth accounts(透传模式下不应用) + if (shouldCreateOpenAI && !isOpenAIModelRestrictionDisabled.value) { + const modelMapping = buildModelMappingObject(modelRestrictionMode.value, allowedModels.value, modelMappings.value) + if (modelMapping) { + credentials.model_mapping = modelMapping + } + } + // 应用临时不可调度配置 if (!applyTempUnschedConfig(credentials)) { return @@ -3588,6 +3879,7 @@ const handleOpenAIExchange = async (authCode: string) => { extra, proxy_id: form.proxy_id, concurrency: form.concurrency, + load_factor: form.load_factor ?? undefined, priority: form.priority, rate_multiplier: form.rate_multiplier, group_ids: form.group_ids, @@ -3617,6 +3909,7 @@ const handleOpenAIExchange = async (authCode: string) => { extra: soraExtra, proxy_id: form.proxy_id, concurrency: form.concurrency, + load_factor: form.load_factor ?? undefined, priority: form.priority, rate_multiplier: form.rate_multiplier, group_ids: form.group_ids, @@ -3679,6 +3972,14 @@ const handleOpenAIValidateRT = async (refreshTokenInput: string) => { const oauthExtra = oauthClient.buildExtraInfo(tokenInfo) as Record | undefined const extra = buildOpenAIExtra(oauthExtra) + // Add model mapping for OpenAI OAuth accounts(透传模式下不应用) + if (shouldCreateOpenAI && !isOpenAIModelRestrictionDisabled.value) { + const modelMapping = buildModelMappingObject(modelRestrictionMode.value, allowedModels.value, modelMappings.value) + if (modelMapping) { + credentials.model_mapping = modelMapping + } + } + // Generate account name with index for batch const accountName = refreshTokens.length > 1 ? `${form.name} #${i + 1}` : form.name @@ -3694,6 +3995,7 @@ const handleOpenAIValidateRT = async (refreshTokenInput: string) => { extra, proxy_id: form.proxy_id, concurrency: form.concurrency, + load_factor: form.load_factor ?? undefined, priority: form.priority, rate_multiplier: form.rate_multiplier, group_ids: form.group_ids, @@ -3721,6 +4023,7 @@ const handleOpenAIValidateRT = async (refreshTokenInput: string) => { extra: soraExtra, proxy_id: form.proxy_id, concurrency: form.concurrency, + load_factor: form.load_factor ?? undefined, priority: form.priority, rate_multiplier: form.rate_multiplier, group_ids: form.group_ids, @@ -3809,6 +4112,7 @@ const handleSoraValidateST = async (sessionTokenInput: string) => { extra: soraExtra, proxy_id: form.proxy_id, concurrency: form.concurrency, + load_factor: form.load_factor ?? undefined, priority: form.priority, rate_multiplier: form.rate_multiplier, group_ids: form.group_ids, @@ -3897,6 +4201,7 @@ const handleAntigravityValidateRT = async (refreshTokenInput: string) => { extra: {}, proxy_id: form.proxy_id, concurrency: form.concurrency, + load_factor: form.load_factor ?? undefined, priority: form.priority, rate_multiplier: form.rate_multiplier, group_ids: form.group_ids, @@ -4055,8 +4360,11 @@ const handleAnthropicExchange = async (authCode: string) => { } // Add RPM limit settings - if (rpmLimitEnabled.value && baseRpm.value != null && baseRpm.value > 0) { - extra.base_rpm = baseRpm.value + if (rpmLimitEnabled.value) { + const DEFAULT_BASE_RPM = 15 + extra.base_rpm = (baseRpm.value != null && baseRpm.value > 0) + ? baseRpm.value + : DEFAULT_BASE_RPM extra.rpm_strategy = rpmStrategy.value if (rpmStickyBuffer.value != null && rpmStickyBuffer.value > 0) { extra.rpm_sticky_buffer = rpmStickyBuffer.value @@ -4167,8 +4475,11 @@ const handleCookieAuth = async (sessionKey: string) => { } // Add RPM limit settings - if (rpmLimitEnabled.value && baseRpm.value != null && baseRpm.value > 0) { - extra.base_rpm = baseRpm.value + if (rpmLimitEnabled.value) { + const DEFAULT_BASE_RPM = 15 + extra.base_rpm = (baseRpm.value != null && baseRpm.value > 0) + ? baseRpm.value + : DEFAULT_BASE_RPM extra.rpm_strategy = rpmStrategy.value if (rpmStickyBuffer.value != null && rpmStickyBuffer.value > 0) { extra.rpm_sticky_buffer = rpmStickyBuffer.value @@ -4214,6 +4525,7 @@ const handleCookieAuth = async (sessionKey: string) => { extra, proxy_id: form.proxy_id, concurrency: form.concurrency, + load_factor: form.load_factor ?? undefined, priority: form.priority, rate_multiplier: form.rate_multiplier, group_ids: form.group_ids, diff --git a/frontend/src/components/account/EditAccountModal.vue b/frontend/src/components/account/EditAccountModal.vue index 24166a5c..14252351 100644 --- a/frontend/src/components/account/EditAccountModal.vue +++ b/frontend/src/components/account/EditAccountModal.vue @@ -251,6 +251,58 @@
+ +
+
+
+ +

+ {{ t('admin.accounts.poolModeHint') }} +

+
+ +
+
+

+ + {{ t('admin.accounts.poolModeInfo') }} +

+
+
+ + +

+ {{ + t('admin.accounts.poolModeRetryCountHint', { + default: DEFAULT_POOL_MODE_RETRY_COUNT, + max: MAX_POOL_MODE_RETRY_COUNT + }) + }} +

+
+
+
@@ -351,6 +403,142 @@
+ +
+ + +
+

+ {{ t('admin.accounts.openai.modelRestrictionDisabledByPassthrough') }} +

+
+ + +
+
@@ -650,10 +838,24 @@
-
+
- + +
+
+ + +

{{ t('admin.accounts.loadFactorHint') }}

+
+
+ + +

{{ t('admin.accounts.loadFactorHint') }}

@@ -708,7 +910,7 @@
- +

- {{ t('admin.accounts.openai.wsModeConcurrencyHint') }} + {{ t(openAIWSModeConcurrencyHintKey) }}

@@ -759,6 +961,24 @@
+ +
+
+

{{ t('admin.accounts.quotaLimit') }}

+

+ {{ t('admin.accounts.quotaLimitHint') }} +

+
+ +
+
([]) const modelRestrictionMode = ref<'whitelist' | 'mapping'>('whitelist') const allowedModels = ref([]) +const DEFAULT_POOL_MODE_RETRY_COUNT = 3 +const MAX_POOL_MODE_RETRY_COUNT = 10 +const poolModeEnabled = ref(false) +const poolModeRetryCount = ref(DEFAULT_POOL_MODE_RETRY_COUNT) const customErrorCodesEnabled = ref(false) const selectedErrorCodes = ref([]) const customErrorCodeInput = ref(null) @@ -1385,10 +1611,14 @@ const openaiOAuthResponsesWebSocketV2Mode = ref(OPENAI_WS_MODE_OFF const openaiAPIKeyResponsesWebSocketV2Mode = ref(OPENAI_WS_MODE_OFF) const codexCLIOnlyEnabled = ref(false) const anthropicPassthroughEnabled = ref(false) +const editQuotaLimit = ref(null) +const editQuotaDailyLimit = ref(null) +const editQuotaWeeklyLimit = ref(null) const openAIWSModeOptions = computed(() => [ { value: OPENAI_WS_MODE_OFF, label: t('admin.accounts.openai.wsModeOff') }, - { value: OPENAI_WS_MODE_SHARED, label: t('admin.accounts.openai.wsModeShared') }, - { value: OPENAI_WS_MODE_DEDICATED, label: t('admin.accounts.openai.wsModeDedicated') } + // TODO: ctx_pool 选项暂时隐藏,待测试完成后恢复 + // { value: OPENAI_WS_MODE_CTX_POOL, label: t('admin.accounts.openai.wsModeCtxPool') }, + { value: OPENAI_WS_MODE_PASSTHROUGH, label: t('admin.accounts.openai.wsModePassthrough') } ]) const openaiResponsesWebSocketV2Mode = computed({ get: () => { @@ -1405,6 +1635,9 @@ const openaiResponsesWebSocketV2Mode = computed({ openaiOAuthResponsesWebSocketV2Mode.value = mode } }) +const openAIWSModeConcurrencyHintKey = computed(() => + resolveOpenAIWSModeConcurrencyHintKey(openaiResponsesWebSocketV2Mode.value) +) const isOpenAIModelRestrictionDisabled = computed(() => props.account?.platform === 'openai' && openaiPassthroughEnabled.value ) @@ -1460,17 +1693,24 @@ const form = reactive({ notes: '', proxy_id: null as number | null, concurrency: 1, + load_factor: null as number | null, priority: 1, rate_multiplier: 1, - status: 'active' as 'active' | 'inactive', + status: 'active' as 'active' | 'inactive' | 'error', group_ids: [] as number[], expires_at: null as number | null }) -const statusOptions = computed(() => [ - { value: 'active', label: t('common.active') }, - { value: 'inactive', label: t('common.inactive') } -]) +const statusOptions = computed(() => { + const options = [ + { value: 'active', label: t('common.active') }, + { value: 'inactive', label: t('common.inactive') } + ] + if (form.status === 'error') { + options.push({ value: 'error', label: t('admin.accounts.status.error') }) + } + return options +}) const expiresAtInput = computed({ get: () => formatDateTimeLocal(form.expires_at), @@ -1480,6 +1720,20 @@ const expiresAtInput = computed({ }) // Watchers +const normalizePoolModeRetryCount = (value: number) => { + if (!Number.isFinite(value)) { + return DEFAULT_POOL_MODE_RETRY_COUNT + } + const normalized = Math.trunc(value) + if (normalized < 0) { + return 0 + } + if (normalized > MAX_POOL_MODE_RETRY_COUNT) { + return MAX_POOL_MODE_RETRY_COUNT + } + return normalized +} + watch( () => props.account, (newAccount) => { @@ -1493,9 +1747,12 @@ watch( form.notes = newAccount.notes || '' form.proxy_id = newAccount.proxy_id form.concurrency = newAccount.concurrency + form.load_factor = newAccount.load_factor ?? null form.priority = newAccount.priority form.rate_multiplier = newAccount.rate_multiplier ?? 1 - form.status = newAccount.status as 'active' | 'inactive' + form.status = (newAccount.status === 'active' || newAccount.status === 'inactive' || newAccount.status === 'error') + ? newAccount.status + : 'active' form.group_ids = newAccount.group_ids || [] form.expires_at = newAccount.expires_at ?? null @@ -1536,6 +1793,20 @@ watch( anthropicPassthroughEnabled.value = extra?.anthropic_passthrough === true } + // Load quota limit for apikey accounts + if (newAccount.type === 'apikey') { + const quotaVal = extra?.quota_limit as number | undefined + editQuotaLimit.value = (quotaVal && quotaVal > 0) ? quotaVal : null + const dailyVal = extra?.quota_daily_limit as number | undefined + editQuotaDailyLimit.value = (dailyVal && dailyVal > 0) ? dailyVal : null + const weeklyVal = extra?.quota_weekly_limit as number | undefined + editQuotaWeeklyLimit.value = (weeklyVal && weeklyVal > 0) ? weeklyVal : null + } else { + editQuotaLimit.value = null + editQuotaDailyLimit.value = null + editQuotaWeeklyLimit.value = null + } + // Load antigravity model mapping (Antigravity 只支持映射模式) if (newAccount.platform === 'antigravity') { const credentials = newAccount.credentials as Record | undefined @@ -1610,6 +1881,12 @@ watch( allowedModels.value = [] } + // Load pool mode + poolModeEnabled.value = credentials.pool_mode === true + poolModeRetryCount.value = normalizePoolModeRetryCount( + Number(credentials.pool_mode_retry_count ?? DEFAULT_POOL_MODE_RETRY_COUNT) + ) + // Load custom error codes customErrorCodesEnabled.value = credentials.custom_error_codes_enabled === true const existingErrorCodes = credentials.custom_error_codes as number[] | undefined @@ -1629,9 +1906,35 @@ watch( ? 'https://generativelanguage.googleapis.com' : 'https://api.anthropic.com' editBaseUrl.value = platformDefaultUrl - modelRestrictionMode.value = 'whitelist' - modelMappings.value = [] - allowedModels.value = [] + + // Load model mappings for OpenAI OAuth accounts + if (newAccount.platform === 'openai' && newAccount.credentials) { + const oauthCredentials = newAccount.credentials as Record + const existingMappings = oauthCredentials.model_mapping as Record | undefined + if (existingMappings && typeof existingMappings === 'object') { + const entries = Object.entries(existingMappings) + const isWhitelistMode = entries.length > 0 && entries.every(([from, to]) => from === to) + if (isWhitelistMode) { + modelRestrictionMode.value = 'whitelist' + allowedModels.value = entries.map(([from]) => from) + modelMappings.value = [] + } else { + modelRestrictionMode.value = 'mapping' + modelMappings.value = entries.map(([from, to]) => ({ from, to })) + allowedModels.value = [] + } + } else { + modelRestrictionMode.value = 'whitelist' + modelMappings.value = [] + allowedModels.value = [] + } + } else { + modelRestrictionMode.value = 'whitelist' + modelMappings.value = [] + allowedModels.value = [] + } + poolModeEnabled.value = false + poolModeRetryCount.value = DEFAULT_POOL_MODE_RETRY_COUNT customErrorCodesEnabled.value = false selectedErrorCodes.value = [] } @@ -2035,6 +2338,11 @@ const handleSubmit = async () => { if (!props.account) return const accountID = props.account.id + if (form.status !== 'active' && form.status !== 'inactive' && form.status !== 'error') { + appStore.showError(t('admin.accounts.pleaseSelectStatus')) + return + } + const updatePayload: Record = { ...form } try { // 后端期望 proxy_id: 0 表示清除代理,而不是 null @@ -2044,6 +2352,11 @@ const handleSubmit = async () => { if (form.expires_at === null) { updatePayload.expires_at = 0 } + // load_factor: 空值/NaN/0/负数 时发送 0(后端约定 <= 0 = 清除) + const lf = form.load_factor + if (lf == null || Number.isNaN(lf) || lf <= 0) { + updatePayload.load_factor = 0 + } updatePayload.auto_pause_on_expired = autoPauseOnExpired.value // For apikey type, handle credentials update @@ -2054,6 +2367,7 @@ const handleSubmit = async () => { // Always update credentials for apikey type to handle model mapping changes const newCredentials: Record = { + ...currentCredentials, base_url: newBaseUrl } @@ -2074,15 +2388,29 @@ const handleSubmit = async () => { const modelMapping = buildModelMappingObject(modelRestrictionMode.value, allowedModels.value, modelMappings.value) if (modelMapping) { newCredentials.model_mapping = modelMapping + } else { + delete newCredentials.model_mapping } } else if (currentCredentials.model_mapping) { newCredentials.model_mapping = currentCredentials.model_mapping } + // Add pool mode if enabled + if (poolModeEnabled.value) { + newCredentials.pool_mode = true + newCredentials.pool_mode_retry_count = normalizePoolModeRetryCount(poolModeRetryCount.value) + } else { + delete newCredentials.pool_mode + delete newCredentials.pool_mode_retry_count + } + // Add custom error codes if enabled if (customErrorCodesEnabled.value) { newCredentials.custom_error_codes_enabled = true newCredentials.custom_error_codes = [...selectedErrorCodes.value] + } else { + delete newCredentials.custom_error_codes_enabled + delete newCredentials.custom_error_codes } // Add intercept warmup requests setting @@ -2123,6 +2451,28 @@ const handleSubmit = async () => { updatePayload.credentials = newCredentials } + // OpenAI OAuth: persist model mapping to credentials + if (props.account.platform === 'openai' && props.account.type === 'oauth') { + const currentCredentials = (updatePayload.credentials as Record) || + ((props.account.credentials as Record) || {}) + const newCredentials: Record = { ...currentCredentials } + const shouldApplyModelMapping = !openaiPassthroughEnabled.value + + if (shouldApplyModelMapping) { + const modelMapping = buildModelMappingObject(modelRestrictionMode.value, allowedModels.value, modelMappings.value) + if (modelMapping) { + newCredentials.model_mapping = modelMapping + } else { + delete newCredentials.model_mapping + } + } else if (currentCredentials.model_mapping) { + // 透传模式保留现有映射 + newCredentials.model_mapping = currentCredentials.model_mapping + } + + updatePayload.credentials = newCredentials + } + // Antigravity: persist model mapping to credentials (applies to all antigravity types) // Antigravity 只支持映射模式 if (props.account.platform === 'antigravity') { @@ -2183,8 +2533,11 @@ const handleSubmit = async () => { } // RPM limit settings - if (rpmLimitEnabled.value && baseRpm.value != null && baseRpm.value > 0) { - newExtra.base_rpm = baseRpm.value + if (rpmLimitEnabled.value) { + const DEFAULT_BASE_RPM = 15 + newExtra.base_rpm = (baseRpm.value != null && baseRpm.value > 0) + ? baseRpm.value + : DEFAULT_BASE_RPM newExtra.rpm_strategy = rpmStrategy.value if (rpmStickyBuffer.value != null && rpmStickyBuffer.value > 0) { newExtra.rpm_sticky_buffer = rpmStickyBuffer.value @@ -2248,10 +2601,13 @@ const handleSubmit = async () => { const currentExtra = (props.account.extra as Record) || {} const newExtra: Record = { ...currentExtra } const hadCodexCLIOnlyEnabled = currentExtra.codex_cli_only === true - newExtra.openai_oauth_responses_websockets_v2_mode = openaiOAuthResponsesWebSocketV2Mode.value - newExtra.openai_apikey_responses_websockets_v2_mode = openaiAPIKeyResponsesWebSocketV2Mode.value - newExtra.openai_oauth_responses_websockets_v2_enabled = isOpenAIWSModeEnabled(openaiOAuthResponsesWebSocketV2Mode.value) - newExtra.openai_apikey_responses_websockets_v2_enabled = isOpenAIWSModeEnabled(openaiAPIKeyResponsesWebSocketV2Mode.value) + if (props.account.type === 'oauth') { + newExtra.openai_oauth_responses_websockets_v2_mode = openaiOAuthResponsesWebSocketV2Mode.value + newExtra.openai_oauth_responses_websockets_v2_enabled = isOpenAIWSModeEnabled(openaiOAuthResponsesWebSocketV2Mode.value) + } else if (props.account.type === 'apikey') { + newExtra.openai_apikey_responses_websockets_v2_mode = openaiAPIKeyResponsesWebSocketV2Mode.value + newExtra.openai_apikey_responses_websockets_v2_enabled = isOpenAIWSModeEnabled(openaiAPIKeyResponsesWebSocketV2Mode.value) + } delete newExtra.responses_websockets_v2_enabled delete newExtra.openai_ws_enabled if (openaiPassthroughEnabled.value) { @@ -2275,6 +2631,29 @@ const handleSubmit = async () => { updatePayload.extra = newExtra } + // For apikey accounts, handle quota_limit in extra + if (props.account.type === 'apikey') { + const currentExtra = (updatePayload.extra as Record) || + (props.account.extra as Record) || {} + const newExtra: Record = { ...currentExtra } + if (editQuotaLimit.value != null && editQuotaLimit.value > 0) { + newExtra.quota_limit = editQuotaLimit.value + } else { + delete newExtra.quota_limit + } + if (editQuotaDailyLimit.value != null && editQuotaDailyLimit.value > 0) { + newExtra.quota_daily_limit = editQuotaDailyLimit.value + } else { + delete newExtra.quota_daily_limit + } + if (editQuotaWeeklyLimit.value != null && editQuotaWeeklyLimit.value > 0) { + newExtra.quota_weekly_limit = editQuotaWeeklyLimit.value + } else { + delete newExtra.quota_weekly_limit + } + updatePayload.extra = newExtra + } + const canContinue = await ensureAntigravityMixedChannelConfirmed(async () => { await submitUpdateAccount(accountID, updatePayload) }) diff --git a/frontend/src/components/account/QuotaBadge.vue b/frontend/src/components/account/QuotaBadge.vue new file mode 100644 index 00000000..7cf0f59d --- /dev/null +++ b/frontend/src/components/account/QuotaBadge.vue @@ -0,0 +1,49 @@ + + + diff --git a/frontend/src/components/account/QuotaLimitCard.vue b/frontend/src/components/account/QuotaLimitCard.vue new file mode 100644 index 00000000..505118ba --- /dev/null +++ b/frontend/src/components/account/QuotaLimitCard.vue @@ -0,0 +1,137 @@ + + + diff --git a/frontend/src/components/account/TempUnschedStatusModal.vue b/frontend/src/components/account/TempUnschedStatusModal.vue index b2c0b71b..a3e64c48 100644 --- a/frontend/src/components/account/TempUnschedStatusModal.vue +++ b/frontend/src/components/account/TempUnschedStatusModal.vue @@ -29,6 +29,10 @@
+
+ {{ t('admin.accounts.recoverStateHint') }} +
+

{{ t('admin.accounts.tempUnschedulable.accountName') }} @@ -131,7 +135,7 @@ d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z" > - {{ t('admin.accounts.tempUnschedulable.reset') }} + {{ t('admin.accounts.recoverState') }}

@@ -154,7 +158,7 @@ const props = defineProps<{ const emit = defineEmits<{ close: [] - reset: [] + reset: [account: Account] }>() const { t } = useI18n() @@ -225,12 +229,12 @@ const handleReset = async () => { if (!props.account) return resetting.value = true try { - await adminAPI.accounts.resetTempUnschedulable(props.account.id) - appStore.showSuccess(t('admin.accounts.tempUnschedulable.resetSuccess')) - emit('reset') + const updated = await adminAPI.accounts.recoverState(props.account.id) + appStore.showSuccess(t('admin.accounts.recoverStateSuccess')) + emit('reset', updated) handleClose() } catch (error: any) { - appStore.showError(error?.message || t('admin.accounts.tempUnschedulable.resetFailed')) + appStore.showError(error?.message || t('admin.accounts.recoverStateFailed')) } finally { resetting.value = false } diff --git a/frontend/src/components/account/UsageProgressBar.vue b/frontend/src/components/account/UsageProgressBar.vue index 93844295..ea6c71f3 100644 --- a/frontend/src/components/account/UsageProgressBar.vue +++ b/frontend/src/components/account/UsageProgressBar.vue @@ -1,30 +1,5 @@
@@ -51,7 +55,7 @@ import { Icon } from '@/components/icons' import type { Account } from '@/types' const props = defineProps<{ show: boolean; account: Account | null; position: { top: number; left: number } | null }>() -const emit = defineEmits(['close', 'test', 'stats', 'reauth', 'refresh-token', 'reset-status', 'clear-rate-limit']) +const emit = defineEmits(['close', 'test', 'stats', 'schedule', 'reauth', 'refresh-token', 'recover-state', 'reset-status', 'clear-rate-limit', 'reset-quota']) const { t } = useI18n() const isRateLimited = computed(() => { if (props.account?.rate_limit_reset_at && new Date(props.account.rate_limit_reset_at) > new Date()) { @@ -67,6 +71,17 @@ const isRateLimited = computed(() => { return false }) const isOverloaded = computed(() => props.account?.overload_until && new Date(props.account.overload_until) > new Date()) +const isTempUnschedulable = computed(() => props.account?.temp_unschedulable_until && new Date(props.account.temp_unschedulable_until) > new Date()) +const hasRecoverableState = computed(() => { + return props.account?.status === 'error' || Boolean(isRateLimited.value) || Boolean(isOverloaded.value) || Boolean(isTempUnschedulable.value) +}) +const hasQuotaLimit = computed(() => { + return props.account?.type === 'apikey' && ( + (props.account?.quota_limit ?? 0) > 0 || + (props.account?.quota_daily_limit ?? 0) > 0 || + (props.account?.quota_weekly_limit ?? 0) > 0 + ) +}) const handleKeydown = (event: KeyboardEvent) => { if (event.key === 'Escape') emit('close') diff --git a/frontend/src/components/admin/account/AccountBulkActionsBar.vue b/frontend/src/components/admin/account/AccountBulkActionsBar.vue index 41111484..3b987bd0 100644 --- a/frontend/src/components/admin/account/AccountBulkActionsBar.vue +++ b/frontend/src/components/admin/account/AccountBulkActionsBar.vue @@ -20,6 +20,8 @@
+ + @@ -29,5 +31,5 @@ \ No newline at end of file +defineProps(['selectedIds']); defineEmits(['delete', 'edit', 'clear', 'select-page', 'toggle-schedulable', 'reset-status', 'refresh-token']); const { t } = useI18n() + diff --git a/frontend/src/components/admin/account/AccountTableFilters.vue b/frontend/src/components/admin/account/AccountTableFilters.vue index 5280e787..abffbaa2 100644 --- a/frontend/src/components/admin/account/AccountTableFilters.vue +++ b/frontend/src/components/admin/account/AccountTableFilters.vue @@ -25,6 +25,6 @@ const updateStatus = (value: string | number | boolean | null) => { emit('update const updateGroup = (value: string | number | boolean | null) => { emit('update:filters', { ...props.filters, group: value }) } const pOpts = computed(() => [{ value: '', label: t('admin.accounts.allPlatforms') }, { value: 'anthropic', label: 'Anthropic' }, { value: 'openai', label: 'OpenAI' }, { value: 'gemini', label: 'Gemini' }, { value: 'antigravity', label: 'Antigravity' }, { value: 'sora', label: 'Sora' }]) const tOpts = computed(() => [{ value: '', label: t('admin.accounts.allTypes') }, { value: 'oauth', label: t('admin.accounts.oauthType') }, { value: 'setup-token', label: t('admin.accounts.setupToken') }, { value: 'apikey', label: t('admin.accounts.apiKey') }]) -const sOpts = computed(() => [{ value: '', label: t('admin.accounts.allStatus') }, { value: 'active', label: t('admin.accounts.status.active') }, { value: 'inactive', label: t('admin.accounts.status.inactive') }, { value: 'error', label: t('admin.accounts.status.error') }, { value: 'rate_limited', label: t('admin.accounts.status.rateLimited') }]) +const sOpts = computed(() => [{ value: '', label: t('admin.accounts.allStatus') }, { value: 'active', label: t('admin.accounts.status.active') }, { value: 'inactive', label: t('admin.accounts.status.inactive') }, { value: 'error', label: t('admin.accounts.status.error') }, { value: 'rate_limited', label: t('admin.accounts.status.rateLimited') }, { value: 'temp_unschedulable', label: t('admin.accounts.status.tempUnschedulable') }]) const gOpts = computed(() => [{ value: '', label: t('admin.accounts.allGroups') }, ...(props.groups || []).map(g => ({ value: String(g.id), label: g.name }))]) diff --git a/frontend/src/components/admin/account/AccountTestModal.vue b/frontend/src/components/admin/account/AccountTestModal.vue index a25c25cc..e731a7b1 100644 --- a/frontend/src/components/admin/account/AccountTestModal.vue +++ b/frontend/src/components/admin/account/AccountTestModal.vue @@ -61,6 +61,17 @@ {{ t('admin.accounts.soraTestHint') }}
+
+