mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-04 07:22:13 +08:00
Compare commits
100 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a817cafe3d | ||
|
|
2857fa2ef7 | ||
|
|
e681431454 | ||
|
|
5b568aa9d4 | ||
|
|
471943269c | ||
|
|
28a5e2f0e6 | ||
|
|
b4c22ce6ce | ||
|
|
5248097f90 | ||
|
|
8e2c22d0bd | ||
|
|
be56a282f2 | ||
|
|
5f4eb9f9d0 | ||
|
|
d1cd5c0a73 | ||
|
|
5429c74c10 | ||
|
|
fe1d46a8ea | ||
|
|
c7b42148a5 | ||
|
|
bc1abb6a23 | ||
|
|
d307d48def | ||
|
|
1bb40084fc | ||
|
|
8f0efa16ca | ||
|
|
ef2c35dbb1 | ||
|
|
04a1a7c2b5 | ||
|
|
d21d70a5cf | ||
|
|
e73b778d2b | ||
|
|
723102766b | ||
|
|
a4a46a8618 | ||
|
|
6ae82e04d5 | ||
|
|
19cca11e00 | ||
|
|
c8f87a9c92 | ||
|
|
ae6fed15cc | ||
|
|
378e476e48 | ||
|
|
2a1067c82b | ||
|
|
a54b81cf74 | ||
|
|
2d4236f76e | ||
|
|
84ced1c497 | ||
|
|
b161312183 | ||
|
|
1f647b120a | ||
|
|
7d0a30fa8f | ||
|
|
d95e04fd1f | ||
|
|
5dd83d3cf2 | ||
|
|
14e1aac9b5 | ||
|
|
6114f69cca | ||
|
|
d6c2921f2b | ||
|
|
61c73287dc | ||
|
|
89905ec43d | ||
|
|
aa4b102108 | ||
|
|
e4bc35151f | ||
|
|
56da498b7e | ||
|
|
1bba1a62b1 | ||
|
|
4a84ca9a02 | ||
|
|
a70d37a676 | ||
|
|
6892e84ad2 | ||
|
|
73f455745c | ||
|
|
021abfca18 | ||
|
|
7d66f7ff0d | ||
|
|
470b37be7e | ||
|
|
f6cfab9901 | ||
|
|
51572b5da0 | ||
|
|
91ca28b7e3 | ||
|
|
04cedce9a1 | ||
|
|
5e0d789440 | ||
|
|
149e4267cd | ||
|
|
9a479d1b55 | ||
|
|
fc095bf054 | ||
|
|
1af06aed96 | ||
|
|
9236936a55 | ||
|
|
125152460f | ||
|
|
6d90fb0bc3 | ||
|
|
b889d5017b | ||
|
|
72b08f9cc5 | ||
|
|
681950dadd | ||
|
|
a67d9337b8 | ||
|
|
2f1182e8a9 | ||
|
|
cbb4d854ab | ||
|
|
35598d5648 | ||
|
|
5c76b9e45a | ||
|
|
0b8fea4cb4 | ||
|
|
5fa93ebdc7 | ||
|
|
8aa0aed566 | ||
|
|
2eb32a0ed7 | ||
|
|
bac9e2bfd5 | ||
|
|
e4d74ae11d | ||
|
|
8a0a8558cf | ||
|
|
2185a3b674 | ||
|
|
9e3c306a5b | ||
|
|
b1c30df8e3 | ||
|
|
69816f8691 | ||
|
|
b4ec65785d | ||
|
|
3c93644146 | ||
|
|
fb58560d15 | ||
|
|
6ab77f5eb5 | ||
|
|
4f57d7f761 | ||
|
|
1563bd3dda | ||
|
|
df3346387f | ||
|
|
77b66653ed | ||
|
|
3077fd279d | ||
|
|
e3748da860 | ||
|
|
36e6fb5fc8 | ||
|
|
86b503f87f | ||
|
|
50a783ff01 | ||
|
|
e1a68497d6 |
2
.github/workflows/security-scan.yml
vendored
2
.github/workflows/security-scan.yml
vendored
@@ -32,7 +32,7 @@ jobs:
|
||||
working-directory: backend
|
||||
run: |
|
||||
go install github.com/securego/gosec/v2/cmd/gosec@latest
|
||||
gosec -severity high -confidence high ./...
|
||||
gosec -conf .gosec.json -severity high -confidence high ./...
|
||||
|
||||
frontend-security:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
323
DEV_GUIDE.md
Normal file
323
DEV_GUIDE.md
Normal file
@@ -0,0 +1,323 @@
|
||||
# sub2api 项目开发指南
|
||||
|
||||
> 本文档记录项目环境配置、常见坑点和注意事项,供 Claude Code 和团队成员参考。
|
||||
|
||||
## 一、项目基本信息
|
||||
|
||||
| 项目 | 说明 |
|
||||
|------|------|
|
||||
| **上游仓库** | Wei-Shaw/sub2api |
|
||||
| **Fork 仓库** | bayma888/sub2api-bmai |
|
||||
| **技术栈** | Go 后端 (Ent ORM + Gin) + Vue3 前端 (pnpm) |
|
||||
| **数据库** | PostgreSQL 16 + Redis |
|
||||
| **包管理** | 后端: go modules, 前端: **pnpm**(不是 npm) |
|
||||
|
||||
## 二、本地环境配置
|
||||
|
||||
### PostgreSQL 16 (Windows 服务)
|
||||
|
||||
| 配置项 | 值 |
|
||||
|--------|-----|
|
||||
| 端口 | 5432 |
|
||||
| psql 路径 | `C:\Program Files\PostgreSQL\16\bin\psql.exe` |
|
||||
| pg_hba.conf | `C:\Program Files\PostgreSQL\16\data\pg_hba.conf` |
|
||||
| 数据库凭据 | user=`sub2api`, password=`sub2api`, dbname=`sub2api` |
|
||||
| 超级用户 | user=`postgres`, password=`postgres` |
|
||||
|
||||
### Redis
|
||||
|
||||
| 配置项 | 值 |
|
||||
|--------|-----|
|
||||
| 端口 | 6379 |
|
||||
| 密码 | 无 |
|
||||
|
||||
### 开发工具
|
||||
|
||||
```bash
|
||||
# golangci-lint v2.7
|
||||
go install github.com/golangci/golangci-lint/v2/cmd/golangci-lint@v2.7
|
||||
|
||||
# pnpm (前端包管理)
|
||||
npm install -g pnpm
|
||||
```
|
||||
|
||||
## 三、CI/CD 流水线
|
||||
|
||||
### GitHub Actions Workflows
|
||||
|
||||
| Workflow | 触发条件 | 检查内容 |
|
||||
|----------|----------|----------|
|
||||
| **backend-ci.yml** | push, pull_request | 单元测试 + 集成测试 + golangci-lint v2.7 |
|
||||
| **security-scan.yml** | push, pull_request, 每周一 | govulncheck + gosec + pnpm audit |
|
||||
| **release.yml** | tag `v*` | 构建发布(PR 不触发) |
|
||||
|
||||
### CI 要求
|
||||
|
||||
- Go 版本必须是 **1.25.7**
|
||||
- 前端使用 `pnpm install --frozen-lockfile`,必须提交 `pnpm-lock.yaml`
|
||||
|
||||
### 本地测试命令
|
||||
|
||||
```bash
|
||||
# 后端单元测试
|
||||
cd backend && go test -tags=unit ./...
|
||||
|
||||
# 后端集成测试
|
||||
cd backend && go test -tags=integration ./...
|
||||
|
||||
# 代码质量检查
|
||||
cd backend && golangci-lint run ./...
|
||||
|
||||
# 前端依赖安装(必须用 pnpm)
|
||||
cd frontend && pnpm install
|
||||
```
|
||||
|
||||
## 四、常见坑点 & 解决方案
|
||||
|
||||
### 坑 1:pnpm-lock.yaml 必须同步提交
|
||||
|
||||
**问题**:`package.json` 新增依赖后,CI 的 `pnpm install --frozen-lockfile` 失败。
|
||||
|
||||
**原因**:上游 CI 使用 pnpm,lock 文件不同步会报错。
|
||||
|
||||
**解决**:
|
||||
```bash
|
||||
cd frontend
|
||||
pnpm install # 更新 pnpm-lock.yaml
|
||||
git add pnpm-lock.yaml
|
||||
git commit -m "chore: update pnpm-lock.yaml"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 坑 2:npm 和 pnpm 的 node_modules 冲突
|
||||
|
||||
**问题**:之前用 npm 装过 `node_modules`,pnpm install 报 `EPERM` 错误。
|
||||
|
||||
**解决**:
|
||||
```bash
|
||||
cd frontend
|
||||
rm -rf node_modules # 或 PowerShell: Remove-Item -Recurse -Force node_modules
|
||||
pnpm install
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 坑 3:PowerShell 中 bcrypt hash 的 `$` 被转义
|
||||
|
||||
**问题**:bcrypt hash 格式如 `$2a$10$xxx...`,PowerShell 把 `$2a` 当变量解析,导致数据丢失。
|
||||
|
||||
**解决**:将 SQL 写入文件,用 `psql -f` 执行:
|
||||
```bash
|
||||
# 错误示范(PowerShell 会吃掉 $)
|
||||
psql -c "INSERT INTO users ... VALUES ('$2a$10$...')"
|
||||
|
||||
# 正确做法
|
||||
echo "INSERT INTO users ... VALUES ('\$2a\$10\$...')" > temp.sql
|
||||
psql -U sub2api -h 127.0.0.1 -d sub2api -f temp.sql
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 坑 4:psql 不支持中文路径
|
||||
|
||||
**问题**:`psql -f "D:\中文路径\file.sql"` 报错找不到文件。
|
||||
|
||||
**解决**:复制到纯英文路径再执行:
|
||||
```bash
|
||||
cp "D:\中文路径\file.sql" "C:\temp.sql"
|
||||
psql -f "C:\temp.sql"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 坑 5:PostgreSQL 密码重置流程
|
||||
|
||||
**场景**:忘记 PostgreSQL 密码。
|
||||
|
||||
**步骤**:
|
||||
1. 修改 `C:\Program Files\PostgreSQL\16\data\pg_hba.conf`
|
||||
```
|
||||
# 将 scram-sha-256 改为 trust
|
||||
host all all 127.0.0.1/32 trust
|
||||
```
|
||||
2. 重启 PostgreSQL 服务
|
||||
```powershell
|
||||
Restart-Service postgresql-x64-16
|
||||
```
|
||||
3. 无密码登录并重置
|
||||
```bash
|
||||
psql -U postgres -h 127.0.0.1
|
||||
ALTER USER sub2api WITH PASSWORD 'sub2api';
|
||||
ALTER USER postgres WITH PASSWORD 'postgres';
|
||||
```
|
||||
4. 改回 `scram-sha-256` 并重启
|
||||
|
||||
---
|
||||
|
||||
### 坑 6:Go interface 新增方法后 test stub 必须补全
|
||||
|
||||
**问题**:给 interface 新增方法后,编译报错 `does not implement interface (missing method XXX)`。
|
||||
|
||||
**原因**:所有测试文件中实现该 interface 的 stub/mock 都必须补上新方法。
|
||||
|
||||
**解决**:
|
||||
```bash
|
||||
# 搜索所有实现该 interface 的 struct
|
||||
cd backend
|
||||
grep -r "type.*Stub.*struct" internal/
|
||||
grep -r "type.*Mock.*struct" internal/
|
||||
|
||||
# 逐一补全新方法
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 坑 7:Windows 上 psql 连 localhost 的 IPv6 问题
|
||||
|
||||
**问题**:psql 连 `localhost` 先尝试 IPv6 (::1),可能报错后再回退 IPv4。
|
||||
|
||||
**建议**:直接用 `127.0.0.1` 代替 `localhost`。
|
||||
|
||||
---
|
||||
|
||||
### 坑 8:Windows 没有 make 命令
|
||||
|
||||
**问题**:CI 里用 `make test-unit`,本地 Windows 没有 make。
|
||||
|
||||
**解决**:直接用 Makefile 里的原始命令:
|
||||
```bash
|
||||
# 代替 make test-unit
|
||||
go test -tags=unit ./...
|
||||
|
||||
# 代替 make test-integration
|
||||
go test -tags=integration ./...
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 坑 9:Ent Schema 修改后必须重新生成
|
||||
|
||||
**问题**:修改 `ent/schema/*.go` 后,代码不生效。
|
||||
|
||||
**解决**:
|
||||
```bash
|
||||
cd backend
|
||||
go generate ./ent # 重新生成 ent 代码
|
||||
git add ent/ # 生成的文件也要提交
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 坑 10:PR 提交前检查清单
|
||||
|
||||
提交 PR 前务必本地验证:
|
||||
|
||||
- [ ] `go test -tags=unit ./...` 通过
|
||||
- [ ] `go test -tags=integration ./...` 通过
|
||||
- [ ] `golangci-lint run ./...` 无新增问题
|
||||
- [ ] `pnpm-lock.yaml` 已同步(如果改了 package.json)
|
||||
- [ ] 所有 test stub 补全新接口方法(如果改了 interface)
|
||||
- [ ] Ent 生成的代码已提交(如果改了 schema)
|
||||
|
||||
## 五、常用命令速查
|
||||
|
||||
### 数据库操作
|
||||
|
||||
```bash
|
||||
# 连接数据库
|
||||
psql -U sub2api -h 127.0.0.1 -d sub2api
|
||||
|
||||
# 查看所有用户
|
||||
psql -U postgres -h 127.0.0.1 -c "\du"
|
||||
|
||||
# 查看所有数据库
|
||||
psql -U postgres -h 127.0.0.1 -c "\l"
|
||||
|
||||
# 执行 SQL 文件
|
||||
psql -U sub2api -h 127.0.0.1 -d sub2api -f migration.sql
|
||||
```
|
||||
|
||||
### Git 操作
|
||||
|
||||
```bash
|
||||
# 同步上游
|
||||
git fetch upstream
|
||||
git checkout main
|
||||
git merge upstream/main
|
||||
git push origin main
|
||||
|
||||
# 创建功能分支
|
||||
git checkout -b feature/xxx
|
||||
|
||||
# Rebase 到最新 main
|
||||
git fetch upstream
|
||||
git rebase upstream/main
|
||||
```
|
||||
|
||||
### 前端操作
|
||||
|
||||
```bash
|
||||
# 安装依赖(必须用 pnpm)
|
||||
cd frontend
|
||||
pnpm install
|
||||
|
||||
# 开发服务器
|
||||
pnpm dev
|
||||
|
||||
# 构建
|
||||
pnpm build
|
||||
```
|
||||
|
||||
### 后端操作
|
||||
|
||||
```bash
|
||||
# 运行服务器
|
||||
cd backend
|
||||
go run ./cmd/server/
|
||||
|
||||
# 生成 Ent 代码
|
||||
go generate ./ent
|
||||
|
||||
# 运行测试
|
||||
go test -tags=unit ./...
|
||||
go test -tags=integration ./...
|
||||
|
||||
# Lint 检查
|
||||
golangci-lint run ./...
|
||||
```
|
||||
|
||||
## 六、项目结构速览
|
||||
|
||||
```
|
||||
sub2api-bmai/
|
||||
├── backend/
|
||||
│ ├── cmd/server/ # 主程序入口
|
||||
│ ├── ent/ # Ent ORM 生成代码
|
||||
│ │ └── schema/ # 数据库 Schema 定义
|
||||
│ ├── internal/
|
||||
│ │ ├── handler/ # HTTP 处理器
|
||||
│ │ ├── service/ # 业务逻辑
|
||||
│ │ ├── repository/ # 数据访问层
|
||||
│ │ └── server/ # 服务器配置
|
||||
│ ├── migrations/ # 数据库迁移脚本
|
||||
│ └── config.yaml # 配置文件
|
||||
├── frontend/
|
||||
│ ├── src/
|
||||
│ │ ├── api/ # API 调用
|
||||
│ │ ├── components/ # Vue 组件
|
||||
│ │ ├── views/ # 页面视图
|
||||
│ │ ├── types/ # TypeScript 类型
|
||||
│ │ └── i18n/ # 国际化
|
||||
│ ├── package.json # 依赖配置
|
||||
│ └── pnpm-lock.yaml # pnpm 锁文件(必须提交)
|
||||
└── .claude/
|
||||
└── CLAUDE.md # 本文档
|
||||
```
|
||||
|
||||
## 七、参考资源
|
||||
|
||||
- [上游仓库](https://github.com/Wei-Shaw/sub2api)
|
||||
- [Ent 文档](https://entgo.io/docs/getting-started)
|
||||
- [Vue3 文档](https://vuejs.org/)
|
||||
- [pnpm 文档](https://pnpm.io/)
|
||||
5
backend/.gosec.json
Normal file
5
backend/.gosec.json
Normal file
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"global": {
|
||||
"exclude": "G704"
|
||||
}
|
||||
}
|
||||
@@ -1 +1 @@
|
||||
0.1.70
|
||||
0.1.76
|
||||
@@ -102,7 +102,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
|
||||
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
|
||||
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator)
|
||||
adminUserHandler := admin.NewUserHandler(adminService)
|
||||
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
|
||||
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
|
||||
adminUserHandler := admin.NewUserHandler(adminService, concurrencyService)
|
||||
groupHandler := admin.NewGroupHandler(adminService)
|
||||
claudeOAuthClient := repository.NewClaudeOAuthClient()
|
||||
oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient)
|
||||
@@ -126,13 +128,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache)
|
||||
geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService)
|
||||
gatewayCache := repository.NewGatewayCache(redisClient)
|
||||
antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService)
|
||||
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
|
||||
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
|
||||
antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService)
|
||||
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService)
|
||||
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig)
|
||||
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
|
||||
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
|
||||
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
|
||||
sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig)
|
||||
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, compositeTokenCacheInvalidator)
|
||||
@@ -154,7 +154,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
identityService := service.NewIdentityService(identityCache)
|
||||
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
|
||||
claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService)
|
||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache)
|
||||
digestSessionStore := service.NewDigestSessionStore()
|
||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, digestSessionStore)
|
||||
openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService)
|
||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
|
||||
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
||||
|
||||
@@ -44,6 +44,8 @@ type ErrorPassthroughRule struct {
|
||||
PassthroughBody bool `json:"passthrough_body,omitempty"`
|
||||
// CustomMessage holds the value of the "custom_message" field.
|
||||
CustomMessage *string `json:"custom_message,omitempty"`
|
||||
// SkipMonitoring holds the value of the "skip_monitoring" field.
|
||||
SkipMonitoring bool `json:"skip_monitoring,omitempty"`
|
||||
// Description holds the value of the "description" field.
|
||||
Description *string `json:"description,omitempty"`
|
||||
selectValues sql.SelectValues
|
||||
@@ -56,7 +58,7 @@ func (*ErrorPassthroughRule) scanValues(columns []string) ([]any, error) {
|
||||
switch columns[i] {
|
||||
case errorpassthroughrule.FieldErrorCodes, errorpassthroughrule.FieldKeywords, errorpassthroughrule.FieldPlatforms:
|
||||
values[i] = new([]byte)
|
||||
case errorpassthroughrule.FieldEnabled, errorpassthroughrule.FieldPassthroughCode, errorpassthroughrule.FieldPassthroughBody:
|
||||
case errorpassthroughrule.FieldEnabled, errorpassthroughrule.FieldPassthroughCode, errorpassthroughrule.FieldPassthroughBody, errorpassthroughrule.FieldSkipMonitoring:
|
||||
values[i] = new(sql.NullBool)
|
||||
case errorpassthroughrule.FieldID, errorpassthroughrule.FieldPriority, errorpassthroughrule.FieldResponseCode:
|
||||
values[i] = new(sql.NullInt64)
|
||||
@@ -171,6 +173,12 @@ func (_m *ErrorPassthroughRule) assignValues(columns []string, values []any) err
|
||||
_m.CustomMessage = new(string)
|
||||
*_m.CustomMessage = value.String
|
||||
}
|
||||
case errorpassthroughrule.FieldSkipMonitoring:
|
||||
if value, ok := values[i].(*sql.NullBool); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field skip_monitoring", values[i])
|
||||
} else if value.Valid {
|
||||
_m.SkipMonitoring = value.Bool
|
||||
}
|
||||
case errorpassthroughrule.FieldDescription:
|
||||
if value, ok := values[i].(*sql.NullString); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field description", values[i])
|
||||
@@ -257,6 +265,9 @@ func (_m *ErrorPassthroughRule) String() string {
|
||||
builder.WriteString(*v)
|
||||
}
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("skip_monitoring=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.SkipMonitoring))
|
||||
builder.WriteString(", ")
|
||||
if v := _m.Description; v != nil {
|
||||
builder.WriteString("description=")
|
||||
builder.WriteString(*v)
|
||||
|
||||
@@ -39,6 +39,8 @@ const (
|
||||
FieldPassthroughBody = "passthrough_body"
|
||||
// FieldCustomMessage holds the string denoting the custom_message field in the database.
|
||||
FieldCustomMessage = "custom_message"
|
||||
// FieldSkipMonitoring holds the string denoting the skip_monitoring field in the database.
|
||||
FieldSkipMonitoring = "skip_monitoring"
|
||||
// FieldDescription holds the string denoting the description field in the database.
|
||||
FieldDescription = "description"
|
||||
// Table holds the table name of the errorpassthroughrule in the database.
|
||||
@@ -61,6 +63,7 @@ var Columns = []string{
|
||||
FieldResponseCode,
|
||||
FieldPassthroughBody,
|
||||
FieldCustomMessage,
|
||||
FieldSkipMonitoring,
|
||||
FieldDescription,
|
||||
}
|
||||
|
||||
@@ -95,6 +98,8 @@ var (
|
||||
DefaultPassthroughCode bool
|
||||
// DefaultPassthroughBody holds the default value on creation for the "passthrough_body" field.
|
||||
DefaultPassthroughBody bool
|
||||
// DefaultSkipMonitoring holds the default value on creation for the "skip_monitoring" field.
|
||||
DefaultSkipMonitoring bool
|
||||
)
|
||||
|
||||
// OrderOption defines the ordering options for the ErrorPassthroughRule queries.
|
||||
@@ -155,6 +160,11 @@ func ByCustomMessage(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldCustomMessage, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// BySkipMonitoring orders the results by the skip_monitoring field.
|
||||
func BySkipMonitoring(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldSkipMonitoring, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByDescription orders the results by the description field.
|
||||
func ByDescription(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldDescription, opts...).ToFunc()
|
||||
|
||||
@@ -104,6 +104,11 @@ func CustomMessage(v string) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldCustomMessage, v))
|
||||
}
|
||||
|
||||
// SkipMonitoring applies equality check predicate on the "skip_monitoring" field. It's identical to SkipMonitoringEQ.
|
||||
func SkipMonitoring(v bool) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldSkipMonitoring, v))
|
||||
}
|
||||
|
||||
// Description applies equality check predicate on the "description" field. It's identical to DescriptionEQ.
|
||||
func Description(v string) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldDescription, v))
|
||||
@@ -544,6 +549,16 @@ func CustomMessageContainsFold(v string) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldContainsFold(FieldCustomMessage, v))
|
||||
}
|
||||
|
||||
// SkipMonitoringEQ applies the EQ predicate on the "skip_monitoring" field.
|
||||
func SkipMonitoringEQ(v bool) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldSkipMonitoring, v))
|
||||
}
|
||||
|
||||
// SkipMonitoringNEQ applies the NEQ predicate on the "skip_monitoring" field.
|
||||
func SkipMonitoringNEQ(v bool) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldNEQ(FieldSkipMonitoring, v))
|
||||
}
|
||||
|
||||
// DescriptionEQ applies the EQ predicate on the "description" field.
|
||||
func DescriptionEQ(v string) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldDescription, v))
|
||||
|
||||
@@ -172,6 +172,20 @@ func (_c *ErrorPassthroughRuleCreate) SetNillableCustomMessage(v *string) *Error
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetSkipMonitoring sets the "skip_monitoring" field.
|
||||
func (_c *ErrorPassthroughRuleCreate) SetSkipMonitoring(v bool) *ErrorPassthroughRuleCreate {
|
||||
_c.mutation.SetSkipMonitoring(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableSkipMonitoring sets the "skip_monitoring" field if the given value is not nil.
|
||||
func (_c *ErrorPassthroughRuleCreate) SetNillableSkipMonitoring(v *bool) *ErrorPassthroughRuleCreate {
|
||||
if v != nil {
|
||||
_c.SetSkipMonitoring(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetDescription sets the "description" field.
|
||||
func (_c *ErrorPassthroughRuleCreate) SetDescription(v string) *ErrorPassthroughRuleCreate {
|
||||
_c.mutation.SetDescription(v)
|
||||
@@ -249,6 +263,10 @@ func (_c *ErrorPassthroughRuleCreate) defaults() {
|
||||
v := errorpassthroughrule.DefaultPassthroughBody
|
||||
_c.mutation.SetPassthroughBody(v)
|
||||
}
|
||||
if _, ok := _c.mutation.SkipMonitoring(); !ok {
|
||||
v := errorpassthroughrule.DefaultSkipMonitoring
|
||||
_c.mutation.SetSkipMonitoring(v)
|
||||
}
|
||||
}
|
||||
|
||||
// check runs all checks and user-defined validators on the builder.
|
||||
@@ -287,6 +305,9 @@ func (_c *ErrorPassthroughRuleCreate) check() error {
|
||||
if _, ok := _c.mutation.PassthroughBody(); !ok {
|
||||
return &ValidationError{Name: "passthrough_body", err: errors.New(`ent: missing required field "ErrorPassthroughRule.passthrough_body"`)}
|
||||
}
|
||||
if _, ok := _c.mutation.SkipMonitoring(); !ok {
|
||||
return &ValidationError{Name: "skip_monitoring", err: errors.New(`ent: missing required field "ErrorPassthroughRule.skip_monitoring"`)}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -366,6 +387,10 @@ func (_c *ErrorPassthroughRuleCreate) createSpec() (*ErrorPassthroughRule, *sqlg
|
||||
_spec.SetField(errorpassthroughrule.FieldCustomMessage, field.TypeString, value)
|
||||
_node.CustomMessage = &value
|
||||
}
|
||||
if value, ok := _c.mutation.SkipMonitoring(); ok {
|
||||
_spec.SetField(errorpassthroughrule.FieldSkipMonitoring, field.TypeBool, value)
|
||||
_node.SkipMonitoring = value
|
||||
}
|
||||
if value, ok := _c.mutation.Description(); ok {
|
||||
_spec.SetField(errorpassthroughrule.FieldDescription, field.TypeString, value)
|
||||
_node.Description = &value
|
||||
@@ -608,6 +633,18 @@ func (u *ErrorPassthroughRuleUpsert) ClearCustomMessage() *ErrorPassthroughRuleU
|
||||
return u
|
||||
}
|
||||
|
||||
// SetSkipMonitoring sets the "skip_monitoring" field.
|
||||
func (u *ErrorPassthroughRuleUpsert) SetSkipMonitoring(v bool) *ErrorPassthroughRuleUpsert {
|
||||
u.Set(errorpassthroughrule.FieldSkipMonitoring, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateSkipMonitoring sets the "skip_monitoring" field to the value that was provided on create.
|
||||
func (u *ErrorPassthroughRuleUpsert) UpdateSkipMonitoring() *ErrorPassthroughRuleUpsert {
|
||||
u.SetExcluded(errorpassthroughrule.FieldSkipMonitoring)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetDescription sets the "description" field.
|
||||
func (u *ErrorPassthroughRuleUpsert) SetDescription(v string) *ErrorPassthroughRuleUpsert {
|
||||
u.Set(errorpassthroughrule.FieldDescription, v)
|
||||
@@ -888,6 +925,20 @@ func (u *ErrorPassthroughRuleUpsertOne) ClearCustomMessage() *ErrorPassthroughRu
|
||||
})
|
||||
}
|
||||
|
||||
// SetSkipMonitoring sets the "skip_monitoring" field.
|
||||
func (u *ErrorPassthroughRuleUpsertOne) SetSkipMonitoring(v bool) *ErrorPassthroughRuleUpsertOne {
|
||||
return u.Update(func(s *ErrorPassthroughRuleUpsert) {
|
||||
s.SetSkipMonitoring(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateSkipMonitoring sets the "skip_monitoring" field to the value that was provided on create.
|
||||
func (u *ErrorPassthroughRuleUpsertOne) UpdateSkipMonitoring() *ErrorPassthroughRuleUpsertOne {
|
||||
return u.Update(func(s *ErrorPassthroughRuleUpsert) {
|
||||
s.UpdateSkipMonitoring()
|
||||
})
|
||||
}
|
||||
|
||||
// SetDescription sets the "description" field.
|
||||
func (u *ErrorPassthroughRuleUpsertOne) SetDescription(v string) *ErrorPassthroughRuleUpsertOne {
|
||||
return u.Update(func(s *ErrorPassthroughRuleUpsert) {
|
||||
@@ -1337,6 +1388,20 @@ func (u *ErrorPassthroughRuleUpsertBulk) ClearCustomMessage() *ErrorPassthroughR
|
||||
})
|
||||
}
|
||||
|
||||
// SetSkipMonitoring sets the "skip_monitoring" field.
|
||||
func (u *ErrorPassthroughRuleUpsertBulk) SetSkipMonitoring(v bool) *ErrorPassthroughRuleUpsertBulk {
|
||||
return u.Update(func(s *ErrorPassthroughRuleUpsert) {
|
||||
s.SetSkipMonitoring(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateSkipMonitoring sets the "skip_monitoring" field to the value that was provided on create.
|
||||
func (u *ErrorPassthroughRuleUpsertBulk) UpdateSkipMonitoring() *ErrorPassthroughRuleUpsertBulk {
|
||||
return u.Update(func(s *ErrorPassthroughRuleUpsert) {
|
||||
s.UpdateSkipMonitoring()
|
||||
})
|
||||
}
|
||||
|
||||
// SetDescription sets the "description" field.
|
||||
func (u *ErrorPassthroughRuleUpsertBulk) SetDescription(v string) *ErrorPassthroughRuleUpsertBulk {
|
||||
return u.Update(func(s *ErrorPassthroughRuleUpsert) {
|
||||
|
||||
@@ -227,6 +227,20 @@ func (_u *ErrorPassthroughRuleUpdate) ClearCustomMessage() *ErrorPassthroughRule
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetSkipMonitoring sets the "skip_monitoring" field.
|
||||
func (_u *ErrorPassthroughRuleUpdate) SetSkipMonitoring(v bool) *ErrorPassthroughRuleUpdate {
|
||||
_u.mutation.SetSkipMonitoring(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableSkipMonitoring sets the "skip_monitoring" field if the given value is not nil.
|
||||
func (_u *ErrorPassthroughRuleUpdate) SetNillableSkipMonitoring(v *bool) *ErrorPassthroughRuleUpdate {
|
||||
if v != nil {
|
||||
_u.SetSkipMonitoring(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetDescription sets the "description" field.
|
||||
func (_u *ErrorPassthroughRuleUpdate) SetDescription(v string) *ErrorPassthroughRuleUpdate {
|
||||
_u.mutation.SetDescription(v)
|
||||
@@ -387,6 +401,9 @@ func (_u *ErrorPassthroughRuleUpdate) sqlSave(ctx context.Context) (_node int, e
|
||||
if _u.mutation.CustomMessageCleared() {
|
||||
_spec.ClearField(errorpassthroughrule.FieldCustomMessage, field.TypeString)
|
||||
}
|
||||
if value, ok := _u.mutation.SkipMonitoring(); ok {
|
||||
_spec.SetField(errorpassthroughrule.FieldSkipMonitoring, field.TypeBool, value)
|
||||
}
|
||||
if value, ok := _u.mutation.Description(); ok {
|
||||
_spec.SetField(errorpassthroughrule.FieldDescription, field.TypeString, value)
|
||||
}
|
||||
@@ -611,6 +628,20 @@ func (_u *ErrorPassthroughRuleUpdateOne) ClearCustomMessage() *ErrorPassthroughR
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetSkipMonitoring sets the "skip_monitoring" field.
|
||||
func (_u *ErrorPassthroughRuleUpdateOne) SetSkipMonitoring(v bool) *ErrorPassthroughRuleUpdateOne {
|
||||
_u.mutation.SetSkipMonitoring(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableSkipMonitoring sets the "skip_monitoring" field if the given value is not nil.
|
||||
func (_u *ErrorPassthroughRuleUpdateOne) SetNillableSkipMonitoring(v *bool) *ErrorPassthroughRuleUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetSkipMonitoring(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetDescription sets the "description" field.
|
||||
func (_u *ErrorPassthroughRuleUpdateOne) SetDescription(v string) *ErrorPassthroughRuleUpdateOne {
|
||||
_u.mutation.SetDescription(v)
|
||||
@@ -801,6 +832,9 @@ func (_u *ErrorPassthroughRuleUpdateOne) sqlSave(ctx context.Context) (_node *Er
|
||||
if _u.mutation.CustomMessageCleared() {
|
||||
_spec.ClearField(errorpassthroughrule.FieldCustomMessage, field.TypeString)
|
||||
}
|
||||
if value, ok := _u.mutation.SkipMonitoring(); ok {
|
||||
_spec.SetField(errorpassthroughrule.FieldSkipMonitoring, field.TypeBool, value)
|
||||
}
|
||||
if value, ok := _u.mutation.Description(); ok {
|
||||
_spec.SetField(errorpassthroughrule.FieldDescription, field.TypeString, value)
|
||||
}
|
||||
|
||||
@@ -66,6 +66,8 @@ type Group struct {
|
||||
McpXMLInject bool `json:"mcp_xml_inject,omitempty"`
|
||||
// 支持的模型系列:claude, gemini_text, gemini_image
|
||||
SupportedModelScopes []string `json:"supported_model_scopes,omitempty"`
|
||||
// 分组显示排序,数值越小越靠前
|
||||
SortOrder int `json:"sort_order,omitempty"`
|
||||
// Edges holds the relations/edges for other nodes in the graph.
|
||||
// The values are being populated by the GroupQuery when eager-loading is set.
|
||||
Edges GroupEdges `json:"edges"`
|
||||
@@ -178,7 +180,7 @@ func (*Group) scanValues(columns []string) ([]any, error) {
|
||||
values[i] = new(sql.NullBool)
|
||||
case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k:
|
||||
values[i] = new(sql.NullFloat64)
|
||||
case group.FieldID, group.FieldDefaultValidityDays, group.FieldFallbackGroupID, group.FieldFallbackGroupIDOnInvalidRequest:
|
||||
case group.FieldID, group.FieldDefaultValidityDays, group.FieldFallbackGroupID, group.FieldFallbackGroupIDOnInvalidRequest, group.FieldSortOrder:
|
||||
values[i] = new(sql.NullInt64)
|
||||
case group.FieldName, group.FieldDescription, group.FieldStatus, group.FieldPlatform, group.FieldSubscriptionType:
|
||||
values[i] = new(sql.NullString)
|
||||
@@ -363,6 +365,12 @@ func (_m *Group) assignValues(columns []string, values []any) error {
|
||||
return fmt.Errorf("unmarshal field supported_model_scopes: %w", err)
|
||||
}
|
||||
}
|
||||
case group.FieldSortOrder:
|
||||
if value, ok := values[i].(*sql.NullInt64); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field sort_order", values[i])
|
||||
} else if value.Valid {
|
||||
_m.SortOrder = int(value.Int64)
|
||||
}
|
||||
default:
|
||||
_m.selectValues.Set(columns[i], values[i])
|
||||
}
|
||||
@@ -530,6 +538,9 @@ func (_m *Group) String() string {
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("supported_model_scopes=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.SupportedModelScopes))
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("sort_order=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.SortOrder))
|
||||
builder.WriteByte(')')
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
@@ -63,6 +63,8 @@ const (
|
||||
FieldMcpXMLInject = "mcp_xml_inject"
|
||||
// FieldSupportedModelScopes holds the string denoting the supported_model_scopes field in the database.
|
||||
FieldSupportedModelScopes = "supported_model_scopes"
|
||||
// FieldSortOrder holds the string denoting the sort_order field in the database.
|
||||
FieldSortOrder = "sort_order"
|
||||
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
|
||||
EdgeAPIKeys = "api_keys"
|
||||
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
|
||||
@@ -162,6 +164,7 @@ var Columns = []string{
|
||||
FieldModelRoutingEnabled,
|
||||
FieldMcpXMLInject,
|
||||
FieldSupportedModelScopes,
|
||||
FieldSortOrder,
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -225,6 +228,8 @@ var (
|
||||
DefaultMcpXMLInject bool
|
||||
// DefaultSupportedModelScopes holds the default value on creation for the "supported_model_scopes" field.
|
||||
DefaultSupportedModelScopes []string
|
||||
// DefaultSortOrder holds the default value on creation for the "sort_order" field.
|
||||
DefaultSortOrder int
|
||||
)
|
||||
|
||||
// OrderOption defines the ordering options for the Group queries.
|
||||
@@ -345,6 +350,11 @@ func ByMcpXMLInject(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldMcpXMLInject, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// BySortOrder orders the results by the sort_order field.
|
||||
func BySortOrder(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldSortOrder, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByAPIKeysCount orders the results by api_keys count.
|
||||
func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption {
|
||||
return func(s *sql.Selector) {
|
||||
|
||||
@@ -165,6 +165,11 @@ func McpXMLInject(v bool) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldMcpXMLInject, v))
|
||||
}
|
||||
|
||||
// SortOrder applies equality check predicate on the "sort_order" field. It's identical to SortOrderEQ.
|
||||
func SortOrder(v int) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldSortOrder, v))
|
||||
}
|
||||
|
||||
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
|
||||
func CreatedAtEQ(v time.Time) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldCreatedAt, v))
|
||||
@@ -1160,6 +1165,46 @@ func McpXMLInjectNEQ(v bool) predicate.Group {
|
||||
return predicate.Group(sql.FieldNEQ(FieldMcpXMLInject, v))
|
||||
}
|
||||
|
||||
// SortOrderEQ applies the EQ predicate on the "sort_order" field.
|
||||
func SortOrderEQ(v int) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldSortOrder, v))
|
||||
}
|
||||
|
||||
// SortOrderNEQ applies the NEQ predicate on the "sort_order" field.
|
||||
func SortOrderNEQ(v int) predicate.Group {
|
||||
return predicate.Group(sql.FieldNEQ(FieldSortOrder, v))
|
||||
}
|
||||
|
||||
// SortOrderIn applies the In predicate on the "sort_order" field.
|
||||
func SortOrderIn(vs ...int) predicate.Group {
|
||||
return predicate.Group(sql.FieldIn(FieldSortOrder, vs...))
|
||||
}
|
||||
|
||||
// SortOrderNotIn applies the NotIn predicate on the "sort_order" field.
|
||||
func SortOrderNotIn(vs ...int) predicate.Group {
|
||||
return predicate.Group(sql.FieldNotIn(FieldSortOrder, vs...))
|
||||
}
|
||||
|
||||
// SortOrderGT applies the GT predicate on the "sort_order" field.
|
||||
func SortOrderGT(v int) predicate.Group {
|
||||
return predicate.Group(sql.FieldGT(FieldSortOrder, v))
|
||||
}
|
||||
|
||||
// SortOrderGTE applies the GTE predicate on the "sort_order" field.
|
||||
func SortOrderGTE(v int) predicate.Group {
|
||||
return predicate.Group(sql.FieldGTE(FieldSortOrder, v))
|
||||
}
|
||||
|
||||
// SortOrderLT applies the LT predicate on the "sort_order" field.
|
||||
func SortOrderLT(v int) predicate.Group {
|
||||
return predicate.Group(sql.FieldLT(FieldSortOrder, v))
|
||||
}
|
||||
|
||||
// SortOrderLTE applies the LTE predicate on the "sort_order" field.
|
||||
func SortOrderLTE(v int) predicate.Group {
|
||||
return predicate.Group(sql.FieldLTE(FieldSortOrder, v))
|
||||
}
|
||||
|
||||
// HasAPIKeys applies the HasEdge predicate on the "api_keys" edge.
|
||||
func HasAPIKeys() predicate.Group {
|
||||
return predicate.Group(func(s *sql.Selector) {
|
||||
|
||||
@@ -340,6 +340,20 @@ func (_c *GroupCreate) SetSupportedModelScopes(v []string) *GroupCreate {
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetSortOrder sets the "sort_order" field.
|
||||
func (_c *GroupCreate) SetSortOrder(v int) *GroupCreate {
|
||||
_c.mutation.SetSortOrder(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableSortOrder sets the "sort_order" field if the given value is not nil.
|
||||
func (_c *GroupCreate) SetNillableSortOrder(v *int) *GroupCreate {
|
||||
if v != nil {
|
||||
_c.SetSortOrder(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||
func (_c *GroupCreate) AddAPIKeyIDs(ids ...int64) *GroupCreate {
|
||||
_c.mutation.AddAPIKeyIDs(ids...)
|
||||
@@ -521,6 +535,10 @@ func (_c *GroupCreate) defaults() error {
|
||||
v := group.DefaultSupportedModelScopes
|
||||
_c.mutation.SetSupportedModelScopes(v)
|
||||
}
|
||||
if _, ok := _c.mutation.SortOrder(); !ok {
|
||||
v := group.DefaultSortOrder
|
||||
_c.mutation.SetSortOrder(v)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -585,6 +603,9 @@ func (_c *GroupCreate) check() error {
|
||||
if _, ok := _c.mutation.SupportedModelScopes(); !ok {
|
||||
return &ValidationError{Name: "supported_model_scopes", err: errors.New(`ent: missing required field "Group.supported_model_scopes"`)}
|
||||
}
|
||||
if _, ok := _c.mutation.SortOrder(); !ok {
|
||||
return &ValidationError{Name: "sort_order", err: errors.New(`ent: missing required field "Group.sort_order"`)}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -708,6 +729,10 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) {
|
||||
_spec.SetField(group.FieldSupportedModelScopes, field.TypeJSON, value)
|
||||
_node.SupportedModelScopes = value
|
||||
}
|
||||
if value, ok := _c.mutation.SortOrder(); ok {
|
||||
_spec.SetField(group.FieldSortOrder, field.TypeInt, value)
|
||||
_node.SortOrder = value
|
||||
}
|
||||
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
@@ -1266,6 +1291,24 @@ func (u *GroupUpsert) UpdateSupportedModelScopes() *GroupUpsert {
|
||||
return u
|
||||
}
|
||||
|
||||
// SetSortOrder sets the "sort_order" field.
|
||||
func (u *GroupUpsert) SetSortOrder(v int) *GroupUpsert {
|
||||
u.Set(group.FieldSortOrder, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateSortOrder sets the "sort_order" field to the value that was provided on create.
|
||||
func (u *GroupUpsert) UpdateSortOrder() *GroupUpsert {
|
||||
u.SetExcluded(group.FieldSortOrder)
|
||||
return u
|
||||
}
|
||||
|
||||
// AddSortOrder adds v to the "sort_order" field.
|
||||
func (u *GroupUpsert) AddSortOrder(v int) *GroupUpsert {
|
||||
u.Add(group.FieldSortOrder, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateNewValues updates the mutable fields using the new values that were set on create.
|
||||
// Using this option is equivalent to using:
|
||||
//
|
||||
@@ -1780,6 +1823,27 @@ func (u *GroupUpsertOne) UpdateSupportedModelScopes() *GroupUpsertOne {
|
||||
})
|
||||
}
|
||||
|
||||
// SetSortOrder sets the "sort_order" field.
|
||||
func (u *GroupUpsertOne) SetSortOrder(v int) *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.SetSortOrder(v)
|
||||
})
|
||||
}
|
||||
|
||||
// AddSortOrder adds v to the "sort_order" field.
|
||||
func (u *GroupUpsertOne) AddSortOrder(v int) *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.AddSortOrder(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateSortOrder sets the "sort_order" field to the value that was provided on create.
|
||||
func (u *GroupUpsertOne) UpdateSortOrder() *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.UpdateSortOrder()
|
||||
})
|
||||
}
|
||||
|
||||
// Exec executes the query.
|
||||
func (u *GroupUpsertOne) Exec(ctx context.Context) error {
|
||||
if len(u.create.conflict) == 0 {
|
||||
@@ -2460,6 +2524,27 @@ func (u *GroupUpsertBulk) UpdateSupportedModelScopes() *GroupUpsertBulk {
|
||||
})
|
||||
}
|
||||
|
||||
// SetSortOrder sets the "sort_order" field.
|
||||
func (u *GroupUpsertBulk) SetSortOrder(v int) *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.SetSortOrder(v)
|
||||
})
|
||||
}
|
||||
|
||||
// AddSortOrder adds v to the "sort_order" field.
|
||||
func (u *GroupUpsertBulk) AddSortOrder(v int) *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.AddSortOrder(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateSortOrder sets the "sort_order" field to the value that was provided on create.
|
||||
func (u *GroupUpsertBulk) UpdateSortOrder() *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.UpdateSortOrder()
|
||||
})
|
||||
}
|
||||
|
||||
// Exec executes the query.
|
||||
func (u *GroupUpsertBulk) Exec(ctx context.Context) error {
|
||||
if u.create.err != nil {
|
||||
|
||||
@@ -475,6 +475,27 @@ func (_u *GroupUpdate) AppendSupportedModelScopes(v []string) *GroupUpdate {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetSortOrder sets the "sort_order" field.
|
||||
func (_u *GroupUpdate) SetSortOrder(v int) *GroupUpdate {
|
||||
_u.mutation.ResetSortOrder()
|
||||
_u.mutation.SetSortOrder(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableSortOrder sets the "sort_order" field if the given value is not nil.
|
||||
func (_u *GroupUpdate) SetNillableSortOrder(v *int) *GroupUpdate {
|
||||
if v != nil {
|
||||
_u.SetSortOrder(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddSortOrder adds value to the "sort_order" field.
|
||||
func (_u *GroupUpdate) AddSortOrder(v int) *GroupUpdate {
|
||||
_u.mutation.AddSortOrder(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||
func (_u *GroupUpdate) AddAPIKeyIDs(ids ...int64) *GroupUpdate {
|
||||
_u.mutation.AddAPIKeyIDs(ids...)
|
||||
@@ -912,6 +933,12 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
||||
sqljson.Append(u, group.FieldSupportedModelScopes, value)
|
||||
})
|
||||
}
|
||||
if value, ok := _u.mutation.SortOrder(); ok {
|
||||
_spec.SetField(group.FieldSortOrder, field.TypeInt, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedSortOrder(); ok {
|
||||
_spec.AddField(group.FieldSortOrder, field.TypeInt, value)
|
||||
}
|
||||
if _u.mutation.APIKeysCleared() {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
@@ -1666,6 +1693,27 @@ func (_u *GroupUpdateOne) AppendSupportedModelScopes(v []string) *GroupUpdateOne
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetSortOrder sets the "sort_order" field.
|
||||
func (_u *GroupUpdateOne) SetSortOrder(v int) *GroupUpdateOne {
|
||||
_u.mutation.ResetSortOrder()
|
||||
_u.mutation.SetSortOrder(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableSortOrder sets the "sort_order" field if the given value is not nil.
|
||||
func (_u *GroupUpdateOne) SetNillableSortOrder(v *int) *GroupUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetSortOrder(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddSortOrder adds value to the "sort_order" field.
|
||||
func (_u *GroupUpdateOne) AddSortOrder(v int) *GroupUpdateOne {
|
||||
_u.mutation.AddSortOrder(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||
func (_u *GroupUpdateOne) AddAPIKeyIDs(ids ...int64) *GroupUpdateOne {
|
||||
_u.mutation.AddAPIKeyIDs(ids...)
|
||||
@@ -2133,6 +2181,12 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error)
|
||||
sqljson.Append(u, group.FieldSupportedModelScopes, value)
|
||||
})
|
||||
}
|
||||
if value, ok := _u.mutation.SortOrder(); ok {
|
||||
_spec.SetField(group.FieldSortOrder, field.TypeInt, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedSortOrder(); ok {
|
||||
_spec.AddField(group.FieldSortOrder, field.TypeInt, value)
|
||||
}
|
||||
if _u.mutation.APIKeysCleared() {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
|
||||
@@ -325,6 +325,7 @@ var (
|
||||
{Name: "response_code", Type: field.TypeInt, Nullable: true},
|
||||
{Name: "passthrough_body", Type: field.TypeBool, Default: true},
|
||||
{Name: "custom_message", Type: field.TypeString, Nullable: true, Size: 2147483647},
|
||||
{Name: "skip_monitoring", Type: field.TypeBool, Default: false},
|
||||
{Name: "description", Type: field.TypeString, Nullable: true, Size: 2147483647},
|
||||
}
|
||||
// ErrorPassthroughRulesTable holds the schema information for the "error_passthrough_rules" table.
|
||||
@@ -372,6 +373,7 @@ var (
|
||||
{Name: "model_routing_enabled", Type: field.TypeBool, Default: false},
|
||||
{Name: "mcp_xml_inject", Type: field.TypeBool, Default: true},
|
||||
{Name: "supported_model_scopes", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
|
||||
{Name: "sort_order", Type: field.TypeInt, Default: 0},
|
||||
}
|
||||
// GroupsTable holds the schema information for the "groups" table.
|
||||
GroupsTable = &schema.Table{
|
||||
@@ -404,6 +406,11 @@ var (
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{GroupsColumns[3]},
|
||||
},
|
||||
{
|
||||
Name: "group_sort_order",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{GroupsColumns[25]},
|
||||
},
|
||||
},
|
||||
}
|
||||
// PromoCodesColumns holds the columns for the "promo_codes" table.
|
||||
|
||||
@@ -5776,6 +5776,7 @@ type ErrorPassthroughRuleMutation struct {
|
||||
addresponse_code *int
|
||||
passthrough_body *bool
|
||||
custom_message *string
|
||||
skip_monitoring *bool
|
||||
description *string
|
||||
clearedFields map[string]struct{}
|
||||
done bool
|
||||
@@ -6503,6 +6504,42 @@ func (m *ErrorPassthroughRuleMutation) ResetCustomMessage() {
|
||||
delete(m.clearedFields, errorpassthroughrule.FieldCustomMessage)
|
||||
}
|
||||
|
||||
// SetSkipMonitoring sets the "skip_monitoring" field.
|
||||
func (m *ErrorPassthroughRuleMutation) SetSkipMonitoring(b bool) {
|
||||
m.skip_monitoring = &b
|
||||
}
|
||||
|
||||
// SkipMonitoring returns the value of the "skip_monitoring" field in the mutation.
|
||||
func (m *ErrorPassthroughRuleMutation) SkipMonitoring() (r bool, exists bool) {
|
||||
v := m.skip_monitoring
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// OldSkipMonitoring returns the old "skip_monitoring" field's value of the ErrorPassthroughRule entity.
|
||||
// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database.
|
||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||
func (m *ErrorPassthroughRuleMutation) OldSkipMonitoring(ctx context.Context) (v bool, err error) {
|
||||
if !m.op.Is(OpUpdateOne) {
|
||||
return v, errors.New("OldSkipMonitoring is only allowed on UpdateOne operations")
|
||||
}
|
||||
if m.id == nil || m.oldValue == nil {
|
||||
return v, errors.New("OldSkipMonitoring requires an ID field in the mutation")
|
||||
}
|
||||
oldValue, err := m.oldValue(ctx)
|
||||
if err != nil {
|
||||
return v, fmt.Errorf("querying old value for OldSkipMonitoring: %w", err)
|
||||
}
|
||||
return oldValue.SkipMonitoring, nil
|
||||
}
|
||||
|
||||
// ResetSkipMonitoring resets all changes to the "skip_monitoring" field.
|
||||
func (m *ErrorPassthroughRuleMutation) ResetSkipMonitoring() {
|
||||
m.skip_monitoring = nil
|
||||
}
|
||||
|
||||
// SetDescription sets the "description" field.
|
||||
func (m *ErrorPassthroughRuleMutation) SetDescription(s string) {
|
||||
m.description = &s
|
||||
@@ -6586,7 +6623,7 @@ func (m *ErrorPassthroughRuleMutation) Type() string {
|
||||
// order to get all numeric fields that were incremented/decremented, call
|
||||
// AddedFields().
|
||||
func (m *ErrorPassthroughRuleMutation) Fields() []string {
|
||||
fields := make([]string, 0, 14)
|
||||
fields := make([]string, 0, 15)
|
||||
if m.created_at != nil {
|
||||
fields = append(fields, errorpassthroughrule.FieldCreatedAt)
|
||||
}
|
||||
@@ -6626,6 +6663,9 @@ func (m *ErrorPassthroughRuleMutation) Fields() []string {
|
||||
if m.custom_message != nil {
|
||||
fields = append(fields, errorpassthroughrule.FieldCustomMessage)
|
||||
}
|
||||
if m.skip_monitoring != nil {
|
||||
fields = append(fields, errorpassthroughrule.FieldSkipMonitoring)
|
||||
}
|
||||
if m.description != nil {
|
||||
fields = append(fields, errorpassthroughrule.FieldDescription)
|
||||
}
|
||||
@@ -6663,6 +6703,8 @@ func (m *ErrorPassthroughRuleMutation) Field(name string) (ent.Value, bool) {
|
||||
return m.PassthroughBody()
|
||||
case errorpassthroughrule.FieldCustomMessage:
|
||||
return m.CustomMessage()
|
||||
case errorpassthroughrule.FieldSkipMonitoring:
|
||||
return m.SkipMonitoring()
|
||||
case errorpassthroughrule.FieldDescription:
|
||||
return m.Description()
|
||||
}
|
||||
@@ -6700,6 +6742,8 @@ func (m *ErrorPassthroughRuleMutation) OldField(ctx context.Context, name string
|
||||
return m.OldPassthroughBody(ctx)
|
||||
case errorpassthroughrule.FieldCustomMessage:
|
||||
return m.OldCustomMessage(ctx)
|
||||
case errorpassthroughrule.FieldSkipMonitoring:
|
||||
return m.OldSkipMonitoring(ctx)
|
||||
case errorpassthroughrule.FieldDescription:
|
||||
return m.OldDescription(ctx)
|
||||
}
|
||||
@@ -6802,6 +6846,13 @@ func (m *ErrorPassthroughRuleMutation) SetField(name string, value ent.Value) er
|
||||
}
|
||||
m.SetCustomMessage(v)
|
||||
return nil
|
||||
case errorpassthroughrule.FieldSkipMonitoring:
|
||||
v, ok := value.(bool)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.SetSkipMonitoring(v)
|
||||
return nil
|
||||
case errorpassthroughrule.FieldDescription:
|
||||
v, ok := value.(string)
|
||||
if !ok {
|
||||
@@ -6963,6 +7014,9 @@ func (m *ErrorPassthroughRuleMutation) ResetField(name string) error {
|
||||
case errorpassthroughrule.FieldCustomMessage:
|
||||
m.ResetCustomMessage()
|
||||
return nil
|
||||
case errorpassthroughrule.FieldSkipMonitoring:
|
||||
m.ResetSkipMonitoring()
|
||||
return nil
|
||||
case errorpassthroughrule.FieldDescription:
|
||||
m.ResetDescription()
|
||||
return nil
|
||||
@@ -7059,6 +7113,8 @@ type GroupMutation struct {
|
||||
mcp_xml_inject *bool
|
||||
supported_model_scopes *[]string
|
||||
appendsupported_model_scopes []string
|
||||
sort_order *int
|
||||
addsort_order *int
|
||||
clearedFields map[string]struct{}
|
||||
api_keys map[int64]struct{}
|
||||
removedapi_keys map[int64]struct{}
|
||||
@@ -8411,6 +8467,62 @@ func (m *GroupMutation) ResetSupportedModelScopes() {
|
||||
m.appendsupported_model_scopes = nil
|
||||
}
|
||||
|
||||
// SetSortOrder sets the "sort_order" field.
|
||||
func (m *GroupMutation) SetSortOrder(i int) {
|
||||
m.sort_order = &i
|
||||
m.addsort_order = nil
|
||||
}
|
||||
|
||||
// SortOrder returns the value of the "sort_order" field in the mutation.
|
||||
func (m *GroupMutation) SortOrder() (r int, exists bool) {
|
||||
v := m.sort_order
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// OldSortOrder returns the old "sort_order" field's value of the Group entity.
|
||||
// If the Group object wasn't provided to the builder, the object is fetched from the database.
|
||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||
func (m *GroupMutation) OldSortOrder(ctx context.Context) (v int, err error) {
|
||||
if !m.op.Is(OpUpdateOne) {
|
||||
return v, errors.New("OldSortOrder is only allowed on UpdateOne operations")
|
||||
}
|
||||
if m.id == nil || m.oldValue == nil {
|
||||
return v, errors.New("OldSortOrder requires an ID field in the mutation")
|
||||
}
|
||||
oldValue, err := m.oldValue(ctx)
|
||||
if err != nil {
|
||||
return v, fmt.Errorf("querying old value for OldSortOrder: %w", err)
|
||||
}
|
||||
return oldValue.SortOrder, nil
|
||||
}
|
||||
|
||||
// AddSortOrder adds i to the "sort_order" field.
|
||||
func (m *GroupMutation) AddSortOrder(i int) {
|
||||
if m.addsort_order != nil {
|
||||
*m.addsort_order += i
|
||||
} else {
|
||||
m.addsort_order = &i
|
||||
}
|
||||
}
|
||||
|
||||
// AddedSortOrder returns the value that was added to the "sort_order" field in this mutation.
|
||||
func (m *GroupMutation) AddedSortOrder() (r int, exists bool) {
|
||||
v := m.addsort_order
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// ResetSortOrder resets all changes to the "sort_order" field.
|
||||
func (m *GroupMutation) ResetSortOrder() {
|
||||
m.sort_order = nil
|
||||
m.addsort_order = nil
|
||||
}
|
||||
|
||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids.
|
||||
func (m *GroupMutation) AddAPIKeyIDs(ids ...int64) {
|
||||
if m.api_keys == nil {
|
||||
@@ -8769,7 +8881,7 @@ func (m *GroupMutation) Type() string {
|
||||
// order to get all numeric fields that were incremented/decremented, call
|
||||
// AddedFields().
|
||||
func (m *GroupMutation) Fields() []string {
|
||||
fields := make([]string, 0, 24)
|
||||
fields := make([]string, 0, 25)
|
||||
if m.created_at != nil {
|
||||
fields = append(fields, group.FieldCreatedAt)
|
||||
}
|
||||
@@ -8842,6 +8954,9 @@ func (m *GroupMutation) Fields() []string {
|
||||
if m.supported_model_scopes != nil {
|
||||
fields = append(fields, group.FieldSupportedModelScopes)
|
||||
}
|
||||
if m.sort_order != nil {
|
||||
fields = append(fields, group.FieldSortOrder)
|
||||
}
|
||||
return fields
|
||||
}
|
||||
|
||||
@@ -8898,6 +9013,8 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) {
|
||||
return m.McpXMLInject()
|
||||
case group.FieldSupportedModelScopes:
|
||||
return m.SupportedModelScopes()
|
||||
case group.FieldSortOrder:
|
||||
return m.SortOrder()
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
@@ -8955,6 +9072,8 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e
|
||||
return m.OldMcpXMLInject(ctx)
|
||||
case group.FieldSupportedModelScopes:
|
||||
return m.OldSupportedModelScopes(ctx)
|
||||
case group.FieldSortOrder:
|
||||
return m.OldSortOrder(ctx)
|
||||
}
|
||||
return nil, fmt.Errorf("unknown Group field %s", name)
|
||||
}
|
||||
@@ -9132,6 +9251,13 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error {
|
||||
}
|
||||
m.SetSupportedModelScopes(v)
|
||||
return nil
|
||||
case group.FieldSortOrder:
|
||||
v, ok := value.(int)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.SetSortOrder(v)
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("unknown Group field %s", name)
|
||||
}
|
||||
@@ -9170,6 +9296,9 @@ func (m *GroupMutation) AddedFields() []string {
|
||||
if m.addfallback_group_id_on_invalid_request != nil {
|
||||
fields = append(fields, group.FieldFallbackGroupIDOnInvalidRequest)
|
||||
}
|
||||
if m.addsort_order != nil {
|
||||
fields = append(fields, group.FieldSortOrder)
|
||||
}
|
||||
return fields
|
||||
}
|
||||
|
||||
@@ -9198,6 +9327,8 @@ func (m *GroupMutation) AddedField(name string) (ent.Value, bool) {
|
||||
return m.AddedFallbackGroupID()
|
||||
case group.FieldFallbackGroupIDOnInvalidRequest:
|
||||
return m.AddedFallbackGroupIDOnInvalidRequest()
|
||||
case group.FieldSortOrder:
|
||||
return m.AddedSortOrder()
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
@@ -9277,6 +9408,13 @@ func (m *GroupMutation) AddField(name string, value ent.Value) error {
|
||||
}
|
||||
m.AddFallbackGroupIDOnInvalidRequest(v)
|
||||
return nil
|
||||
case group.FieldSortOrder:
|
||||
v, ok := value.(int)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.AddSortOrder(v)
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("unknown Group numeric field %s", name)
|
||||
}
|
||||
@@ -9445,6 +9583,9 @@ func (m *GroupMutation) ResetField(name string) error {
|
||||
case group.FieldSupportedModelScopes:
|
||||
m.ResetSupportedModelScopes()
|
||||
return nil
|
||||
case group.FieldSortOrder:
|
||||
m.ResetSortOrder()
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("unknown Group field %s", name)
|
||||
}
|
||||
|
||||
@@ -326,6 +326,10 @@ func init() {
|
||||
errorpassthroughruleDescPassthroughBody := errorpassthroughruleFields[9].Descriptor()
|
||||
// errorpassthroughrule.DefaultPassthroughBody holds the default value on creation for the passthrough_body field.
|
||||
errorpassthroughrule.DefaultPassthroughBody = errorpassthroughruleDescPassthroughBody.Default.(bool)
|
||||
// errorpassthroughruleDescSkipMonitoring is the schema descriptor for skip_monitoring field.
|
||||
errorpassthroughruleDescSkipMonitoring := errorpassthroughruleFields[11].Descriptor()
|
||||
// errorpassthroughrule.DefaultSkipMonitoring holds the default value on creation for the skip_monitoring field.
|
||||
errorpassthroughrule.DefaultSkipMonitoring = errorpassthroughruleDescSkipMonitoring.Default.(bool)
|
||||
groupMixin := schema.Group{}.Mixin()
|
||||
groupMixinHooks1 := groupMixin[1].Hooks()
|
||||
group.Hooks[0] = groupMixinHooks1[0]
|
||||
@@ -409,6 +413,10 @@ func init() {
|
||||
groupDescSupportedModelScopes := groupFields[20].Descriptor()
|
||||
// group.DefaultSupportedModelScopes holds the default value on creation for the supported_model_scopes field.
|
||||
group.DefaultSupportedModelScopes = groupDescSupportedModelScopes.Default.([]string)
|
||||
// groupDescSortOrder is the schema descriptor for sort_order field.
|
||||
groupDescSortOrder := groupFields[21].Descriptor()
|
||||
// group.DefaultSortOrder holds the default value on creation for the sort_order field.
|
||||
group.DefaultSortOrder = groupDescSortOrder.Default.(int)
|
||||
promocodeFields := schema.PromoCode{}.Fields()
|
||||
_ = promocodeFields
|
||||
// promocodeDescCode is the schema descriptor for code field.
|
||||
|
||||
@@ -105,6 +105,12 @@ func (ErrorPassthroughRule) Fields() []ent.Field {
|
||||
Optional().
|
||||
Nillable(),
|
||||
|
||||
// skip_monitoring: 是否跳过运维监控记录
|
||||
// true: 匹配此规则的错误不会被记录到 ops_error_logs
|
||||
// false: 正常记录到运维监控(默认行为)
|
||||
field.Bool("skip_monitoring").
|
||||
Default(false),
|
||||
|
||||
// description: 规则描述,用于说明规则的用途
|
||||
field.Text("description").
|
||||
Optional().
|
||||
|
||||
@@ -121,6 +121,11 @@ func (Group) Fields() []ent.Field {
|
||||
Default([]string{"claude", "gemini_text", "gemini_image"}).
|
||||
SchemaType(map[string]string{dialect.Postgres: "jsonb"}).
|
||||
Comment("支持的模型系列:claude, gemini_text, gemini_image"),
|
||||
|
||||
// 分组排序 (added by migration 052)
|
||||
field.Int("sort_order").
|
||||
Default(0).
|
||||
Comment("分组显示排序,数值越小越靠前"),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -149,5 +154,6 @@ func (Group) Indexes() []ent.Index {
|
||||
index.Fields("subscription_type"),
|
||||
index.Fields("is_exclusive"),
|
||||
index.Fields("deleted_at"),
|
||||
index.Fields("sort_order"),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -103,6 +103,7 @@ require (
|
||||
github.com/ncruces/go-strftime v1.0.0 // indirect
|
||||
github.com/opencontainers/go-digest v1.0.0 // indirect
|
||||
github.com/opencontainers/image-spec v1.1.1 // indirect
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
||||
|
||||
@@ -135,6 +135,8 @@ github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4=
|
||||
github.com/icholy/digest v1.1.0/go.mod h1:QNrsSGQ5v7v9cReDI0+eyjsXGUoRSUZQHeQ5C4XLa0Y=
|
||||
github.com/imroc/req/v3 v3.57.0 h1:LMTUjNRUybUkTPn8oJDq8Kg3JRBOBTcnDhKu7mzupKI=
|
||||
github.com/imroc/req/v3 v3.57.0/go.mod h1:JL62ey1nvSLq81HORNcosvlf7SxZStONNqOprg0Pz00=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
||||
@@ -170,6 +172,8 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk
|
||||
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U=
|
||||
github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
|
||||
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
|
||||
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
|
||||
github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI=
|
||||
@@ -203,10 +207,14 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
|
||||
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
|
||||
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
|
||||
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
||||
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
|
||||
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
|
||||
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
|
||||
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
|
||||
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
|
||||
github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M=
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc=
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ=
|
||||
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
|
||||
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
@@ -230,6 +238,8 @@ github.com/refraction-networking/utls v1.8.1 h1:yNY1kapmQU8JeM1sSw2H2asfTIwWxIkr
|
||||
github.com/refraction-networking/utls v1.8.1/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
|
||||
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs=
|
||||
github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
|
||||
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
|
||||
@@ -252,6 +262,8 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
|
||||
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
|
||||
github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0=
|
||||
github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
|
||||
github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I=
|
||||
github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0=
|
||||
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
|
||||
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||
github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ=
|
||||
|
||||
@@ -341,7 +341,7 @@ func (h *AccountHandler) listAccountsFiltered(ctx context.Context, platform, acc
|
||||
pageSize := dataPageCap
|
||||
var out []service.Account
|
||||
for {
|
||||
items, total, err := h.adminService.ListAccounts(ctx, page, pageSize, platform, accountType, status, search)
|
||||
items, total, err := h.adminService.ListAccounts(ctx, page, pageSize, platform, accountType, status, search, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -156,7 +156,12 @@ func (h *AccountHandler) List(c *gin.Context) {
|
||||
search = search[:100]
|
||||
}
|
||||
|
||||
accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search)
|
||||
var groupID int64
|
||||
if groupIDStr := c.Query("group"); groupIDStr != "" {
|
||||
groupID, _ = strconv.ParseInt(groupIDStr, 10, 64)
|
||||
}
|
||||
|
||||
accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search, groupID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
@@ -424,10 +429,17 @@ type TestAccountRequest struct {
|
||||
}
|
||||
|
||||
type SyncFromCRSRequest struct {
|
||||
BaseURL string `json:"base_url" binding:"required"`
|
||||
Username string `json:"username" binding:"required"`
|
||||
Password string `json:"password" binding:"required"`
|
||||
SyncProxies *bool `json:"sync_proxies"`
|
||||
BaseURL string `json:"base_url" binding:"required"`
|
||||
Username string `json:"username" binding:"required"`
|
||||
Password string `json:"password" binding:"required"`
|
||||
SyncProxies *bool `json:"sync_proxies"`
|
||||
SelectedAccountIDs []string `json:"selected_account_ids"`
|
||||
}
|
||||
|
||||
type PreviewFromCRSRequest struct {
|
||||
BaseURL string `json:"base_url" binding:"required"`
|
||||
Username string `json:"username" binding:"required"`
|
||||
Password string `json:"password" binding:"required"`
|
||||
}
|
||||
|
||||
// Test handles testing account connectivity with SSE streaming
|
||||
@@ -466,10 +478,11 @@ func (h *AccountHandler) SyncFromCRS(c *gin.Context) {
|
||||
}
|
||||
|
||||
result, err := h.crsSyncService.SyncFromCRS(c.Request.Context(), service.SyncFromCRSInput{
|
||||
BaseURL: req.BaseURL,
|
||||
Username: req.Username,
|
||||
Password: req.Password,
|
||||
SyncProxies: syncProxies,
|
||||
BaseURL: req.BaseURL,
|
||||
Username: req.Username,
|
||||
Password: req.Password,
|
||||
SyncProxies: syncProxies,
|
||||
SelectedAccountIDs: req.SelectedAccountIDs,
|
||||
})
|
||||
if err != nil {
|
||||
// Provide detailed error message for CRS sync failures
|
||||
@@ -480,6 +493,28 @@ func (h *AccountHandler) SyncFromCRS(c *gin.Context) {
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
// PreviewFromCRS handles previewing accounts from CRS before sync
|
||||
// POST /api/v1/admin/accounts/sync/crs/preview
|
||||
func (h *AccountHandler) PreviewFromCRS(c *gin.Context) {
|
||||
var req PreviewFromCRSRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.crsSyncService.PreviewFromCRS(c.Request.Context(), service.SyncFromCRSInput{
|
||||
BaseURL: req.BaseURL,
|
||||
Username: req.Username,
|
||||
Password: req.Password,
|
||||
})
|
||||
if err != nil {
|
||||
response.InternalError(c, "CRS preview failed: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
// Refresh handles refreshing account credentials
|
||||
// POST /api/v1/admin/accounts/:id/refresh
|
||||
func (h *AccountHandler) Refresh(c *gin.Context) {
|
||||
@@ -1399,7 +1434,7 @@ func (h *AccountHandler) BatchRefreshTier(c *gin.Context) {
|
||||
accounts := make([]*service.Account, 0)
|
||||
|
||||
if len(req.AccountIDs) == 0 {
|
||||
allAccounts, _, err := h.adminService.ListAccounts(ctx, 1, 10000, "gemini", "oauth", "", "")
|
||||
allAccounts, _, err := h.adminService.ListAccounts(ctx, 1, 10000, "gemini", "oauth", "", "", 0)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
|
||||
@@ -16,7 +16,7 @@ func setupAdminRouter() (*gin.Engine, *stubAdminService) {
|
||||
router := gin.New()
|
||||
adminSvc := newStubAdminService()
|
||||
|
||||
userHandler := NewUserHandler(adminSvc)
|
||||
userHandler := NewUserHandler(adminSvc, nil)
|
||||
groupHandler := NewGroupHandler(adminSvc)
|
||||
proxyHandler := NewProxyHandler(adminSvc)
|
||||
redeemHandler := NewRedeemHandler(adminSvc)
|
||||
|
||||
@@ -166,7 +166,7 @@ func (s *stubAdminService) GetGroupAPIKeys(ctx context.Context, groupID int64, p
|
||||
return s.apiKeys, int64(len(s.apiKeys)), nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]service.Account, int64, error) {
|
||||
func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]service.Account, int64, error) {
|
||||
return s.accounts, int64(len(s.accounts)), nil
|
||||
}
|
||||
|
||||
@@ -357,5 +357,9 @@ func (s *stubAdminService) GetUserBalanceHistory(ctx context.Context, userID int
|
||||
return s.redeems, int64(len(s.redeems)), 100.0, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) UpdateGroupSortOrders(ctx context.Context, updates []service.GroupSortOrderUpdate) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ensure stub implements interface.
|
||||
var _ service.AdminService = (*stubAdminService)(nil)
|
||||
|
||||
@@ -65,3 +65,27 @@ func (h *AntigravityOAuthHandler) ExchangeCode(c *gin.Context) {
|
||||
|
||||
response.Success(c, tokenInfo)
|
||||
}
|
||||
|
||||
// AntigravityRefreshTokenRequest represents the request for validating Antigravity refresh token
|
||||
type AntigravityRefreshTokenRequest struct {
|
||||
RefreshToken string `json:"refresh_token" binding:"required"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
}
|
||||
|
||||
// RefreshToken validates an Antigravity refresh token and returns full token info
|
||||
// POST /api/v1/admin/antigravity/oauth/refresh-token
|
||||
func (h *AntigravityOAuthHandler) RefreshToken(c *gin.Context) {
|
||||
var req AntigravityRefreshTokenRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "请求无效: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
tokenInfo, err := h.antigravityOAuthService.ValidateRefreshToken(c.Request.Context(), req.RefreshToken, req.ProxyID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, tokenInfo)
|
||||
}
|
||||
|
||||
@@ -32,6 +32,7 @@ type CreateErrorPassthroughRuleRequest struct {
|
||||
ResponseCode *int `json:"response_code"`
|
||||
PassthroughBody *bool `json:"passthrough_body"`
|
||||
CustomMessage *string `json:"custom_message"`
|
||||
SkipMonitoring *bool `json:"skip_monitoring"`
|
||||
Description *string `json:"description"`
|
||||
}
|
||||
|
||||
@@ -48,6 +49,7 @@ type UpdateErrorPassthroughRuleRequest struct {
|
||||
ResponseCode *int `json:"response_code"`
|
||||
PassthroughBody *bool `json:"passthrough_body"`
|
||||
CustomMessage *string `json:"custom_message"`
|
||||
SkipMonitoring *bool `json:"skip_monitoring"`
|
||||
Description *string `json:"description"`
|
||||
}
|
||||
|
||||
@@ -122,6 +124,9 @@ func (h *ErrorPassthroughHandler) Create(c *gin.Context) {
|
||||
} else {
|
||||
rule.PassthroughBody = true
|
||||
}
|
||||
if req.SkipMonitoring != nil {
|
||||
rule.SkipMonitoring = *req.SkipMonitoring
|
||||
}
|
||||
rule.ResponseCode = req.ResponseCode
|
||||
rule.CustomMessage = req.CustomMessage
|
||||
rule.Description = req.Description
|
||||
@@ -190,6 +195,7 @@ func (h *ErrorPassthroughHandler) Update(c *gin.Context) {
|
||||
ResponseCode: existing.ResponseCode,
|
||||
PassthroughBody: existing.PassthroughBody,
|
||||
CustomMessage: existing.CustomMessage,
|
||||
SkipMonitoring: existing.SkipMonitoring,
|
||||
Description: existing.Description,
|
||||
}
|
||||
|
||||
@@ -230,6 +236,9 @@ func (h *ErrorPassthroughHandler) Update(c *gin.Context) {
|
||||
if req.Description != nil {
|
||||
rule.Description = req.Description
|
||||
}
|
||||
if req.SkipMonitoring != nil {
|
||||
rule.SkipMonitoring = *req.SkipMonitoring
|
||||
}
|
||||
|
||||
// 确保切片不为 nil
|
||||
if rule.ErrorCodes == nil {
|
||||
|
||||
@@ -302,3 +302,36 @@ func (h *GroupHandler) GetGroupAPIKeys(c *gin.Context) {
|
||||
}
|
||||
response.Paginated(c, outKeys, total, page, pageSize)
|
||||
}
|
||||
|
||||
// UpdateSortOrderRequest represents the request to update group sort orders
|
||||
type UpdateSortOrderRequest struct {
|
||||
Updates []struct {
|
||||
ID int64 `json:"id" binding:"required"`
|
||||
SortOrder int `json:"sort_order"`
|
||||
} `json:"updates" binding:"required,min=1"`
|
||||
}
|
||||
|
||||
// UpdateSortOrder handles updating group sort orders
|
||||
// PUT /api/v1/admin/groups/sort-order
|
||||
func (h *GroupHandler) UpdateSortOrder(c *gin.Context) {
|
||||
var req UpdateSortOrderRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
updates := make([]service.GroupSortOrderUpdate, 0, len(req.Updates))
|
||||
for _, u := range req.Updates {
|
||||
updates = append(updates, service.GroupSortOrderUpdate{
|
||||
ID: u.ID,
|
||||
SortOrder: u.SortOrder,
|
||||
})
|
||||
}
|
||||
|
||||
if err := h.adminService.UpdateGroupSortOrders(c.Request.Context(), updates); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "Sort order updated successfully"})
|
||||
}
|
||||
|
||||
@@ -202,7 +202,7 @@ func (h *RedeemHandler) Export(c *gin.Context) {
|
||||
writer := csv.NewWriter(&buf)
|
||||
|
||||
// Write header
|
||||
if err := writer.Write([]string{"id", "code", "type", "value", "status", "used_by", "used_at", "created_at"}); err != nil {
|
||||
if err := writer.Write([]string{"id", "code", "type", "value", "status", "used_by", "used_by_email", "used_at", "created_at"}); err != nil {
|
||||
response.InternalError(c, "Failed to export redeem codes: "+err.Error())
|
||||
return
|
||||
}
|
||||
@@ -213,6 +213,10 @@ func (h *RedeemHandler) Export(c *gin.Context) {
|
||||
if code.UsedBy != nil {
|
||||
usedBy = fmt.Sprintf("%d", *code.UsedBy)
|
||||
}
|
||||
usedByEmail := ""
|
||||
if code.User != nil {
|
||||
usedByEmail = code.User.Email
|
||||
}
|
||||
usedAt := ""
|
||||
if code.UsedAt != nil {
|
||||
usedAt = code.UsedAt.Format("2006-01-02 15:04:05")
|
||||
@@ -224,6 +228,7 @@ func (h *RedeemHandler) Export(c *gin.Context) {
|
||||
fmt.Sprintf("%.2f", code.Value),
|
||||
code.Status,
|
||||
usedBy,
|
||||
usedByEmail,
|
||||
usedAt,
|
||||
code.CreatedAt.Format("2006-01-02 15:04:05"),
|
||||
}); err != nil {
|
||||
|
||||
@@ -11,15 +11,23 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// UserWithConcurrency wraps AdminUser with current concurrency info
|
||||
type UserWithConcurrency struct {
|
||||
dto.AdminUser
|
||||
CurrentConcurrency int `json:"current_concurrency"`
|
||||
}
|
||||
|
||||
// UserHandler handles admin user management
|
||||
type UserHandler struct {
|
||||
adminService service.AdminService
|
||||
adminService service.AdminService
|
||||
concurrencyService *service.ConcurrencyService
|
||||
}
|
||||
|
||||
// NewUserHandler creates a new admin user handler
|
||||
func NewUserHandler(adminService service.AdminService) *UserHandler {
|
||||
func NewUserHandler(adminService service.AdminService, concurrencyService *service.ConcurrencyService) *UserHandler {
|
||||
return &UserHandler{
|
||||
adminService: adminService,
|
||||
adminService: adminService,
|
||||
concurrencyService: concurrencyService,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -87,10 +95,30 @@ func (h *UserHandler) List(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.AdminUser, 0, len(users))
|
||||
for i := range users {
|
||||
out = append(out, *dto.UserFromServiceAdmin(&users[i]))
|
||||
// Batch get current concurrency (nil map if unavailable)
|
||||
var loadInfo map[int64]*service.UserLoadInfo
|
||||
if len(users) > 0 && h.concurrencyService != nil {
|
||||
usersConcurrency := make([]service.UserWithConcurrency, len(users))
|
||||
for i := range users {
|
||||
usersConcurrency[i] = service.UserWithConcurrency{
|
||||
ID: users[i].ID,
|
||||
MaxConcurrency: users[i].Concurrency,
|
||||
}
|
||||
}
|
||||
loadInfo, _ = h.concurrencyService.GetUsersLoadBatch(c.Request.Context(), usersConcurrency)
|
||||
}
|
||||
|
||||
// Build response with concurrency info
|
||||
out := make([]UserWithConcurrency, len(users))
|
||||
for i := range users {
|
||||
out[i] = UserWithConcurrency{
|
||||
AdminUser: *dto.UserFromServiceAdmin(&users[i]),
|
||||
}
|
||||
if info := loadInfo[users[i].ID]; info != nil {
|
||||
out[i].CurrentConcurrency = info.CurrentConcurrency
|
||||
}
|
||||
}
|
||||
|
||||
response.Paginated(c, out, total, page, pageSize)
|
||||
}
|
||||
|
||||
|
||||
@@ -115,6 +115,7 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup {
|
||||
MCPXMLInject: g.MCPXMLInject,
|
||||
SupportedModelScopes: g.SupportedModelScopes,
|
||||
AccountCount: g.AccountCount,
|
||||
SortOrder: g.SortOrder,
|
||||
}
|
||||
if len(g.AccountGroups) > 0 {
|
||||
out.AccountGroups = make([]AccountGroup, 0, len(g.AccountGroups))
|
||||
|
||||
@@ -2,11 +2,6 @@ package dto
|
||||
|
||||
import "time"
|
||||
|
||||
type ScopeRateLimitInfo struct {
|
||||
ResetAt time.Time `json:"reset_at"`
|
||||
RemainingSec int64 `json:"remaining_sec"`
|
||||
}
|
||||
|
||||
type User struct {
|
||||
ID int64 `json:"id"`
|
||||
Email string `json:"email"`
|
||||
@@ -98,6 +93,9 @@ type AdminGroup struct {
|
||||
SupportedModelScopes []string `json:"supported_model_scopes"`
|
||||
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
|
||||
AccountCount int64 `json:"account_count,omitempty"`
|
||||
|
||||
// 分组排序
|
||||
SortOrder int `json:"sort_order"`
|
||||
}
|
||||
|
||||
type Account struct {
|
||||
@@ -126,9 +124,6 @@ type Account struct {
|
||||
RateLimitResetAt *time.Time `json:"rate_limit_reset_at"`
|
||||
OverloadUntil *time.Time `json:"overload_until"`
|
||||
|
||||
// Antigravity scope 级限流状态(从 extra 提取)
|
||||
ScopeRateLimits map[string]ScopeRateLimitInfo `json:"scope_rate_limits,omitempty"`
|
||||
|
||||
TempUnschedulableUntil *time.Time `json:"temp_unschedulable_until"`
|
||||
TempUnschedulableReason string `json:"temp_unschedulable_reason"`
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/domain"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
@@ -114,7 +115,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
|
||||
setOpsRequestContext(c, "", false, body)
|
||||
|
||||
parsedReq, err := service.ParseGatewayRequest(body)
|
||||
parsedReq, err := service.ParseGatewayRequest(body, domain.PlatformAnthropic)
|
||||
if err != nil {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
||||
return
|
||||
@@ -203,6 +204,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 计算粘性会话hash
|
||||
parsedReq.SessionContext = &service.SessionContext{
|
||||
ClientIP: ip.GetClientIP(c),
|
||||
UserAgent: c.GetHeader("User-Agent"),
|
||||
APIKeyID: apiKey.ID,
|
||||
}
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
|
||||
|
||||
// 获取平台:优先使用强制平台(/antigravity 路由,中间件已设置 request.Context),否则使用分组平台
|
||||
@@ -229,9 +235,17 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
maxAccountSwitches := h.maxAccountSwitchesGemini
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
sameAccountRetryCount := make(map[int64]int) // 同账号重试计数
|
||||
var lastFailoverErr *service.UpstreamFailoverError
|
||||
var forceCacheBilling bool // 粘性会话切换时的缓存计费标记
|
||||
|
||||
// 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。
|
||||
// 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。
|
||||
if h.gatewayService.IsSingleAntigravityAccountGroup(c.Request.Context(), apiKey.GroupID) {
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
}
|
||||
|
||||
for {
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, "") // Gemini 不使用会话限制
|
||||
if err != nil {
|
||||
@@ -239,6 +253,19 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||
return
|
||||
}
|
||||
// Antigravity 单账号退避重试:分组内没有其他可用账号时,
|
||||
// 对 503 错误不直接返回,而是清除排除列表、等待退避后重试同一个账号。
|
||||
// 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。
|
||||
if lastFailoverErr != nil && lastFailoverErr.StatusCode == http.StatusServiceUnavailable && switchCount <= maxAccountSwitches {
|
||||
if sleepAntigravitySingleAccountBackoff(c.Request.Context(), switchCount) {
|
||||
log.Printf("Antigravity single-account 503 retry: clearing failed accounts, retry %d/%d", switchCount, maxAccountSwitches)
|
||||
failedAccountIDs = make(map[int64]struct{})
|
||||
// 设置 context 标记,让 Service 层预检查等待限流过期而非直接切换
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
continue
|
||||
}
|
||||
}
|
||||
if lastFailoverErr != nil {
|
||||
h.handleFailoverExhausted(c, lastFailoverErr, service.PlatformGemini, streamStarted)
|
||||
} else {
|
||||
@@ -333,17 +360,39 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
lastFailoverErr = failoverErr
|
||||
if failoverErr.ForceCacheBilling {
|
||||
if needForceCacheBilling(hasBoundSession, failoverErr) {
|
||||
forceCacheBilling = true
|
||||
}
|
||||
|
||||
// 同账号重试:对 RetryableOnSameAccount 的临时性错误,先在同一账号上重试
|
||||
if failoverErr.RetryableOnSameAccount && sameAccountRetryCount[account.ID] < maxSameAccountRetries {
|
||||
sameAccountRetryCount[account.ID]++
|
||||
log.Printf("Account %d: retryable error %d, same-account retry %d/%d",
|
||||
account.ID, failoverErr.StatusCode, sameAccountRetryCount[account.ID], maxSameAccountRetries)
|
||||
if !sleepSameAccountRetryDelay(c.Request.Context()) {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// 同账号重试用尽,执行临时封禁并切换账号
|
||||
if failoverErr.RetryableOnSameAccount {
|
||||
h.gatewayService.TempUnscheduleRetryableError(c.Request.Context(), account.ID, failoverErr)
|
||||
}
|
||||
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
if switchCount >= maxAccountSwitches {
|
||||
h.handleFailoverExhausted(c, failoverErr, service.PlatformGemini, streamStarted)
|
||||
return
|
||||
}
|
||||
switchCount++
|
||||
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
|
||||
if account.Platform == service.PlatformAntigravity {
|
||||
if !sleepFailoverDelay(c.Request.Context(), switchCount) {
|
||||
return
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
// 错误响应已在Forward中处理,这里只记录日志
|
||||
@@ -385,10 +434,18 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
fallbackUsed := false
|
||||
|
||||
// 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。
|
||||
// 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。
|
||||
if h.gatewayService.IsSingleAntigravityAccountGroup(c.Request.Context(), currentAPIKey.GroupID) {
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
}
|
||||
|
||||
for {
|
||||
maxAccountSwitches := h.maxAccountSwitches
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
sameAccountRetryCount := make(map[int64]int) // 同账号重试计数
|
||||
var lastFailoverErr *service.UpstreamFailoverError
|
||||
retryWithFallback := false
|
||||
var forceCacheBilling bool // 粘性会话切换时的缓存计费标记
|
||||
@@ -401,6 +458,19 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||
return
|
||||
}
|
||||
// Antigravity 单账号退避重试:分组内没有其他可用账号时,
|
||||
// 对 503 错误不直接返回,而是清除排除列表、等待退避后重试同一个账号。
|
||||
// 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。
|
||||
if lastFailoverErr != nil && lastFailoverErr.StatusCode == http.StatusServiceUnavailable && switchCount <= maxAccountSwitches {
|
||||
if sleepAntigravitySingleAccountBackoff(c.Request.Context(), switchCount) {
|
||||
log.Printf("Antigravity single-account 503 retry: clearing failed accounts, retry %d/%d", switchCount, maxAccountSwitches)
|
||||
failedAccountIDs = make(map[int64]struct{})
|
||||
// 设置 context 标记,让 Service 层预检查等待限流过期而非直接切换
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
continue
|
||||
}
|
||||
}
|
||||
if lastFailoverErr != nil {
|
||||
h.handleFailoverExhausted(c, lastFailoverErr, platform, streamStarted)
|
||||
} else {
|
||||
@@ -482,7 +552,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
if switchCount > 0 {
|
||||
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount)
|
||||
}
|
||||
if account.Platform == service.PlatformAntigravity {
|
||||
if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey {
|
||||
result, err = h.antigravityGatewayService.Forward(requestCtx, c, account, body, hasBoundSession)
|
||||
} else {
|
||||
result, err = h.gatewayService.Forward(requestCtx, c, account, parsedReq)
|
||||
@@ -528,17 +598,39 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
lastFailoverErr = failoverErr
|
||||
if failoverErr.ForceCacheBilling {
|
||||
if needForceCacheBilling(hasBoundSession, failoverErr) {
|
||||
forceCacheBilling = true
|
||||
}
|
||||
|
||||
// 同账号重试:对 RetryableOnSameAccount 的临时性错误,先在同一账号上重试
|
||||
if failoverErr.RetryableOnSameAccount && sameAccountRetryCount[account.ID] < maxSameAccountRetries {
|
||||
sameAccountRetryCount[account.ID]++
|
||||
log.Printf("Account %d: retryable error %d, same-account retry %d/%d",
|
||||
account.ID, failoverErr.StatusCode, sameAccountRetryCount[account.ID], maxSameAccountRetries)
|
||||
if !sleepSameAccountRetryDelay(c.Request.Context()) {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// 同账号重试用尽,执行临时封禁并切换账号
|
||||
if failoverErr.RetryableOnSameAccount {
|
||||
h.gatewayService.TempUnscheduleRetryableError(c.Request.Context(), account.ID, failoverErr)
|
||||
}
|
||||
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
if switchCount >= maxAccountSwitches {
|
||||
h.handleFailoverExhausted(c, failoverErr, account.Platform, streamStarted)
|
||||
return
|
||||
}
|
||||
switchCount++
|
||||
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
|
||||
if account.Platform == service.PlatformAntigravity {
|
||||
if !sleepFailoverDelay(c.Request.Context(), switchCount) {
|
||||
return
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
// 错误响应已在Forward中处理,这里只记录日志
|
||||
@@ -801,6 +893,65 @@ func (h *GatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotT
|
||||
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
|
||||
}
|
||||
|
||||
// needForceCacheBilling 判断 failover 时是否需要强制缓存计费
|
||||
// 粘性会话切换账号、或上游明确标记时,将 input_tokens 转为 cache_read 计费
|
||||
func needForceCacheBilling(hasBoundSession bool, failoverErr *service.UpstreamFailoverError) bool {
|
||||
return hasBoundSession || (failoverErr != nil && failoverErr.ForceCacheBilling)
|
||||
}
|
||||
|
||||
const (
|
||||
// maxSameAccountRetries 同账号重试次数上限(针对 RetryableOnSameAccount 错误)
|
||||
maxSameAccountRetries = 2
|
||||
// sameAccountRetryDelay 同账号重试间隔
|
||||
sameAccountRetryDelay = 500 * time.Millisecond
|
||||
)
|
||||
|
||||
// sleepSameAccountRetryDelay 同账号重试固定延时,返回 false 表示 context 已取消。
|
||||
func sleepSameAccountRetryDelay(ctx context.Context) bool {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
case <-time.After(sameAccountRetryDelay):
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// sleepFailoverDelay 账号切换线性递增延时:第1次0s、第2次1s、第3次2s…
|
||||
// 返回 false 表示 context 已取消。
|
||||
func sleepFailoverDelay(ctx context.Context, switchCount int) bool {
|
||||
delay := time.Duration(switchCount-1) * time.Second
|
||||
if delay <= 0 {
|
||||
return true
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
case <-time.After(delay):
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// sleepAntigravitySingleAccountBackoff Antigravity 平台单账号分组的 503 退避重试延时。
|
||||
// 当分组内只有一个可用账号且上游返回 503(MODEL_CAPACITY_EXHAUSTED)时使用,
|
||||
// 采用短固定延时策略。Service 层在 SingleAccountRetry 模式下已经做了充分的原地重试
|
||||
// (最多 3 次、总等待 30s),所以 Handler 层的退避只需短暂等待即可。
|
||||
// 返回 false 表示 context 已取消。
|
||||
func sleepAntigravitySingleAccountBackoff(ctx context.Context, retryCount int) bool {
|
||||
// 固定短延时:2s
|
||||
// Service 层已经在原地等待了足够长的时间(retryDelay × 重试次数),
|
||||
// Handler 层只需短暂间隔后重新进入 Service 层即可。
|
||||
const delay = 2 * time.Second
|
||||
|
||||
log.Printf("Antigravity single-account 503 backoff: waiting %v before retry (attempt %d)", delay, retryCount)
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
case <-time.After(delay):
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, platform string, streamStarted bool) {
|
||||
statusCode := failoverErr.StatusCode
|
||||
responseBody := failoverErr.ResponseBody
|
||||
@@ -820,6 +971,10 @@ func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *se
|
||||
msg = *rule.CustomMessage
|
||||
}
|
||||
|
||||
if rule.SkipMonitoring {
|
||||
c.Set(service.OpsSkipPassthroughKey, true)
|
||||
}
|
||||
|
||||
h.handleStreamingAwareError(c, respCode, "upstream_error", msg, streamStarted)
|
||||
return
|
||||
}
|
||||
@@ -934,7 +1089,7 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
||||
|
||||
setOpsRequestContext(c, "", false, body)
|
||||
|
||||
parsedReq, err := service.ParseGatewayRequest(body)
|
||||
parsedReq, err := service.ParseGatewayRequest(body, domain.PlatformAnthropic)
|
||||
if err != nil {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
||||
return
|
||||
@@ -962,6 +1117,11 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 计算粘性会话 hash
|
||||
parsedReq.SessionContext = &service.SessionContext{
|
||||
ClientIP: ip.GetClientIP(c),
|
||||
UserAgent: c.GetHeader("User-Agent"),
|
||||
APIKeyID: apiKey.ID,
|
||||
}
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
|
||||
|
||||
// 选择支持该模型的账号
|
||||
|
||||
@@ -0,0 +1,51 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// sleepAntigravitySingleAccountBackoff 测试
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestSleepAntigravitySingleAccountBackoff_ReturnsTrue(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
start := time.Now()
|
||||
ok := sleepAntigravitySingleAccountBackoff(ctx, 1)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
require.True(t, ok, "should return true when context is not canceled")
|
||||
// 固定延迟 2s
|
||||
require.GreaterOrEqual(t, elapsed, 1500*time.Millisecond, "should wait approximately 2s")
|
||||
require.Less(t, elapsed, 5*time.Second, "should not wait too long")
|
||||
}
|
||||
|
||||
func TestSleepAntigravitySingleAccountBackoff_ContextCanceled(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // 立即取消
|
||||
|
||||
start := time.Now()
|
||||
ok := sleepAntigravitySingleAccountBackoff(ctx, 1)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
require.False(t, ok, "should return false when context is canceled")
|
||||
require.Less(t, elapsed, 500*time.Millisecond, "should return immediately on cancel")
|
||||
}
|
||||
|
||||
func TestSleepAntigravitySingleAccountBackoff_FixedDelay(t *testing.T) {
|
||||
// 验证不同 retryCount 都使用固定 2s 延迟
|
||||
ctx := context.Background()
|
||||
|
||||
start := time.Now()
|
||||
ok := sleepAntigravitySingleAccountBackoff(ctx, 5)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
require.True(t, ok)
|
||||
// 即使 retryCount=5,延迟仍然是固定的 2s
|
||||
require.GreaterOrEqual(t, elapsed, 1500*time.Millisecond)
|
||||
require.Less(t, elapsed, 5*time.Second)
|
||||
}
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/domain"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/gemini"
|
||||
@@ -30,13 +31,6 @@ import (
|
||||
// 匹配格式: /Users/xxx/.gemini/tmp/[64位十六进制哈希]
|
||||
var geminiCLITmpDirRegex = regexp.MustCompile(`/\.gemini/tmp/([A-Fa-f0-9]{64})`)
|
||||
|
||||
func isGeminiCLIRequest(c *gin.Context, body []byte) bool {
|
||||
if strings.TrimSpace(c.GetHeader("x-gemini-api-privileged-user-id")) != "" {
|
||||
return true
|
||||
}
|
||||
return geminiCLITmpDirRegex.Match(body)
|
||||
}
|
||||
|
||||
// GeminiV1BetaListModels proxies:
|
||||
// GET /v1beta/models
|
||||
func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) {
|
||||
@@ -239,7 +233,14 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
sessionHash := extractGeminiCLISessionHash(c, body)
|
||||
if sessionHash == "" {
|
||||
// Fallback: 使用通用的会话哈希生成逻辑(适用于其他客户端)
|
||||
parsedReq, _ := service.ParseGatewayRequest(body)
|
||||
parsedReq, _ := service.ParseGatewayRequest(body, domain.PlatformGemini)
|
||||
if parsedReq != nil {
|
||||
parsedReq.SessionContext = &service.SessionContext{
|
||||
ClientIP: ip.GetClientIP(c),
|
||||
UserAgent: c.GetHeader("User-Agent"),
|
||||
APIKeyID: apiKey.ID,
|
||||
}
|
||||
}
|
||||
sessionHash = h.gatewayService.GenerateSessionHash(parsedReq)
|
||||
}
|
||||
sessionKey := sessionHash
|
||||
@@ -258,6 +259,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
var geminiDigestChain string
|
||||
var geminiPrefixHash string
|
||||
var geminiSessionUUID string
|
||||
var matchedDigestChain string
|
||||
useDigestFallback := sessionBoundAccountID == 0
|
||||
|
||||
if useDigestFallback {
|
||||
@@ -284,13 +286,14 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
)
|
||||
|
||||
// 查找会话
|
||||
foundUUID, foundAccountID, found := h.gatewayService.FindGeminiSession(
|
||||
foundUUID, foundAccountID, foundMatchedChain, found := h.gatewayService.FindGeminiSession(
|
||||
c.Request.Context(),
|
||||
derefGroupID(apiKey.GroupID),
|
||||
geminiPrefixHash,
|
||||
geminiDigestChain,
|
||||
)
|
||||
if found {
|
||||
matchedDigestChain = foundMatchedChain
|
||||
sessionBoundAccountID = foundAccountID
|
||||
geminiSessionUUID = foundUUID
|
||||
log.Printf("[Gemini] Digest fallback matched: uuid=%s, accountID=%d, chain=%s",
|
||||
@@ -316,7 +319,6 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
|
||||
// 判断是否真的绑定了粘性会话:有 sessionKey 且已经绑定到某个账号
|
||||
hasBoundSession := sessionKey != "" && sessionBoundAccountID > 0
|
||||
isCLI := isGeminiCLIRequest(c, body)
|
||||
cleanedForUnknownBinding := false
|
||||
|
||||
maxAccountSwitches := h.maxAccountSwitchesGemini
|
||||
@@ -325,6 +327,13 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
var lastFailoverErr *service.UpstreamFailoverError
|
||||
var forceCacheBilling bool // 粘性会话切换时的缓存计费标记
|
||||
|
||||
// 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。
|
||||
// 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。
|
||||
if h.gatewayService.IsSingleAntigravityAccountGroup(c.Request.Context(), apiKey.GroupID) {
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
}
|
||||
|
||||
for {
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, failedAccountIDs, "") // Gemini 不使用会话限制
|
||||
if err != nil {
|
||||
@@ -332,6 +341,19 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
|
||||
return
|
||||
}
|
||||
// Antigravity 单账号退避重试:分组内没有其他可用账号时,
|
||||
// 对 503 错误不直接返回,而是清除排除列表、等待退避后重试同一个账号。
|
||||
// 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。
|
||||
if lastFailoverErr != nil && lastFailoverErr.StatusCode == http.StatusServiceUnavailable && switchCount <= maxAccountSwitches {
|
||||
if sleepAntigravitySingleAccountBackoff(c.Request.Context(), switchCount) {
|
||||
log.Printf("Antigravity single-account 503 retry: clearing failed accounts, retry %d/%d", switchCount, maxAccountSwitches)
|
||||
failedAccountIDs = make(map[int64]struct{})
|
||||
// 设置 context 标记,让 Service 层预检查等待限流过期而非直接切换
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
continue
|
||||
}
|
||||
}
|
||||
h.handleGeminiFailoverExhausted(c, lastFailoverErr)
|
||||
return
|
||||
}
|
||||
@@ -344,10 +366,10 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
log.Printf("[Gemini] Sticky session account switched: %d -> %d, cleaning thoughtSignature", sessionBoundAccountID, account.ID)
|
||||
body = service.CleanGeminiNativeThoughtSignatures(body)
|
||||
sessionBoundAccountID = account.ID
|
||||
} else if sessionKey != "" && sessionBoundAccountID == 0 && isCLI && !cleanedForUnknownBinding && bytes.Contains(body, []byte(`"thoughtSignature"`)) {
|
||||
// 无缓存绑定但请求里已有 thoughtSignature:常见于缓存丢失/TTL 过期后,CLI 继续携带旧签名。
|
||||
} else if sessionKey != "" && sessionBoundAccountID == 0 && !cleanedForUnknownBinding && bytes.Contains(body, []byte(`"thoughtSignature"`)) {
|
||||
// 无缓存绑定但请求里已有 thoughtSignature:常见于缓存丢失/TTL 过期后,客户端继续携带旧签名。
|
||||
// 为避免第一次转发就 400,这里做一次确定性清理,让新账号重新生成签名链路。
|
||||
log.Printf("[Gemini] Sticky session binding missing for CLI request, cleaning thoughtSignature proactively")
|
||||
log.Printf("[Gemini] Sticky session binding missing, cleaning thoughtSignature proactively")
|
||||
body = service.CleanGeminiNativeThoughtSignatures(body)
|
||||
cleanedForUnknownBinding = true
|
||||
sessionBoundAccountID = account.ID
|
||||
@@ -410,7 +432,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
if switchCount > 0 {
|
||||
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount)
|
||||
}
|
||||
if account.Platform == service.PlatformAntigravity {
|
||||
if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey {
|
||||
result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, modelName, action, stream, body, hasBoundSession)
|
||||
} else {
|
||||
result, err = h.geminiCompatService.ForwardNative(requestCtx, c, account, modelName, action, stream, body)
|
||||
@@ -422,7 +444,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
if failoverErr.ForceCacheBilling {
|
||||
if needForceCacheBilling(hasBoundSession, failoverErr) {
|
||||
forceCacheBilling = true
|
||||
}
|
||||
if switchCount >= maxAccountSwitches {
|
||||
@@ -433,6 +455,11 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
lastFailoverErr = failoverErr
|
||||
switchCount++
|
||||
log.Printf("Gemini account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
|
||||
if account.Platform == service.PlatformAntigravity {
|
||||
if !sleepFailoverDelay(c.Request.Context(), switchCount) {
|
||||
return
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
// ForwardNative already wrote the response
|
||||
@@ -453,6 +480,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
geminiDigestChain,
|
||||
geminiSessionUUID,
|
||||
account.ID,
|
||||
matchedDigestChain,
|
||||
); err != nil {
|
||||
log.Printf("[Gemini] Failed to save digest session: %v", err)
|
||||
}
|
||||
@@ -526,6 +554,10 @@ func (h *GatewayHandler) handleGeminiFailoverExhausted(c *gin.Context, failoverE
|
||||
msg = *rule.CustomMessage
|
||||
}
|
||||
|
||||
if rule.SkipMonitoring {
|
||||
c.Set(service.OpsSkipPassthroughKey, true)
|
||||
}
|
||||
|
||||
googleError(c, respCode, msg)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -354,6 +354,10 @@ func (h *OpenAIGatewayHandler) handleFailoverExhausted(c *gin.Context, failoverE
|
||||
msg = *rule.CustomMessage
|
||||
}
|
||||
|
||||
if rule.SkipMonitoring {
|
||||
c.Set(service.OpsSkipPassthroughKey, true)
|
||||
}
|
||||
|
||||
h.handleStreamingAwareError(c, respCode, "upstream_error", msg, streamStarted)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -537,6 +537,13 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
|
||||
// Store request headers/body only when an upstream error occurred to keep overhead minimal.
|
||||
entry.RequestHeadersJSON = extractOpsRetryRequestHeaders(c)
|
||||
|
||||
// Skip logging if a passthrough rule with skip_monitoring=true matched.
|
||||
if v, ok := c.Get(service.OpsSkipPassthroughKey); ok {
|
||||
if skip, _ := v.(bool); skip {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
enqueueOpsErrorLog(ops, entry, requestBody)
|
||||
return
|
||||
}
|
||||
@@ -544,6 +551,13 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
|
||||
body := w.buf.Bytes()
|
||||
parsed := parseOpsErrorResponse(body)
|
||||
|
||||
// Skip logging if a passthrough rule with skip_monitoring=true matched.
|
||||
if v, ok := c.Get(service.OpsSkipPassthroughKey); ok {
|
||||
if skip, _ := v.(bool); skip {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Skip logging if the error should be filtered based on settings
|
||||
if shouldSkipOpsErrorLog(c.Request.Context(), ops, parsed.Message, string(body), c.Request.URL.Path) {
|
||||
return
|
||||
|
||||
@@ -18,6 +18,7 @@ type ErrorPassthroughRule struct {
|
||||
ResponseCode *int `json:"response_code"` // 自定义状态码(passthrough_code=false 时使用)
|
||||
PassthroughBody bool `json:"passthrough_body"` // 是否透传原始错误信息
|
||||
CustomMessage *string `json:"custom_message"` // 自定义错误信息(passthrough_body=false 时使用)
|
||||
SkipMonitoring bool `json:"skip_monitoring"` // 是否跳过运维监控记录
|
||||
Description *string `json:"description"` // 规则描述
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
|
||||
@@ -27,7 +27,7 @@ type ClaudeMessage struct {
|
||||
|
||||
// ThinkingConfig Thinking 配置
|
||||
type ThinkingConfig struct {
|
||||
Type string `json:"type"` // "enabled" or "disabled"
|
||||
Type string `json:"type"` // "enabled" / "adaptive" / "disabled"
|
||||
BudgetTokens int `json:"budget_tokens,omitempty"` // thinking budget
|
||||
}
|
||||
|
||||
|
||||
@@ -115,6 +115,23 @@ type LoadCodeAssistResponse struct {
|
||||
IneligibleTiers []*IneligibleTier `json:"ineligibleTiers,omitempty"`
|
||||
}
|
||||
|
||||
// OnboardUserRequest onboardUser 请求
|
||||
type OnboardUserRequest struct {
|
||||
TierID string `json:"tierId"`
|
||||
Metadata struct {
|
||||
IDEType string `json:"ideType"`
|
||||
Platform string `json:"platform,omitempty"`
|
||||
PluginType string `json:"pluginType,omitempty"`
|
||||
} `json:"metadata"`
|
||||
}
|
||||
|
||||
// OnboardUserResponse onboardUser 响应
|
||||
type OnboardUserResponse struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
Done bool `json:"done"`
|
||||
Response map[string]any `json:"response,omitempty"`
|
||||
}
|
||||
|
||||
// GetTier 获取账户类型
|
||||
// 优先返回 paidTier(付费订阅级别),否则返回 currentTier
|
||||
func (r *LoadCodeAssistResponse) GetTier() string {
|
||||
@@ -361,6 +378,117 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC
|
||||
return nil, nil, lastErr
|
||||
}
|
||||
|
||||
// OnboardUser 触发账号 onboarding,并返回 project_id
|
||||
// 说明:
|
||||
// 1) 部分账号 loadCodeAssist 不会立即返回 cloudaicompanionProject;
|
||||
// 2) 这时需要调用 onboardUser 完成初始化,之后才能拿到 project_id。
|
||||
func (c *Client) OnboardUser(ctx context.Context, accessToken, tierID string) (string, error) {
|
||||
tierID = strings.TrimSpace(tierID)
|
||||
if tierID == "" {
|
||||
return "", fmt.Errorf("tier_id 为空")
|
||||
}
|
||||
|
||||
reqBody := OnboardUserRequest{TierID: tierID}
|
||||
reqBody.Metadata.IDEType = "ANTIGRAVITY"
|
||||
reqBody.Metadata.Platform = "PLATFORM_UNSPECIFIED"
|
||||
reqBody.Metadata.PluginType = "GEMINI"
|
||||
|
||||
bodyBytes, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("序列化请求失败: %w", err)
|
||||
}
|
||||
|
||||
availableURLs := BaseURLs
|
||||
var lastErr error
|
||||
|
||||
for urlIdx, baseURL := range availableURLs {
|
||||
apiURL := baseURL + "/v1internal:onboardUser"
|
||||
|
||||
for attempt := 1; attempt <= 5; attempt++ {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(bodyBytes))
|
||||
if err != nil {
|
||||
lastErr = fmt.Errorf("创建请求失败: %w", err)
|
||||
break
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", UserAgent)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
lastErr = fmt.Errorf("onboardUser 请求失败: %w", err)
|
||||
if shouldFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
|
||||
log.Printf("[antigravity] onboardUser URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1])
|
||||
break
|
||||
}
|
||||
return "", lastErr
|
||||
}
|
||||
|
||||
respBodyBytes, err := io.ReadAll(resp.Body)
|
||||
_ = resp.Body.Close()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("读取响应失败: %w", err)
|
||||
}
|
||||
|
||||
if shouldFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 {
|
||||
log.Printf("[antigravity] onboardUser URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1])
|
||||
break
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
lastErr = fmt.Errorf("onboardUser 失败 (HTTP %d): %s", resp.StatusCode, string(respBodyBytes))
|
||||
return "", lastErr
|
||||
}
|
||||
|
||||
var onboardResp OnboardUserResponse
|
||||
if err := json.Unmarshal(respBodyBytes, &onboardResp); err != nil {
|
||||
lastErr = fmt.Errorf("onboardUser 响应解析失败: %w", err)
|
||||
return "", lastErr
|
||||
}
|
||||
|
||||
if onboardResp.Done {
|
||||
if projectID := extractProjectIDFromOnboardResponse(onboardResp.Response); projectID != "" {
|
||||
DefaultURLAvailability.MarkSuccess(baseURL)
|
||||
return projectID, nil
|
||||
}
|
||||
lastErr = fmt.Errorf("onboardUser 完成但未返回 project_id")
|
||||
return "", lastErr
|
||||
}
|
||||
|
||||
// done=false 时等待后重试(与 CLIProxyAPI 行为一致)
|
||||
select {
|
||||
case <-time.After(2 * time.Second):
|
||||
case <-ctx.Done():
|
||||
return "", ctx.Err()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if lastErr != nil {
|
||||
return "", lastErr
|
||||
}
|
||||
return "", fmt.Errorf("onboardUser 未返回 project_id")
|
||||
}
|
||||
|
||||
func extractProjectIDFromOnboardResponse(resp map[string]any) string {
|
||||
if len(resp) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
if v, ok := resp["cloudaicompanionProject"]; ok {
|
||||
switch project := v.(type) {
|
||||
case string:
|
||||
return strings.TrimSpace(project)
|
||||
case map[string]any:
|
||||
if id, ok := project["id"].(string); ok {
|
||||
return strings.TrimSpace(id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// ModelQuotaInfo 模型配额信息
|
||||
type ModelQuotaInfo struct {
|
||||
RemainingFraction float64 `json:"remainingFraction"`
|
||||
|
||||
76
backend/internal/pkg/antigravity/client_test.go
Normal file
76
backend/internal/pkg/antigravity/client_test.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package antigravity
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestExtractProjectIDFromOnboardResponse(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
resp map[string]any
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "nil response",
|
||||
resp: nil,
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "empty response",
|
||||
resp: map[string]any{},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "project as string",
|
||||
resp: map[string]any{
|
||||
"cloudaicompanionProject": "my-project-123",
|
||||
},
|
||||
want: "my-project-123",
|
||||
},
|
||||
{
|
||||
name: "project as string with spaces",
|
||||
resp: map[string]any{
|
||||
"cloudaicompanionProject": " my-project-123 ",
|
||||
},
|
||||
want: "my-project-123",
|
||||
},
|
||||
{
|
||||
name: "project as map with id",
|
||||
resp: map[string]any{
|
||||
"cloudaicompanionProject": map[string]any{
|
||||
"id": "proj-from-map",
|
||||
},
|
||||
},
|
||||
want: "proj-from-map",
|
||||
},
|
||||
{
|
||||
name: "project as map without id",
|
||||
resp: map[string]any{
|
||||
"cloudaicompanionProject": map[string]any{
|
||||
"name": "some-name",
|
||||
},
|
||||
},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "missing cloudaicompanionProject key",
|
||||
resp: map[string]any{
|
||||
"otherField": "value",
|
||||
},
|
||||
want: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := extractProjectIDFromOnboardResponse(tc.resp)
|
||||
if got != tc.want {
|
||||
t.Fatalf("extractProjectIDFromOnboardResponse() = %q, want %q", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -155,6 +155,7 @@ type GeminiUsageMetadata struct {
|
||||
CandidatesTokenCount int `json:"candidatesTokenCount,omitempty"`
|
||||
CachedContentTokenCount int `json:"cachedContentTokenCount,omitempty"`
|
||||
TotalTokenCount int `json:"totalTokenCount,omitempty"`
|
||||
ThoughtsTokenCount int `json:"thoughtsTokenCount,omitempty"` // thinking tokens(按输出价格计费)
|
||||
}
|
||||
|
||||
// GeminiGroundingMetadata Gemini grounding 元数据(Web Search)
|
||||
|
||||
@@ -64,6 +64,10 @@ const MaxTokensBudgetPadding = 1000
|
||||
// Gemini 2.5 Flash thinking budget 上限
|
||||
const Gemini25FlashThinkingBudgetLimit = 24576
|
||||
|
||||
// 对于 Antigravity 的 Claude(budget-only)模型,该语义最终等价为 thinkingBudget=24576。
|
||||
// 这里复用相同数值以保持行为一致。
|
||||
const ClaudeAdaptiveHighThinkingBudgetTokens = Gemini25FlashThinkingBudgetLimit
|
||||
|
||||
// ensureMaxTokensGreaterThanBudget 确保 max_tokens > budget_tokens
|
||||
// Claude API 要求启用 thinking 时,max_tokens 必须大于 thinking.budget_tokens
|
||||
// 返回调整后的 maxTokens 和是否进行了调整
|
||||
@@ -96,7 +100,7 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map
|
||||
}
|
||||
|
||||
// 检测是否启用 thinking
|
||||
isThinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled"
|
||||
isThinkingEnabled := claudeReq.Thinking != nil && (claudeReq.Thinking.Type == "enabled" || claudeReq.Thinking.Type == "adaptive")
|
||||
|
||||
// 只有 Gemini 模型支持 dummy thought workaround
|
||||
// Claude 模型通过 Vertex/Google API 需要有效的 thought signatures
|
||||
@@ -198,8 +202,7 @@ type modelInfo struct {
|
||||
|
||||
// modelInfoMap 模型前缀 → 模型信息映射
|
||||
// 只有在此映射表中的模型才会注入身份提示词
|
||||
// 注意:当前 claude-opus-4-6 会被映射到 claude-opus-4-5-thinking,
|
||||
// 但保留此条目以便后续 Antigravity 上游支持 4.6 时快速切换
|
||||
// 注意:模型映射逻辑在网关层完成;这里仅用于按模型前缀判断是否注入身份提示词。
|
||||
var modelInfoMap = map[string]modelInfo{
|
||||
"claude-opus-4-5": {DisplayName: "Claude Opus 4.5", CanonicalID: "claude-opus-4-5-20250929"},
|
||||
"claude-opus-4-6": {DisplayName: "Claude Opus 4.6", CanonicalID: "claude-opus-4-6"},
|
||||
@@ -271,6 +274,21 @@ func filterOpenCodePrompt(text string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// systemBlockFilterPrefixes 需要从 system 中过滤的文本前缀列表
|
||||
var systemBlockFilterPrefixes = []string{
|
||||
"x-anthropic-billing-header",
|
||||
}
|
||||
|
||||
// filterSystemBlockByPrefix 如果文本匹配过滤前缀,返回空字符串
|
||||
func filterSystemBlockByPrefix(text string) string {
|
||||
for _, prefix := range systemBlockFilterPrefixes {
|
||||
if strings.HasPrefix(text, prefix) {
|
||||
return ""
|
||||
}
|
||||
}
|
||||
return text
|
||||
}
|
||||
|
||||
// buildSystemInstruction 构建 systemInstruction(与 Antigravity-Manager 保持一致)
|
||||
func buildSystemInstruction(system json.RawMessage, modelName string, opts TransformOptions, tools []ClaudeTool) *GeminiContent {
|
||||
var parts []GeminiPart
|
||||
@@ -287,8 +305,8 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans
|
||||
if strings.Contains(sysStr, "You are Antigravity") {
|
||||
userHasAntigravityIdentity = true
|
||||
}
|
||||
// 过滤 OpenCode 默认提示词
|
||||
filtered := filterOpenCodePrompt(sysStr)
|
||||
// 过滤 OpenCode 默认提示词和黑名单前缀
|
||||
filtered := filterSystemBlockByPrefix(filterOpenCodePrompt(sysStr))
|
||||
if filtered != "" {
|
||||
userSystemParts = append(userSystemParts, GeminiPart{Text: filtered})
|
||||
}
|
||||
@@ -302,8 +320,8 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans
|
||||
if strings.Contains(block.Text, "You are Antigravity") {
|
||||
userHasAntigravityIdentity = true
|
||||
}
|
||||
// 过滤 OpenCode 默认提示词
|
||||
filtered := filterOpenCodePrompt(block.Text)
|
||||
// 过滤 OpenCode 默认提示词和黑名单前缀
|
||||
filtered := filterSystemBlockByPrefix(filterOpenCodePrompt(block.Text))
|
||||
if filtered != "" {
|
||||
userSystemParts = append(userSystemParts, GeminiPart{Text: filtered})
|
||||
}
|
||||
@@ -578,6 +596,10 @@ func maxOutputTokensLimit(model string) int {
|
||||
return maxOutputTokensUpperBound
|
||||
}
|
||||
|
||||
func isAntigravityOpus46Model(model string) bool {
|
||||
return strings.HasPrefix(strings.ToLower(model), "claude-opus-4-6")
|
||||
}
|
||||
|
||||
func buildGenerationConfig(req *ClaudeRequest) *GeminiGenerationConfig {
|
||||
maxLimit := maxOutputTokensLimit(req.Model)
|
||||
config := &GeminiGenerationConfig{
|
||||
@@ -591,25 +613,36 @@ func buildGenerationConfig(req *ClaudeRequest) *GeminiGenerationConfig {
|
||||
}
|
||||
|
||||
// Thinking 配置
|
||||
if req.Thinking != nil && req.Thinking.Type == "enabled" {
|
||||
if req.Thinking != nil && (req.Thinking.Type == "enabled" || req.Thinking.Type == "adaptive") {
|
||||
config.ThinkingConfig = &GeminiThinkingConfig{
|
||||
IncludeThoughts: true,
|
||||
}
|
||||
|
||||
// - thinking.type=enabled:budget_tokens>0 用显式预算
|
||||
// - thinking.type=adaptive:仅在 Antigravity 的 Opus 4.6 上覆写为 (24576)
|
||||
budget := -1
|
||||
if req.Thinking.BudgetTokens > 0 {
|
||||
budget := req.Thinking.BudgetTokens
|
||||
budget = req.Thinking.BudgetTokens
|
||||
}
|
||||
if req.Thinking.Type == "adaptive" && isAntigravityOpus46Model(req.Model) {
|
||||
budget = ClaudeAdaptiveHighThinkingBudgetTokens
|
||||
}
|
||||
|
||||
// 正预算需要做上限与 max_tokens 约束;动态预算(-1)直接透传给上游。
|
||||
if budget > 0 {
|
||||
// gemini-2.5-flash 上限
|
||||
if strings.Contains(req.Model, "gemini-2.5-flash") && budget > Gemini25FlashThinkingBudgetLimit {
|
||||
budget = Gemini25FlashThinkingBudgetLimit
|
||||
}
|
||||
config.ThinkingConfig.ThinkingBudget = budget
|
||||
|
||||
// 自动修正:max_tokens 必须大于 budget_tokens
|
||||
// 自动修正:max_tokens 必须大于 budget_tokens(Claude 上游要求)
|
||||
if adjusted, ok := ensureMaxTokensGreaterThanBudget(config.MaxOutputTokens, budget); ok {
|
||||
log.Printf("[Antigravity] Auto-adjusted max_tokens from %d to %d (must be > budget_tokens=%d)",
|
||||
config.MaxOutputTokens, adjusted, budget)
|
||||
config.MaxOutputTokens = adjusted
|
||||
}
|
||||
}
|
||||
config.ThinkingConfig.ThinkingBudget = budget
|
||||
}
|
||||
|
||||
if config.MaxOutputTokens > maxLimit {
|
||||
|
||||
@@ -259,3 +259,93 @@ func TestBuildTools_CustomTypeTools(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildGenerationConfig_ThinkingDynamicBudget(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
thinking *ThinkingConfig
|
||||
wantBudget int
|
||||
wantPresent bool
|
||||
}{
|
||||
{
|
||||
name: "enabled without budget defaults to dynamic (-1)",
|
||||
model: "claude-opus-4-6-thinking",
|
||||
thinking: &ThinkingConfig{Type: "enabled"},
|
||||
wantBudget: -1,
|
||||
wantPresent: true,
|
||||
},
|
||||
{
|
||||
name: "enabled with budget uses the provided value",
|
||||
model: "claude-opus-4-6-thinking",
|
||||
thinking: &ThinkingConfig{Type: "enabled", BudgetTokens: 1024},
|
||||
wantBudget: 1024,
|
||||
wantPresent: true,
|
||||
},
|
||||
{
|
||||
name: "enabled with -1 budget uses dynamic (-1)",
|
||||
model: "claude-opus-4-6-thinking",
|
||||
thinking: &ThinkingConfig{Type: "enabled", BudgetTokens: -1},
|
||||
wantBudget: -1,
|
||||
wantPresent: true,
|
||||
},
|
||||
{
|
||||
name: "adaptive on opus4.6 maps to high budget (24576)",
|
||||
model: "claude-opus-4-6-thinking",
|
||||
thinking: &ThinkingConfig{Type: "adaptive", BudgetTokens: 20000},
|
||||
wantBudget: ClaudeAdaptiveHighThinkingBudgetTokens,
|
||||
wantPresent: true,
|
||||
},
|
||||
{
|
||||
name: "adaptive on non-opus model keeps default dynamic (-1)",
|
||||
model: "claude-sonnet-4-5-thinking",
|
||||
thinking: &ThinkingConfig{Type: "adaptive"},
|
||||
wantBudget: -1,
|
||||
wantPresent: true,
|
||||
},
|
||||
{
|
||||
name: "disabled does not emit thinkingConfig",
|
||||
model: "claude-opus-4-6-thinking",
|
||||
thinking: &ThinkingConfig{Type: "disabled", BudgetTokens: 1024},
|
||||
wantBudget: 0,
|
||||
wantPresent: false,
|
||||
},
|
||||
{
|
||||
name: "nil thinking does not emit thinkingConfig",
|
||||
model: "claude-opus-4-6-thinking",
|
||||
thinking: nil,
|
||||
wantBudget: 0,
|
||||
wantPresent: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := &ClaudeRequest{
|
||||
Model: tt.model,
|
||||
Thinking: tt.thinking,
|
||||
}
|
||||
cfg := buildGenerationConfig(req)
|
||||
if cfg == nil {
|
||||
t.Fatalf("expected non-nil generationConfig")
|
||||
}
|
||||
|
||||
if tt.wantPresent {
|
||||
if cfg.ThinkingConfig == nil {
|
||||
t.Fatalf("expected thinkingConfig to be present")
|
||||
}
|
||||
if !cfg.ThinkingConfig.IncludeThoughts {
|
||||
t.Fatalf("expected includeThoughts=true")
|
||||
}
|
||||
if cfg.ThinkingConfig.ThinkingBudget != tt.wantBudget {
|
||||
t.Fatalf("expected thinkingBudget=%d, got %d", tt.wantBudget, cfg.ThinkingConfig.ThinkingBudget)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if cfg.ThinkingConfig != nil {
|
||||
t.Fatalf("expected thinkingConfig to be nil, got %+v", cfg.ThinkingConfig)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -279,7 +279,7 @@ func (p *NonStreamingProcessor) buildResponse(geminiResp *GeminiResponse, respon
|
||||
if geminiResp.UsageMetadata != nil {
|
||||
cached := geminiResp.UsageMetadata.CachedContentTokenCount
|
||||
usage.InputTokens = geminiResp.UsageMetadata.PromptTokenCount - cached
|
||||
usage.OutputTokens = geminiResp.UsageMetadata.CandidatesTokenCount
|
||||
usage.OutputTokens = geminiResp.UsageMetadata.CandidatesTokenCount + geminiResp.UsageMetadata.ThoughtsTokenCount
|
||||
usage.CacheReadInputTokens = cached
|
||||
}
|
||||
|
||||
|
||||
@@ -85,7 +85,7 @@ func (p *StreamingProcessor) ProcessLine(line string) []byte {
|
||||
if geminiResp.UsageMetadata != nil {
|
||||
cached := geminiResp.UsageMetadata.CachedContentTokenCount
|
||||
p.inputTokens = geminiResp.UsageMetadata.PromptTokenCount - cached
|
||||
p.outputTokens = geminiResp.UsageMetadata.CandidatesTokenCount
|
||||
p.outputTokens = geminiResp.UsageMetadata.CandidatesTokenCount + geminiResp.UsageMetadata.ThoughtsTokenCount
|
||||
p.cacheReadTokens = cached
|
||||
}
|
||||
|
||||
@@ -146,7 +146,7 @@ func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte
|
||||
if v1Resp.Response.UsageMetadata != nil {
|
||||
cached := v1Resp.Response.UsageMetadata.CachedContentTokenCount
|
||||
usage.InputTokens = v1Resp.Response.UsageMetadata.PromptTokenCount - cached
|
||||
usage.OutputTokens = v1Resp.Response.UsageMetadata.CandidatesTokenCount
|
||||
usage.OutputTokens = v1Resp.Response.UsageMetadata.CandidatesTokenCount + v1Resp.Response.UsageMetadata.ThoughtsTokenCount
|
||||
usage.CacheReadInputTokens = cached
|
||||
}
|
||||
|
||||
|
||||
@@ -28,4 +28,8 @@ const (
|
||||
// IsMaxTokensOneHaikuRequest 标识当前请求是否为 max_tokens=1 + haiku 模型的探测请求
|
||||
// 用于 ClaudeCodeOnly 验证绕过(绕过 system prompt 检查,但仍需验证 User-Agent)
|
||||
IsMaxTokensOneHaikuRequest Key = "ctx_is_max_tokens_one_haiku"
|
||||
|
||||
// SingleAccountRetry 标识当前请求处于单账号 503 退避重试模式。
|
||||
// 在此模式下,Service 层的模型限流预检查将等待限流过期而非直接切换账号。
|
||||
SingleAccountRetry Key = "ctx_single_account_retry"
|
||||
)
|
||||
|
||||
@@ -282,6 +282,34 @@ func (r *accountRepository) GetByCRSAccountID(ctx context.Context, crsAccountID
|
||||
return &accounts[0], nil
|
||||
}
|
||||
|
||||
func (r *accountRepository) ListCRSAccountIDs(ctx context.Context) (map[string]int64, error) {
|
||||
rows, err := r.sql.QueryContext(ctx, `
|
||||
SELECT id, extra->>'crs_account_id'
|
||||
FROM accounts
|
||||
WHERE deleted_at IS NULL
|
||||
AND extra->>'crs_account_id' IS NOT NULL
|
||||
AND extra->>'crs_account_id' != ''
|
||||
`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
result := make(map[string]int64)
|
||||
for rows.Next() {
|
||||
var id int64
|
||||
var crsID string
|
||||
if err := rows.Scan(&id, &crsID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result[crsID] = id
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *accountRepository) Update(ctx context.Context, account *service.Account) error {
|
||||
if account == nil {
|
||||
return nil
|
||||
@@ -407,10 +435,10 @@ func (r *accountRepository) Delete(ctx context.Context, id int64) error {
|
||||
}
|
||||
|
||||
func (r *accountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) {
|
||||
return r.ListWithFilters(ctx, params, "", "", "", "")
|
||||
return r.ListWithFilters(ctx, params, "", "", "", "", 0)
|
||||
}
|
||||
|
||||
func (r *accountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]service.Account, *pagination.PaginationResult, error) {
|
||||
func (r *accountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]service.Account, *pagination.PaginationResult, error) {
|
||||
q := r.client.Account.Query()
|
||||
|
||||
if platform != "" {
|
||||
@@ -420,11 +448,19 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati
|
||||
q = q.Where(dbaccount.TypeEQ(accountType))
|
||||
}
|
||||
if status != "" {
|
||||
q = q.Where(dbaccount.StatusEQ(status))
|
||||
switch status {
|
||||
case "rate_limited":
|
||||
q = q.Where(dbaccount.RateLimitResetAtGT(time.Now()))
|
||||
default:
|
||||
q = q.Where(dbaccount.StatusEQ(status))
|
||||
}
|
||||
}
|
||||
if search != "" {
|
||||
q = q.Where(dbaccount.NameContainsFold(search))
|
||||
}
|
||||
if groupID > 0 {
|
||||
q = q.Where(dbaccount.HasAccountGroupsWith(dbaccountgroup.GroupIDEQ(groupID)))
|
||||
}
|
||||
|
||||
total, err := q.Count(ctx)
|
||||
if err != nil {
|
||||
@@ -798,53 +834,6 @@ func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetA
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *accountRepository) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope service.AntigravityQuotaScope, resetAt time.Time) error {
|
||||
now := time.Now().UTC()
|
||||
payload := map[string]string{
|
||||
"rate_limited_at": now.Format(time.RFC3339),
|
||||
"rate_limit_reset_at": resetAt.UTC().Format(time.RFC3339),
|
||||
}
|
||||
raw, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
scopeKey := string(scope)
|
||||
client := clientFromContext(ctx, r.client)
|
||||
result, err := client.ExecContext(
|
||||
ctx,
|
||||
`UPDATE accounts SET
|
||||
extra = jsonb_set(
|
||||
jsonb_set(COALESCE(extra, '{}'::jsonb), '{antigravity_quota_scopes}'::text[], COALESCE(extra->'antigravity_quota_scopes', '{}'::jsonb), true),
|
||||
ARRAY['antigravity_quota_scopes', $1]::text[],
|
||||
$2::jsonb,
|
||||
true
|
||||
),
|
||||
updated_at = NOW(),
|
||||
last_used_at = NOW()
|
||||
WHERE id = $3 AND deleted_at IS NULL`,
|
||||
scopeKey,
|
||||
raw,
|
||||
id,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
affected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if affected == 0 {
|
||||
return service.ErrAccountNotFound
|
||||
}
|
||||
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue quota scope failed: account=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *accountRepository) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
|
||||
if scope == "" {
|
||||
return nil
|
||||
|
||||
@@ -238,7 +238,7 @@ func (s *AccountRepoSuite) TestListWithFilters() {
|
||||
|
||||
tt.setup(client)
|
||||
|
||||
accounts, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, tt.platform, tt.accType, tt.status, tt.search)
|
||||
accounts, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, tt.platform, tt.accType, tt.status, tt.search, 0)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(accounts, tt.wantCount)
|
||||
if tt.validate != nil {
|
||||
@@ -305,7 +305,7 @@ func (s *AccountRepoSuite) TestPreload_And_VirtualFields() {
|
||||
s.Require().Len(got.Groups, 1, "expected Groups to be populated")
|
||||
s.Require().Equal(group.ID, got.Groups[0].ID)
|
||||
|
||||
accounts, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "", "acc")
|
||||
accounts, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "", "acc", 0)
|
||||
s.Require().NoError(err, "ListWithFilters")
|
||||
s.Require().Equal(int64(1), page.Total)
|
||||
s.Require().Len(accounts, 1)
|
||||
|
||||
@@ -485,6 +485,7 @@ func groupEntityToService(g *dbent.Group) *service.Group {
|
||||
ModelRoutingEnabled: g.ModelRoutingEnabled,
|
||||
MCPXMLInject: g.McpXMLInject,
|
||||
SupportedModelScopes: g.SupportedModelScopes,
|
||||
SortOrder: g.SortOrder,
|
||||
CreatedAt: g.CreatedAt,
|
||||
UpdatedAt: g.UpdatedAt,
|
||||
}
|
||||
|
||||
@@ -54,7 +54,8 @@ func (r *errorPassthroughRepository) Create(ctx context.Context, rule *model.Err
|
||||
SetPriority(rule.Priority).
|
||||
SetMatchMode(rule.MatchMode).
|
||||
SetPassthroughCode(rule.PassthroughCode).
|
||||
SetPassthroughBody(rule.PassthroughBody)
|
||||
SetPassthroughBody(rule.PassthroughBody).
|
||||
SetSkipMonitoring(rule.SkipMonitoring)
|
||||
|
||||
if len(rule.ErrorCodes) > 0 {
|
||||
builder.SetErrorCodes(rule.ErrorCodes)
|
||||
@@ -90,7 +91,8 @@ func (r *errorPassthroughRepository) Update(ctx context.Context, rule *model.Err
|
||||
SetPriority(rule.Priority).
|
||||
SetMatchMode(rule.MatchMode).
|
||||
SetPassthroughCode(rule.PassthroughCode).
|
||||
SetPassthroughBody(rule.PassthroughBody)
|
||||
SetPassthroughBody(rule.PassthroughBody).
|
||||
SetSkipMonitoring(rule.SkipMonitoring)
|
||||
|
||||
// 处理可选字段
|
||||
if len(rule.ErrorCodes) > 0 {
|
||||
@@ -149,6 +151,7 @@ func (r *errorPassthroughRepository) toModel(e *ent.ErrorPassthroughRule) *model
|
||||
Platforms: e.Platforms,
|
||||
PassthroughCode: e.PassthroughCode,
|
||||
PassthroughBody: e.PassthroughBody,
|
||||
SkipMonitoring: e.SkipMonitoring,
|
||||
CreatedAt: e.CreatedAt,
|
||||
UpdatedAt: e.UpdatedAt,
|
||||
}
|
||||
|
||||
@@ -11,63 +11,6 @@ import (
|
||||
|
||||
const stickySessionPrefix = "sticky_session:"
|
||||
|
||||
// Gemini Trie Lua 脚本
|
||||
const (
|
||||
// geminiTrieFindScript 查找最长前缀匹配的 Lua 脚本
|
||||
// KEYS[1] = trie key
|
||||
// ARGV[1] = digestChain (如 "u:a-m:b-u:c-m:d")
|
||||
// ARGV[2] = TTL seconds (用于刷新)
|
||||
// 返回: 最长匹配的 value (uuid:accountID) 或 nil
|
||||
// 查找成功时自动刷新 TTL,防止活跃会话意外过期
|
||||
geminiTrieFindScript = `
|
||||
local chain = ARGV[1]
|
||||
local ttl = tonumber(ARGV[2])
|
||||
local lastMatch = nil
|
||||
local path = ""
|
||||
|
||||
for part in string.gmatch(chain, "[^-]+") do
|
||||
path = path == "" and part or path .. "-" .. part
|
||||
local val = redis.call('HGET', KEYS[1], path)
|
||||
if val and val ~= "" then
|
||||
lastMatch = val
|
||||
end
|
||||
end
|
||||
|
||||
if lastMatch then
|
||||
redis.call('EXPIRE', KEYS[1], ttl)
|
||||
end
|
||||
|
||||
return lastMatch
|
||||
`
|
||||
|
||||
// geminiTrieSaveScript 保存会话到 Trie 的 Lua 脚本
|
||||
// KEYS[1] = trie key
|
||||
// ARGV[1] = digestChain
|
||||
// ARGV[2] = value (uuid:accountID)
|
||||
// ARGV[3] = TTL seconds
|
||||
geminiTrieSaveScript = `
|
||||
local chain = ARGV[1]
|
||||
local value = ARGV[2]
|
||||
local ttl = tonumber(ARGV[3])
|
||||
local path = ""
|
||||
|
||||
for part in string.gmatch(chain, "[^-]+") do
|
||||
path = path == "" and part or path .. "-" .. part
|
||||
end
|
||||
redis.call('HSET', KEYS[1], path, value)
|
||||
redis.call('EXPIRE', KEYS[1], ttl)
|
||||
return "OK"
|
||||
`
|
||||
)
|
||||
|
||||
// 模型负载统计相关常量
|
||||
const (
|
||||
modelLoadKeyPrefix = "ag:model_load:" // 模型调用次数 key 前缀
|
||||
modelLastUsedKeyPrefix = "ag:model_last_used:" // 模型最后调度时间 key 前缀
|
||||
modelLoadTTL = 24 * time.Hour // 调用次数 TTL(24 小时无调用后清零)
|
||||
modelLastUsedTTL = 24 * time.Hour // 最后调度时间 TTL
|
||||
)
|
||||
|
||||
type gatewayCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
@@ -108,133 +51,3 @@ func (c *gatewayCache) DeleteSessionAccountID(ctx context.Context, groupID int64
|
||||
key := buildSessionKey(groupID, sessionHash)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
// ============ Antigravity 模型负载统计方法 ============
|
||||
|
||||
// modelLoadKey 构建模型调用次数 key
|
||||
// 格式: ag:model_load:{accountID}:{model}
|
||||
func modelLoadKey(accountID int64, model string) string {
|
||||
return fmt.Sprintf("%s%d:%s", modelLoadKeyPrefix, accountID, model)
|
||||
}
|
||||
|
||||
// modelLastUsedKey 构建模型最后调度时间 key
|
||||
// 格式: ag:model_last_used:{accountID}:{model}
|
||||
func modelLastUsedKey(accountID int64, model string) string {
|
||||
return fmt.Sprintf("%s%d:%s", modelLastUsedKeyPrefix, accountID, model)
|
||||
}
|
||||
|
||||
// IncrModelCallCount 增加模型调用次数并更新最后调度时间
|
||||
// 返回更新后的调用次数
|
||||
func (c *gatewayCache) IncrModelCallCount(ctx context.Context, accountID int64, model string) (int64, error) {
|
||||
loadKey := modelLoadKey(accountID, model)
|
||||
lastUsedKey := modelLastUsedKey(accountID, model)
|
||||
|
||||
pipe := c.rdb.Pipeline()
|
||||
incrCmd := pipe.Incr(ctx, loadKey)
|
||||
pipe.Expire(ctx, loadKey, modelLoadTTL) // 每次调用刷新 TTL
|
||||
pipe.Set(ctx, lastUsedKey, time.Now().Unix(), modelLastUsedTTL)
|
||||
if _, err := pipe.Exec(ctx); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return incrCmd.Val(), nil
|
||||
}
|
||||
|
||||
// GetModelLoadBatch 批量获取账号的模型负载信息
|
||||
func (c *gatewayCache) GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*service.ModelLoadInfo, error) {
|
||||
if len(accountIDs) == 0 {
|
||||
return make(map[int64]*service.ModelLoadInfo), nil
|
||||
}
|
||||
|
||||
loadCmds, lastUsedCmds := c.pipelineModelLoadGet(ctx, accountIDs, model)
|
||||
return c.parseModelLoadResults(accountIDs, loadCmds, lastUsedCmds), nil
|
||||
}
|
||||
|
||||
// pipelineModelLoadGet 批量获取模型负载的 Pipeline 操作
|
||||
func (c *gatewayCache) pipelineModelLoadGet(
|
||||
ctx context.Context,
|
||||
accountIDs []int64,
|
||||
model string,
|
||||
) (map[int64]*redis.StringCmd, map[int64]*redis.StringCmd) {
|
||||
pipe := c.rdb.Pipeline()
|
||||
loadCmds := make(map[int64]*redis.StringCmd, len(accountIDs))
|
||||
lastUsedCmds := make(map[int64]*redis.StringCmd, len(accountIDs))
|
||||
|
||||
for _, id := range accountIDs {
|
||||
loadCmds[id] = pipe.Get(ctx, modelLoadKey(id, model))
|
||||
lastUsedCmds[id] = pipe.Get(ctx, modelLastUsedKey(id, model))
|
||||
}
|
||||
_, _ = pipe.Exec(ctx) // 忽略错误,key 不存在是正常的
|
||||
return loadCmds, lastUsedCmds
|
||||
}
|
||||
|
||||
// parseModelLoadResults 解析 Pipeline 结果
|
||||
func (c *gatewayCache) parseModelLoadResults(
|
||||
accountIDs []int64,
|
||||
loadCmds map[int64]*redis.StringCmd,
|
||||
lastUsedCmds map[int64]*redis.StringCmd,
|
||||
) map[int64]*service.ModelLoadInfo {
|
||||
result := make(map[int64]*service.ModelLoadInfo, len(accountIDs))
|
||||
for _, id := range accountIDs {
|
||||
result[id] = &service.ModelLoadInfo{
|
||||
CallCount: getInt64OrZero(loadCmds[id]),
|
||||
LastUsedAt: getTimeOrZero(lastUsedCmds[id]),
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// getInt64OrZero 从 StringCmd 获取 int64 值,失败返回 0
|
||||
func getInt64OrZero(cmd *redis.StringCmd) int64 {
|
||||
val, _ := cmd.Int64()
|
||||
return val
|
||||
}
|
||||
|
||||
// getTimeOrZero 从 StringCmd 获取 time.Time,失败返回零值
|
||||
func getTimeOrZero(cmd *redis.StringCmd) time.Time {
|
||||
val, err := cmd.Int64()
|
||||
if err != nil {
|
||||
return time.Time{}
|
||||
}
|
||||
return time.Unix(val, 0)
|
||||
}
|
||||
|
||||
// ============ Gemini 会话 Fallback 方法 (Trie 实现) ============
|
||||
|
||||
// FindGeminiSession 查找 Gemini 会话(使用 Trie + Lua 脚本实现 O(L) 查询)
|
||||
// 返回最长匹配的会话信息,匹配成功时自动刷新 TTL
|
||||
func (c *gatewayCache) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
|
||||
if digestChain == "" {
|
||||
return "", 0, false
|
||||
}
|
||||
|
||||
trieKey := service.BuildGeminiTrieKey(groupID, prefixHash)
|
||||
ttlSeconds := int(service.GeminiSessionTTL().Seconds())
|
||||
|
||||
// 使用 Lua 脚本在 Redis 端执行 Trie 查找,O(L) 次 HGET,1 次网络往返
|
||||
// 查找成功时自动刷新 TTL,防止活跃会话意外过期
|
||||
result, err := c.rdb.Eval(ctx, geminiTrieFindScript, []string{trieKey}, digestChain, ttlSeconds).Result()
|
||||
if err != nil || result == nil {
|
||||
return "", 0, false
|
||||
}
|
||||
|
||||
value, ok := result.(string)
|
||||
if !ok || value == "" {
|
||||
return "", 0, false
|
||||
}
|
||||
|
||||
uuid, accountID, ok = service.ParseGeminiSessionValue(value)
|
||||
return uuid, accountID, ok
|
||||
}
|
||||
|
||||
// SaveGeminiSession 保存 Gemini 会话(使用 Trie + Lua 脚本)
|
||||
func (c *gatewayCache) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
|
||||
if digestChain == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
trieKey := service.BuildGeminiTrieKey(groupID, prefixHash)
|
||||
value := service.FormatGeminiSessionValue(uuid, accountID)
|
||||
ttlSeconds := int(service.GeminiSessionTTL().Seconds())
|
||||
|
||||
return c.rdb.Eval(ctx, geminiTrieSaveScript, []string{trieKey}, digestChain, value, ttlSeconds).Err()
|
||||
}
|
||||
|
||||
@@ -104,157 +104,6 @@ func (s *GatewayCacheSuite) TestGetSessionAccountID_CorruptedValue() {
|
||||
require.False(s.T(), errors.Is(err, redis.Nil), "expected parsing error, not redis.Nil")
|
||||
}
|
||||
|
||||
// ============ Gemini Trie 会话测试 ============
|
||||
|
||||
func (s *GatewayCacheSuite) TestGeminiSessionTrie_SaveAndFind() {
|
||||
groupID := int64(1)
|
||||
prefixHash := "testprefix"
|
||||
digestChain := "u:hash1-m:hash2-u:hash3"
|
||||
uuid := "test-uuid-123"
|
||||
accountID := int64(42)
|
||||
|
||||
// 保存会话
|
||||
err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, digestChain, uuid, accountID)
|
||||
require.NoError(s.T(), err, "SaveGeminiSession")
|
||||
|
||||
// 精确匹配查找
|
||||
foundUUID, foundAccountID, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, digestChain)
|
||||
require.True(s.T(), found, "should find exact match")
|
||||
require.Equal(s.T(), uuid, foundUUID)
|
||||
require.Equal(s.T(), accountID, foundAccountID)
|
||||
}
|
||||
|
||||
func (s *GatewayCacheSuite) TestGeminiSessionTrie_PrefixMatch() {
|
||||
groupID := int64(1)
|
||||
prefixHash := "prefixmatch"
|
||||
shortChain := "u:a-m:b"
|
||||
longChain := "u:a-m:b-u:c-m:d"
|
||||
uuid := "uuid-prefix"
|
||||
accountID := int64(100)
|
||||
|
||||
// 保存短链
|
||||
err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, shortChain, uuid, accountID)
|
||||
require.NoError(s.T(), err)
|
||||
|
||||
// 用长链查找,应该匹配到短链(前缀匹配)
|
||||
foundUUID, foundAccountID, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, longChain)
|
||||
require.True(s.T(), found, "should find prefix match")
|
||||
require.Equal(s.T(), uuid, foundUUID)
|
||||
require.Equal(s.T(), accountID, foundAccountID)
|
||||
}
|
||||
|
||||
func (s *GatewayCacheSuite) TestGeminiSessionTrie_LongestPrefixMatch() {
|
||||
groupID := int64(1)
|
||||
prefixHash := "longestmatch"
|
||||
|
||||
// 保存多个不同长度的链
|
||||
err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, "u:a", "uuid-short", 1)
|
||||
require.NoError(s.T(), err)
|
||||
err = s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, "u:a-m:b", "uuid-medium", 2)
|
||||
require.NoError(s.T(), err)
|
||||
err = s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, "u:a-m:b-u:c", "uuid-long", 3)
|
||||
require.NoError(s.T(), err)
|
||||
|
||||
// 查找更长的链,应该匹配到最长的前缀
|
||||
foundUUID, foundAccountID, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, "u:a-m:b-u:c-m:d-u:e")
|
||||
require.True(s.T(), found, "should find longest prefix match")
|
||||
require.Equal(s.T(), "uuid-long", foundUUID)
|
||||
require.Equal(s.T(), int64(3), foundAccountID)
|
||||
|
||||
// 查找中等长度的链
|
||||
foundUUID, foundAccountID, found = s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, "u:a-m:b-u:x")
|
||||
require.True(s.T(), found)
|
||||
require.Equal(s.T(), "uuid-medium", foundUUID)
|
||||
require.Equal(s.T(), int64(2), foundAccountID)
|
||||
}
|
||||
|
||||
func (s *GatewayCacheSuite) TestGeminiSessionTrie_NoMatch() {
|
||||
groupID := int64(1)
|
||||
prefixHash := "nomatch"
|
||||
digestChain := "u:a-m:b"
|
||||
|
||||
// 保存一个会话
|
||||
err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, digestChain, "uuid", 1)
|
||||
require.NoError(s.T(), err)
|
||||
|
||||
// 用不同的链查找,应该找不到
|
||||
_, _, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, "u:x-m:y")
|
||||
require.False(s.T(), found, "should not find non-matching chain")
|
||||
}
|
||||
|
||||
func (s *GatewayCacheSuite) TestGeminiSessionTrie_DifferentPrefixHash() {
|
||||
groupID := int64(1)
|
||||
digestChain := "u:a-m:b"
|
||||
|
||||
// 保存到 prefixHash1
|
||||
err := s.cache.SaveGeminiSession(s.ctx, groupID, "prefix1", digestChain, "uuid1", 1)
|
||||
require.NoError(s.T(), err)
|
||||
|
||||
// 用 prefixHash2 查找,应该找不到(不同用户/客户端隔离)
|
||||
_, _, found := s.cache.FindGeminiSession(s.ctx, groupID, "prefix2", digestChain)
|
||||
require.False(s.T(), found, "different prefixHash should be isolated")
|
||||
}
|
||||
|
||||
func (s *GatewayCacheSuite) TestGeminiSessionTrie_DifferentGroupID() {
|
||||
prefixHash := "sameprefix"
|
||||
digestChain := "u:a-m:b"
|
||||
|
||||
// 保存到 groupID 1
|
||||
err := s.cache.SaveGeminiSession(s.ctx, 1, prefixHash, digestChain, "uuid1", 1)
|
||||
require.NoError(s.T(), err)
|
||||
|
||||
// 用 groupID 2 查找,应该找不到(分组隔离)
|
||||
_, _, found := s.cache.FindGeminiSession(s.ctx, 2, prefixHash, digestChain)
|
||||
require.False(s.T(), found, "different groupID should be isolated")
|
||||
}
|
||||
|
||||
func (s *GatewayCacheSuite) TestGeminiSessionTrie_EmptyDigestChain() {
|
||||
groupID := int64(1)
|
||||
prefixHash := "emptytest"
|
||||
|
||||
// 空链不应该保存
|
||||
err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, "", "uuid", 1)
|
||||
require.NoError(s.T(), err, "empty chain should not error")
|
||||
|
||||
// 空链查找应该返回 false
|
||||
_, _, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, "")
|
||||
require.False(s.T(), found, "empty chain should not match")
|
||||
}
|
||||
|
||||
func (s *GatewayCacheSuite) TestGeminiSessionTrie_MultipleSessions() {
|
||||
groupID := int64(1)
|
||||
prefixHash := "multisession"
|
||||
|
||||
// 保存多个不同会话(模拟 1000 个并发会话的场景)
|
||||
sessions := []struct {
|
||||
chain string
|
||||
uuid string
|
||||
accountID int64
|
||||
}{
|
||||
{"u:session1", "uuid-1", 1},
|
||||
{"u:session2-m:reply2", "uuid-2", 2},
|
||||
{"u:session3-m:reply3-u:msg3", "uuid-3", 3},
|
||||
}
|
||||
|
||||
for _, sess := range sessions {
|
||||
err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, sess.chain, sess.uuid, sess.accountID)
|
||||
require.NoError(s.T(), err)
|
||||
}
|
||||
|
||||
// 验证每个会话都能正确查找
|
||||
for _, sess := range sessions {
|
||||
foundUUID, foundAccountID, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, sess.chain)
|
||||
require.True(s.T(), found, "should find session: %s", sess.chain)
|
||||
require.Equal(s.T(), sess.uuid, foundUUID)
|
||||
require.Equal(s.T(), sess.accountID, foundAccountID)
|
||||
}
|
||||
|
||||
// 验证继续对话的场景
|
||||
foundUUID, foundAccountID, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, "u:session2-m:reply2-u:newmsg")
|
||||
require.True(s.T(), found)
|
||||
require.Equal(s.T(), "uuid-2", foundUUID)
|
||||
require.Equal(s.T(), int64(2), foundAccountID)
|
||||
}
|
||||
|
||||
func TestGatewayCacheSuite(t *testing.T) {
|
||||
suite.Run(t, new(GatewayCacheSuite))
|
||||
|
||||
@@ -1,234 +0,0 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
// ============ Gateway Cache 模型负载统计集成测试 ============
|
||||
|
||||
type GatewayCacheModelLoadSuite struct {
|
||||
suite.Suite
|
||||
}
|
||||
|
||||
func TestGatewayCacheModelLoadSuite(t *testing.T) {
|
||||
suite.Run(t, new(GatewayCacheModelLoadSuite))
|
||||
}
|
||||
|
||||
func (s *GatewayCacheModelLoadSuite) TestIncrModelCallCount_Basic() {
|
||||
t := s.T()
|
||||
rdb := testRedis(t)
|
||||
cache := &gatewayCache{rdb: rdb}
|
||||
ctx := context.Background()
|
||||
|
||||
accountID := int64(123)
|
||||
model := "claude-sonnet-4-20250514"
|
||||
|
||||
// 首次调用应返回 1
|
||||
count1, err := cache.IncrModelCallCount(ctx, accountID, model)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1), count1)
|
||||
|
||||
// 第二次调用应返回 2
|
||||
count2, err := cache.IncrModelCallCount(ctx, accountID, model)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(2), count2)
|
||||
|
||||
// 第三次调用应返回 3
|
||||
count3, err := cache.IncrModelCallCount(ctx, accountID, model)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(3), count3)
|
||||
}
|
||||
|
||||
func (s *GatewayCacheModelLoadSuite) TestIncrModelCallCount_DifferentModels() {
|
||||
t := s.T()
|
||||
rdb := testRedis(t)
|
||||
cache := &gatewayCache{rdb: rdb}
|
||||
ctx := context.Background()
|
||||
|
||||
accountID := int64(456)
|
||||
model1 := "claude-sonnet-4-20250514"
|
||||
model2 := "claude-opus-4-5-20251101"
|
||||
|
||||
// 不同模型应该独立计数
|
||||
count1, err := cache.IncrModelCallCount(ctx, accountID, model1)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1), count1)
|
||||
|
||||
count2, err := cache.IncrModelCallCount(ctx, accountID, model2)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1), count2)
|
||||
|
||||
count1Again, err := cache.IncrModelCallCount(ctx, accountID, model1)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(2), count1Again)
|
||||
}
|
||||
|
||||
func (s *GatewayCacheModelLoadSuite) TestIncrModelCallCount_DifferentAccounts() {
|
||||
t := s.T()
|
||||
rdb := testRedis(t)
|
||||
cache := &gatewayCache{rdb: rdb}
|
||||
ctx := context.Background()
|
||||
|
||||
account1 := int64(111)
|
||||
account2 := int64(222)
|
||||
model := "gemini-2.5-pro"
|
||||
|
||||
// 不同账号应该独立计数
|
||||
count1, err := cache.IncrModelCallCount(ctx, account1, model)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1), count1)
|
||||
|
||||
count2, err := cache.IncrModelCallCount(ctx, account2, model)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1), count2)
|
||||
}
|
||||
|
||||
func (s *GatewayCacheModelLoadSuite) TestGetModelLoadBatch_Empty() {
|
||||
t := s.T()
|
||||
rdb := testRedis(t)
|
||||
cache := &gatewayCache{rdb: rdb}
|
||||
ctx := context.Background()
|
||||
|
||||
result, err := cache.GetModelLoadBatch(ctx, []int64{}, "any-model")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Empty(t, result)
|
||||
}
|
||||
|
||||
func (s *GatewayCacheModelLoadSuite) TestGetModelLoadBatch_NonExistent() {
|
||||
t := s.T()
|
||||
rdb := testRedis(t)
|
||||
cache := &gatewayCache{rdb: rdb}
|
||||
ctx := context.Background()
|
||||
|
||||
// 查询不存在的账号应返回零值
|
||||
result, err := cache.GetModelLoadBatch(ctx, []int64{9999, 9998}, "claude-sonnet-4-20250514")
|
||||
require.NoError(t, err)
|
||||
require.Len(t, result, 2)
|
||||
|
||||
require.Equal(t, int64(0), result[9999].CallCount)
|
||||
require.True(t, result[9999].LastUsedAt.IsZero())
|
||||
require.Equal(t, int64(0), result[9998].CallCount)
|
||||
require.True(t, result[9998].LastUsedAt.IsZero())
|
||||
}
|
||||
|
||||
func (s *GatewayCacheModelLoadSuite) TestGetModelLoadBatch_AfterIncrement() {
|
||||
t := s.T()
|
||||
rdb := testRedis(t)
|
||||
cache := &gatewayCache{rdb: rdb}
|
||||
ctx := context.Background()
|
||||
|
||||
accountID := int64(789)
|
||||
model := "claude-sonnet-4-20250514"
|
||||
|
||||
// 先增加调用次数
|
||||
beforeIncr := time.Now()
|
||||
_, err := cache.IncrModelCallCount(ctx, accountID, model)
|
||||
require.NoError(t, err)
|
||||
_, err = cache.IncrModelCallCount(ctx, accountID, model)
|
||||
require.NoError(t, err)
|
||||
_, err = cache.IncrModelCallCount(ctx, accountID, model)
|
||||
require.NoError(t, err)
|
||||
afterIncr := time.Now()
|
||||
|
||||
// 获取负载信息
|
||||
result, err := cache.GetModelLoadBatch(ctx, []int64{accountID}, model)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, result, 1)
|
||||
|
||||
loadInfo := result[accountID]
|
||||
require.NotNil(t, loadInfo)
|
||||
require.Equal(t, int64(3), loadInfo.CallCount)
|
||||
require.False(t, loadInfo.LastUsedAt.IsZero())
|
||||
// LastUsedAt 应该在 beforeIncr 和 afterIncr 之间
|
||||
require.True(t, loadInfo.LastUsedAt.After(beforeIncr.Add(-time.Second)) || loadInfo.LastUsedAt.Equal(beforeIncr))
|
||||
require.True(t, loadInfo.LastUsedAt.Before(afterIncr.Add(time.Second)) || loadInfo.LastUsedAt.Equal(afterIncr))
|
||||
}
|
||||
|
||||
func (s *GatewayCacheModelLoadSuite) TestGetModelLoadBatch_MultipleAccounts() {
|
||||
t := s.T()
|
||||
rdb := testRedis(t)
|
||||
cache := &gatewayCache{rdb: rdb}
|
||||
ctx := context.Background()
|
||||
|
||||
model := "claude-opus-4-5-20251101"
|
||||
account1 := int64(1001)
|
||||
account2 := int64(1002)
|
||||
account3 := int64(1003) // 不调用
|
||||
|
||||
// account1 调用 2 次
|
||||
_, err := cache.IncrModelCallCount(ctx, account1, model)
|
||||
require.NoError(t, err)
|
||||
_, err = cache.IncrModelCallCount(ctx, account1, model)
|
||||
require.NoError(t, err)
|
||||
|
||||
// account2 调用 5 次
|
||||
for i := 0; i < 5; i++ {
|
||||
_, err = cache.IncrModelCallCount(ctx, account2, model)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// 批量获取
|
||||
result, err := cache.GetModelLoadBatch(ctx, []int64{account1, account2, account3}, model)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, result, 3)
|
||||
|
||||
require.Equal(t, int64(2), result[account1].CallCount)
|
||||
require.False(t, result[account1].LastUsedAt.IsZero())
|
||||
|
||||
require.Equal(t, int64(5), result[account2].CallCount)
|
||||
require.False(t, result[account2].LastUsedAt.IsZero())
|
||||
|
||||
require.Equal(t, int64(0), result[account3].CallCount)
|
||||
require.True(t, result[account3].LastUsedAt.IsZero())
|
||||
}
|
||||
|
||||
func (s *GatewayCacheModelLoadSuite) TestGetModelLoadBatch_ModelIsolation() {
|
||||
t := s.T()
|
||||
rdb := testRedis(t)
|
||||
cache := &gatewayCache{rdb: rdb}
|
||||
ctx := context.Background()
|
||||
|
||||
accountID := int64(2001)
|
||||
model1 := "claude-sonnet-4-20250514"
|
||||
model2 := "gemini-2.5-pro"
|
||||
|
||||
// 对 model1 调用 3 次
|
||||
for i := 0; i < 3; i++ {
|
||||
_, err := cache.IncrModelCallCount(ctx, accountID, model1)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// 获取 model1 的负载
|
||||
result1, err := cache.GetModelLoadBatch(ctx, []int64{accountID}, model1)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(3), result1[accountID].CallCount)
|
||||
|
||||
// 获取 model2 的负载(应该为 0)
|
||||
result2, err := cache.GetModelLoadBatch(ctx, []int64{accountID}, model2)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(0), result2[accountID].CallCount)
|
||||
}
|
||||
|
||||
// ============ 辅助函数测试 ============
|
||||
|
||||
func (s *GatewayCacheModelLoadSuite) TestModelLoadKey_Format() {
|
||||
t := s.T()
|
||||
|
||||
key := modelLoadKey(123, "claude-sonnet-4")
|
||||
require.Equal(t, "ag:model_load:123:claude-sonnet-4", key)
|
||||
}
|
||||
|
||||
func (s *GatewayCacheModelLoadSuite) TestModelLastUsedKey_Format() {
|
||||
t := s.T()
|
||||
|
||||
key := modelLastUsedKey(456, "gemini-2.5-pro")
|
||||
require.Equal(t, "ag:model_last_used:456:gemini-2.5-pro", key)
|
||||
}
|
||||
@@ -191,7 +191,7 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination
|
||||
groups, err := q.
|
||||
Offset(params.Offset()).
|
||||
Limit(params.Limit()).
|
||||
Order(dbent.Asc(group.FieldID)).
|
||||
Order(dbent.Asc(group.FieldSortOrder), dbent.Asc(group.FieldID)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
@@ -218,7 +218,7 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination
|
||||
func (r *groupRepository) ListActive(ctx context.Context) ([]service.Group, error) {
|
||||
groups, err := r.client.Group.Query().
|
||||
Where(group.StatusEQ(service.StatusActive)).
|
||||
Order(dbent.Asc(group.FieldID)).
|
||||
Order(dbent.Asc(group.FieldSortOrder), dbent.Asc(group.FieldID)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -245,7 +245,7 @@ func (r *groupRepository) ListActive(ctx context.Context) ([]service.Group, erro
|
||||
func (r *groupRepository) ListActiveByPlatform(ctx context.Context, platform string) ([]service.Group, error) {
|
||||
groups, err := r.client.Group.Query().
|
||||
Where(group.StatusEQ(service.StatusActive), group.PlatformEQ(platform)).
|
||||
Order(dbent.Asc(group.FieldID)).
|
||||
Order(dbent.Asc(group.FieldSortOrder), dbent.Asc(group.FieldID)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -497,3 +497,29 @@ func (r *groupRepository) BindAccountsToGroup(ctx context.Context, groupID int64
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateSortOrders 批量更新分组排序
|
||||
func (r *groupRepository) UpdateSortOrders(ctx context.Context, updates []service.GroupSortOrderUpdate) error {
|
||||
if len(updates) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 使用事务批量更新
|
||||
tx, err := r.client.Tx(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
|
||||
for _, u := range updates {
|
||||
if _, err := tx.Group.UpdateOneID(u.ID).SetSortOrder(u.SortOrder).Save(ctx); err != nil {
|
||||
return translatePersistenceError(err, service.ErrGroupNotFound, nil)
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
|
||||
"github.com/Wei-Shaw/sub2api/ent/user"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
@@ -106,7 +107,12 @@ func (r *redeemCodeRepository) ListWithFilters(ctx context.Context, params pagin
|
||||
q = q.Where(redeemcode.StatusEQ(status))
|
||||
}
|
||||
if search != "" {
|
||||
q = q.Where(redeemcode.CodeContainsFold(search))
|
||||
q = q.Where(
|
||||
redeemcode.Or(
|
||||
redeemcode.CodeContainsFold(search),
|
||||
redeemcode.HasUserWith(user.EmailContainsFold(search)),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
total, err := q.Count(ctx)
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/apikey"
|
||||
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
|
||||
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
|
||||
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
|
||||
@@ -191,6 +192,7 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
|
||||
dbuser.EmailContainsFold(filters.Search),
|
||||
dbuser.UsernameContainsFold(filters.Search),
|
||||
dbuser.NotesContainsFold(filters.Search),
|
||||
dbuser.HasAPIKeysWith(apikey.KeyContainsFold(filters.Search)),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -896,6 +896,10 @@ func (stubGroupRepo) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (stubGroupRepo) UpdateSortOrders(ctx context.Context, updates []service.GroupSortOrderUpdate) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type stubAccountRepo struct {
|
||||
bulkUpdateIDs []int64
|
||||
}
|
||||
@@ -932,7 +936,7 @@ func (s *stubAccountRepo) List(ctx context.Context, params pagination.Pagination
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (s *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]service.Account, *pagination.PaginationResult, error) {
|
||||
func (s *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]service.Account, *pagination.PaginationResult, error) {
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
@@ -1004,10 +1008,6 @@ func (s *stubAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (s *stubAccountRepo) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope service.AntigravityQuotaScope, resetAt time.Time) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (s *stubAccountRepo) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
@@ -1049,6 +1049,10 @@ func (s *stubAccountRepo) BulkUpdate(ctx context.Context, ids []int64, updates s
|
||||
return int64(len(ids)), nil
|
||||
}
|
||||
|
||||
func (s *stubAccountRepo) ListCRSAccountIDs(ctx context.Context) (map[string]int64, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
type stubProxyRepo struct{}
|
||||
|
||||
func (stubProxyRepo) Create(ctx context.Context, proxy *service.Proxy) error {
|
||||
|
||||
@@ -192,6 +192,7 @@ func registerGroupRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
{
|
||||
groups.GET("", h.Admin.Group.List)
|
||||
groups.GET("/all", h.Admin.Group.GetAll)
|
||||
groups.PUT("/sort-order", h.Admin.Group.UpdateSortOrder)
|
||||
groups.GET("/:id", h.Admin.Group.GetByID)
|
||||
groups.POST("", h.Admin.Group.Create)
|
||||
groups.PUT("/:id", h.Admin.Group.Update)
|
||||
@@ -208,6 +209,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
accounts.GET("/:id", h.Admin.Account.GetByID)
|
||||
accounts.POST("", h.Admin.Account.Create)
|
||||
accounts.POST("/sync/crs", h.Admin.Account.SyncFromCRS)
|
||||
accounts.POST("/sync/crs/preview", h.Admin.Account.PreviewFromCRS)
|
||||
accounts.PUT("/:id", h.Admin.Account.Update)
|
||||
accounts.DELETE("/:id", h.Admin.Account.Delete)
|
||||
accounts.POST("/:id/test", h.Admin.Account.Test)
|
||||
@@ -279,6 +281,7 @@ func registerAntigravityOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers)
|
||||
{
|
||||
antigravity.POST("/oauth/auth-url", h.Admin.AntigravityOAuth.GenerateAuthURL)
|
||||
antigravity.POST("/oauth/exchange-code", h.Admin.AntigravityOAuth.ExchangeCode)
|
||||
antigravity.POST("/oauth/refresh-token", h.Admin.AntigravityOAuth.RefreshToken)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -425,6 +425,22 @@ func (a *Account) GetBaseURL() string {
|
||||
if baseURL == "" {
|
||||
return "https://api.anthropic.com"
|
||||
}
|
||||
if a.Platform == PlatformAntigravity {
|
||||
return strings.TrimRight(baseURL, "/") + "/antigravity"
|
||||
}
|
||||
return baseURL
|
||||
}
|
||||
|
||||
// GetGeminiBaseURL 返回 Gemini 兼容端点的 base URL。
|
||||
// Antigravity 平台的 APIKey 账号自动拼接 /antigravity。
|
||||
func (a *Account) GetGeminiBaseURL(defaultBaseURL string) string {
|
||||
baseURL := strings.TrimSpace(a.GetCredential("base_url"))
|
||||
if baseURL == "" {
|
||||
return defaultBaseURL
|
||||
}
|
||||
if a.Platform == PlatformAntigravity && a.Type == AccountTypeAPIKey {
|
||||
return strings.TrimRight(baseURL, "/") + "/antigravity"
|
||||
}
|
||||
return baseURL
|
||||
}
|
||||
|
||||
|
||||
160
backend/internal/service/account_base_url_test.go
Normal file
160
backend/internal/service/account_base_url_test.go
Normal file
@@ -0,0 +1,160 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGetBaseURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
account Account
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "non-apikey type returns empty",
|
||||
account: Account{
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAnthropic,
|
||||
},
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "apikey without base_url returns default anthropic",
|
||||
account: Account{
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformAnthropic,
|
||||
Credentials: map[string]any{},
|
||||
},
|
||||
expected: "https://api.anthropic.com",
|
||||
},
|
||||
{
|
||||
name: "apikey with custom base_url",
|
||||
account: Account{
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformAnthropic,
|
||||
Credentials: map[string]any{"base_url": "https://custom.example.com"},
|
||||
},
|
||||
expected: "https://custom.example.com",
|
||||
},
|
||||
{
|
||||
name: "antigravity apikey auto-appends /antigravity",
|
||||
account: Account{
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformAntigravity,
|
||||
Credentials: map[string]any{"base_url": "https://upstream.example.com"},
|
||||
},
|
||||
expected: "https://upstream.example.com/antigravity",
|
||||
},
|
||||
{
|
||||
name: "antigravity apikey trims trailing slash before appending",
|
||||
account: Account{
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformAntigravity,
|
||||
Credentials: map[string]any{"base_url": "https://upstream.example.com/"},
|
||||
},
|
||||
expected: "https://upstream.example.com/antigravity",
|
||||
},
|
||||
{
|
||||
name: "antigravity non-apikey returns empty",
|
||||
account: Account{
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
Credentials: map[string]any{"base_url": "https://upstream.example.com"},
|
||||
},
|
||||
expected: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := tt.account.GetBaseURL()
|
||||
if result != tt.expected {
|
||||
t.Errorf("GetBaseURL() = %q, want %q", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetGeminiBaseURL(t *testing.T) {
|
||||
const defaultGeminiURL = "https://generativelanguage.googleapis.com"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
account Account
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "apikey without base_url returns default",
|
||||
account: Account{
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformGemini,
|
||||
Credentials: map[string]any{},
|
||||
},
|
||||
expected: defaultGeminiURL,
|
||||
},
|
||||
{
|
||||
name: "apikey with custom base_url",
|
||||
account: Account{
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformGemini,
|
||||
Credentials: map[string]any{"base_url": "https://custom-gemini.example.com"},
|
||||
},
|
||||
expected: "https://custom-gemini.example.com",
|
||||
},
|
||||
{
|
||||
name: "antigravity apikey auto-appends /antigravity",
|
||||
account: Account{
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformAntigravity,
|
||||
Credentials: map[string]any{"base_url": "https://upstream.example.com"},
|
||||
},
|
||||
expected: "https://upstream.example.com/antigravity",
|
||||
},
|
||||
{
|
||||
name: "antigravity apikey trims trailing slash",
|
||||
account: Account{
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformAntigravity,
|
||||
Credentials: map[string]any{"base_url": "https://upstream.example.com/"},
|
||||
},
|
||||
expected: "https://upstream.example.com/antigravity",
|
||||
},
|
||||
{
|
||||
name: "antigravity oauth does NOT append /antigravity",
|
||||
account: Account{
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
Credentials: map[string]any{"base_url": "https://upstream.example.com"},
|
||||
},
|
||||
expected: "https://upstream.example.com",
|
||||
},
|
||||
{
|
||||
name: "oauth without base_url returns default",
|
||||
account: Account{
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
Credentials: map[string]any{},
|
||||
},
|
||||
expected: defaultGeminiURL,
|
||||
},
|
||||
{
|
||||
name: "nil credentials returns default",
|
||||
account: Account{
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformGemini,
|
||||
},
|
||||
expected: defaultGeminiURL,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := tt.account.GetGeminiBaseURL(defaultGeminiURL)
|
||||
if result != tt.expected {
|
||||
t.Errorf("GetGeminiBaseURL() = %q, want %q", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -25,11 +25,14 @@ type AccountRepository interface {
|
||||
// GetByCRSAccountID finds an account previously synced from CRS.
|
||||
// Returns (nil, nil) if not found.
|
||||
GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error)
|
||||
// ListCRSAccountIDs returns a map of crs_account_id -> local account ID
|
||||
// for all accounts that have been synced from CRS.
|
||||
ListCRSAccountIDs(ctx context.Context) (map[string]int64, error)
|
||||
Update(ctx context.Context, account *Account) error
|
||||
Delete(ctx context.Context, id int64) error
|
||||
|
||||
List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error)
|
||||
ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error)
|
||||
ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error)
|
||||
ListByGroup(ctx context.Context, groupID int64) ([]Account, error)
|
||||
ListActive(ctx context.Context) ([]Account, error)
|
||||
ListByPlatform(ctx context.Context, platform string) ([]Account, error)
|
||||
@@ -50,7 +53,6 @@ type AccountRepository interface {
|
||||
ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error)
|
||||
|
||||
SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error
|
||||
SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error
|
||||
SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error
|
||||
SetOverloaded(ctx context.Context, id int64, until time.Time) error
|
||||
SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error
|
||||
|
||||
@@ -54,6 +54,10 @@ func (s *accountRepoStub) GetByCRSAccountID(ctx context.Context, crsAccountID st
|
||||
panic("unexpected GetByCRSAccountID call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) ListCRSAccountIDs(ctx context.Context) (map[string]int64, error) {
|
||||
panic("unexpected ListCRSAccountIDs call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) Update(ctx context.Context, account *Account) error {
|
||||
panic("unexpected Update call")
|
||||
}
|
||||
@@ -71,7 +75,7 @@ func (s *accountRepoStub) List(ctx context.Context, params pagination.Pagination
|
||||
panic("unexpected List call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) {
|
||||
func (s *accountRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListWithFilters call")
|
||||
}
|
||||
|
||||
@@ -143,10 +147,6 @@ func (s *accountRepoStub) SetRateLimited(ctx context.Context, id int64, resetAt
|
||||
panic("unexpected SetRateLimited call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error {
|
||||
panic("unexpected SetAntigravityQuotaScopeLimit call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
|
||||
panic("unexpected SetModelRateLimit call")
|
||||
}
|
||||
|
||||
@@ -245,7 +245,6 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
|
||||
// Set common headers
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("anthropic-version", "2023-06-01")
|
||||
req.Header.Set("anthropic-beta", claude.DefaultBetaHeader)
|
||||
|
||||
// Apply Claude Code client headers
|
||||
for key, value := range claude.DefaultHeaders {
|
||||
@@ -254,8 +253,10 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
|
||||
|
||||
// Set authentication header
|
||||
if useBearer {
|
||||
req.Header.Set("anthropic-beta", claude.DefaultBetaHeader)
|
||||
req.Header.Set("Authorization", "Bearer "+authToken)
|
||||
} else {
|
||||
req.Header.Set("anthropic-beta", claude.APIKeyBetaHeader)
|
||||
req.Header.Set("x-api-key", authToken)
|
||||
}
|
||||
|
||||
|
||||
@@ -36,9 +36,10 @@ type AdminService interface {
|
||||
UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error)
|
||||
DeleteGroup(ctx context.Context, id int64) error
|
||||
GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]APIKey, int64, error)
|
||||
UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error
|
||||
|
||||
// Account management
|
||||
ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error)
|
||||
ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]Account, int64, error)
|
||||
GetAccount(ctx context.Context, id int64) (*Account, error)
|
||||
GetAccountsByIDs(ctx context.Context, ids []int64) ([]*Account, error)
|
||||
CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error)
|
||||
@@ -1015,10 +1016,14 @@ func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, p
|
||||
return keys, result.Total, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error {
|
||||
return s.groupRepo.UpdateSortOrders(ctx, updates)
|
||||
}
|
||||
|
||||
// Account management implementations
|
||||
func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error) {
|
||||
func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]Account, int64, error) {
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search)
|
||||
accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search, groupID)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
@@ -172,6 +172,10 @@ func (s *groupRepoStub) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []
|
||||
panic("unexpected GetAccountIDsByGroupIDs call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStub) UpdateSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type proxyRepoStub struct {
|
||||
deleteErr error
|
||||
countErr error
|
||||
|
||||
@@ -116,6 +116,10 @@ func (s *groupRepoStubForAdmin) GetAccountIDsByGroupIDs(_ context.Context, _ []i
|
||||
panic("unexpected GetAccountIDsByGroupIDs call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) UpdateSortOrders(_ context.Context, _ []GroupSortOrderUpdate) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TestAdminService_CreateGroup_WithImagePricing 测试创建分组时 ImagePrice 字段正确传递
|
||||
func TestAdminService_CreateGroup_WithImagePricing(t *testing.T) {
|
||||
repo := &groupRepoStubForAdmin{}
|
||||
@@ -395,6 +399,10 @@ func (s *groupRepoStubForFallbackCycle) GetAccountIDsByGroupIDs(_ context.Contex
|
||||
panic("unexpected GetAccountIDsByGroupIDs call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) UpdateSortOrders(_ context.Context, _ []GroupSortOrderUpdate) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type groupRepoStubForInvalidRequestFallback struct {
|
||||
groups map[int64]*Group
|
||||
created *Group
|
||||
@@ -466,6 +474,10 @@ func (s *groupRepoStubForInvalidRequestFallback) BindAccountsToGroup(_ context.C
|
||||
panic("unexpected BindAccountsToGroup call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForInvalidRequestFallback) UpdateSortOrders(_ context.Context, _ []GroupSortOrderUpdate) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsUnsupportedPlatform(t *testing.T) {
|
||||
fallbackID := int64(10)
|
||||
repo := &groupRepoStubForInvalidRequestFallback{
|
||||
|
||||
@@ -24,7 +24,7 @@ type accountRepoStubForAdminList struct {
|
||||
listWithFiltersErr error
|
||||
}
|
||||
|
||||
func (s *accountRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) {
|
||||
func (s *accountRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) {
|
||||
s.listWithFiltersCalls++
|
||||
s.listWithFiltersParams = params
|
||||
s.listWithFiltersPlatform = platform
|
||||
@@ -168,7 +168,7 @@ func TestAdminService_ListAccounts_WithSearch(t *testing.T) {
|
||||
}
|
||||
svc := &adminServiceImpl{accountRepo: repo}
|
||||
|
||||
accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformGemini, AccountTypeOAuth, StatusActive, "acc")
|
||||
accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformGemini, AccountTypeOAuth, StatusActive, "acc", 0)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(10), total)
|
||||
require.Equal(t, []Account{{ID: 1, Name: "acc"}}, accounts)
|
||||
|
||||
79
backend/internal/service/anthropic_session.go
Normal file
79
backend/internal/service/anthropic_session.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Anthropic 会话 Fallback 相关常量
|
||||
const (
|
||||
// anthropicSessionTTLSeconds Anthropic 会话缓存 TTL(5 分钟)
|
||||
anthropicSessionTTLSeconds = 300
|
||||
|
||||
// anthropicDigestSessionKeyPrefix Anthropic 摘要 fallback 会话 key 前缀
|
||||
anthropicDigestSessionKeyPrefix = "anthropic:digest:"
|
||||
)
|
||||
|
||||
// AnthropicSessionTTL 返回 Anthropic 会话缓存 TTL
|
||||
func AnthropicSessionTTL() time.Duration {
|
||||
return anthropicSessionTTLSeconds * time.Second
|
||||
}
|
||||
|
||||
// BuildAnthropicDigestChain 根据 Anthropic 请求生成摘要链
|
||||
// 格式: s:<hash>-u:<hash>-a:<hash>-u:<hash>-...
|
||||
// s = system, u = user, a = assistant
|
||||
func BuildAnthropicDigestChain(parsed *ParsedRequest) string {
|
||||
if parsed == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
var parts []string
|
||||
|
||||
// 1. system prompt
|
||||
if parsed.System != nil {
|
||||
systemData, _ := json.Marshal(parsed.System)
|
||||
if len(systemData) > 0 && string(systemData) != "null" {
|
||||
parts = append(parts, "s:"+shortHash(systemData))
|
||||
}
|
||||
}
|
||||
|
||||
// 2. messages
|
||||
for _, msg := range parsed.Messages {
|
||||
msgMap, ok := msg.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
role, _ := msgMap["role"].(string)
|
||||
prefix := rolePrefix(role)
|
||||
content := msgMap["content"]
|
||||
contentData, _ := json.Marshal(content)
|
||||
parts = append(parts, prefix+":"+shortHash(contentData))
|
||||
}
|
||||
|
||||
return strings.Join(parts, "-")
|
||||
}
|
||||
|
||||
// rolePrefix 将 Anthropic 的 role 映射为单字符前缀
|
||||
func rolePrefix(role string) string {
|
||||
switch role {
|
||||
case "assistant":
|
||||
return "a"
|
||||
default:
|
||||
return "u"
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateAnthropicDigestSessionKey 生成 Anthropic 摘要 fallback 的 sessionKey
|
||||
// 组合 prefixHash 前 8 位 + uuid 前 8 位,确保不同会话产生不同的 sessionKey
|
||||
func GenerateAnthropicDigestSessionKey(prefixHash, uuid string) string {
|
||||
prefix := prefixHash
|
||||
if len(prefixHash) >= 8 {
|
||||
prefix = prefixHash[:8]
|
||||
}
|
||||
uuidPart := uuid
|
||||
if len(uuid) >= 8 {
|
||||
uuidPart = uuid[:8]
|
||||
}
|
||||
return anthropicDigestSessionKeyPrefix + prefix + ":" + uuidPart
|
||||
}
|
||||
320
backend/internal/service/anthropic_session_test.go
Normal file
320
backend/internal/service/anthropic_session_test.go
Normal file
@@ -0,0 +1,320 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestBuildAnthropicDigestChain_NilRequest(t *testing.T) {
|
||||
result := BuildAnthropicDigestChain(nil)
|
||||
if result != "" {
|
||||
t.Errorf("expected empty string for nil request, got: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAnthropicDigestChain_EmptyMessages(t *testing.T) {
|
||||
parsed := &ParsedRequest{
|
||||
Messages: []any{},
|
||||
}
|
||||
result := BuildAnthropicDigestChain(parsed)
|
||||
if result != "" {
|
||||
t.Errorf("expected empty string for empty messages, got: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAnthropicDigestChain_SingleUserMessage(t *testing.T) {
|
||||
parsed := &ParsedRequest{
|
||||
Messages: []any{
|
||||
map[string]any{"role": "user", "content": "hello"},
|
||||
},
|
||||
}
|
||||
result := BuildAnthropicDigestChain(parsed)
|
||||
parts := splitChain(result)
|
||||
if len(parts) != 1 {
|
||||
t.Fatalf("expected 1 part, got %d: %s", len(parts), result)
|
||||
}
|
||||
if !strings.HasPrefix(parts[0], "u:") {
|
||||
t.Errorf("expected prefix 'u:', got: %s", parts[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAnthropicDigestChain_UserAndAssistant(t *testing.T) {
|
||||
parsed := &ParsedRequest{
|
||||
Messages: []any{
|
||||
map[string]any{"role": "user", "content": "hello"},
|
||||
map[string]any{"role": "assistant", "content": "hi there"},
|
||||
},
|
||||
}
|
||||
result := BuildAnthropicDigestChain(parsed)
|
||||
parts := splitChain(result)
|
||||
if len(parts) != 2 {
|
||||
t.Fatalf("expected 2 parts, got %d: %s", len(parts), result)
|
||||
}
|
||||
if !strings.HasPrefix(parts[0], "u:") {
|
||||
t.Errorf("part[0] expected prefix 'u:', got: %s", parts[0])
|
||||
}
|
||||
if !strings.HasPrefix(parts[1], "a:") {
|
||||
t.Errorf("part[1] expected prefix 'a:', got: %s", parts[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAnthropicDigestChain_WithSystemString(t *testing.T) {
|
||||
parsed := &ParsedRequest{
|
||||
System: "You are a helpful assistant",
|
||||
Messages: []any{
|
||||
map[string]any{"role": "user", "content": "hello"},
|
||||
},
|
||||
}
|
||||
result := BuildAnthropicDigestChain(parsed)
|
||||
parts := splitChain(result)
|
||||
if len(parts) != 2 {
|
||||
t.Fatalf("expected 2 parts (s + u), got %d: %s", len(parts), result)
|
||||
}
|
||||
if !strings.HasPrefix(parts[0], "s:") {
|
||||
t.Errorf("part[0] expected prefix 's:', got: %s", parts[0])
|
||||
}
|
||||
if !strings.HasPrefix(parts[1], "u:") {
|
||||
t.Errorf("part[1] expected prefix 'u:', got: %s", parts[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAnthropicDigestChain_WithSystemContentBlocks(t *testing.T) {
|
||||
parsed := &ParsedRequest{
|
||||
System: []any{
|
||||
map[string]any{"type": "text", "text": "You are a helpful assistant"},
|
||||
},
|
||||
Messages: []any{
|
||||
map[string]any{"role": "user", "content": "hello"},
|
||||
},
|
||||
}
|
||||
result := BuildAnthropicDigestChain(parsed)
|
||||
parts := splitChain(result)
|
||||
if len(parts) != 2 {
|
||||
t.Fatalf("expected 2 parts (s + u), got %d: %s", len(parts), result)
|
||||
}
|
||||
if !strings.HasPrefix(parts[0], "s:") {
|
||||
t.Errorf("part[0] expected prefix 's:', got: %s", parts[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAnthropicDigestChain_ConversationPrefixRelationship(t *testing.T) {
|
||||
// 核心测试:验证对话增长时链的前缀关系
|
||||
// 上一轮的完整链一定是下一轮链的前缀
|
||||
system := "You are a helpful assistant"
|
||||
|
||||
// 第 1 轮: system + user
|
||||
round1 := &ParsedRequest{
|
||||
System: system,
|
||||
Messages: []any{
|
||||
map[string]any{"role": "user", "content": "hello"},
|
||||
},
|
||||
}
|
||||
chain1 := BuildAnthropicDigestChain(round1)
|
||||
|
||||
// 第 2 轮: system + user + assistant + user
|
||||
round2 := &ParsedRequest{
|
||||
System: system,
|
||||
Messages: []any{
|
||||
map[string]any{"role": "user", "content": "hello"},
|
||||
map[string]any{"role": "assistant", "content": "hi there"},
|
||||
map[string]any{"role": "user", "content": "how are you?"},
|
||||
},
|
||||
}
|
||||
chain2 := BuildAnthropicDigestChain(round2)
|
||||
|
||||
// 第 3 轮: system + user + assistant + user + assistant + user
|
||||
round3 := &ParsedRequest{
|
||||
System: system,
|
||||
Messages: []any{
|
||||
map[string]any{"role": "user", "content": "hello"},
|
||||
map[string]any{"role": "assistant", "content": "hi there"},
|
||||
map[string]any{"role": "user", "content": "how are you?"},
|
||||
map[string]any{"role": "assistant", "content": "I'm doing well"},
|
||||
map[string]any{"role": "user", "content": "great"},
|
||||
},
|
||||
}
|
||||
chain3 := BuildAnthropicDigestChain(round3)
|
||||
|
||||
t.Logf("Chain1: %s", chain1)
|
||||
t.Logf("Chain2: %s", chain2)
|
||||
t.Logf("Chain3: %s", chain3)
|
||||
|
||||
// chain1 是 chain2 的前缀
|
||||
if !strings.HasPrefix(chain2, chain1) {
|
||||
t.Errorf("chain1 should be prefix of chain2:\n chain1: %s\n chain2: %s", chain1, chain2)
|
||||
}
|
||||
|
||||
// chain2 是 chain3 的前缀
|
||||
if !strings.HasPrefix(chain3, chain2) {
|
||||
t.Errorf("chain2 should be prefix of chain3:\n chain2: %s\n chain3: %s", chain2, chain3)
|
||||
}
|
||||
|
||||
// chain1 也是 chain3 的前缀(传递性)
|
||||
if !strings.HasPrefix(chain3, chain1) {
|
||||
t.Errorf("chain1 should be prefix of chain3:\n chain1: %s\n chain3: %s", chain1, chain3)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAnthropicDigestChain_DifferentSystemProducesDifferentChain(t *testing.T) {
|
||||
parsed1 := &ParsedRequest{
|
||||
System: "System A",
|
||||
Messages: []any{
|
||||
map[string]any{"role": "user", "content": "hello"},
|
||||
},
|
||||
}
|
||||
parsed2 := &ParsedRequest{
|
||||
System: "System B",
|
||||
Messages: []any{
|
||||
map[string]any{"role": "user", "content": "hello"},
|
||||
},
|
||||
}
|
||||
|
||||
chain1 := BuildAnthropicDigestChain(parsed1)
|
||||
chain2 := BuildAnthropicDigestChain(parsed2)
|
||||
|
||||
if chain1 == chain2 {
|
||||
t.Error("Different system prompts should produce different chains")
|
||||
}
|
||||
|
||||
// 但 user 部分的 hash 应该相同
|
||||
parts1 := splitChain(chain1)
|
||||
parts2 := splitChain(chain2)
|
||||
if parts1[1] != parts2[1] {
|
||||
t.Error("Same user message should produce same hash regardless of system")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAnthropicDigestChain_DifferentContentProducesDifferentChain(t *testing.T) {
|
||||
parsed1 := &ParsedRequest{
|
||||
Messages: []any{
|
||||
map[string]any{"role": "user", "content": "hello"},
|
||||
map[string]any{"role": "assistant", "content": "ORIGINAL reply"},
|
||||
map[string]any{"role": "user", "content": "next"},
|
||||
},
|
||||
}
|
||||
parsed2 := &ParsedRequest{
|
||||
Messages: []any{
|
||||
map[string]any{"role": "user", "content": "hello"},
|
||||
map[string]any{"role": "assistant", "content": "TAMPERED reply"},
|
||||
map[string]any{"role": "user", "content": "next"},
|
||||
},
|
||||
}
|
||||
|
||||
chain1 := BuildAnthropicDigestChain(parsed1)
|
||||
chain2 := BuildAnthropicDigestChain(parsed2)
|
||||
|
||||
if chain1 == chain2 {
|
||||
t.Error("Different content should produce different chains")
|
||||
}
|
||||
|
||||
parts1 := splitChain(chain1)
|
||||
parts2 := splitChain(chain2)
|
||||
// 第一个 user message hash 应该相同
|
||||
if parts1[0] != parts2[0] {
|
||||
t.Error("First user message hash should be the same")
|
||||
}
|
||||
// assistant reply hash 应该不同
|
||||
if parts1[1] == parts2[1] {
|
||||
t.Error("Assistant reply hash should differ")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAnthropicDigestChain_Deterministic(t *testing.T) {
|
||||
parsed := &ParsedRequest{
|
||||
System: "test system",
|
||||
Messages: []any{
|
||||
map[string]any{"role": "user", "content": "hello"},
|
||||
map[string]any{"role": "assistant", "content": "hi"},
|
||||
},
|
||||
}
|
||||
|
||||
chain1 := BuildAnthropicDigestChain(parsed)
|
||||
chain2 := BuildAnthropicDigestChain(parsed)
|
||||
|
||||
if chain1 != chain2 {
|
||||
t.Errorf("BuildAnthropicDigestChain not deterministic: %s vs %s", chain1, chain2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateAnthropicDigestSessionKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
prefixHash string
|
||||
uuid string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "normal 16 char hash with uuid",
|
||||
prefixHash: "abcdefgh12345678",
|
||||
uuid: "550e8400-e29b-41d4-a716-446655440000",
|
||||
want: "anthropic:digest:abcdefgh:550e8400",
|
||||
},
|
||||
{
|
||||
name: "exactly 8 chars",
|
||||
prefixHash: "12345678",
|
||||
uuid: "abcdefgh",
|
||||
want: "anthropic:digest:12345678:abcdefgh",
|
||||
},
|
||||
{
|
||||
name: "short values",
|
||||
prefixHash: "abc",
|
||||
uuid: "xyz",
|
||||
want: "anthropic:digest:abc:xyz",
|
||||
},
|
||||
{
|
||||
name: "empty values",
|
||||
prefixHash: "",
|
||||
uuid: "",
|
||||
want: "anthropic:digest::",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := GenerateAnthropicDigestSessionKey(tt.prefixHash, tt.uuid)
|
||||
if got != tt.want {
|
||||
t.Errorf("GenerateAnthropicDigestSessionKey(%q, %q) = %q, want %q", tt.prefixHash, tt.uuid, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// 验证不同 uuid 产生不同 sessionKey
|
||||
t.Run("different uuid different key", func(t *testing.T) {
|
||||
hash := "sameprefix123456"
|
||||
result1 := GenerateAnthropicDigestSessionKey(hash, "uuid0001-session-a")
|
||||
result2 := GenerateAnthropicDigestSessionKey(hash, "uuid0002-session-b")
|
||||
if result1 == result2 {
|
||||
t.Errorf("Different UUIDs should produce different session keys: %s vs %s", result1, result2)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestAnthropicSessionTTL(t *testing.T) {
|
||||
ttl := AnthropicSessionTTL()
|
||||
if ttl.Seconds() != 300 {
|
||||
t.Errorf("expected 300 seconds, got: %v", ttl.Seconds())
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAnthropicDigestChain_ContentBlocks(t *testing.T) {
|
||||
// 测试 content 为 content blocks 数组的情况
|
||||
parsed := &ParsedRequest{
|
||||
Messages: []any{
|
||||
map[string]any{
|
||||
"role": "user",
|
||||
"content": []any{
|
||||
map[string]any{"type": "text", "text": "describe this image"},
|
||||
map[string]any{"type": "image", "source": map[string]any{"type": "base64"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
result := BuildAnthropicDigestChain(parsed)
|
||||
parts := splitChain(result)
|
||||
if len(parts) != 1 {
|
||||
t.Fatalf("expected 1 part, got %d: %s", len(parts), result)
|
||||
}
|
||||
if !strings.HasPrefix(parts[0], "u:") {
|
||||
t.Errorf("expected prefix 'u:', got: %s", parts[0])
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -4,17 +4,42 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// antigravityFailingWriter 模拟客户端断开连接的 gin.ResponseWriter
|
||||
type antigravityFailingWriter struct {
|
||||
gin.ResponseWriter
|
||||
failAfter int // 允许成功写入的次数,之后所有写入返回错误
|
||||
writes int
|
||||
}
|
||||
|
||||
func (w *antigravityFailingWriter) Write(p []byte) (int, error) {
|
||||
if w.writes >= w.failAfter {
|
||||
return 0, errors.New("write failed: client disconnected")
|
||||
}
|
||||
w.writes++
|
||||
return w.ResponseWriter.Write(p)
|
||||
}
|
||||
|
||||
// newAntigravityTestService 创建用于流式测试的 AntigravityGatewayService
|
||||
func newAntigravityTestService(cfg *config.Config) *AntigravityGatewayService {
|
||||
return &AntigravityGatewayService{
|
||||
settingService: &SettingService{cfg: cfg},
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripSignatureSensitiveBlocksFromClaudeRequest(t *testing.T) {
|
||||
req := &antigravity.ClaudeRequest{
|
||||
Model: "claude-sonnet-4-5",
|
||||
@@ -337,8 +362,8 @@ func TestAntigravityGatewayService_Forward_StickySessionForceCacheBilling(t *tes
|
||||
require.True(t, failoverErr.ForceCacheBilling, "ForceCacheBilling should be true for sticky session switch")
|
||||
}
|
||||
|
||||
// TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling
|
||||
// 验证:ForwardGemini 粘性会话切换时,UpstreamFailoverError.ForceCacheBilling 应为 true
|
||||
// TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling verifies
|
||||
// that ForwardGemini sets ForceCacheBilling=true for sticky session switch.
|
||||
func TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
writer := httptest.NewRecorder()
|
||||
@@ -391,3 +416,507 @@ func TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling(
|
||||
require.Equal(t, http.StatusServiceUnavailable, failoverErr.StatusCode)
|
||||
require.True(t, failoverErr.ForceCacheBilling, "ForceCacheBilling should be true for sticky session switch")
|
||||
}
|
||||
|
||||
// --- 流式 happy path 测试 ---
|
||||
|
||||
// TestStreamUpstreamResponse_NormalComplete
|
||||
// 验证:正常流式转发完成时,数据正确透传、usage 正确收集、clientDisconnect=false
|
||||
func TestStreamUpstreamResponse_NormalComplete(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
svc := newAntigravityTestService(&config.Config{
|
||||
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
|
||||
})
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
|
||||
|
||||
go func() {
|
||||
defer func() { _ = pw.Close() }()
|
||||
fmt.Fprintln(pw, `event: message_start`)
|
||||
fmt.Fprintln(pw, `data: {"type":"message_start","message":{"usage":{"input_tokens":10}}}`)
|
||||
fmt.Fprintln(pw, "")
|
||||
fmt.Fprintln(pw, `event: content_block_delta`)
|
||||
fmt.Fprintln(pw, `data: {"type":"content_block_delta","delta":{"text":"hello"}}`)
|
||||
fmt.Fprintln(pw, "")
|
||||
fmt.Fprintln(pw, `event: message_delta`)
|
||||
fmt.Fprintln(pw, `data: {"type":"message_delta","usage":{"output_tokens":5}}`)
|
||||
fmt.Fprintln(pw, "")
|
||||
}()
|
||||
|
||||
result := svc.streamUpstreamResponse(c, resp, time.Now())
|
||||
_ = pr.Close()
|
||||
|
||||
require.NotNil(t, result)
|
||||
require.False(t, result.clientDisconnect, "normal completion should not set clientDisconnect")
|
||||
require.NotNil(t, result.usage)
|
||||
require.Equal(t, 5, result.usage.OutputTokens, "should collect output_tokens from message_delta")
|
||||
require.NotNil(t, result.firstTokenMs, "should record first token time")
|
||||
|
||||
// 验证数据被透传到客户端
|
||||
body := rec.Body.String()
|
||||
require.Contains(t, body, "event: message_start")
|
||||
require.Contains(t, body, "content_block_delta")
|
||||
require.Contains(t, body, "message_delta")
|
||||
}
|
||||
|
||||
// TestHandleGeminiStreamingResponse_NormalComplete
|
||||
// 验证:正常 Gemini 流式转发,数据正确透传、usage 正确收集
|
||||
func TestHandleGeminiStreamingResponse_NormalComplete(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
svc := newAntigravityTestService(&config.Config{
|
||||
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
|
||||
})
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
|
||||
|
||||
go func() {
|
||||
defer func() { _ = pw.Close() }()
|
||||
// 第一个 chunk(部分内容)
|
||||
fmt.Fprintln(pw, `data: {"candidates":[{"content":{"parts":[{"text":"Hello"}]}}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":3}}`)
|
||||
fmt.Fprintln(pw, "")
|
||||
// 第二个 chunk(最终内容+完整 usage)
|
||||
fmt.Fprintln(pw, `data: {"candidates":[{"content":{"parts":[{"text":" world"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":8,"cachedContentTokenCount":2}}`)
|
||||
fmt.Fprintln(pw, "")
|
||||
}()
|
||||
|
||||
result, err := svc.handleGeminiStreamingResponse(c, resp, time.Now())
|
||||
_ = pr.Close()
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.False(t, result.clientDisconnect, "normal completion should not set clientDisconnect")
|
||||
require.NotNil(t, result.usage)
|
||||
// Gemini usage: promptTokenCount=10, candidatesTokenCount=8, cachedContentTokenCount=2
|
||||
// → InputTokens=10-2=8, OutputTokens=8, CacheReadInputTokens=2
|
||||
require.Equal(t, 8, result.usage.InputTokens)
|
||||
require.Equal(t, 8, result.usage.OutputTokens)
|
||||
require.Equal(t, 2, result.usage.CacheReadInputTokens)
|
||||
require.NotNil(t, result.firstTokenMs, "should record first token time")
|
||||
|
||||
// 验证数据被透传到客户端
|
||||
body := rec.Body.String()
|
||||
require.Contains(t, body, "Hello")
|
||||
require.Contains(t, body, "world")
|
||||
// 不应包含错误事件
|
||||
require.NotContains(t, body, "event: error")
|
||||
}
|
||||
|
||||
// TestHandleClaudeStreamingResponse_NormalComplete
|
||||
// 验证:正常 Claude 流式转发(Gemini→Claude 转换),数据正确转换并输出
|
||||
func TestHandleClaudeStreamingResponse_NormalComplete(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
svc := newAntigravityTestService(&config.Config{
|
||||
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
|
||||
})
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
|
||||
|
||||
go func() {
|
||||
defer func() { _ = pw.Close() }()
|
||||
// v1internal 包装格式:Gemini 数据嵌套在 "response" 字段下
|
||||
// ProcessLine 先尝试反序列化为 V1InternalResponse,裸格式会导致 Response.UsageMetadata 为空
|
||||
fmt.Fprintln(pw, `data: {"response":{"candidates":[{"content":{"parts":[{"text":"Hi there"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":5,"candidatesTokenCount":3}}}`)
|
||||
fmt.Fprintln(pw, "")
|
||||
}()
|
||||
|
||||
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5")
|
||||
_ = pr.Close()
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.False(t, result.clientDisconnect, "normal completion should not set clientDisconnect")
|
||||
require.NotNil(t, result.usage)
|
||||
// Gemini→Claude 转换的 usage:promptTokenCount=5→InputTokens=5, candidatesTokenCount=3→OutputTokens=3
|
||||
require.Equal(t, 5, result.usage.InputTokens)
|
||||
require.Equal(t, 3, result.usage.OutputTokens)
|
||||
require.NotNil(t, result.firstTokenMs, "should record first token time")
|
||||
|
||||
// 验证输出是 Claude SSE 格式(processor 会转换)
|
||||
body := rec.Body.String()
|
||||
require.Contains(t, body, "event: message_start", "should contain Claude message_start event")
|
||||
require.Contains(t, body, "event: message_stop", "should contain Claude message_stop event")
|
||||
// 不应包含错误事件
|
||||
require.NotContains(t, body, "event: error")
|
||||
}
|
||||
|
||||
// TestHandleGeminiStreamingResponse_ThoughtsTokenCount
|
||||
// 验证:Gemini 流式转发时 thoughtsTokenCount 被计入 OutputTokens
|
||||
func TestHandleGeminiStreamingResponse_ThoughtsTokenCount(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
svc := newAntigravityTestService(&config.Config{
|
||||
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
|
||||
})
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
|
||||
|
||||
go func() {
|
||||
defer func() { _ = pw.Close() }()
|
||||
fmt.Fprintln(pw, `data: {"candidates":[{"content":{"parts":[{"text":"Hello"}]}}],"usageMetadata":{"promptTokenCount":100,"candidatesTokenCount":20,"thoughtsTokenCount":50}}`)
|
||||
fmt.Fprintln(pw, "")
|
||||
fmt.Fprintln(pw, `data: {"candidates":[{"content":{"parts":[{"text":" world"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":100,"candidatesTokenCount":30,"thoughtsTokenCount":80,"cachedContentTokenCount":10}}`)
|
||||
fmt.Fprintln(pw, "")
|
||||
}()
|
||||
|
||||
result, err := svc.handleGeminiStreamingResponse(c, resp, time.Now())
|
||||
_ = pr.Close()
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.usage)
|
||||
// promptTokenCount=100, cachedContentTokenCount=10 → InputTokens=90
|
||||
require.Equal(t, 90, result.usage.InputTokens)
|
||||
// candidatesTokenCount=30 + thoughtsTokenCount=80 → OutputTokens=110
|
||||
require.Equal(t, 110, result.usage.OutputTokens)
|
||||
require.Equal(t, 10, result.usage.CacheReadInputTokens)
|
||||
}
|
||||
|
||||
// TestHandleClaudeStreamingResponse_ThoughtsTokenCount
|
||||
// 验证:Gemini→Claude 流式转换时 thoughtsTokenCount 被计入 OutputTokens
|
||||
func TestHandleClaudeStreamingResponse_ThoughtsTokenCount(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
svc := newAntigravityTestService(&config.Config{
|
||||
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
|
||||
})
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
|
||||
|
||||
go func() {
|
||||
defer func() { _ = pw.Close() }()
|
||||
fmt.Fprintln(pw, `data: {"response":{"candidates":[{"content":{"parts":[{"text":"Hi"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":50,"candidatesTokenCount":10,"thoughtsTokenCount":25}}}`)
|
||||
fmt.Fprintln(pw, "")
|
||||
}()
|
||||
|
||||
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "gemini-2.5-pro")
|
||||
_ = pr.Close()
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.usage)
|
||||
// promptTokenCount=50 → InputTokens=50
|
||||
require.Equal(t, 50, result.usage.InputTokens)
|
||||
// candidatesTokenCount=10 + thoughtsTokenCount=25 → OutputTokens=35
|
||||
require.Equal(t, 35, result.usage.OutputTokens)
|
||||
}
|
||||
|
||||
// --- 流式客户端断开检测测试 ---
|
||||
|
||||
// TestStreamUpstreamResponse_ClientDisconnectDrainsUsage
|
||||
// 验证:客户端写入失败后,streamUpstreamResponse 继续读取上游以收集 usage
|
||||
func TestStreamUpstreamResponse_ClientDisconnectDrainsUsage(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
svc := newAntigravityTestService(&config.Config{
|
||||
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
|
||||
})
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
c.Writer = &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
|
||||
|
||||
go func() {
|
||||
defer func() { _ = pw.Close() }()
|
||||
fmt.Fprintln(pw, `event: message_start`)
|
||||
fmt.Fprintln(pw, `data: {"type":"message_start","message":{"usage":{"input_tokens":10}}}`)
|
||||
fmt.Fprintln(pw, "")
|
||||
fmt.Fprintln(pw, `event: message_delta`)
|
||||
fmt.Fprintln(pw, `data: {"type":"message_delta","usage":{"output_tokens":20}}`)
|
||||
fmt.Fprintln(pw, "")
|
||||
}()
|
||||
|
||||
result := svc.streamUpstreamResponse(c, resp, time.Now())
|
||||
_ = pr.Close()
|
||||
|
||||
require.NotNil(t, result)
|
||||
require.True(t, result.clientDisconnect)
|
||||
require.NotNil(t, result.usage)
|
||||
require.Equal(t, 20, result.usage.OutputTokens)
|
||||
}
|
||||
|
||||
// TestStreamUpstreamResponse_ContextCanceled
|
||||
// 验证:context 取消时返回 usage 且标记 clientDisconnect
|
||||
func TestStreamUpstreamResponse_ContextCanceled(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
svc := newAntigravityTestService(&config.Config{
|
||||
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
|
||||
})
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil).WithContext(ctx)
|
||||
|
||||
resp := &http.Response{StatusCode: http.StatusOK, Body: cancelReadCloser{}, Header: http.Header{}}
|
||||
|
||||
result := svc.streamUpstreamResponse(c, resp, time.Now())
|
||||
|
||||
require.NotNil(t, result)
|
||||
require.True(t, result.clientDisconnect)
|
||||
require.NotContains(t, rec.Body.String(), "event: error")
|
||||
}
|
||||
|
||||
// TestStreamUpstreamResponse_Timeout
|
||||
// 验证:上游超时时返回已收集的 usage
|
||||
func TestStreamUpstreamResponse_Timeout(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
svc := newAntigravityTestService(&config.Config{
|
||||
Gateway: config.GatewayConfig{StreamDataIntervalTimeout: 1, MaxLineSize: defaultMaxLineSize},
|
||||
})
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
|
||||
|
||||
result := svc.streamUpstreamResponse(c, resp, time.Now())
|
||||
_ = pw.Close()
|
||||
_ = pr.Close()
|
||||
|
||||
require.NotNil(t, result)
|
||||
require.False(t, result.clientDisconnect)
|
||||
}
|
||||
|
||||
// TestStreamUpstreamResponse_TimeoutAfterClientDisconnect
|
||||
// 验证:客户端断开后上游超时,返回 usage 并标记 clientDisconnect
|
||||
func TestStreamUpstreamResponse_TimeoutAfterClientDisconnect(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
svc := newAntigravityTestService(&config.Config{
|
||||
Gateway: config.GatewayConfig{StreamDataIntervalTimeout: 1, MaxLineSize: defaultMaxLineSize},
|
||||
})
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
c.Writer = &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
|
||||
|
||||
go func() {
|
||||
fmt.Fprintln(pw, `data: {"type":"message_start","message":{"usage":{"input_tokens":5}}}`)
|
||||
fmt.Fprintln(pw, "")
|
||||
// 不关闭 pw → 等待超时
|
||||
}()
|
||||
|
||||
result := svc.streamUpstreamResponse(c, resp, time.Now())
|
||||
_ = pw.Close()
|
||||
_ = pr.Close()
|
||||
|
||||
require.NotNil(t, result)
|
||||
require.True(t, result.clientDisconnect)
|
||||
}
|
||||
|
||||
// TestHandleGeminiStreamingResponse_ClientDisconnect
|
||||
// 验证:Gemini 流式转发中客户端断开后继续 drain 上游
|
||||
func TestHandleGeminiStreamingResponse_ClientDisconnect(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
svc := newAntigravityTestService(&config.Config{
|
||||
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
|
||||
})
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
c.Writer = &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
|
||||
|
||||
go func() {
|
||||
defer func() { _ = pw.Close() }()
|
||||
fmt.Fprintln(pw, `data: {"candidates":[{"content":{"parts":[{"text":"hi"}]}}],"usageMetadata":{"promptTokenCount":5,"candidatesTokenCount":10}}`)
|
||||
fmt.Fprintln(pw, "")
|
||||
}()
|
||||
|
||||
result, err := svc.handleGeminiStreamingResponse(c, resp, time.Now())
|
||||
_ = pr.Close()
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.True(t, result.clientDisconnect)
|
||||
require.NotContains(t, rec.Body.String(), "write_failed")
|
||||
}
|
||||
|
||||
// TestHandleGeminiStreamingResponse_ContextCanceled
|
||||
// 验证:context 取消时不注入错误事件
|
||||
func TestHandleGeminiStreamingResponse_ContextCanceled(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
svc := newAntigravityTestService(&config.Config{
|
||||
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
|
||||
})
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil).WithContext(ctx)
|
||||
|
||||
resp := &http.Response{StatusCode: http.StatusOK, Body: cancelReadCloser{}, Header: http.Header{}}
|
||||
|
||||
result, err := svc.handleGeminiStreamingResponse(c, resp, time.Now())
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.True(t, result.clientDisconnect)
|
||||
require.NotContains(t, rec.Body.String(), "event: error")
|
||||
}
|
||||
|
||||
// TestHandleClaudeStreamingResponse_ClientDisconnect
|
||||
// 验证:Claude 流式转发中客户端断开后继续 drain 上游
|
||||
func TestHandleClaudeStreamingResponse_ClientDisconnect(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
svc := newAntigravityTestService(&config.Config{
|
||||
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
|
||||
})
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
c.Writer = &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
|
||||
|
||||
go func() {
|
||||
defer func() { _ = pw.Close() }()
|
||||
// v1internal 包装格式
|
||||
fmt.Fprintln(pw, `data: {"response":{"candidates":[{"content":{"parts":[{"text":"hello"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":8,"candidatesTokenCount":15}}}`)
|
||||
fmt.Fprintln(pw, "")
|
||||
}()
|
||||
|
||||
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5")
|
||||
_ = pr.Close()
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.True(t, result.clientDisconnect)
|
||||
}
|
||||
|
||||
// TestHandleClaudeStreamingResponse_ContextCanceled
|
||||
// 验证:context 取消时不注入错误事件
|
||||
func TestHandleClaudeStreamingResponse_ContextCanceled(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
svc := newAntigravityTestService(&config.Config{
|
||||
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
|
||||
})
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil).WithContext(ctx)
|
||||
|
||||
resp := &http.Response{StatusCode: http.StatusOK, Body: cancelReadCloser{}, Header: http.Header{}}
|
||||
|
||||
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5")
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.True(t, result.clientDisconnect)
|
||||
require.NotContains(t, rec.Body.String(), "event: error")
|
||||
}
|
||||
|
||||
// TestExtractSSEUsage 验证 extractSSEUsage 从 SSE data 行正确提取 usage
|
||||
func TestExtractSSEUsage(t *testing.T) {
|
||||
svc := &AntigravityGatewayService{}
|
||||
tests := []struct {
|
||||
name string
|
||||
line string
|
||||
expected ClaudeUsage
|
||||
}{
|
||||
{
|
||||
name: "message_delta with output_tokens",
|
||||
line: `data: {"type":"message_delta","usage":{"output_tokens":42}}`,
|
||||
expected: ClaudeUsage{OutputTokens: 42},
|
||||
},
|
||||
{
|
||||
name: "non-data line ignored",
|
||||
line: `event: message_start`,
|
||||
expected: ClaudeUsage{},
|
||||
},
|
||||
{
|
||||
name: "top-level usage with all fields",
|
||||
line: `data: {"usage":{"input_tokens":10,"output_tokens":20,"cache_read_input_tokens":5,"cache_creation_input_tokens":3}}`,
|
||||
expected: ClaudeUsage{InputTokens: 10, OutputTokens: 20, CacheReadInputTokens: 5, CacheCreationInputTokens: 3},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
usage := &ClaudeUsage{}
|
||||
svc.extractSSEUsage(tt.line, usage)
|
||||
require.Equal(t, tt.expected, *usage)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAntigravityClientWriter 验证 antigravityClientWriter 的断开检测
|
||||
func TestAntigravityClientWriter(t *testing.T) {
|
||||
t.Run("normal write succeeds", func(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
flusher, _ := c.Writer.(http.Flusher)
|
||||
cw := newAntigravityClientWriter(c.Writer, flusher, "test")
|
||||
|
||||
ok := cw.Write([]byte("hello"))
|
||||
require.True(t, ok)
|
||||
require.False(t, cw.Disconnected())
|
||||
require.Contains(t, rec.Body.String(), "hello")
|
||||
})
|
||||
|
||||
t.Run("write failure marks disconnected", func(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
fw := &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
|
||||
flusher, _ := c.Writer.(http.Flusher)
|
||||
cw := newAntigravityClientWriter(fw, flusher, "test")
|
||||
|
||||
ok := cw.Write([]byte("hello"))
|
||||
require.False(t, ok)
|
||||
require.True(t, cw.Disconnected())
|
||||
})
|
||||
|
||||
t.Run("subsequent writes are no-op", func(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
fw := &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
|
||||
flusher, _ := c.Writer.(http.Flusher)
|
||||
cw := newAntigravityClientWriter(fw, flusher, "test")
|
||||
|
||||
cw.Write([]byte("first"))
|
||||
ok := cw.Fprintf("second %d", 2)
|
||||
require.False(t, ok)
|
||||
require.True(t, cw.Disconnected())
|
||||
})
|
||||
}
|
||||
|
||||
@@ -192,6 +192,43 @@ func (s *AntigravityOAuthService) RefreshToken(ctx context.Context, refreshToken
|
||||
return nil, fmt.Errorf("token 刷新失败 (重试后): %w", lastErr)
|
||||
}
|
||||
|
||||
// ValidateRefreshToken 用 refresh token 验证并获取完整的 token 信息(含 email 和 project_id)
|
||||
func (s *AntigravityOAuthService) ValidateRefreshToken(ctx context.Context, refreshToken string, proxyID *int64) (*AntigravityTokenInfo, error) {
|
||||
var proxyURL string
|
||||
if proxyID != nil {
|
||||
proxy, err := s.proxyRepo.GetByID(ctx, *proxyID)
|
||||
if err == nil && proxy != nil {
|
||||
proxyURL = proxy.URL()
|
||||
}
|
||||
}
|
||||
|
||||
// 刷新 token
|
||||
tokenInfo, err := s.RefreshToken(ctx, refreshToken, proxyURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 获取用户信息(email)
|
||||
client := antigravity.NewClient(proxyURL)
|
||||
userInfo, err := client.GetUserInfo(ctx, tokenInfo.AccessToken)
|
||||
if err != nil {
|
||||
fmt.Printf("[AntigravityOAuth] 警告: 获取用户信息失败: %v\n", err)
|
||||
} else {
|
||||
tokenInfo.Email = userInfo.Email
|
||||
}
|
||||
|
||||
// 获取 project_id(容错,失败不阻塞)
|
||||
projectID, loadErr := s.loadProjectIDWithRetry(ctx, tokenInfo.AccessToken, proxyURL, 3)
|
||||
if loadErr != nil {
|
||||
fmt.Printf("[AntigravityOAuth] 警告: 获取 project_id 失败(重试后): %v\n", loadErr)
|
||||
tokenInfo.ProjectIDMissing = true
|
||||
} else {
|
||||
tokenInfo.ProjectID = projectID
|
||||
}
|
||||
|
||||
return tokenInfo, nil
|
||||
}
|
||||
|
||||
func isNonRetryableAntigravityOAuthError(err error) bool {
|
||||
msg := err.Error()
|
||||
nonRetryable := []string{
|
||||
@@ -273,12 +310,21 @@ func (s *AntigravityOAuthService) loadProjectIDWithRetry(ctx context.Context, ac
|
||||
}
|
||||
|
||||
client := antigravity.NewClient(proxyURL)
|
||||
loadResp, _, err := client.LoadCodeAssist(ctx, accessToken)
|
||||
loadResp, loadRaw, err := client.LoadCodeAssist(ctx, accessToken)
|
||||
|
||||
if err == nil && loadResp != nil && loadResp.CloudAICompanionProject != "" {
|
||||
return loadResp.CloudAICompanionProject, nil
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
if projectID, onboardErr := tryOnboardProjectID(ctx, client, accessToken, loadRaw); onboardErr == nil && projectID != "" {
|
||||
return projectID, nil
|
||||
} else if onboardErr != nil {
|
||||
lastErr = onboardErr
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// 记录错误
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
@@ -292,6 +338,65 @@ func (s *AntigravityOAuthService) loadProjectIDWithRetry(ctx context.Context, ac
|
||||
return "", fmt.Errorf("获取 project_id 失败 (重试 %d 次后): %w", maxRetries, lastErr)
|
||||
}
|
||||
|
||||
func tryOnboardProjectID(ctx context.Context, client *antigravity.Client, accessToken string, loadRaw map[string]any) (string, error) {
|
||||
tierID := resolveDefaultTierID(loadRaw)
|
||||
if tierID == "" {
|
||||
return "", fmt.Errorf("loadCodeAssist 未返回可用的默认 tier")
|
||||
}
|
||||
|
||||
projectID, err := client.OnboardUser(ctx, accessToken, tierID)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("onboardUser 失败 (tier=%s): %w", tierID, err)
|
||||
}
|
||||
return projectID, nil
|
||||
}
|
||||
|
||||
func resolveDefaultTierID(loadRaw map[string]any) string {
|
||||
if len(loadRaw) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
rawTiers, ok := loadRaw["allowedTiers"]
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
|
||||
tiers, ok := rawTiers.([]any)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
|
||||
for _, rawTier := range tiers {
|
||||
tier, ok := rawTier.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if isDefault, _ := tier["isDefault"].(bool); !isDefault {
|
||||
continue
|
||||
}
|
||||
if id, ok := tier["id"].(string); ok {
|
||||
id = strings.TrimSpace(id)
|
||||
if id != "" {
|
||||
return id
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// FillProjectID 仅获取 project_id,不刷新 OAuth token
|
||||
func (s *AntigravityOAuthService) FillProjectID(ctx context.Context, account *Account, accessToken string) (string, error) {
|
||||
var proxyURL string
|
||||
if account.ProxyID != nil {
|
||||
proxy, err := s.proxyRepo.GetByID(ctx, *account.ProxyID)
|
||||
if err == nil && proxy != nil {
|
||||
proxyURL = proxy.URL()
|
||||
}
|
||||
}
|
||||
return s.loadProjectIDWithRetry(ctx, accessToken, proxyURL, 3)
|
||||
}
|
||||
|
||||
// BuildAccountCredentials 构建账户凭证
|
||||
func (s *AntigravityOAuthService) BuildAccountCredentials(tokenInfo *AntigravityTokenInfo) map[string]any {
|
||||
creds := map[string]any{
|
||||
|
||||
82
backend/internal/service/antigravity_oauth_service_test.go
Normal file
82
backend/internal/service/antigravity_oauth_service_test.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestResolveDefaultTierID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
loadRaw map[string]any
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "nil loadRaw",
|
||||
loadRaw: nil,
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "missing allowedTiers",
|
||||
loadRaw: map[string]any{
|
||||
"paidTier": map[string]any{"id": "g1-pro-tier"},
|
||||
},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "empty allowedTiers",
|
||||
loadRaw: map[string]any{"allowedTiers": []any{}},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "tier missing id field",
|
||||
loadRaw: map[string]any{
|
||||
"allowedTiers": []any{
|
||||
map[string]any{"isDefault": true},
|
||||
},
|
||||
},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "allowedTiers but no default",
|
||||
loadRaw: map[string]any{
|
||||
"allowedTiers": []any{
|
||||
map[string]any{"id": "free-tier", "isDefault": false},
|
||||
map[string]any{"id": "standard-tier", "isDefault": false},
|
||||
},
|
||||
},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "default tier found",
|
||||
loadRaw: map[string]any{
|
||||
"allowedTiers": []any{
|
||||
map[string]any{"id": "free-tier", "isDefault": true},
|
||||
map[string]any{"id": "standard-tier", "isDefault": false},
|
||||
},
|
||||
},
|
||||
want: "free-tier",
|
||||
},
|
||||
{
|
||||
name: "default tier id with spaces",
|
||||
loadRaw: map[string]any{
|
||||
"allowedTiers": []any{
|
||||
map[string]any{"id": " standard-tier ", "isDefault": true},
|
||||
},
|
||||
},
|
||||
want: "standard-tier",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := resolveDefaultTierID(tc.loadRaw)
|
||||
if got != tc.want {
|
||||
t.Fatalf("resolveDefaultTierID() = %q, want %q", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -2,63 +2,23 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const antigravityQuotaScopesKey = "antigravity_quota_scopes"
|
||||
|
||||
// AntigravityQuotaScope 表示 Antigravity 的配额域
|
||||
type AntigravityQuotaScope string
|
||||
|
||||
const (
|
||||
AntigravityQuotaScopeClaude AntigravityQuotaScope = "claude"
|
||||
AntigravityQuotaScopeGeminiText AntigravityQuotaScope = "gemini_text"
|
||||
AntigravityQuotaScopeGeminiImage AntigravityQuotaScope = "gemini_image"
|
||||
)
|
||||
|
||||
// IsScopeSupported 检查给定的 scope 是否在分组支持的 scope 列表中
|
||||
func IsScopeSupported(supportedScopes []string, scope AntigravityQuotaScope) bool {
|
||||
if len(supportedScopes) == 0 {
|
||||
// 未配置时默认全部支持
|
||||
return true
|
||||
}
|
||||
supported := slices.Contains(supportedScopes, string(scope))
|
||||
return supported
|
||||
}
|
||||
|
||||
// ResolveAntigravityQuotaScope 根据模型名称解析配额域(导出版本)
|
||||
func ResolveAntigravityQuotaScope(requestedModel string) (AntigravityQuotaScope, bool) {
|
||||
return resolveAntigravityQuotaScope(requestedModel)
|
||||
}
|
||||
|
||||
// resolveAntigravityQuotaScope 根据模型名称解析配额域
|
||||
func resolveAntigravityQuotaScope(requestedModel string) (AntigravityQuotaScope, bool) {
|
||||
model := normalizeAntigravityModelName(requestedModel)
|
||||
if model == "" {
|
||||
return "", false
|
||||
}
|
||||
switch {
|
||||
case strings.HasPrefix(model, "claude-"):
|
||||
return AntigravityQuotaScopeClaude, true
|
||||
case strings.HasPrefix(model, "gemini-"):
|
||||
if isImageGenerationModel(model) {
|
||||
return AntigravityQuotaScopeGeminiImage, true
|
||||
}
|
||||
return AntigravityQuotaScopeGeminiText, true
|
||||
default:
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeAntigravityModelName(model string) string {
|
||||
normalized := strings.ToLower(strings.TrimSpace(model))
|
||||
normalized = strings.TrimPrefix(normalized, "models/")
|
||||
return normalized
|
||||
}
|
||||
|
||||
// IsSchedulableForModel 结合 Antigravity 配额域限流判断是否可调度。
|
||||
// resolveAntigravityModelKey 根据请求的模型名解析限流 key
|
||||
// 返回空字符串表示无法解析
|
||||
func resolveAntigravityModelKey(requestedModel string) string {
|
||||
return normalizeAntigravityModelName(requestedModel)
|
||||
}
|
||||
|
||||
// IsSchedulableForModel 结合模型级限流判断是否可调度。
|
||||
// 保持旧签名以兼容既有调用方;默认使用 context.Background()。
|
||||
func (a *Account) IsSchedulableForModel(requestedModel string) bool {
|
||||
return a.IsSchedulableForModelWithContext(context.Background(), requestedModel)
|
||||
@@ -74,107 +34,20 @@ func (a *Account) IsSchedulableForModelWithContext(ctx context.Context, requeste
|
||||
if a.isModelRateLimitedWithContext(ctx, requestedModel) {
|
||||
return false
|
||||
}
|
||||
if a.Platform != PlatformAntigravity {
|
||||
return true
|
||||
}
|
||||
scope, ok := resolveAntigravityQuotaScope(requestedModel)
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
resetAt := a.antigravityQuotaScopeResetAt(scope)
|
||||
if resetAt == nil {
|
||||
return true
|
||||
}
|
||||
now := time.Now()
|
||||
return !now.Before(*resetAt)
|
||||
return true
|
||||
}
|
||||
|
||||
func (a *Account) antigravityQuotaScopeResetAt(scope AntigravityQuotaScope) *time.Time {
|
||||
if a == nil || a.Extra == nil || scope == "" {
|
||||
return nil
|
||||
}
|
||||
rawScopes, ok := a.Extra[antigravityQuotaScopesKey].(map[string]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
rawScope, ok := rawScopes[string(scope)].(map[string]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
resetAtRaw, ok := rawScope["rate_limit_reset_at"].(string)
|
||||
if !ok || strings.TrimSpace(resetAtRaw) == "" {
|
||||
return nil
|
||||
}
|
||||
resetAt, err := time.Parse(time.RFC3339, resetAtRaw)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return &resetAt
|
||||
}
|
||||
|
||||
var antigravityAllScopes = []AntigravityQuotaScope{
|
||||
AntigravityQuotaScopeClaude,
|
||||
AntigravityQuotaScopeGeminiText,
|
||||
AntigravityQuotaScopeGeminiImage,
|
||||
}
|
||||
|
||||
func (a *Account) GetAntigravityScopeRateLimits() map[string]int64 {
|
||||
if a == nil || a.Platform != PlatformAntigravity {
|
||||
return nil
|
||||
}
|
||||
now := time.Now()
|
||||
result := make(map[string]int64)
|
||||
for _, scope := range antigravityAllScopes {
|
||||
resetAt := a.antigravityQuotaScopeResetAt(scope)
|
||||
if resetAt != nil && now.Before(*resetAt) {
|
||||
remainingSec := int64(time.Until(*resetAt).Seconds())
|
||||
if remainingSec > 0 {
|
||||
result[string(scope)] = remainingSec
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(result) == 0 {
|
||||
return nil
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// GetQuotaScopeRateLimitRemainingTime 获取模型域限流剩余时间
|
||||
// 返回 0 表示未限流或已过期
|
||||
func (a *Account) GetQuotaScopeRateLimitRemainingTime(requestedModel string) time.Duration {
|
||||
if a == nil || a.Platform != PlatformAntigravity {
|
||||
return 0
|
||||
}
|
||||
scope, ok := resolveAntigravityQuotaScope(requestedModel)
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
resetAt := a.antigravityQuotaScopeResetAt(scope)
|
||||
if resetAt == nil {
|
||||
return 0
|
||||
}
|
||||
if remaining := time.Until(*resetAt); remaining > 0 {
|
||||
return remaining
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// GetRateLimitRemainingTime 获取限流剩余时间(模型限流和模型域限流取最大值)
|
||||
// GetRateLimitRemainingTime 获取限流剩余时间(模型级限流)
|
||||
// 返回 0 表示未限流或已过期
|
||||
func (a *Account) GetRateLimitRemainingTime(requestedModel string) time.Duration {
|
||||
return a.GetRateLimitRemainingTimeWithContext(context.Background(), requestedModel)
|
||||
}
|
||||
|
||||
// GetRateLimitRemainingTimeWithContext 获取限流剩余时间(模型限流和模型域限流取最大值)
|
||||
// GetRateLimitRemainingTimeWithContext 获取限流剩余时间(模型级限流)
|
||||
// 返回 0 表示未限流或已过期
|
||||
func (a *Account) GetRateLimitRemainingTimeWithContext(ctx context.Context, requestedModel string) time.Duration {
|
||||
if a == nil {
|
||||
return 0
|
||||
}
|
||||
modelRemaining := a.GetModelRateLimitRemainingTimeWithContext(ctx, requestedModel)
|
||||
scopeRemaining := a.GetQuotaScopeRateLimitRemainingTime(requestedModel)
|
||||
if modelRemaining > scopeRemaining {
|
||||
return modelRemaining
|
||||
}
|
||||
return scopeRemaining
|
||||
return a.GetModelRateLimitRemainingTimeWithContext(ctx, requestedModel)
|
||||
}
|
||||
|
||||
@@ -59,12 +59,6 @@ func (s *stubAntigravityUpstream) DoWithTLS(req *http.Request, proxyURL string,
|
||||
return s.Do(req, proxyURL, accountID, accountConcurrency)
|
||||
}
|
||||
|
||||
type scopeLimitCall struct {
|
||||
accountID int64
|
||||
scope AntigravityQuotaScope
|
||||
resetAt time.Time
|
||||
}
|
||||
|
||||
type rateLimitCall struct {
|
||||
accountID int64
|
||||
resetAt time.Time
|
||||
@@ -78,16 +72,10 @@ type modelRateLimitCall struct {
|
||||
|
||||
type stubAntigravityAccountRepo struct {
|
||||
AccountRepository
|
||||
scopeCalls []scopeLimitCall
|
||||
rateCalls []rateLimitCall
|
||||
modelRateLimitCalls []modelRateLimitCall
|
||||
}
|
||||
|
||||
func (s *stubAntigravityAccountRepo) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error {
|
||||
s.scopeCalls = append(s.scopeCalls, scopeLimitCall{accountID: id, scope: scope, resetAt: resetAt})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stubAntigravityAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
||||
s.rateCalls = append(s.rateCalls, rateLimitCall{accountID: id, resetAt: resetAt})
|
||||
return nil
|
||||
@@ -98,7 +86,9 @@ func (s *stubAntigravityAccountRepo) SetModelRateLimit(ctx context.Context, id i
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestAntigravityRetryLoop_URLFallback_UsesLatestSuccess(t *testing.T) {
|
||||
func TestAntigravityRetryLoop_NoURLFallback_UsesConfiguredBaseURL(t *testing.T) {
|
||||
t.Setenv(antigravityForwardBaseURLEnv, "")
|
||||
|
||||
oldBaseURLs := append([]string(nil), antigravity.BaseURLs...)
|
||||
oldAvailability := antigravity.DefaultURLAvailability
|
||||
defer func() {
|
||||
@@ -131,10 +121,9 @@ func TestAntigravityRetryLoop_URLFallback_UsesLatestSuccess(t *testing.T) {
|
||||
accessToken: "token",
|
||||
action: "generateContent",
|
||||
body: []byte(`{"input":"test"}`),
|
||||
quotaScope: AntigravityQuotaScopeClaude,
|
||||
httpUpstream: upstream,
|
||||
requestedModel: "claude-sonnet-4-5",
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
handleErrorCalled = true
|
||||
return nil
|
||||
},
|
||||
@@ -144,32 +133,16 @@ func TestAntigravityRetryLoop_URLFallback_UsesLatestSuccess(t *testing.T) {
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.resp)
|
||||
defer func() { _ = result.resp.Body.Close() }()
|
||||
require.Equal(t, http.StatusOK, result.resp.StatusCode)
|
||||
require.False(t, handleErrorCalled)
|
||||
require.Len(t, upstream.calls, 2)
|
||||
require.True(t, strings.HasPrefix(upstream.calls[0], base1))
|
||||
require.True(t, strings.HasPrefix(upstream.calls[1], base2))
|
||||
require.Equal(t, http.StatusTooManyRequests, result.resp.StatusCode)
|
||||
require.True(t, handleErrorCalled)
|
||||
require.Len(t, upstream.calls, antigravityMaxRetries)
|
||||
for _, callURL := range upstream.calls {
|
||||
require.True(t, strings.HasPrefix(callURL, base1))
|
||||
}
|
||||
|
||||
available := antigravity.DefaultURLAvailability.GetAvailableURLs()
|
||||
require.NotEmpty(t, available)
|
||||
require.Equal(t, base2, available[0])
|
||||
}
|
||||
|
||||
func TestAntigravityHandleUpstreamError_UsesScopeLimit(t *testing.T) {
|
||||
// 分区限流始终开启,不再支持通过环境变量关闭
|
||||
repo := &stubAntigravityAccountRepo{}
|
||||
svc := &AntigravityGatewayService{accountRepo: repo}
|
||||
account := &Account{ID: 9, Name: "acc-9", Platform: PlatformAntigravity}
|
||||
|
||||
body := buildGeminiRateLimitBody("3s")
|
||||
svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, AntigravityQuotaScopeClaude, 0, "", false)
|
||||
|
||||
require.Len(t, repo.scopeCalls, 1)
|
||||
require.Empty(t, repo.rateCalls)
|
||||
call := repo.scopeCalls[0]
|
||||
require.Equal(t, account.ID, call.accountID)
|
||||
require.Equal(t, AntigravityQuotaScopeClaude, call.scope)
|
||||
require.WithinDuration(t, time.Now().Add(3*time.Second), call.resetAt, 2*time.Second)
|
||||
require.Equal(t, base1, available[0])
|
||||
}
|
||||
|
||||
// TestHandleUpstreamError_429_ModelRateLimit 测试 429 模型限流场景
|
||||
@@ -189,7 +162,7 @@ func TestHandleUpstreamError_429_ModelRateLimit(t *testing.T) {
|
||||
}
|
||||
}`)
|
||||
|
||||
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, AntigravityQuotaScopeClaude, 0, "", false)
|
||||
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, "claude-sonnet-4-5", 0, "", false)
|
||||
|
||||
// 应该触发模型限流
|
||||
require.NotNil(t, result)
|
||||
@@ -200,31 +173,32 @@ func TestHandleUpstreamError_429_ModelRateLimit(t *testing.T) {
|
||||
require.Equal(t, "claude-sonnet-4-5", repo.modelRateLimitCalls[0].modelKey)
|
||||
}
|
||||
|
||||
// TestHandleUpstreamError_429_NonModelRateLimit 测试 429 非模型限流场景(走 scope 限流)
|
||||
// TestHandleUpstreamError_429_NonModelRateLimit 测试 429 非模型限流场景(走模型级限流兜底)
|
||||
func TestHandleUpstreamError_429_NonModelRateLimit(t *testing.T) {
|
||||
repo := &stubAntigravityAccountRepo{}
|
||||
svc := &AntigravityGatewayService{accountRepo: repo}
|
||||
account := &Account{ID: 2, Name: "acc-2", Platform: PlatformAntigravity}
|
||||
|
||||
// 429 + 普通限流响应(无 RATE_LIMIT_EXCEEDED reason)→ scope 限流
|
||||
// 429 + 普通限流响应(无 RATE_LIMIT_EXCEEDED reason)→ 走模型级限流兜底
|
||||
body := buildGeminiRateLimitBody("5s")
|
||||
|
||||
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, AntigravityQuotaScopeClaude, 0, "", false)
|
||||
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, "claude-sonnet-4-5", 0, "", false)
|
||||
|
||||
// 不应该触发模型限流,应该走 scope 限流
|
||||
// handleModelRateLimit 不会处理(因为没有 RATE_LIMIT_EXCEEDED),
|
||||
// 但 429 兜底逻辑会使用 requestedModel 设置模型级限流
|
||||
require.Nil(t, result)
|
||||
require.Empty(t, repo.modelRateLimitCalls)
|
||||
require.Len(t, repo.scopeCalls, 1)
|
||||
require.Equal(t, AntigravityQuotaScopeClaude, repo.scopeCalls[0].scope)
|
||||
require.Len(t, repo.modelRateLimitCalls, 1)
|
||||
require.Equal(t, "claude-sonnet-4-5", repo.modelRateLimitCalls[0].modelKey)
|
||||
}
|
||||
|
||||
// TestHandleUpstreamError_503_ModelRateLimit 测试 503 模型限流场景
|
||||
func TestHandleUpstreamError_503_ModelRateLimit(t *testing.T) {
|
||||
// TestHandleUpstreamError_503_ModelCapacityExhausted 测试 503 模型容量不足场景
|
||||
// MODEL_CAPACITY_EXHAUSTED 时应等待重试,不切换账号
|
||||
func TestHandleUpstreamError_503_ModelCapacityExhausted(t *testing.T) {
|
||||
repo := &stubAntigravityAccountRepo{}
|
||||
svc := &AntigravityGatewayService{accountRepo: repo}
|
||||
account := &Account{ID: 3, Name: "acc-3", Platform: PlatformAntigravity}
|
||||
|
||||
// 503 + MODEL_CAPACITY_EXHAUSTED → 模型限流
|
||||
// 503 + MODEL_CAPACITY_EXHAUSTED → 等待重试,不切换账号
|
||||
body := []byte(`{
|
||||
"error": {
|
||||
"status": "UNAVAILABLE",
|
||||
@@ -235,15 +209,15 @@ func TestHandleUpstreamError_503_ModelRateLimit(t *testing.T) {
|
||||
}
|
||||
}`)
|
||||
|
||||
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, AntigravityQuotaScopeGeminiText, 0, "", false)
|
||||
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, "gemini-3-pro-high", 0, "", false)
|
||||
|
||||
// 应该触发模型限流
|
||||
// MODEL_CAPACITY_EXHAUSTED 应该标记为已处理,不切换账号,不设置模型限流
|
||||
// 实际重试由 handleSmartRetry 处理
|
||||
require.NotNil(t, result)
|
||||
require.True(t, result.Handled)
|
||||
require.NotNil(t, result.SwitchError)
|
||||
require.Equal(t, "gemini-3-pro-high", result.SwitchError.RateLimitedModel)
|
||||
require.Len(t, repo.modelRateLimitCalls, 1)
|
||||
require.Equal(t, "gemini-3-pro-high", repo.modelRateLimitCalls[0].modelKey)
|
||||
require.False(t, result.ShouldRetry, "MODEL_CAPACITY_EXHAUSTED should not trigger retry from handleModelRateLimit path")
|
||||
require.Nil(t, result.SwitchError, "MODEL_CAPACITY_EXHAUSTED should not trigger account switch")
|
||||
require.Empty(t, repo.modelRateLimitCalls, "MODEL_CAPACITY_EXHAUSTED should not set model rate limit")
|
||||
}
|
||||
|
||||
// TestHandleUpstreamError_503_NonModelRateLimit 测试 503 非模型限流场景(不处理)
|
||||
@@ -263,12 +237,11 @@ func TestHandleUpstreamError_503_NonModelRateLimit(t *testing.T) {
|
||||
}
|
||||
}`)
|
||||
|
||||
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, AntigravityQuotaScopeGeminiText, 0, "", false)
|
||||
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, "gemini-3-pro-high", 0, "", false)
|
||||
|
||||
// 503 非模型限流不应该做任何处理
|
||||
require.Nil(t, result)
|
||||
require.Empty(t, repo.modelRateLimitCalls, "503 non-model rate limit should not trigger model rate limit")
|
||||
require.Empty(t, repo.scopeCalls, "503 non-model rate limit should not trigger scope rate limit")
|
||||
require.Empty(t, repo.rateCalls, "503 non-model rate limit should not trigger account rate limit")
|
||||
}
|
||||
|
||||
@@ -281,12 +254,11 @@ func TestHandleUpstreamError_503_EmptyBody(t *testing.T) {
|
||||
// 503 + 空响应体 → 不做任何处理
|
||||
body := []byte(`{}`)
|
||||
|
||||
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, AntigravityQuotaScopeGeminiText, 0, "", false)
|
||||
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, "gemini-3-pro-high", 0, "", false)
|
||||
|
||||
// 503 空响应不应该做任何处理
|
||||
require.Nil(t, result)
|
||||
require.Empty(t, repo.modelRateLimitCalls)
|
||||
require.Empty(t, repo.scopeCalls)
|
||||
require.Empty(t, repo.rateCalls)
|
||||
}
|
||||
|
||||
@@ -307,15 +279,7 @@ func TestAccountIsSchedulableForModel_AntigravityRateLimits(t *testing.T) {
|
||||
require.False(t, account.IsSchedulableForModel("gemini-3-flash"))
|
||||
|
||||
account.RateLimitResetAt = nil
|
||||
account.Extra = map[string]any{
|
||||
antigravityQuotaScopesKey: map[string]any{
|
||||
"claude": map[string]any{
|
||||
"rate_limit_reset_at": future.Format(time.RFC3339),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
require.False(t, account.IsSchedulableForModel("claude-sonnet-4-5"))
|
||||
require.True(t, account.IsSchedulableForModel("claude-sonnet-4-5"))
|
||||
require.True(t, account.IsSchedulableForModel("gemini-3-flash"))
|
||||
}
|
||||
|
||||
@@ -341,11 +305,12 @@ func TestParseGeminiRateLimitResetTime_QuotaResetDelay_RoundsUp(t *testing.T) {
|
||||
|
||||
func TestParseAntigravitySmartRetryInfo(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
expectedDelay time.Duration
|
||||
expectedModel string
|
||||
expectedNil bool
|
||||
name string
|
||||
body string
|
||||
expectedDelay time.Duration
|
||||
expectedModel string
|
||||
expectedNil bool
|
||||
expectedIsModelCapacityExhausted bool
|
||||
}{
|
||||
{
|
||||
name: "valid complete response with RATE_LIMIT_EXCEEDED",
|
||||
@@ -408,8 +373,9 @@ func TestParseAntigravitySmartRetryInfo(t *testing.T) {
|
||||
"message": "No capacity available for model gemini-3-pro-high on the server"
|
||||
}
|
||||
}`,
|
||||
expectedDelay: 39 * time.Second,
|
||||
expectedModel: "gemini-3-pro-high",
|
||||
expectedDelay: 39 * time.Second,
|
||||
expectedModel: "gemini-3-pro-high",
|
||||
expectedIsModelCapacityExhausted: true,
|
||||
},
|
||||
{
|
||||
name: "503 UNAVAILABLE without MODEL_CAPACITY_EXHAUSTED - should return nil",
|
||||
@@ -520,6 +486,9 @@ func TestParseAntigravitySmartRetryInfo(t *testing.T) {
|
||||
if result.ModelName != tt.expectedModel {
|
||||
t.Errorf("ModelName = %q, want %q", result.ModelName, tt.expectedModel)
|
||||
}
|
||||
if result.IsModelCapacityExhausted != tt.expectedIsModelCapacityExhausted {
|
||||
t.Errorf("IsModelCapacityExhausted = %v, want %v", result.IsModelCapacityExhausted, tt.expectedIsModelCapacityExhausted)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -531,13 +500,14 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) {
|
||||
apiKeyAccount := &Account{Type: AccountTypeAPIKey}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
account *Account
|
||||
body string
|
||||
expectedShouldRetry bool
|
||||
expectedShouldRateLimit bool
|
||||
minWait time.Duration
|
||||
modelName string
|
||||
name string
|
||||
account *Account
|
||||
body string
|
||||
expectedShouldRetry bool
|
||||
expectedShouldRateLimit bool
|
||||
expectedIsModelCapacityExhausted bool
|
||||
minWait time.Duration
|
||||
modelName string
|
||||
}{
|
||||
{
|
||||
name: "OAuth account with short delay (< 7s) - smart retry",
|
||||
@@ -635,6 +605,7 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) {
|
||||
}`,
|
||||
expectedShouldRetry: false,
|
||||
expectedShouldRateLimit: true,
|
||||
minWait: 7 * time.Second,
|
||||
modelName: "gemini-pro",
|
||||
},
|
||||
{
|
||||
@@ -650,12 +621,14 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) {
|
||||
]
|
||||
}
|
||||
}`,
|
||||
expectedShouldRetry: false,
|
||||
expectedShouldRateLimit: true,
|
||||
modelName: "gemini-3-pro-high",
|
||||
expectedShouldRetry: true,
|
||||
expectedShouldRateLimit: false,
|
||||
expectedIsModelCapacityExhausted: true,
|
||||
minWait: 1 * time.Second,
|
||||
modelName: "gemini-3-pro-high",
|
||||
},
|
||||
{
|
||||
name: "503 UNAVAILABLE with MODEL_CAPACITY_EXHAUSTED - no retryDelay - use default rate limit",
|
||||
name: "503 UNAVAILABLE with MODEL_CAPACITY_EXHAUSTED - no retryDelay - use fixed wait",
|
||||
account: oauthAccount,
|
||||
body: `{
|
||||
"error": {
|
||||
@@ -667,9 +640,11 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) {
|
||||
"message": "No capacity available for model gemini-2.5-flash on the server"
|
||||
}
|
||||
}`,
|
||||
expectedShouldRetry: false,
|
||||
expectedShouldRateLimit: true,
|
||||
modelName: "gemini-2.5-flash",
|
||||
expectedShouldRetry: true,
|
||||
expectedShouldRateLimit: false,
|
||||
expectedIsModelCapacityExhausted: true,
|
||||
minWait: 1 * time.Second,
|
||||
modelName: "gemini-2.5-flash",
|
||||
},
|
||||
{
|
||||
name: "429 RESOURCE_EXHAUSTED with RATE_LIMIT_EXCEEDED - no retryDelay - use default rate limit",
|
||||
@@ -686,24 +661,33 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) {
|
||||
}`,
|
||||
expectedShouldRetry: false,
|
||||
expectedShouldRateLimit: true,
|
||||
minWait: 30 * time.Second,
|
||||
modelName: "claude-sonnet-4-5",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
shouldRetry, shouldRateLimit, wait, model := shouldTriggerAntigravitySmartRetry(tt.account, []byte(tt.body))
|
||||
shouldRetry, shouldRateLimit, wait, model, isModelCapacityExhausted := shouldTriggerAntigravitySmartRetry(tt.account, []byte(tt.body))
|
||||
if shouldRetry != tt.expectedShouldRetry {
|
||||
t.Errorf("shouldRetry = %v, want %v", shouldRetry, tt.expectedShouldRetry)
|
||||
}
|
||||
if shouldRateLimit != tt.expectedShouldRateLimit {
|
||||
t.Errorf("shouldRateLimit = %v, want %v", shouldRateLimit, tt.expectedShouldRateLimit)
|
||||
}
|
||||
if isModelCapacityExhausted != tt.expectedIsModelCapacityExhausted {
|
||||
t.Errorf("isModelCapacityExhausted = %v, want %v", isModelCapacityExhausted, tt.expectedIsModelCapacityExhausted)
|
||||
}
|
||||
if shouldRetry {
|
||||
if wait < tt.minWait {
|
||||
t.Errorf("wait = %v, want >= %v", wait, tt.minWait)
|
||||
}
|
||||
}
|
||||
if shouldRateLimit && tt.minWait > 0 {
|
||||
if wait < tt.minWait {
|
||||
t.Errorf("rate limit wait = %v, want >= %v", wait, tt.minWait)
|
||||
}
|
||||
}
|
||||
if (shouldRetry || shouldRateLimit) && model != tt.modelName {
|
||||
t.Errorf("modelName = %q, want %q", model, tt.modelName)
|
||||
}
|
||||
@@ -803,7 +787,7 @@ func TestSetModelRateLimitByModelName_NotConvertToScope(t *testing.T) {
|
||||
require.NotEqual(t, "claude_sonnet", call.modelKey, "should NOT be scope")
|
||||
}
|
||||
|
||||
func TestAntigravityRetryLoop_PreCheck_WaitsWhenRemainingBelowThreshold(t *testing.T) {
|
||||
func TestAntigravityRetryLoop_PreCheck_SwitchesWhenRateLimited(t *testing.T) {
|
||||
upstream := &recordingOKUpstream{}
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
@@ -815,19 +799,15 @@ func TestAntigravityRetryLoop_PreCheck_WaitsWhenRemainingBelowThreshold(t *testi
|
||||
Extra: map[string]any{
|
||||
modelRateLimitsKey: map[string]any{
|
||||
"claude-sonnet-4-5": map[string]any{
|
||||
// RFC3339 here is second-precision; keep it safely in the future.
|
||||
"rate_limit_reset_at": time.Now().Add(2 * time.Second).Format(time.RFC3339),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
svc := &AntigravityGatewayService{}
|
||||
result, err := svc.antigravityRetryLoop(antigravityRetryLoopParams{
|
||||
ctx: ctx,
|
||||
ctx: context.Background(),
|
||||
prefix: "[test]",
|
||||
account: account,
|
||||
accessToken: "token",
|
||||
@@ -836,17 +816,21 @@ func TestAntigravityRetryLoop_PreCheck_WaitsWhenRemainingBelowThreshold(t *testi
|
||||
requestedModel: "claude-sonnet-4-5",
|
||||
httpUpstream: upstream,
|
||||
isStickySession: true,
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
})
|
||||
|
||||
require.ErrorIs(t, err, context.DeadlineExceeded)
|
||||
require.Nil(t, result)
|
||||
require.Equal(t, 0, upstream.calls, "should not call upstream while waiting on pre-check")
|
||||
var switchErr *AntigravityAccountSwitchError
|
||||
require.ErrorAs(t, err, &switchErr)
|
||||
require.Equal(t, account.ID, switchErr.OriginalAccountID)
|
||||
require.Equal(t, "claude-sonnet-4-5", switchErr.RateLimitedModel)
|
||||
require.True(t, switchErr.IsStickySession)
|
||||
require.Equal(t, 0, upstream.calls, "should not call upstream when switching on pre-check")
|
||||
}
|
||||
|
||||
func TestAntigravityRetryLoop_PreCheck_SwitchesWhenRemainingAtOrAboveThreshold(t *testing.T) {
|
||||
func TestAntigravityRetryLoop_PreCheck_SwitchesWhenRemainingLong(t *testing.T) {
|
||||
upstream := &recordingOKUpstream{}
|
||||
account := &Account{
|
||||
ID: 2,
|
||||
@@ -875,7 +859,7 @@ func TestAntigravityRetryLoop_PreCheck_SwitchesWhenRemainingAtOrAboveThreshold(t
|
||||
requestedModel: "claude-sonnet-4-5",
|
||||
httpUpstream: upstream,
|
||||
isStickySession: true,
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
})
|
||||
@@ -946,6 +930,22 @@ func TestIsAntigravityAccountSwitchError(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveAntigravityForwardBaseURL_DefaultDaily(t *testing.T) {
|
||||
t.Setenv(antigravityForwardBaseURLEnv, "")
|
||||
|
||||
oldBaseURLs := append([]string(nil), antigravity.BaseURLs...)
|
||||
defer func() {
|
||||
antigravity.BaseURLs = oldBaseURLs
|
||||
}()
|
||||
|
||||
prodURL := "https://prod.test"
|
||||
dailyURL := "https://daily.test"
|
||||
antigravity.BaseURLs = []string{dailyURL, prodURL}
|
||||
|
||||
resolved := resolveAntigravityForwardBaseURL()
|
||||
require.Equal(t, dailyURL, resolved)
|
||||
}
|
||||
|
||||
func TestAntigravityAccountSwitchError_Error(t *testing.T) {
|
||||
err := &AntigravityAccountSwitchError{
|
||||
OriginalAccountID: 789,
|
||||
|
||||
@@ -0,0 +1,904 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// 辅助函数:构造带 SingleAccountRetry 标记的 context
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func ctxWithSingleAccountRetry() context.Context {
|
||||
return context.WithValue(context.Background(), ctxkey.SingleAccountRetry, true)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// 1. isSingleAccountRetry 测试
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestIsSingleAccountRetry_True(t *testing.T) {
|
||||
ctx := context.WithValue(context.Background(), ctxkey.SingleAccountRetry, true)
|
||||
require.True(t, isSingleAccountRetry(ctx))
|
||||
}
|
||||
|
||||
func TestIsSingleAccountRetry_False_NoValue(t *testing.T) {
|
||||
require.False(t, isSingleAccountRetry(context.Background()))
|
||||
}
|
||||
|
||||
func TestIsSingleAccountRetry_False_ExplicitFalse(t *testing.T) {
|
||||
ctx := context.WithValue(context.Background(), ctxkey.SingleAccountRetry, false)
|
||||
require.False(t, isSingleAccountRetry(ctx))
|
||||
}
|
||||
|
||||
func TestIsSingleAccountRetry_False_WrongType(t *testing.T) {
|
||||
ctx := context.WithValue(context.Background(), ctxkey.SingleAccountRetry, "true")
|
||||
require.False(t, isSingleAccountRetry(ctx))
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// 2. 常量验证
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestSingleAccountRetryConstants(t *testing.T) {
|
||||
require.Equal(t, 3, antigravitySingleAccountSmartRetryMaxAttempts,
|
||||
"单账号原地重试最多 3 次")
|
||||
require.Equal(t, 15*time.Second, antigravitySingleAccountSmartRetryMaxWait,
|
||||
"单次最大等待 15s")
|
||||
require.Equal(t, 30*time.Second, antigravitySingleAccountSmartRetryTotalMaxWait,
|
||||
"总累计等待不超过 30s")
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// 3. handleSmartRetry + 503 + SingleAccountRetry → 走 handleSingleAccountRetryInPlace
|
||||
// (而非设模型限流 + 切换账号)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// TestHandleSmartRetry_503_LongDelay_SingleAccountRetry_RetryInPlace
|
||||
// 核心场景:503 + retryDelay >= 7s + SingleAccountRetry 标记
|
||||
// → 不设模型限流、不切换账号,改为原地重试
|
||||
func TestHandleSmartRetry_503_LongDelay_SingleAccountRetry_RetryInPlace(t *testing.T) {
|
||||
// 原地重试成功
|
||||
successResp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(strings.NewReader(`{"result":"ok"}`)),
|
||||
}
|
||||
upstream := &mockSmartRetryUpstream{
|
||||
responses: []*http.Response{successResp},
|
||||
errors: []error{nil},
|
||||
}
|
||||
|
||||
repo := &stubAntigravityAccountRepo{}
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Name: "acc-single",
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
Concurrency: 1,
|
||||
}
|
||||
|
||||
// 503 + 39s >= 7s 阈值 + MODEL_CAPACITY_EXHAUSTED
|
||||
respBody := []byte(`{
|
||||
"error": {
|
||||
"code": 503,
|
||||
"status": "UNAVAILABLE",
|
||||
"details": [
|
||||
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro-high"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
|
||||
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "39s"}
|
||||
],
|
||||
"message": "No capacity available for model gemini-3-pro-high on the server"
|
||||
}
|
||||
}`)
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusServiceUnavailable,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||||
}
|
||||
|
||||
params := antigravityRetryLoopParams{
|
||||
ctx: ctxWithSingleAccountRetry(), // 关键:设置单账号标记
|
||||
prefix: "[test]",
|
||||
account: account,
|
||||
accessToken: "token",
|
||||
action: "generateContent",
|
||||
body: []byte(`{"input":"test"}`),
|
||||
httpUpstream: upstream,
|
||||
accountRepo: repo,
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
availableURLs := []string{"https://ag-1.test"}
|
||||
|
||||
svc := &AntigravityGatewayService{}
|
||||
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
|
||||
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, smartRetryActionBreakWithResp, result.action)
|
||||
// 关键断言:返回 resp(原地重试成功),而非 switchError(切换账号)
|
||||
require.NotNil(t, result.resp, "should return successful response from in-place retry")
|
||||
require.Equal(t, http.StatusOK, result.resp.StatusCode)
|
||||
require.Nil(t, result.switchError, "should NOT return switchError in single account mode")
|
||||
require.Nil(t, result.err)
|
||||
|
||||
// 验证未设模型限流(单账号模式不应设限流)
|
||||
require.Len(t, repo.modelRateLimitCalls, 0,
|
||||
"should NOT set model rate limit in single account retry mode")
|
||||
|
||||
// 验证确实调用了 upstream(原地重试)
|
||||
require.GreaterOrEqual(t, len(upstream.calls), 1, "should have made at least one retry call")
|
||||
}
|
||||
|
||||
// TestHandleSmartRetry_503_LongDelay_NoSingleAccountRetry_StillSwitches
|
||||
// 对照组:503 + retryDelay >= 7s + 无 SingleAccountRetry 标记
|
||||
// → 照常设模型限流 + 切换账号
|
||||
func TestHandleSmartRetry_503_LongDelay_NoSingleAccountRetry_StillSwitches(t *testing.T) {
|
||||
repo := &stubAntigravityAccountRepo{}
|
||||
account := &Account{
|
||||
ID: 2,
|
||||
Name: "acc-multi",
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
}
|
||||
|
||||
// 503 + 39s >= 7s 阈值(使用 RATE_LIMIT_EXCEEDED 而非 MODEL_CAPACITY_EXHAUSTED,
|
||||
// 因为 MODEL_CAPACITY_EXHAUSTED 走独立的重试路径,不触发 shouldRateLimitModel)
|
||||
respBody := []byte(`{
|
||||
"error": {
|
||||
"code": 503,
|
||||
"status": "RESOURCE_EXHAUSTED",
|
||||
"details": [
|
||||
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro-high"}, "reason": "RATE_LIMIT_EXCEEDED"},
|
||||
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "39s"}
|
||||
]
|
||||
}
|
||||
}`)
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusServiceUnavailable,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||||
}
|
||||
|
||||
params := antigravityRetryLoopParams{
|
||||
ctx: context.Background(), // 关键:无单账号标记
|
||||
prefix: "[test]",
|
||||
account: account,
|
||||
accessToken: "token",
|
||||
action: "generateContent",
|
||||
body: []byte(`{"input":"test"}`),
|
||||
accountRepo: repo,
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
availableURLs := []string{"https://ag-1.test"}
|
||||
|
||||
svc := &AntigravityGatewayService{}
|
||||
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
|
||||
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, smartRetryActionBreakWithResp, result.action)
|
||||
// 对照:多账号模式返回 switchError
|
||||
require.NotNil(t, result.switchError, "multi-account mode should return switchError for 503")
|
||||
require.Nil(t, result.resp, "should not return resp when switchError is set")
|
||||
|
||||
// 对照:多账号模式应设模型限流
|
||||
require.Len(t, repo.modelRateLimitCalls, 1,
|
||||
"multi-account mode SHOULD set model rate limit")
|
||||
}
|
||||
|
||||
// TestHandleSmartRetry_429_LongDelay_SingleAccountRetry_StillSwitches
|
||||
// 边界情况:429(非 503)+ SingleAccountRetry 标记
|
||||
// → 单账号原地重试仅针对 503,429 依然走切换账号逻辑
|
||||
func TestHandleSmartRetry_429_LongDelay_SingleAccountRetry_StillSwitches(t *testing.T) {
|
||||
repo := &stubAntigravityAccountRepo{}
|
||||
account := &Account{
|
||||
ID: 3,
|
||||
Name: "acc-429",
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
}
|
||||
|
||||
// 429 + 15s >= 7s 阈值
|
||||
respBody := []byte(`{
|
||||
"error": {
|
||||
"status": "RESOURCE_EXHAUSTED",
|
||||
"details": [
|
||||
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"},
|
||||
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "15s"}
|
||||
]
|
||||
}
|
||||
}`)
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusTooManyRequests, // 429,不是 503
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||||
}
|
||||
|
||||
params := antigravityRetryLoopParams{
|
||||
ctx: ctxWithSingleAccountRetry(), // 有单账号标记
|
||||
prefix: "[test]",
|
||||
account: account,
|
||||
accessToken: "token",
|
||||
action: "generateContent",
|
||||
body: []byte(`{"input":"test"}`),
|
||||
accountRepo: repo,
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
availableURLs := []string{"https://ag-1.test"}
|
||||
|
||||
svc := &AntigravityGatewayService{}
|
||||
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
|
||||
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, smartRetryActionBreakWithResp, result.action)
|
||||
// 429 即使有单账号标记,也应走切换账号
|
||||
require.NotNil(t, result.switchError, "429 should still return switchError even with SingleAccountRetry")
|
||||
require.Len(t, repo.modelRateLimitCalls, 1,
|
||||
"429 should still set model rate limit even with SingleAccountRetry")
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// 4. handleSmartRetry + 503 + 短延迟 + SingleAccountRetry → 智能重试耗尽后不设限流
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// TestHandleSmartRetry_503_ShortDelay_SingleAccountRetry_NoRateLimit
|
||||
// 503 + retryDelay < 7s + SingleAccountRetry → 智能重试耗尽后直接返回 503,不设限流
|
||||
func TestHandleSmartRetry_503_ShortDelay_SingleAccountRetry_NoRateLimit(t *testing.T) {
|
||||
// 智能重试也返回 503
|
||||
failRespBody := `{
|
||||
"error": {
|
||||
"code": 503,
|
||||
"status": "UNAVAILABLE",
|
||||
"details": [
|
||||
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
|
||||
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
|
||||
]
|
||||
}
|
||||
}`
|
||||
failResp := &http.Response{
|
||||
StatusCode: http.StatusServiceUnavailable,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(strings.NewReader(failRespBody)),
|
||||
}
|
||||
upstream := &mockSmartRetryUpstream{
|
||||
responses: []*http.Response{failResp},
|
||||
errors: []error{nil},
|
||||
}
|
||||
|
||||
repo := &stubAntigravityAccountRepo{}
|
||||
account := &Account{
|
||||
ID: 4,
|
||||
Name: "acc-short-503",
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
}
|
||||
|
||||
// 0.1s < 7s 阈值
|
||||
respBody := []byte(`{
|
||||
"error": {
|
||||
"code": 503,
|
||||
"status": "UNAVAILABLE",
|
||||
"details": [
|
||||
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
|
||||
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
|
||||
]
|
||||
}
|
||||
}`)
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusServiceUnavailable,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||||
}
|
||||
|
||||
params := antigravityRetryLoopParams{
|
||||
ctx: ctxWithSingleAccountRetry(),
|
||||
prefix: "[test]",
|
||||
account: account,
|
||||
accessToken: "token",
|
||||
action: "generateContent",
|
||||
body: []byte(`{"input":"test"}`),
|
||||
httpUpstream: upstream,
|
||||
accountRepo: repo,
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
availableURLs := []string{"https://ag-1.test"}
|
||||
|
||||
svc := &AntigravityGatewayService{}
|
||||
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
|
||||
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, smartRetryActionBreakWithResp, result.action)
|
||||
// 关键断言:单账号 503 模式下,智能重试耗尽后直接返回 503 响应,不切换
|
||||
require.NotNil(t, result.resp, "should return 503 response directly for single account mode")
|
||||
require.Equal(t, http.StatusServiceUnavailable, result.resp.StatusCode)
|
||||
require.Nil(t, result.switchError, "should NOT switch account in single account mode")
|
||||
|
||||
// 关键断言:不设模型限流
|
||||
require.Len(t, repo.modelRateLimitCalls, 0,
|
||||
"should NOT set model rate limit for 503 in single account mode")
|
||||
}
|
||||
|
||||
// TestHandleSmartRetry_503_ShortDelay_NoSingleAccountRetry_SetsRateLimit
|
||||
// 对照组:503 + retryDelay < 7s + 无 SingleAccountRetry → 智能重试耗尽后照常设限流
|
||||
// 使用 RATE_LIMIT_EXCEEDED 而非 MODEL_CAPACITY_EXHAUSTED,因为后者走独立的 60 次重试路径
|
||||
func TestHandleSmartRetry_503_ShortDelay_NoSingleAccountRetry_SetsRateLimit(t *testing.T) {
|
||||
failRespBody := `{
|
||||
"error": {
|
||||
"code": 503,
|
||||
"status": "RESOURCE_EXHAUSTED",
|
||||
"details": [
|
||||
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "RATE_LIMIT_EXCEEDED"},
|
||||
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
|
||||
]
|
||||
}
|
||||
}`
|
||||
failResp := &http.Response{
|
||||
StatusCode: http.StatusServiceUnavailable,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(strings.NewReader(failRespBody)),
|
||||
}
|
||||
upstream := &mockSmartRetryUpstream{
|
||||
responses: []*http.Response{failResp},
|
||||
errors: []error{nil},
|
||||
}
|
||||
|
||||
repo := &stubAntigravityAccountRepo{}
|
||||
account := &Account{
|
||||
ID: 5,
|
||||
Name: "acc-multi-503",
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
}
|
||||
|
||||
respBody := []byte(`{
|
||||
"error": {
|
||||
"code": 503,
|
||||
"status": "RESOURCE_EXHAUSTED",
|
||||
"details": [
|
||||
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "RATE_LIMIT_EXCEEDED"},
|
||||
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
|
||||
]
|
||||
}
|
||||
}`)
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusServiceUnavailable,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||||
}
|
||||
|
||||
params := antigravityRetryLoopParams{
|
||||
ctx: context.Background(), // 无单账号标记
|
||||
prefix: "[test]",
|
||||
account: account,
|
||||
accessToken: "token",
|
||||
action: "generateContent",
|
||||
body: []byte(`{"input":"test"}`),
|
||||
httpUpstream: upstream,
|
||||
accountRepo: repo,
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
availableURLs := []string{"https://ag-1.test"}
|
||||
|
||||
svc := &AntigravityGatewayService{}
|
||||
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
|
||||
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, smartRetryActionBreakWithResp, result.action)
|
||||
// 对照:多账号模式应返回 switchError
|
||||
require.NotNil(t, result.switchError, "multi-account mode should return switchError for 503")
|
||||
// 对照:多账号模式应设模型限流
|
||||
require.Len(t, repo.modelRateLimitCalls, 1,
|
||||
"multi-account mode should set model rate limit")
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// 5. handleSingleAccountRetryInPlace 直接测试
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// TestHandleSingleAccountRetryInPlace_Success 原地重试成功
|
||||
func TestHandleSingleAccountRetryInPlace_Success(t *testing.T) {
|
||||
successResp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(strings.NewReader(`{"result":"ok"}`)),
|
||||
}
|
||||
upstream := &mockSmartRetryUpstream{
|
||||
responses: []*http.Response{successResp},
|
||||
errors: []error{nil},
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: 10,
|
||||
Name: "acc-inplace-ok",
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
Concurrency: 1,
|
||||
}
|
||||
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusServiceUnavailable,
|
||||
Header: http.Header{},
|
||||
}
|
||||
|
||||
params := antigravityRetryLoopParams{
|
||||
ctx: ctxWithSingleAccountRetry(),
|
||||
prefix: "[test]",
|
||||
account: account,
|
||||
accessToken: "token",
|
||||
action: "generateContent",
|
||||
body: []byte(`{"input":"test"}`),
|
||||
httpUpstream: upstream,
|
||||
}
|
||||
|
||||
svc := &AntigravityGatewayService{}
|
||||
result := svc.handleSingleAccountRetryInPlace(params, resp, nil, "https://ag-1.test", 1*time.Second, "gemini-3-pro")
|
||||
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, smartRetryActionBreakWithResp, result.action)
|
||||
require.NotNil(t, result.resp, "should return successful response")
|
||||
require.Equal(t, http.StatusOK, result.resp.StatusCode)
|
||||
require.Nil(t, result.switchError, "should not switch account on success")
|
||||
require.Nil(t, result.err)
|
||||
}
|
||||
|
||||
// TestHandleSingleAccountRetryInPlace_AllRetriesFail 所有重试都失败,返回 503(不设限流)
|
||||
func TestHandleSingleAccountRetryInPlace_AllRetriesFail(t *testing.T) {
|
||||
// 构造 3 个 503 响应(对应 3 次原地重试)
|
||||
var responses []*http.Response
|
||||
var errors []error
|
||||
for i := 0; i < antigravitySingleAccountSmartRetryMaxAttempts; i++ {
|
||||
responses = append(responses, &http.Response{
|
||||
StatusCode: http.StatusServiceUnavailable,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(strings.NewReader(`{
|
||||
"error": {
|
||||
"code": 503,
|
||||
"status": "UNAVAILABLE",
|
||||
"details": [
|
||||
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
|
||||
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
|
||||
]
|
||||
}
|
||||
}`)),
|
||||
})
|
||||
errors = append(errors, nil)
|
||||
}
|
||||
upstream := &mockSmartRetryUpstream{
|
||||
responses: responses,
|
||||
errors: errors,
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: 11,
|
||||
Name: "acc-inplace-fail",
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
Concurrency: 1,
|
||||
}
|
||||
|
||||
origBody := []byte(`{"error":{"code":503,"status":"UNAVAILABLE"}}`)
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusServiceUnavailable,
|
||||
Header: http.Header{"X-Test": {"original"}},
|
||||
}
|
||||
|
||||
params := antigravityRetryLoopParams{
|
||||
ctx: ctxWithSingleAccountRetry(),
|
||||
prefix: "[test]",
|
||||
account: account,
|
||||
accessToken: "token",
|
||||
action: "generateContent",
|
||||
body: []byte(`{"input":"test"}`),
|
||||
httpUpstream: upstream,
|
||||
}
|
||||
|
||||
svc := &AntigravityGatewayService{}
|
||||
result := svc.handleSingleAccountRetryInPlace(params, resp, origBody, "https://ag-1.test", 1*time.Second, "gemini-3-pro")
|
||||
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, smartRetryActionBreakWithResp, result.action)
|
||||
// 关键:返回 503 resp,不返回 switchError
|
||||
require.NotNil(t, result.resp, "should return 503 response directly")
|
||||
require.Equal(t, http.StatusServiceUnavailable, result.resp.StatusCode)
|
||||
require.Nil(t, result.switchError, "should NOT return switchError - let Handler handle it")
|
||||
require.Nil(t, result.err)
|
||||
|
||||
// 验证确实重试了指定次数
|
||||
require.Len(t, upstream.calls, antigravitySingleAccountSmartRetryMaxAttempts,
|
||||
"should have made exactly maxAttempts retry calls")
|
||||
}
|
||||
|
||||
// TestHandleSingleAccountRetryInPlace_WaitDurationClamped 等待时间被限制在 [min, max] 范围
|
||||
func TestHandleSingleAccountRetryInPlace_WaitDurationClamped(t *testing.T) {
|
||||
// 用短延迟的成功响应,只验证不 panic
|
||||
successResp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(strings.NewReader(`{"result":"ok"}`)),
|
||||
}
|
||||
upstream := &mockSmartRetryUpstream{
|
||||
responses: []*http.Response{successResp},
|
||||
errors: []error{nil},
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: 12,
|
||||
Name: "acc-clamp",
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
Concurrency: 1,
|
||||
}
|
||||
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusServiceUnavailable,
|
||||
Header: http.Header{},
|
||||
}
|
||||
|
||||
params := antigravityRetryLoopParams{
|
||||
ctx: ctxWithSingleAccountRetry(),
|
||||
prefix: "[test]",
|
||||
account: account,
|
||||
accessToken: "token",
|
||||
action: "generateContent",
|
||||
body: []byte(`{"input":"test"}`),
|
||||
httpUpstream: upstream,
|
||||
}
|
||||
|
||||
svc := &AntigravityGatewayService{}
|
||||
|
||||
// 等待时间过大应被 clamp 到 antigravitySingleAccountSmartRetryMaxWait
|
||||
result := svc.handleSingleAccountRetryInPlace(params, resp, nil, "https://ag-1.test", 999*time.Second, "gemini-3-pro")
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, smartRetryActionBreakWithResp, result.action)
|
||||
require.NotNil(t, result.resp)
|
||||
require.Equal(t, http.StatusOK, result.resp.StatusCode)
|
||||
}
|
||||
|
||||
// TestHandleSingleAccountRetryInPlace_ContextCanceled context 取消时立即返回
|
||||
func TestHandleSingleAccountRetryInPlace_ContextCanceled(t *testing.T) {
|
||||
upstream := &mockSmartRetryUpstream{
|
||||
responses: []*http.Response{nil},
|
||||
errors: []error{nil},
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: 13,
|
||||
Name: "acc-cancel",
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
Concurrency: 1,
|
||||
}
|
||||
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusServiceUnavailable,
|
||||
Header: http.Header{},
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
ctx = context.WithValue(ctx, ctxkey.SingleAccountRetry, true)
|
||||
cancel() // 立即取消
|
||||
|
||||
params := antigravityRetryLoopParams{
|
||||
ctx: ctx,
|
||||
prefix: "[test]",
|
||||
account: account,
|
||||
accessToken: "token",
|
||||
action: "generateContent",
|
||||
body: []byte(`{"input":"test"}`),
|
||||
httpUpstream: upstream,
|
||||
}
|
||||
|
||||
svc := &AntigravityGatewayService{}
|
||||
result := svc.handleSingleAccountRetryInPlace(params, resp, nil, "https://ag-1.test", 1*time.Second, "gemini-3-pro")
|
||||
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, smartRetryActionBreakWithResp, result.action)
|
||||
require.Error(t, result.err, "should return context error")
|
||||
// 不应调用 upstream(因为在等待阶段就被取消了)
|
||||
require.Len(t, upstream.calls, 0, "should not call upstream when context is canceled")
|
||||
}
|
||||
|
||||
// TestHandleSingleAccountRetryInPlace_NetworkError_ContinuesRetry 网络错误时继续重试
|
||||
func TestHandleSingleAccountRetryInPlace_NetworkError_ContinuesRetry(t *testing.T) {
|
||||
successResp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(strings.NewReader(`{"result":"ok"}`)),
|
||||
}
|
||||
upstream := &mockSmartRetryUpstream{
|
||||
// 第1次网络错误(nil resp),第2次成功
|
||||
responses: []*http.Response{nil, successResp},
|
||||
errors: []error{nil, nil},
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: 14,
|
||||
Name: "acc-net-retry",
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
Concurrency: 1,
|
||||
}
|
||||
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusServiceUnavailable,
|
||||
Header: http.Header{},
|
||||
}
|
||||
|
||||
params := antigravityRetryLoopParams{
|
||||
ctx: ctxWithSingleAccountRetry(),
|
||||
prefix: "[test]",
|
||||
account: account,
|
||||
accessToken: "token",
|
||||
action: "generateContent",
|
||||
body: []byte(`{"input":"test"}`),
|
||||
httpUpstream: upstream,
|
||||
}
|
||||
|
||||
svc := &AntigravityGatewayService{}
|
||||
result := svc.handleSingleAccountRetryInPlace(params, resp, nil, "https://ag-1.test", 1*time.Second, "gemini-3-pro")
|
||||
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, smartRetryActionBreakWithResp, result.action)
|
||||
require.NotNil(t, result.resp, "should return successful response after network error recovery")
|
||||
require.Equal(t, http.StatusOK, result.resp.StatusCode)
|
||||
require.Len(t, upstream.calls, 2, "first call fails (network error), second succeeds")
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// 6. antigravityRetryLoop 预检查:单账号模式跳过限流
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// TestAntigravityRetryLoop_PreCheck_SingleAccountRetry_SkipsRateLimit
|
||||
// 预检查中,如果有 SingleAccountRetry 标记,即使账号已限流也跳过直接发请求
|
||||
func TestAntigravityRetryLoop_PreCheck_SingleAccountRetry_SkipsRateLimit(t *testing.T) {
|
||||
// 创建一个已设模型限流的账号
|
||||
upstream := &recordingOKUpstream{}
|
||||
account := &Account{
|
||||
ID: 20,
|
||||
Name: "acc-rate-limited",
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
Schedulable: true,
|
||||
Status: StatusActive,
|
||||
Concurrency: 1,
|
||||
Extra: map[string]any{
|
||||
modelRateLimitsKey: map[string]any{
|
||||
"claude-sonnet-4-5": map[string]any{
|
||||
"rate_limit_reset_at": time.Now().Add(30 * time.Second).Format(time.RFC3339),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
svc := &AntigravityGatewayService{}
|
||||
result, err := svc.antigravityRetryLoop(antigravityRetryLoopParams{
|
||||
ctx: ctxWithSingleAccountRetry(),
|
||||
prefix: "[test]",
|
||||
account: account,
|
||||
accessToken: "token",
|
||||
action: "generateContent",
|
||||
body: []byte(`{"input":"test"}`),
|
||||
httpUpstream: upstream,
|
||||
requestedModel: "claude-sonnet-4-5",
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
})
|
||||
|
||||
require.NoError(t, err, "should not return error")
|
||||
require.NotNil(t, result, "should return result")
|
||||
require.NotNil(t, result.resp, "should have response")
|
||||
require.Equal(t, http.StatusOK, result.resp.StatusCode)
|
||||
// 关键:尽管限流了,有 SingleAccountRetry 标记时仍然到达了 upstream
|
||||
require.Equal(t, 1, upstream.calls, "should have reached upstream despite rate limit")
|
||||
}
|
||||
|
||||
// TestAntigravityRetryLoop_PreCheck_NoSingleAccountRetry_SwitchesOnRateLimit
|
||||
// 对照组:无 SingleAccountRetry + 已限流 → 预检查返回 switchError
|
||||
func TestAntigravityRetryLoop_PreCheck_NoSingleAccountRetry_SwitchesOnRateLimit(t *testing.T) {
|
||||
upstream := &recordingOKUpstream{}
|
||||
account := &Account{
|
||||
ID: 21,
|
||||
Name: "acc-rate-limited-multi",
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
Schedulable: true,
|
||||
Status: StatusActive,
|
||||
Concurrency: 1,
|
||||
Extra: map[string]any{
|
||||
modelRateLimitsKey: map[string]any{
|
||||
"claude-sonnet-4-5": map[string]any{
|
||||
"rate_limit_reset_at": time.Now().Add(30 * time.Second).Format(time.RFC3339),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
svc := &AntigravityGatewayService{}
|
||||
result, err := svc.antigravityRetryLoop(antigravityRetryLoopParams{
|
||||
ctx: context.Background(), // 无单账号标记
|
||||
prefix: "[test]",
|
||||
account: account,
|
||||
accessToken: "token",
|
||||
action: "generateContent",
|
||||
body: []byte(`{"input":"test"}`),
|
||||
httpUpstream: upstream,
|
||||
requestedModel: "claude-sonnet-4-5",
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
})
|
||||
|
||||
require.Nil(t, result, "should not return result on rate limit switch")
|
||||
require.NotNil(t, err, "should return error")
|
||||
|
||||
var switchErr *AntigravityAccountSwitchError
|
||||
require.ErrorAs(t, err, &switchErr, "should return AntigravityAccountSwitchError")
|
||||
require.Equal(t, account.ID, switchErr.OriginalAccountID)
|
||||
require.Equal(t, "claude-sonnet-4-5", switchErr.RateLimitedModel)
|
||||
|
||||
// upstream 不应被调用(预检查就短路了)
|
||||
require.Equal(t, 0, upstream.calls, "upstream should NOT be called when pre-check blocks")
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// 7. 端到端集成场景测试
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// TestHandleSmartRetry_503_SingleAccount_RetryInPlace_ThenSuccess_E2E
|
||||
// 端到端场景:503 + 单账号 + 原地重试第2次成功
|
||||
func TestHandleSmartRetry_503_SingleAccount_RetryInPlace_ThenSuccess_E2E(t *testing.T) {
|
||||
// 第1次原地重试仍返回 503,第2次成功
|
||||
fail503Body := `{
|
||||
"error": {
|
||||
"code": 503,
|
||||
"status": "UNAVAILABLE",
|
||||
"details": [
|
||||
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
|
||||
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
|
||||
]
|
||||
}
|
||||
}`
|
||||
resp503 := &http.Response{
|
||||
StatusCode: http.StatusServiceUnavailable,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(strings.NewReader(fail503Body)),
|
||||
}
|
||||
successResp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(strings.NewReader(`{"result":"ok"}`)),
|
||||
}
|
||||
|
||||
upstream := &mockSmartRetryUpstream{
|
||||
responses: []*http.Response{resp503, successResp},
|
||||
errors: []error{nil, nil},
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: 30,
|
||||
Name: "acc-e2e",
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
Concurrency: 1,
|
||||
}
|
||||
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusServiceUnavailable,
|
||||
Header: http.Header{},
|
||||
}
|
||||
|
||||
params := antigravityRetryLoopParams{
|
||||
ctx: ctxWithSingleAccountRetry(),
|
||||
prefix: "[test]",
|
||||
account: account,
|
||||
accessToken: "token",
|
||||
action: "generateContent",
|
||||
body: []byte(`{"input":"test"}`),
|
||||
httpUpstream: upstream,
|
||||
}
|
||||
|
||||
svc := &AntigravityGatewayService{}
|
||||
result := svc.handleSingleAccountRetryInPlace(params, resp, nil, "https://ag-1.test", 1*time.Second, "gemini-3-pro")
|
||||
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, smartRetryActionBreakWithResp, result.action)
|
||||
require.NotNil(t, result.resp, "should return successful response after 2nd attempt")
|
||||
require.Equal(t, http.StatusOK, result.resp.StatusCode)
|
||||
require.Nil(t, result.switchError)
|
||||
require.Len(t, upstream.calls, 2, "first 503, second OK")
|
||||
}
|
||||
|
||||
// TestAntigravityRetryLoop_503_SingleAccount_InPlaceRetryUsed_E2E
|
||||
// 通过 antigravityRetryLoop → handleSmartRetry → handleSingleAccountRetryInPlace 完整链路
|
||||
func TestAntigravityRetryLoop_503_SingleAccount_InPlaceRetryUsed_E2E(t *testing.T) {
|
||||
// 初始请求返回 503 + 长延迟
|
||||
initial503Body := []byte(`{
|
||||
"error": {
|
||||
"code": 503,
|
||||
"status": "UNAVAILABLE",
|
||||
"details": [
|
||||
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
|
||||
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "10s"}
|
||||
],
|
||||
"message": "No capacity available"
|
||||
}
|
||||
}`)
|
||||
initial503Resp := &http.Response{
|
||||
StatusCode: http.StatusServiceUnavailable,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(bytes.NewReader(initial503Body)),
|
||||
}
|
||||
|
||||
// 原地重试成功
|
||||
successResp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(strings.NewReader(`{"result":"ok"}`)),
|
||||
}
|
||||
|
||||
upstream := &mockSmartRetryUpstream{
|
||||
// 第1次调用(retryLoop 主循环)返回 503
|
||||
// 第2次调用(handleSingleAccountRetryInPlace 原地重试)返回 200
|
||||
responses: []*http.Response{initial503Resp, successResp},
|
||||
errors: []error{nil, nil},
|
||||
}
|
||||
|
||||
repo := &stubAntigravityAccountRepo{}
|
||||
account := &Account{
|
||||
ID: 31,
|
||||
Name: "acc-e2e-loop",
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
Schedulable: true,
|
||||
Status: StatusActive,
|
||||
Concurrency: 1,
|
||||
}
|
||||
|
||||
svc := &AntigravityGatewayService{}
|
||||
result, err := svc.antigravityRetryLoop(antigravityRetryLoopParams{
|
||||
ctx: ctxWithSingleAccountRetry(),
|
||||
prefix: "[test]",
|
||||
account: account,
|
||||
accessToken: "token",
|
||||
action: "generateContent",
|
||||
body: []byte(`{"input":"test"}`),
|
||||
httpUpstream: upstream,
|
||||
accountRepo: repo,
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
})
|
||||
|
||||
require.NoError(t, err, "should not return error on successful retry")
|
||||
require.NotNil(t, result, "should return result")
|
||||
require.NotNil(t, result.resp, "should return response")
|
||||
require.Equal(t, http.StatusOK, result.resp.StatusCode)
|
||||
|
||||
// 验证未设模型限流
|
||||
require.Len(t, repo.modelRateLimitCalls, 0,
|
||||
"should NOT set model rate limit in single account retry mode")
|
||||
}
|
||||
@@ -13,6 +13,23 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// stubSmartRetryCache 用于 handleSmartRetry 测试的 GatewayCache mock
|
||||
// 仅关注 DeleteSessionAccountID 的调用记录
|
||||
type stubSmartRetryCache struct {
|
||||
GatewayCache // 嵌入接口,未实现的方法 panic(确保只调用预期方法)
|
||||
deleteCalls []deleteSessionCall
|
||||
}
|
||||
|
||||
type deleteSessionCall struct {
|
||||
groupID int64
|
||||
sessionHash string
|
||||
}
|
||||
|
||||
func (c *stubSmartRetryCache) DeleteSessionAccountID(_ context.Context, groupID int64, sessionHash string) error {
|
||||
c.deleteCalls = append(c.deleteCalls, deleteSessionCall{groupID: groupID, sessionHash: sessionHash})
|
||||
return nil
|
||||
}
|
||||
|
||||
// mockSmartRetryUpstream 用于 handleSmartRetry 测试的 mock upstream
|
||||
type mockSmartRetryUpstream struct {
|
||||
responses []*http.Response
|
||||
@@ -58,7 +75,7 @@ func TestHandleSmartRetry_URLLevelRateLimit(t *testing.T) {
|
||||
accessToken: "token",
|
||||
action: "generateContent",
|
||||
body: []byte(`{"input":"test"}`),
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
@@ -110,7 +127,7 @@ func TestHandleSmartRetry_LongDelay_ReturnsSwitchError(t *testing.T) {
|
||||
body: []byte(`{"input":"test"}`),
|
||||
accountRepo: repo,
|
||||
isStickySession: true,
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
@@ -177,7 +194,7 @@ func TestHandleSmartRetry_ShortDelay_SmartRetrySuccess(t *testing.T) {
|
||||
action: "generateContent",
|
||||
body: []byte(`{"input":"test"}`),
|
||||
httpUpstream: upstream,
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
@@ -198,7 +215,7 @@ func TestHandleSmartRetry_ShortDelay_SmartRetrySuccess(t *testing.T) {
|
||||
|
||||
// TestHandleSmartRetry_ShortDelay_SmartRetryFailed_ReturnsSwitchError 测试智能重试失败后返回 switchError
|
||||
func TestHandleSmartRetry_ShortDelay_SmartRetryFailed_ReturnsSwitchError(t *testing.T) {
|
||||
// 智能重试后仍然返回 429(需要提供 3 个响应,因为智能重试最多 3 次)
|
||||
// 智能重试后仍然返回 429(需要提供 1 个响应,因为智能重试最多 1 次)
|
||||
failRespBody := `{
|
||||
"error": {
|
||||
"status": "RESOURCE_EXHAUSTED",
|
||||
@@ -213,19 +230,9 @@ func TestHandleSmartRetry_ShortDelay_SmartRetryFailed_ReturnsSwitchError(t *test
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(strings.NewReader(failRespBody)),
|
||||
}
|
||||
failResp2 := &http.Response{
|
||||
StatusCode: http.StatusTooManyRequests,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(strings.NewReader(failRespBody)),
|
||||
}
|
||||
failResp3 := &http.Response{
|
||||
StatusCode: http.StatusTooManyRequests,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(strings.NewReader(failRespBody)),
|
||||
}
|
||||
upstream := &mockSmartRetryUpstream{
|
||||
responses: []*http.Response{failResp1, failResp2, failResp3},
|
||||
errors: []error{nil, nil, nil},
|
||||
responses: []*http.Response{failResp1},
|
||||
errors: []error{nil},
|
||||
}
|
||||
|
||||
repo := &stubAntigravityAccountRepo{}
|
||||
@@ -236,7 +243,7 @@ func TestHandleSmartRetry_ShortDelay_SmartRetryFailed_ReturnsSwitchError(t *test
|
||||
Platform: PlatformAntigravity,
|
||||
}
|
||||
|
||||
// 3s < 7s 阈值,应该触发智能重试(最多 3 次)
|
||||
// 3s < 7s 阈值,应该触发智能重试(最多 1 次)
|
||||
respBody := []byte(`{
|
||||
"error": {
|
||||
"status": "RESOURCE_EXHAUSTED",
|
||||
@@ -262,7 +269,7 @@ func TestHandleSmartRetry_ShortDelay_SmartRetryFailed_ReturnsSwitchError(t *test
|
||||
httpUpstream: upstream,
|
||||
accountRepo: repo,
|
||||
isStickySession: false,
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
@@ -284,11 +291,12 @@ func TestHandleSmartRetry_ShortDelay_SmartRetryFailed_ReturnsSwitchError(t *test
|
||||
// 验证模型限流已设置
|
||||
require.Len(t, repo.modelRateLimitCalls, 1)
|
||||
require.Equal(t, "gemini-3-flash", repo.modelRateLimitCalls[0].modelKey)
|
||||
require.Len(t, upstream.calls, 3, "should have made three retry calls (max attempts)")
|
||||
require.Len(t, upstream.calls, 1, "should have made one retry call (max attempts)")
|
||||
}
|
||||
|
||||
// TestHandleSmartRetry_503_ModelCapacityExhausted_ReturnsSwitchError 测试 503 MODEL_CAPACITY_EXHAUSTED 返回 switchError
|
||||
func TestHandleSmartRetry_503_ModelCapacityExhausted_ReturnsSwitchError(t *testing.T) {
|
||||
// TestHandleSmartRetry_503_ModelCapacityExhausted_RetrySuccess 测试 503 MODEL_CAPACITY_EXHAUSTED 重试成功
|
||||
// MODEL_CAPACITY_EXHAUSTED 使用固定 1s 间隔重试,不切换账号
|
||||
func TestHandleSmartRetry_503_ModelCapacityExhausted_RetrySuccess(t *testing.T) {
|
||||
repo := &stubAntigravityAccountRepo{}
|
||||
account := &Account{
|
||||
ID: 3,
|
||||
@@ -297,7 +305,7 @@ func TestHandleSmartRetry_503_ModelCapacityExhausted_ReturnsSwitchError(t *testi
|
||||
Platform: PlatformAntigravity,
|
||||
}
|
||||
|
||||
// 503 + MODEL_CAPACITY_EXHAUSTED + 39s >= 7s 阈值
|
||||
// 503 + MODEL_CAPACITY_EXHAUSTED + 39s(上游 retryDelay 应被忽略,使用固定 1s)
|
||||
respBody := []byte(`{
|
||||
"error": {
|
||||
"code": 503,
|
||||
@@ -315,6 +323,14 @@ func TestHandleSmartRetry_503_ModelCapacityExhausted_ReturnsSwitchError(t *testi
|
||||
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||||
}
|
||||
|
||||
// mock: 第 1 次重试返回 200 成功
|
||||
upstream := &mockSmartRetryUpstream{
|
||||
responses: []*http.Response{
|
||||
{StatusCode: http.StatusOK, Header: http.Header{}, Body: io.NopCloser(strings.NewReader(`{"ok":true}`))},
|
||||
},
|
||||
errors: []error{nil},
|
||||
}
|
||||
|
||||
params := antigravityRetryLoopParams{
|
||||
ctx: context.Background(),
|
||||
prefix: "[test]",
|
||||
@@ -323,8 +339,9 @@ func TestHandleSmartRetry_503_ModelCapacityExhausted_ReturnsSwitchError(t *testi
|
||||
action: "generateContent",
|
||||
body: []byte(`{"input":"test"}`),
|
||||
accountRepo: repo,
|
||||
httpUpstream: upstream,
|
||||
isStickySession: true,
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
@@ -336,16 +353,67 @@ func TestHandleSmartRetry_503_ModelCapacityExhausted_ReturnsSwitchError(t *testi
|
||||
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, smartRetryActionBreakWithResp, result.action)
|
||||
require.Nil(t, result.resp)
|
||||
require.NotNil(t, result.resp, "should return successful response")
|
||||
require.Equal(t, http.StatusOK, result.resp.StatusCode)
|
||||
require.Nil(t, result.err)
|
||||
require.NotNil(t, result.switchError, "should return switchError for 503 model capacity exhausted")
|
||||
require.Equal(t, account.ID, result.switchError.OriginalAccountID)
|
||||
require.Equal(t, "gemini-3-pro-high", result.switchError.RateLimitedModel)
|
||||
require.True(t, result.switchError.IsStickySession)
|
||||
require.Nil(t, result.switchError, "MODEL_CAPACITY_EXHAUSTED should not return switchError")
|
||||
|
||||
// 验证模型限流已设置
|
||||
require.Len(t, repo.modelRateLimitCalls, 1)
|
||||
require.Equal(t, "gemini-3-pro-high", repo.modelRateLimitCalls[0].modelKey)
|
||||
// 不应设置模型限流
|
||||
require.Empty(t, repo.modelRateLimitCalls, "MODEL_CAPACITY_EXHAUSTED should not set model rate limit")
|
||||
require.Len(t, upstream.calls, 1, "should have made one retry call before success")
|
||||
}
|
||||
|
||||
// TestHandleSmartRetry_503_ModelCapacityExhausted_ContextCancel 测试 MODEL_CAPACITY_EXHAUSTED 上下文取消
|
||||
func TestHandleSmartRetry_503_ModelCapacityExhausted_ContextCancel(t *testing.T) {
|
||||
repo := &stubAntigravityAccountRepo{}
|
||||
account := &Account{
|
||||
ID: 3,
|
||||
Name: "acc-3",
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
}
|
||||
|
||||
respBody := []byte(`{
|
||||
"error": {
|
||||
"code": 503,
|
||||
"status": "UNAVAILABLE",
|
||||
"details": [
|
||||
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro-high"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
|
||||
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "39s"}
|
||||
]
|
||||
}
|
||||
}`)
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusServiceUnavailable,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||||
}
|
||||
|
||||
// 立即取消上下文,验证重试循环能正确退出
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
params := antigravityRetryLoopParams{
|
||||
ctx: ctx,
|
||||
prefix: "[test]",
|
||||
account: account,
|
||||
accessToken: "token",
|
||||
action: "generateContent",
|
||||
body: []byte(`{"input":"test"}`),
|
||||
accountRepo: repo,
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
svc := &AntigravityGatewayService{}
|
||||
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, []string{"https://ag-1.test"})
|
||||
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, smartRetryActionBreakWithResp, result.action)
|
||||
require.Error(t, result.err, "should return context error")
|
||||
require.Nil(t, result.switchError, "should not return switchError on context cancel")
|
||||
require.Empty(t, repo.modelRateLimitCalls, "should not set model rate limit on context cancel")
|
||||
}
|
||||
|
||||
// TestHandleSmartRetry_NonAntigravityAccount_ContinuesDefaultLogic 测试非 Antigravity 平台账号走默认逻辑
|
||||
@@ -380,7 +448,7 @@ func TestHandleSmartRetry_NonAntigravityAccount_ContinuesDefaultLogic(t *testing
|
||||
accessToken: "token",
|
||||
action: "generateContent",
|
||||
body: []byte(`{"input":"test"}`),
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
@@ -429,7 +497,7 @@ func TestHandleSmartRetry_NonModelRateLimit_ContinuesDefaultLogic(t *testing.T)
|
||||
accessToken: "token",
|
||||
action: "generateContent",
|
||||
body: []byte(`{"input":"test"}`),
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
@@ -480,7 +548,7 @@ func TestHandleSmartRetry_ExactlyAtThreshold_ReturnsSwitchError(t *testing.T) {
|
||||
action: "generateContent",
|
||||
body: []byte(`{"input":"test"}`),
|
||||
accountRepo: repo,
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
@@ -541,7 +609,7 @@ func TestAntigravityRetryLoop_HandleSmartRetry_SwitchError_Propagates(t *testing
|
||||
httpUpstream: upstream,
|
||||
accountRepo: repo,
|
||||
isStickySession: true,
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
})
|
||||
@@ -556,19 +624,15 @@ func TestAntigravityRetryLoop_HandleSmartRetry_SwitchError_Propagates(t *testing
|
||||
require.True(t, switchErr.IsStickySession)
|
||||
}
|
||||
|
||||
// TestHandleSmartRetry_NetworkError_ContinuesRetry 测试网络错误时继续重试
|
||||
func TestHandleSmartRetry_NetworkError_ContinuesRetry(t *testing.T) {
|
||||
// 第一次网络错误,第二次成功
|
||||
successResp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(strings.NewReader(`{"result":"ok"}`)),
|
||||
}
|
||||
// TestHandleSmartRetry_NetworkError_ExhaustsRetry 测试网络错误时(maxAttempts=1)直接耗尽重试并切换账号
|
||||
func TestHandleSmartRetry_NetworkError_ExhaustsRetry(t *testing.T) {
|
||||
// 唯一一次重试遇到网络错误(nil response)
|
||||
upstream := &mockSmartRetryUpstream{
|
||||
responses: []*http.Response{nil, successResp}, // 第一次返回 nil(模拟网络错误)
|
||||
errors: []error{nil, nil}, // mock 不返回 error,靠 nil response 触发
|
||||
responses: []*http.Response{nil}, // 返回 nil(模拟网络错误)
|
||||
errors: []error{nil}, // mock 不返回 error,靠 nil response 触发
|
||||
}
|
||||
|
||||
repo := &stubAntigravityAccountRepo{}
|
||||
account := &Account{
|
||||
ID: 8,
|
||||
Name: "acc-8",
|
||||
@@ -600,7 +664,8 @@ func TestHandleSmartRetry_NetworkError_ContinuesRetry(t *testing.T) {
|
||||
action: "generateContent",
|
||||
body: []byte(`{"input":"test"}`),
|
||||
httpUpstream: upstream,
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
accountRepo: repo,
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
@@ -612,10 +677,15 @@ func TestHandleSmartRetry_NetworkError_ContinuesRetry(t *testing.T) {
|
||||
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, smartRetryActionBreakWithResp, result.action)
|
||||
require.NotNil(t, result.resp, "should return successful response after network error recovery")
|
||||
require.Equal(t, http.StatusOK, result.resp.StatusCode)
|
||||
require.Nil(t, result.switchError, "should not return switchError on success")
|
||||
require.Len(t, upstream.calls, 2, "should have made two retry calls")
|
||||
require.Nil(t, result.resp, "should not return resp when switchError is set")
|
||||
require.NotNil(t, result.switchError, "should return switchError after network error exhausted retry")
|
||||
require.Equal(t, account.ID, result.switchError.OriginalAccountID)
|
||||
require.Equal(t, "claude-sonnet-4-5", result.switchError.RateLimitedModel)
|
||||
require.Len(t, upstream.calls, 1, "should have made one retry call")
|
||||
|
||||
// 验证模型限流已设置
|
||||
require.Len(t, repo.modelRateLimitCalls, 1)
|
||||
require.Equal(t, "claude-sonnet-4-5", repo.modelRateLimitCalls[0].modelKey)
|
||||
}
|
||||
|
||||
// TestHandleSmartRetry_NoRetryDelay_UsesDefaultRateLimit 测试无 retryDelay 时使用默认 1 分钟限流
|
||||
@@ -653,7 +723,7 @@ func TestHandleSmartRetry_NoRetryDelay_UsesDefaultRateLimit(t *testing.T) {
|
||||
body: []byte(`{"input":"test"}`),
|
||||
accountRepo: repo,
|
||||
isStickySession: true,
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
@@ -674,3 +744,617 @@ func TestHandleSmartRetry_NoRetryDelay_UsesDefaultRateLimit(t *testing.T) {
|
||||
require.Len(t, repo.modelRateLimitCalls, 1)
|
||||
require.Equal(t, "claude-sonnet-4-5", repo.modelRateLimitCalls[0].modelKey)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// 以下测试覆盖本次改动:
|
||||
// 1. antigravitySmartRetryMaxAttempts = 1(仅重试 1 次)
|
||||
// 2. 智能重试失败后清除粘性会话绑定(DeleteSessionAccountID)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// TestSmartRetryMaxAttempts_VerifyConstant 验证常量值为 1
|
||||
func TestSmartRetryMaxAttempts_VerifyConstant(t *testing.T) {
|
||||
require.Equal(t, 1, antigravitySmartRetryMaxAttempts,
|
||||
"antigravitySmartRetryMaxAttempts should be 1 to prevent repeated rate limiting")
|
||||
}
|
||||
|
||||
// TestHandleSmartRetry_ShortDelay_StickySession_FailedRetry_ClearsSession
|
||||
// 核心场景:粘性会话 + 短延迟重试失败 → 必须清除粘性绑定
|
||||
func TestHandleSmartRetry_ShortDelay_StickySession_FailedRetry_ClearsSession(t *testing.T) {
|
||||
failRespBody := `{
|
||||
"error": {
|
||||
"status": "RESOURCE_EXHAUSTED",
|
||||
"details": [
|
||||
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"},
|
||||
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
|
||||
]
|
||||
}
|
||||
}`
|
||||
failResp := &http.Response{
|
||||
StatusCode: http.StatusTooManyRequests,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(strings.NewReader(failRespBody)),
|
||||
}
|
||||
upstream := &mockSmartRetryUpstream{
|
||||
responses: []*http.Response{failResp},
|
||||
errors: []error{nil},
|
||||
}
|
||||
|
||||
repo := &stubAntigravityAccountRepo{}
|
||||
cache := &stubSmartRetryCache{}
|
||||
account := &Account{
|
||||
ID: 10,
|
||||
Name: "acc-10",
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
}
|
||||
|
||||
respBody := []byte(`{
|
||||
"error": {
|
||||
"status": "RESOURCE_EXHAUSTED",
|
||||
"details": [
|
||||
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"},
|
||||
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
|
||||
]
|
||||
}
|
||||
}`)
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusTooManyRequests,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||||
}
|
||||
|
||||
params := antigravityRetryLoopParams{
|
||||
ctx: context.Background(),
|
||||
prefix: "[test]",
|
||||
account: account,
|
||||
accessToken: "token",
|
||||
action: "generateContent",
|
||||
body: []byte(`{"input":"test"}`),
|
||||
httpUpstream: upstream,
|
||||
accountRepo: repo,
|
||||
isStickySession: true,
|
||||
groupID: 42,
|
||||
sessionHash: "sticky-hash-abc",
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
availableURLs := []string{"https://ag-1.test"}
|
||||
|
||||
svc := &AntigravityGatewayService{cache: cache}
|
||||
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
|
||||
|
||||
// 验证返回 switchError
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, smartRetryActionBreakWithResp, result.action)
|
||||
require.NotNil(t, result.switchError)
|
||||
require.True(t, result.switchError.IsStickySession, "switchError should carry IsStickySession=true")
|
||||
require.Equal(t, account.ID, result.switchError.OriginalAccountID)
|
||||
|
||||
// 核心断言:DeleteSessionAccountID 被调用,且参数正确
|
||||
require.Len(t, cache.deleteCalls, 1, "should call DeleteSessionAccountID exactly once")
|
||||
require.Equal(t, int64(42), cache.deleteCalls[0].groupID)
|
||||
require.Equal(t, "sticky-hash-abc", cache.deleteCalls[0].sessionHash)
|
||||
|
||||
// 验证仅重试 1 次
|
||||
require.Len(t, upstream.calls, 1, "should make exactly 1 retry call (maxAttempts=1)")
|
||||
|
||||
// 验证模型限流已设置
|
||||
require.Len(t, repo.modelRateLimitCalls, 1)
|
||||
require.Equal(t, "claude-sonnet-4-5", repo.modelRateLimitCalls[0].modelKey)
|
||||
}
|
||||
|
||||
// TestHandleSmartRetry_ShortDelay_NonStickySession_FailedRetry_NoDeleteSession
|
||||
// 非粘性会话 + 短延迟重试失败 → 不应调用 DeleteSessionAccountID(sessionHash 为空)
|
||||
func TestHandleSmartRetry_ShortDelay_NonStickySession_FailedRetry_NoDeleteSession(t *testing.T) {
|
||||
failRespBody := `{
|
||||
"error": {
|
||||
"status": "RESOURCE_EXHAUSTED",
|
||||
"details": [
|
||||
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "RATE_LIMIT_EXCEEDED"},
|
||||
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
|
||||
]
|
||||
}
|
||||
}`
|
||||
failResp := &http.Response{
|
||||
StatusCode: http.StatusTooManyRequests,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(strings.NewReader(failRespBody)),
|
||||
}
|
||||
upstream := &mockSmartRetryUpstream{
|
||||
responses: []*http.Response{failResp},
|
||||
errors: []error{nil},
|
||||
}
|
||||
|
||||
repo := &stubAntigravityAccountRepo{}
|
||||
cache := &stubSmartRetryCache{}
|
||||
account := &Account{
|
||||
ID: 11,
|
||||
Name: "acc-11",
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
}
|
||||
|
||||
respBody := []byte(`{
|
||||
"error": {
|
||||
"status": "RESOURCE_EXHAUSTED",
|
||||
"details": [
|
||||
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "RATE_LIMIT_EXCEEDED"},
|
||||
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
|
||||
]
|
||||
}
|
||||
}`)
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusTooManyRequests,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||||
}
|
||||
|
||||
params := antigravityRetryLoopParams{
|
||||
ctx: context.Background(),
|
||||
prefix: "[test]",
|
||||
account: account,
|
||||
accessToken: "token",
|
||||
action: "generateContent",
|
||||
body: []byte(`{"input":"test"}`),
|
||||
httpUpstream: upstream,
|
||||
accountRepo: repo,
|
||||
isStickySession: false,
|
||||
groupID: 42,
|
||||
sessionHash: "", // 非粘性会话,sessionHash 为空
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
availableURLs := []string{"https://ag-1.test"}
|
||||
|
||||
svc := &AntigravityGatewayService{cache: cache}
|
||||
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
|
||||
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, smartRetryActionBreakWithResp, result.action)
|
||||
require.NotNil(t, result.switchError)
|
||||
require.False(t, result.switchError.IsStickySession)
|
||||
|
||||
// 核心断言:sessionHash 为空时不应调用 DeleteSessionAccountID
|
||||
require.Len(t, cache.deleteCalls, 0, "should NOT call DeleteSessionAccountID when sessionHash is empty")
|
||||
}
|
||||
|
||||
// TestHandleSmartRetry_ShortDelay_StickySession_FailedRetry_NilCache_NoPanic
|
||||
// 边界:cache 为 nil 时不应 panic
|
||||
func TestHandleSmartRetry_ShortDelay_StickySession_FailedRetry_NilCache_NoPanic(t *testing.T) {
|
||||
failRespBody := `{
|
||||
"error": {
|
||||
"status": "RESOURCE_EXHAUSTED",
|
||||
"details": [
|
||||
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"},
|
||||
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
|
||||
]
|
||||
}
|
||||
}`
|
||||
failResp := &http.Response{
|
||||
StatusCode: http.StatusTooManyRequests,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(strings.NewReader(failRespBody)),
|
||||
}
|
||||
upstream := &mockSmartRetryUpstream{
|
||||
responses: []*http.Response{failResp},
|
||||
errors: []error{nil},
|
||||
}
|
||||
|
||||
repo := &stubAntigravityAccountRepo{}
|
||||
account := &Account{
|
||||
ID: 12,
|
||||
Name: "acc-12",
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
}
|
||||
|
||||
respBody := []byte(`{
|
||||
"error": {
|
||||
"status": "RESOURCE_EXHAUSTED",
|
||||
"details": [
|
||||
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"},
|
||||
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
|
||||
]
|
||||
}
|
||||
}`)
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusTooManyRequests,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||||
}
|
||||
|
||||
params := antigravityRetryLoopParams{
|
||||
ctx: context.Background(),
|
||||
prefix: "[test]",
|
||||
account: account,
|
||||
accessToken: "token",
|
||||
action: "generateContent",
|
||||
body: []byte(`{"input":"test"}`),
|
||||
httpUpstream: upstream,
|
||||
accountRepo: repo,
|
||||
isStickySession: true,
|
||||
groupID: 42,
|
||||
sessionHash: "sticky-hash-nil-cache",
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
availableURLs := []string{"https://ag-1.test"}
|
||||
|
||||
// cache 为 nil,不应 panic
|
||||
svc := &AntigravityGatewayService{cache: nil}
|
||||
require.NotPanics(t, func() {
|
||||
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, smartRetryActionBreakWithResp, result.action)
|
||||
require.NotNil(t, result.switchError)
|
||||
require.True(t, result.switchError.IsStickySession)
|
||||
})
|
||||
}
|
||||
|
||||
// TestHandleSmartRetry_ShortDelay_StickySession_SuccessRetry_NoDeleteSession
|
||||
// 重试成功时不应清除粘性会话(只有失败才清除)
|
||||
func TestHandleSmartRetry_ShortDelay_StickySession_SuccessRetry_NoDeleteSession(t *testing.T) {
|
||||
successResp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(strings.NewReader(`{"result":"ok"}`)),
|
||||
}
|
||||
upstream := &mockSmartRetryUpstream{
|
||||
responses: []*http.Response{successResp},
|
||||
errors: []error{nil},
|
||||
}
|
||||
|
||||
cache := &stubSmartRetryCache{}
|
||||
account := &Account{
|
||||
ID: 13,
|
||||
Name: "acc-13",
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
}
|
||||
|
||||
respBody := []byte(`{
|
||||
"error": {
|
||||
"status": "RESOURCE_EXHAUSTED",
|
||||
"details": [
|
||||
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-opus-4"}, "reason": "RATE_LIMIT_EXCEEDED"},
|
||||
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.5s"}
|
||||
]
|
||||
}
|
||||
}`)
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusTooManyRequests,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||||
}
|
||||
|
||||
params := antigravityRetryLoopParams{
|
||||
ctx: context.Background(),
|
||||
prefix: "[test]",
|
||||
account: account,
|
||||
accessToken: "token",
|
||||
action: "generateContent",
|
||||
body: []byte(`{"input":"test"}`),
|
||||
httpUpstream: upstream,
|
||||
isStickySession: true,
|
||||
groupID: 42,
|
||||
sessionHash: "sticky-hash-success",
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
availableURLs := []string{"https://ag-1.test"}
|
||||
|
||||
svc := &AntigravityGatewayService{cache: cache}
|
||||
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
|
||||
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, smartRetryActionBreakWithResp, result.action)
|
||||
require.NotNil(t, result.resp, "should return successful response")
|
||||
require.Equal(t, http.StatusOK, result.resp.StatusCode)
|
||||
require.Nil(t, result.switchError, "should not return switchError on success")
|
||||
|
||||
// 核心断言:重试成功时不应清除粘性会话
|
||||
require.Len(t, cache.deleteCalls, 0, "should NOT call DeleteSessionAccountID on successful retry")
|
||||
}
|
||||
|
||||
// TestHandleSmartRetry_LongDelay_StickySession_NoDeleteInHandleSmartRetry
|
||||
// 长延迟路径(情况1)在 handleSmartRetry 中不直接调用 DeleteSessionAccountID
|
||||
// (清除由 handler 层的 shouldClearStickySession 在下次请求时处理)
|
||||
func TestHandleSmartRetry_LongDelay_StickySession_NoDeleteInHandleSmartRetry(t *testing.T) {
|
||||
repo := &stubAntigravityAccountRepo{}
|
||||
cache := &stubSmartRetryCache{}
|
||||
account := &Account{
|
||||
ID: 14,
|
||||
Name: "acc-14",
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
}
|
||||
|
||||
// 15s >= 7s 阈值 → 走长延迟路径
|
||||
respBody := []byte(`{
|
||||
"error": {
|
||||
"status": "RESOURCE_EXHAUSTED",
|
||||
"details": [
|
||||
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"},
|
||||
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "15s"}
|
||||
]
|
||||
}
|
||||
}`)
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusTooManyRequests,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||||
}
|
||||
|
||||
params := antigravityRetryLoopParams{
|
||||
ctx: context.Background(),
|
||||
prefix: "[test]",
|
||||
account: account,
|
||||
accessToken: "token",
|
||||
action: "generateContent",
|
||||
body: []byte(`{"input":"test"}`),
|
||||
accountRepo: repo,
|
||||
isStickySession: true,
|
||||
groupID: 42,
|
||||
sessionHash: "sticky-hash-long-delay",
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
availableURLs := []string{"https://ag-1.test"}
|
||||
|
||||
svc := &AntigravityGatewayService{cache: cache}
|
||||
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
|
||||
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, smartRetryActionBreakWithResp, result.action)
|
||||
require.NotNil(t, result.switchError)
|
||||
require.True(t, result.switchError.IsStickySession)
|
||||
|
||||
// 长延迟路径不在 handleSmartRetry 中调用 DeleteSessionAccountID
|
||||
// (由上游 handler 的 shouldClearStickySession 处理)
|
||||
require.Len(t, cache.deleteCalls, 0,
|
||||
"long delay path should NOT call DeleteSessionAccountID in handleSmartRetry (handled by handler layer)")
|
||||
}
|
||||
|
||||
// TestHandleSmartRetry_ShortDelay_NetworkError_StickySession_ClearsSession
|
||||
// 网络错误耗尽重试 + 粘性会话 → 也应清除粘性绑定
|
||||
func TestHandleSmartRetry_ShortDelay_NetworkError_StickySession_ClearsSession(t *testing.T) {
|
||||
upstream := &mockSmartRetryUpstream{
|
||||
responses: []*http.Response{nil}, // 网络错误
|
||||
errors: []error{nil},
|
||||
}
|
||||
|
||||
repo := &stubAntigravityAccountRepo{}
|
||||
cache := &stubSmartRetryCache{}
|
||||
account := &Account{
|
||||
ID: 15,
|
||||
Name: "acc-15",
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
}
|
||||
|
||||
respBody := []byte(`{
|
||||
"error": {
|
||||
"status": "RESOURCE_EXHAUSTED",
|
||||
"details": [
|
||||
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "RATE_LIMIT_EXCEEDED"},
|
||||
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
|
||||
]
|
||||
}
|
||||
}`)
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusTooManyRequests,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||||
}
|
||||
|
||||
params := antigravityRetryLoopParams{
|
||||
ctx: context.Background(),
|
||||
prefix: "[test]",
|
||||
account: account,
|
||||
accessToken: "token",
|
||||
action: "generateContent",
|
||||
body: []byte(`{"input":"test"}`),
|
||||
httpUpstream: upstream,
|
||||
accountRepo: repo,
|
||||
isStickySession: true,
|
||||
groupID: 99,
|
||||
sessionHash: "sticky-net-error",
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
availableURLs := []string{"https://ag-1.test"}
|
||||
|
||||
svc := &AntigravityGatewayService{cache: cache}
|
||||
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
|
||||
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.switchError)
|
||||
require.True(t, result.switchError.IsStickySession)
|
||||
|
||||
// 核心断言:网络错误耗尽重试后也应清除粘性绑定
|
||||
require.Len(t, cache.deleteCalls, 1, "should call DeleteSessionAccountID after network error exhausts retry")
|
||||
require.Equal(t, int64(99), cache.deleteCalls[0].groupID)
|
||||
require.Equal(t, "sticky-net-error", cache.deleteCalls[0].sessionHash)
|
||||
}
|
||||
|
||||
// TestHandleSmartRetry_ShortDelay_503_StickySession_FailedRetry_ClearsSession
|
||||
// 429 + 短延迟 + 粘性会话 + 重试失败 → 清除粘性绑定
|
||||
func TestHandleSmartRetry_ShortDelay_503_StickySession_FailedRetry_ClearsSession(t *testing.T) {
|
||||
failRespBody := `{
|
||||
"error": {
|
||||
"code": 429,
|
||||
"status": "RESOURCE_EXHAUSTED",
|
||||
"details": [
|
||||
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro"}, "reason": "RATE_LIMIT_EXCEEDED"},
|
||||
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.5s"}
|
||||
]
|
||||
}
|
||||
}`
|
||||
failResp := &http.Response{
|
||||
StatusCode: http.StatusTooManyRequests,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(strings.NewReader(failRespBody)),
|
||||
}
|
||||
upstream := &mockSmartRetryUpstream{
|
||||
responses: []*http.Response{failResp},
|
||||
errors: []error{nil},
|
||||
}
|
||||
|
||||
repo := &stubAntigravityAccountRepo{}
|
||||
cache := &stubSmartRetryCache{}
|
||||
account := &Account{
|
||||
ID: 16,
|
||||
Name: "acc-16",
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
}
|
||||
|
||||
respBody := []byte(`{
|
||||
"error": {
|
||||
"code": 429,
|
||||
"status": "RESOURCE_EXHAUSTED",
|
||||
"details": [
|
||||
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro"}, "reason": "RATE_LIMIT_EXCEEDED"},
|
||||
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.5s"}
|
||||
]
|
||||
}
|
||||
}`)
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusTooManyRequests,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||||
}
|
||||
|
||||
params := antigravityRetryLoopParams{
|
||||
ctx: context.Background(),
|
||||
prefix: "[test]",
|
||||
account: account,
|
||||
accessToken: "token",
|
||||
action: "generateContent",
|
||||
body: []byte(`{"input":"test"}`),
|
||||
httpUpstream: upstream,
|
||||
accountRepo: repo,
|
||||
isStickySession: true,
|
||||
groupID: 77,
|
||||
sessionHash: "sticky-503-short",
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
availableURLs := []string{"https://ag-1.test"}
|
||||
|
||||
svc := &AntigravityGatewayService{cache: cache}
|
||||
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
|
||||
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.switchError)
|
||||
require.True(t, result.switchError.IsStickySession)
|
||||
|
||||
// 验证粘性绑定被清除
|
||||
require.Len(t, cache.deleteCalls, 1)
|
||||
require.Equal(t, int64(77), cache.deleteCalls[0].groupID)
|
||||
require.Equal(t, "sticky-503-short", cache.deleteCalls[0].sessionHash)
|
||||
|
||||
// 验证模型限流已设置
|
||||
require.Len(t, repo.modelRateLimitCalls, 1)
|
||||
require.Equal(t, "gemini-3-pro", repo.modelRateLimitCalls[0].modelKey)
|
||||
}
|
||||
|
||||
// TestAntigravityRetryLoop_SmartRetryFailed_StickySession_SwitchErrorPropagates
|
||||
// 集成测试:antigravityRetryLoop → handleSmartRetry → switchError 传播
|
||||
// 验证 IsStickySession 正确传递到上层,且粘性绑定被清除
|
||||
func TestAntigravityRetryLoop_SmartRetryFailed_StickySession_SwitchErrorPropagates(t *testing.T) {
|
||||
// 初始 429 响应
|
||||
initialRespBody := []byte(`{
|
||||
"error": {
|
||||
"status": "RESOURCE_EXHAUSTED",
|
||||
"details": [
|
||||
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-opus-4-6"}, "reason": "RATE_LIMIT_EXCEEDED"},
|
||||
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
|
||||
]
|
||||
}
|
||||
}`)
|
||||
initialResp := &http.Response{
|
||||
StatusCode: http.StatusTooManyRequests,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(bytes.NewReader(initialRespBody)),
|
||||
}
|
||||
|
||||
// 智能重试也返回 429
|
||||
retryRespBody := `{
|
||||
"error": {
|
||||
"status": "RESOURCE_EXHAUSTED",
|
||||
"details": [
|
||||
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-opus-4-6"}, "reason": "RATE_LIMIT_EXCEEDED"},
|
||||
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
|
||||
]
|
||||
}
|
||||
}`
|
||||
retryResp := &http.Response{
|
||||
StatusCode: http.StatusTooManyRequests,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(strings.NewReader(retryRespBody)),
|
||||
}
|
||||
|
||||
upstream := &mockSmartRetryUpstream{
|
||||
responses: []*http.Response{initialResp, retryResp},
|
||||
errors: []error{nil, nil},
|
||||
}
|
||||
|
||||
repo := &stubAntigravityAccountRepo{}
|
||||
cache := &stubSmartRetryCache{}
|
||||
account := &Account{
|
||||
ID: 17,
|
||||
Name: "acc-17",
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
Schedulable: true,
|
||||
Status: StatusActive,
|
||||
Concurrency: 1,
|
||||
}
|
||||
|
||||
svc := &AntigravityGatewayService{cache: cache}
|
||||
result, err := svc.antigravityRetryLoop(antigravityRetryLoopParams{
|
||||
ctx: context.Background(),
|
||||
prefix: "[test]",
|
||||
account: account,
|
||||
accessToken: "token",
|
||||
action: "generateContent",
|
||||
body: []byte(`{"input":"test"}`),
|
||||
httpUpstream: upstream,
|
||||
accountRepo: repo,
|
||||
isStickySession: true,
|
||||
groupID: 55,
|
||||
sessionHash: "sticky-loop-test",
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
})
|
||||
|
||||
require.Nil(t, result, "should not return result when switchError")
|
||||
require.NotNil(t, err, "should return error")
|
||||
|
||||
var switchErr *AntigravityAccountSwitchError
|
||||
require.ErrorAs(t, err, &switchErr, "error should be AntigravityAccountSwitchError")
|
||||
require.Equal(t, account.ID, switchErr.OriginalAccountID)
|
||||
require.Equal(t, "claude-opus-4-6", switchErr.RateLimitedModel)
|
||||
require.True(t, switchErr.IsStickySession, "IsStickySession must propagate through retryLoop")
|
||||
|
||||
// 验证粘性绑定被清除
|
||||
require.Len(t, cache.deleteCalls, 1, "should clear sticky session in handleSmartRetry")
|
||||
require.Equal(t, int64(55), cache.deleteCalls[0].groupID)
|
||||
require.Equal(t, "sticky-loop-test", cache.deleteCalls[0].sessionHash)
|
||||
}
|
||||
|
||||
@@ -7,12 +7,14 @@ import (
|
||||
"log/slog"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
antigravityTokenRefreshSkew = 3 * time.Minute
|
||||
antigravityTokenCacheSkew = 5 * time.Minute
|
||||
antigravityBackfillCooldown = 5 * time.Minute
|
||||
)
|
||||
|
||||
// AntigravityTokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义)
|
||||
@@ -23,6 +25,7 @@ type AntigravityTokenProvider struct {
|
||||
accountRepo AccountRepository
|
||||
tokenCache AntigravityTokenCache
|
||||
antigravityOAuthService *AntigravityOAuthService
|
||||
backfillCooldown sync.Map // key: int64 (account.ID) → value: time.Time
|
||||
}
|
||||
|
||||
func NewAntigravityTokenProvider(
|
||||
@@ -93,13 +96,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
newCredentials := p.antigravityOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
for k, v := range account.Credentials {
|
||||
if _, exists := newCredentials[k]; !exists {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
}
|
||||
account.Credentials = newCredentials
|
||||
p.mergeCredentials(account, tokenInfo)
|
||||
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
|
||||
log.Printf("[AntigravityTokenProvider] Failed to update account credentials: %v", updateErr)
|
||||
}
|
||||
@@ -113,6 +110,21 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
|
||||
return "", errors.New("access_token not found in credentials")
|
||||
}
|
||||
|
||||
// 如果账号还没有 project_id,尝试在线补齐,避免请求 daily/sandbox 时出现
|
||||
// "Invalid project resource name projects/"。
|
||||
// 仅调用 loadProjectIDWithRetry,不刷新 OAuth token;带冷却机制防止频繁重试。
|
||||
if strings.TrimSpace(account.GetCredential("project_id")) == "" && p.antigravityOAuthService != nil {
|
||||
if p.shouldAttemptBackfill(account.ID) {
|
||||
p.markBackfillAttempted(account.ID)
|
||||
if projectID, err := p.antigravityOAuthService.FillProjectID(ctx, account, accessToken); err == nil && projectID != "" {
|
||||
account.Credentials["project_id"] = projectID
|
||||
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
|
||||
log.Printf("[AntigravityTokenProvider] project_id 补齐持久化失败: %v", updateErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 存入缓存(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件)
|
||||
if p.tokenCache != nil {
|
||||
latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
|
||||
@@ -144,6 +156,31 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
|
||||
return accessToken, nil
|
||||
}
|
||||
|
||||
// mergeCredentials 将 tokenInfo 构建的凭证合并到 account 中,保留原有未覆盖的字段
|
||||
func (p *AntigravityTokenProvider) mergeCredentials(account *Account, tokenInfo *AntigravityTokenInfo) {
|
||||
newCredentials := p.antigravityOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
for k, v := range account.Credentials {
|
||||
if _, exists := newCredentials[k]; !exists {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
}
|
||||
account.Credentials = newCredentials
|
||||
}
|
||||
|
||||
// shouldAttemptBackfill 检查是否应该尝试补齐 project_id(冷却期内不重复尝试)
|
||||
func (p *AntigravityTokenProvider) shouldAttemptBackfill(accountID int64) bool {
|
||||
if v, ok := p.backfillCooldown.Load(accountID); ok {
|
||||
if lastAttempt, ok := v.(time.Time); ok {
|
||||
return time.Since(lastAttempt) > antigravityBackfillCooldown
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *AntigravityTokenProvider) markBackfillAttempted(accountID int64) {
|
||||
p.backfillCooldown.Store(accountID, time.Now())
|
||||
}
|
||||
|
||||
func AntigravityTokenCacheKey(account *Account) string {
|
||||
projectID := strings.TrimSpace(account.GetCredential("project_id"))
|
||||
if projectID != "" {
|
||||
|
||||
@@ -31,8 +31,8 @@ type ModelPricing struct {
|
||||
OutputPricePerToken float64 // 每token输出价格 (USD)
|
||||
CacheCreationPricePerToken float64 // 缓存创建每token价格 (USD)
|
||||
CacheReadPricePerToken float64 // 缓存读取每token价格 (USD)
|
||||
CacheCreation5mPrice float64 // 5分钟缓存创建价格(每百万token)- 仅用于硬编码回退
|
||||
CacheCreation1hPrice float64 // 1小时缓存创建价格(每百万token)- 仅用于硬编码回退
|
||||
CacheCreation5mPrice float64 // 5分钟缓存创建每token价格 (USD)
|
||||
CacheCreation1hPrice float64 // 1小时缓存创建每token价格 (USD)
|
||||
SupportsCacheBreakdown bool // 是否支持详细的缓存分类
|
||||
}
|
||||
|
||||
@@ -172,12 +172,20 @@ func (s *BillingService) GetModelPricing(model string) (*ModelPricing, error) {
|
||||
if s.pricingService != nil {
|
||||
litellmPricing := s.pricingService.GetModelPricing(model)
|
||||
if litellmPricing != nil {
|
||||
// 启用 5m/1h 分类计费的条件:
|
||||
// 1. 存在 1h 价格
|
||||
// 2. 1h 价格 > 5m 价格(防止 LiteLLM 数据错误导致少收费)
|
||||
price5m := litellmPricing.CacheCreationInputTokenCost
|
||||
price1h := litellmPricing.CacheCreationInputTokenCostAbove1hr
|
||||
enableBreakdown := price1h > 0 && price1h > price5m
|
||||
return &ModelPricing{
|
||||
InputPricePerToken: litellmPricing.InputCostPerToken,
|
||||
OutputPricePerToken: litellmPricing.OutputCostPerToken,
|
||||
CacheCreationPricePerToken: litellmPricing.CacheCreationInputTokenCost,
|
||||
CacheReadPricePerToken: litellmPricing.CacheReadInputTokenCost,
|
||||
SupportsCacheBreakdown: false,
|
||||
CacheCreation5mPrice: price5m,
|
||||
CacheCreation1hPrice: price1h,
|
||||
SupportsCacheBreakdown: enableBreakdown,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
@@ -209,9 +217,14 @@ func (s *BillingService) CalculateCost(model string, tokens UsageTokens, rateMul
|
||||
|
||||
// 计算缓存费用
|
||||
if pricing.SupportsCacheBreakdown && (pricing.CacheCreation5mPrice > 0 || pricing.CacheCreation1hPrice > 0) {
|
||||
// 支持详细缓存分类的模型(5分钟/1小时缓存)
|
||||
breakdown.CacheCreationCost = float64(tokens.CacheCreation5mTokens)/1_000_000*pricing.CacheCreation5mPrice +
|
||||
float64(tokens.CacheCreation1hTokens)/1_000_000*pricing.CacheCreation1hPrice
|
||||
// 支持详细缓存分类的模型(5分钟/1小时缓存,价格为 per-token)
|
||||
if tokens.CacheCreation5mTokens == 0 && tokens.CacheCreation1hTokens == 0 && tokens.CacheCreationTokens > 0 {
|
||||
// API 未返回 ephemeral 明细,回退到全部按 5m 单价计费
|
||||
breakdown.CacheCreationCost = float64(tokens.CacheCreationTokens) * pricing.CacheCreation5mPrice
|
||||
} else {
|
||||
breakdown.CacheCreationCost = float64(tokens.CacheCreation5mTokens)*pricing.CacheCreation5mPrice +
|
||||
float64(tokens.CacheCreation1hTokens)*pricing.CacheCreation1hPrice
|
||||
}
|
||||
} else {
|
||||
// 标准缓存创建价格(per-token)
|
||||
breakdown.CacheCreationCost = float64(tokens.CacheCreationTokens) * pricing.CacheCreationPricePerToken
|
||||
@@ -280,10 +293,12 @@ func (s *BillingService) CalculateCostWithLongContext(model string, tokens Usage
|
||||
|
||||
// 范围内部分:正常计费
|
||||
inRangeTokens := UsageTokens{
|
||||
InputTokens: inRangeInputTokens,
|
||||
OutputTokens: tokens.OutputTokens, // 输出只算一次
|
||||
CacheCreationTokens: tokens.CacheCreationTokens,
|
||||
CacheReadTokens: inRangeCacheTokens,
|
||||
InputTokens: inRangeInputTokens,
|
||||
OutputTokens: tokens.OutputTokens, // 输出只算一次
|
||||
CacheCreationTokens: tokens.CacheCreationTokens,
|
||||
CacheReadTokens: inRangeCacheTokens,
|
||||
CacheCreation5mTokens: tokens.CacheCreation5mTokens,
|
||||
CacheCreation1hTokens: tokens.CacheCreation1hTokens,
|
||||
}
|
||||
inRangeCost, err := s.CalculateCost(model, inRangeTokens, rateMultiplier)
|
||||
if err != nil {
|
||||
|
||||
112
backend/internal/service/crs_sync_helpers_test.go
Normal file
112
backend/internal/service/crs_sync_helpers_test.go
Normal file
@@ -0,0 +1,112 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestBuildSelectedSet(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ids []string
|
||||
wantNil bool
|
||||
wantSize int
|
||||
}{
|
||||
{
|
||||
name: "nil input returns nil (backward compatible: create all)",
|
||||
ids: nil,
|
||||
wantNil: true,
|
||||
},
|
||||
{
|
||||
name: "empty slice returns empty map (create none)",
|
||||
ids: []string{},
|
||||
wantNil: false,
|
||||
wantSize: 0,
|
||||
},
|
||||
{
|
||||
name: "single ID",
|
||||
ids: []string{"abc-123"},
|
||||
wantNil: false,
|
||||
wantSize: 1,
|
||||
},
|
||||
{
|
||||
name: "multiple IDs",
|
||||
ids: []string{"a", "b", "c"},
|
||||
wantNil: false,
|
||||
wantSize: 3,
|
||||
},
|
||||
{
|
||||
name: "duplicate IDs are deduplicated",
|
||||
ids: []string{"a", "a", "b"},
|
||||
wantNil: false,
|
||||
wantSize: 2,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := buildSelectedSet(tt.ids)
|
||||
if tt.wantNil {
|
||||
if got != nil {
|
||||
t.Errorf("buildSelectedSet(%v) = %v, want nil", tt.ids, got)
|
||||
}
|
||||
return
|
||||
}
|
||||
if got == nil {
|
||||
t.Fatalf("buildSelectedSet(%v) = nil, want non-nil map", tt.ids)
|
||||
}
|
||||
if len(got) != tt.wantSize {
|
||||
t.Errorf("buildSelectedSet(%v) has %d entries, want %d", tt.ids, len(got), tt.wantSize)
|
||||
}
|
||||
// Verify all unique IDs are present
|
||||
for _, id := range tt.ids {
|
||||
if _, ok := got[id]; !ok {
|
||||
t.Errorf("buildSelectedSet(%v) missing key %q", tt.ids, id)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldCreateAccount(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
crsID string
|
||||
selectedSet map[string]struct{}
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "nil set allows all (backward compatible)",
|
||||
crsID: "any-id",
|
||||
selectedSet: nil,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "empty set blocks all",
|
||||
crsID: "any-id",
|
||||
selectedSet: map[string]struct{}{},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "ID in set is allowed",
|
||||
crsID: "abc-123",
|
||||
selectedSet: map[string]struct{}{"abc-123": {}, "def-456": {}},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "ID not in set is blocked",
|
||||
crsID: "xyz-789",
|
||||
selectedSet: map[string]struct{}{"abc-123": {}, "def-456": {}},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := shouldCreateAccount(tt.crsID, tt.selectedSet)
|
||||
if got != tt.want {
|
||||
t.Errorf("shouldCreateAccount(%q, %v) = %v, want %v",
|
||||
tt.crsID, tt.selectedSet, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -45,10 +45,11 @@ func NewCRSSyncService(
|
||||
}
|
||||
|
||||
type SyncFromCRSInput struct {
|
||||
BaseURL string
|
||||
Username string
|
||||
Password string
|
||||
SyncProxies bool
|
||||
BaseURL string
|
||||
Username string
|
||||
Password string
|
||||
SyncProxies bool
|
||||
SelectedAccountIDs []string // if non-empty, only create new accounts with these CRS IDs
|
||||
}
|
||||
|
||||
type SyncFromCRSItemResult struct {
|
||||
@@ -190,25 +191,27 @@ type crsGeminiAPIKeyAccount struct {
|
||||
Extra map[string]any `json:"extra"`
|
||||
}
|
||||
|
||||
func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput) (*SyncFromCRSResult, error) {
|
||||
// fetchCRSExport validates the connection parameters, authenticates with CRS,
|
||||
// and returns the exported accounts. Shared by SyncFromCRS and PreviewFromCRS.
|
||||
func (s *CRSSyncService) fetchCRSExport(ctx context.Context, baseURL, username, password string) (*crsExportResponse, error) {
|
||||
if s.cfg == nil {
|
||||
return nil, errors.New("config is not available")
|
||||
}
|
||||
baseURL := strings.TrimSpace(input.BaseURL)
|
||||
normalizedURL := strings.TrimSpace(baseURL)
|
||||
if s.cfg.Security.URLAllowlist.Enabled {
|
||||
normalized, err := normalizeBaseURL(baseURL, s.cfg.Security.URLAllowlist.CRSHosts, s.cfg.Security.URLAllowlist.AllowPrivateHosts)
|
||||
normalized, err := normalizeBaseURL(normalizedURL, s.cfg.Security.URLAllowlist.CRSHosts, s.cfg.Security.URLAllowlist.AllowPrivateHosts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
baseURL = normalized
|
||||
normalizedURL = normalized
|
||||
} else {
|
||||
normalized, err := urlvalidator.ValidateURLFormat(baseURL, s.cfg.Security.URLAllowlist.AllowInsecureHTTP)
|
||||
normalized, err := urlvalidator.ValidateURLFormat(normalizedURL, s.cfg.Security.URLAllowlist.AllowInsecureHTTP)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid base_url: %w", err)
|
||||
}
|
||||
baseURL = normalized
|
||||
normalizedURL = normalized
|
||||
}
|
||||
if strings.TrimSpace(input.Username) == "" || strings.TrimSpace(input.Password) == "" {
|
||||
if strings.TrimSpace(username) == "" || strings.TrimSpace(password) == "" {
|
||||
return nil, errors.New("username and password are required")
|
||||
}
|
||||
|
||||
@@ -221,12 +224,16 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
||||
client = &http.Client{Timeout: 20 * time.Second}
|
||||
}
|
||||
|
||||
adminToken, err := crsLogin(ctx, client, baseURL, input.Username, input.Password)
|
||||
adminToken, err := crsLogin(ctx, client, normalizedURL, username, password)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
exported, err := crsExportAccounts(ctx, client, baseURL, adminToken)
|
||||
return crsExportAccounts(ctx, client, normalizedURL, adminToken)
|
||||
}
|
||||
|
||||
func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput) (*SyncFromCRSResult, error) {
|
||||
exported, err := s.fetchCRSExport(ctx, input.BaseURL, input.Username, input.Password)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -241,6 +248,8 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
||||
),
|
||||
}
|
||||
|
||||
selectedSet := buildSelectedSet(input.SelectedAccountIDs)
|
||||
|
||||
var proxies []Proxy
|
||||
if input.SyncProxies {
|
||||
proxies, _ = s.proxyRepo.ListActive(ctx)
|
||||
@@ -329,6 +338,13 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
||||
}
|
||||
|
||||
if existing == nil {
|
||||
if !shouldCreateAccount(src.ID, selectedSet) {
|
||||
item.Action = "skipped"
|
||||
item.Error = "not selected"
|
||||
result.Skipped++
|
||||
result.Items = append(result.Items, item)
|
||||
continue
|
||||
}
|
||||
account := &Account{
|
||||
Name: defaultName(src.Name, src.ID),
|
||||
Platform: PlatformAnthropic,
|
||||
@@ -446,6 +462,13 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
||||
}
|
||||
|
||||
if existing == nil {
|
||||
if !shouldCreateAccount(src.ID, selectedSet) {
|
||||
item.Action = "skipped"
|
||||
item.Error = "not selected"
|
||||
result.Skipped++
|
||||
result.Items = append(result.Items, item)
|
||||
continue
|
||||
}
|
||||
account := &Account{
|
||||
Name: defaultName(src.Name, src.ID),
|
||||
Platform: PlatformAnthropic,
|
||||
@@ -569,6 +592,13 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
||||
}
|
||||
|
||||
if existing == nil {
|
||||
if !shouldCreateAccount(src.ID, selectedSet) {
|
||||
item.Action = "skipped"
|
||||
item.Error = "not selected"
|
||||
result.Skipped++
|
||||
result.Items = append(result.Items, item)
|
||||
continue
|
||||
}
|
||||
account := &Account{
|
||||
Name: defaultName(src.Name, src.ID),
|
||||
Platform: PlatformOpenAI,
|
||||
@@ -690,6 +720,13 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
||||
}
|
||||
|
||||
if existing == nil {
|
||||
if !shouldCreateAccount(src.ID, selectedSet) {
|
||||
item.Action = "skipped"
|
||||
item.Error = "not selected"
|
||||
result.Skipped++
|
||||
result.Items = append(result.Items, item)
|
||||
continue
|
||||
}
|
||||
account := &Account{
|
||||
Name: defaultName(src.Name, src.ID),
|
||||
Platform: PlatformOpenAI,
|
||||
@@ -798,6 +835,13 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
||||
}
|
||||
|
||||
if existing == nil {
|
||||
if !shouldCreateAccount(src.ID, selectedSet) {
|
||||
item.Action = "skipped"
|
||||
item.Error = "not selected"
|
||||
result.Skipped++
|
||||
result.Items = append(result.Items, item)
|
||||
continue
|
||||
}
|
||||
account := &Account{
|
||||
Name: defaultName(src.Name, src.ID),
|
||||
Platform: PlatformGemini,
|
||||
@@ -909,6 +953,13 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
||||
}
|
||||
|
||||
if existing == nil {
|
||||
if !shouldCreateAccount(src.ID, selectedSet) {
|
||||
item.Action = "skipped"
|
||||
item.Error = "not selected"
|
||||
result.Skipped++
|
||||
result.Items = append(result.Items, item)
|
||||
continue
|
||||
}
|
||||
account := &Account{
|
||||
Name: defaultName(src.Name, src.ID),
|
||||
Platform: PlatformGemini,
|
||||
@@ -1253,3 +1304,102 @@ func (s *CRSSyncService) refreshOAuthToken(ctx context.Context, account *Account
|
||||
|
||||
return newCredentials
|
||||
}
|
||||
|
||||
// buildSelectedSet converts a slice of selected CRS account IDs to a set for O(1) lookup.
|
||||
// Returns nil if ids is nil (field not sent → backward compatible: create all).
|
||||
// Returns an empty map if ids is non-nil but empty (user selected none → create none).
|
||||
func buildSelectedSet(ids []string) map[string]struct{} {
|
||||
if ids == nil {
|
||||
return nil
|
||||
}
|
||||
set := make(map[string]struct{}, len(ids))
|
||||
for _, id := range ids {
|
||||
set[id] = struct{}{}
|
||||
}
|
||||
return set
|
||||
}
|
||||
|
||||
// shouldCreateAccount checks if a new CRS account should be created based on user selection.
|
||||
// Returns true if selectedSet is nil (backward compatible: create all) or if crsID is in the set.
|
||||
func shouldCreateAccount(crsID string, selectedSet map[string]struct{}) bool {
|
||||
if selectedSet == nil {
|
||||
return true
|
||||
}
|
||||
_, ok := selectedSet[crsID]
|
||||
return ok
|
||||
}
|
||||
|
||||
// PreviewFromCRSResult contains the preview of accounts from CRS before sync.
|
||||
type PreviewFromCRSResult struct {
|
||||
NewAccounts []CRSPreviewAccount `json:"new_accounts"`
|
||||
ExistingAccounts []CRSPreviewAccount `json:"existing_accounts"`
|
||||
}
|
||||
|
||||
// CRSPreviewAccount represents a single account in the preview result.
|
||||
type CRSPreviewAccount struct {
|
||||
CRSAccountID string `json:"crs_account_id"`
|
||||
Kind string `json:"kind"`
|
||||
Name string `json:"name"`
|
||||
Platform string `json:"platform"`
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
// PreviewFromCRS connects to CRS, fetches all accounts, and classifies them
|
||||
// as new or existing by batch-querying local crs_account_id mappings.
|
||||
func (s *CRSSyncService) PreviewFromCRS(ctx context.Context, input SyncFromCRSInput) (*PreviewFromCRSResult, error) {
|
||||
exported, err := s.fetchCRSExport(ctx, input.BaseURL, input.Username, input.Password)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Batch query all existing CRS account IDs
|
||||
existingCRSIDs, err := s.accountRepo.ListCRSAccountIDs(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list existing CRS accounts: %w", err)
|
||||
}
|
||||
|
||||
result := &PreviewFromCRSResult{
|
||||
NewAccounts: make([]CRSPreviewAccount, 0),
|
||||
ExistingAccounts: make([]CRSPreviewAccount, 0),
|
||||
}
|
||||
|
||||
classify := func(crsID, kind, name, platform, accountType string) {
|
||||
preview := CRSPreviewAccount{
|
||||
CRSAccountID: crsID,
|
||||
Kind: kind,
|
||||
Name: defaultName(name, crsID),
|
||||
Platform: platform,
|
||||
Type: accountType,
|
||||
}
|
||||
if _, exists := existingCRSIDs[crsID]; exists {
|
||||
result.ExistingAccounts = append(result.ExistingAccounts, preview)
|
||||
} else {
|
||||
result.NewAccounts = append(result.NewAccounts, preview)
|
||||
}
|
||||
}
|
||||
|
||||
for _, src := range exported.Data.ClaudeAccounts {
|
||||
authType := strings.TrimSpace(src.AuthType)
|
||||
if authType == "" {
|
||||
authType = AccountTypeOAuth
|
||||
}
|
||||
classify(src.ID, src.Kind, src.Name, PlatformAnthropic, authType)
|
||||
}
|
||||
for _, src := range exported.Data.ClaudeConsoleAccounts {
|
||||
classify(src.ID, src.Kind, src.Name, PlatformAnthropic, AccountTypeAPIKey)
|
||||
}
|
||||
for _, src := range exported.Data.OpenAIOAuthAccounts {
|
||||
classify(src.ID, src.Kind, src.Name, PlatformOpenAI, AccountTypeOAuth)
|
||||
}
|
||||
for _, src := range exported.Data.OpenAIResponsesAccounts {
|
||||
classify(src.ID, src.Kind, src.Name, PlatformOpenAI, AccountTypeAPIKey)
|
||||
}
|
||||
for _, src := range exported.Data.GeminiOAuthAccounts {
|
||||
classify(src.ID, src.Kind, src.Name, PlatformGemini, AccountTypeOAuth)
|
||||
}
|
||||
for _, src := range exported.Data.GeminiAPIKeyAccounts {
|
||||
classify(src.ID, src.Kind, src.Name, PlatformGemini, AccountTypeAPIKey)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
69
backend/internal/service/digest_session_store.go
Normal file
69
backend/internal/service/digest_session_store.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
gocache "github.com/patrickmn/go-cache"
|
||||
)
|
||||
|
||||
// digestSessionTTL 摘要会话默认 TTL
|
||||
const digestSessionTTL = 5 * time.Minute
|
||||
|
||||
// sessionEntry flat cache 条目
|
||||
type sessionEntry struct {
|
||||
uuid string
|
||||
accountID int64
|
||||
}
|
||||
|
||||
// DigestSessionStore 内存摘要会话存储(flat cache 实现)
|
||||
// key: "{groupID}:{prefixHash}|{digestChain}" → *sessionEntry
|
||||
type DigestSessionStore struct {
|
||||
cache *gocache.Cache
|
||||
}
|
||||
|
||||
// NewDigestSessionStore 创建内存摘要会话存储
|
||||
func NewDigestSessionStore() *DigestSessionStore {
|
||||
return &DigestSessionStore{
|
||||
cache: gocache.New(digestSessionTTL, time.Minute),
|
||||
}
|
||||
}
|
||||
|
||||
// Save 保存摘要会话。oldDigestChain 为 Find 返回的 matchedChain,用于删旧 key。
|
||||
func (s *DigestSessionStore) Save(groupID int64, prefixHash, digestChain, uuid string, accountID int64, oldDigestChain string) {
|
||||
if digestChain == "" {
|
||||
return
|
||||
}
|
||||
ns := buildNS(groupID, prefixHash)
|
||||
s.cache.Set(ns+digestChain, &sessionEntry{uuid: uuid, accountID: accountID}, gocache.DefaultExpiration)
|
||||
if oldDigestChain != "" && oldDigestChain != digestChain {
|
||||
s.cache.Delete(ns + oldDigestChain)
|
||||
}
|
||||
}
|
||||
|
||||
// Find 查找摘要会话,从完整 chain 逐段截断,返回最长匹配及对应 matchedChain。
|
||||
func (s *DigestSessionStore) Find(groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, matchedChain string, found bool) {
|
||||
if digestChain == "" {
|
||||
return "", 0, "", false
|
||||
}
|
||||
ns := buildNS(groupID, prefixHash)
|
||||
chain := digestChain
|
||||
for {
|
||||
if val, ok := s.cache.Get(ns + chain); ok {
|
||||
if e, ok := val.(*sessionEntry); ok {
|
||||
return e.uuid, e.accountID, chain, true
|
||||
}
|
||||
}
|
||||
i := strings.LastIndex(chain, "-")
|
||||
if i < 0 {
|
||||
return "", 0, "", false
|
||||
}
|
||||
chain = chain[:i]
|
||||
}
|
||||
}
|
||||
|
||||
// buildNS 构建 namespace 前缀
|
||||
func buildNS(groupID int64, prefixHash string) string {
|
||||
return strconv.FormatInt(groupID, 10) + ":" + prefixHash + "|"
|
||||
}
|
||||
312
backend/internal/service/digest_session_store_test.go
Normal file
312
backend/internal/service/digest_session_store_test.go
Normal file
@@ -0,0 +1,312 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
gocache "github.com/patrickmn/go-cache"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDigestSessionStore_SaveAndFind(t *testing.T) {
|
||||
store := NewDigestSessionStore()
|
||||
|
||||
store.Save(1, "prefix", "s:a1-u:b2-m:c3", "uuid-1", 100, "")
|
||||
|
||||
uuid, accountID, _, found := store.Find(1, "prefix", "s:a1-u:b2-m:c3")
|
||||
require.True(t, found)
|
||||
assert.Equal(t, "uuid-1", uuid)
|
||||
assert.Equal(t, int64(100), accountID)
|
||||
}
|
||||
|
||||
func TestDigestSessionStore_PrefixMatch(t *testing.T) {
|
||||
store := NewDigestSessionStore()
|
||||
|
||||
// 保存短链
|
||||
store.Save(1, "prefix", "u:a-m:b", "uuid-short", 10, "")
|
||||
|
||||
// 用长链查找,应前缀匹配到短链
|
||||
uuid, accountID, matchedChain, found := store.Find(1, "prefix", "u:a-m:b-u:c-m:d")
|
||||
require.True(t, found)
|
||||
assert.Equal(t, "uuid-short", uuid)
|
||||
assert.Equal(t, int64(10), accountID)
|
||||
assert.Equal(t, "u:a-m:b", matchedChain)
|
||||
}
|
||||
|
||||
func TestDigestSessionStore_LongestPrefixMatch(t *testing.T) {
|
||||
store := NewDigestSessionStore()
|
||||
|
||||
store.Save(1, "prefix", "u:a", "uuid-1", 1, "")
|
||||
store.Save(1, "prefix", "u:a-m:b", "uuid-2", 2, "")
|
||||
store.Save(1, "prefix", "u:a-m:b-u:c", "uuid-3", 3, "")
|
||||
|
||||
// 应匹配最深的 "u:a-m:b-u:c"(从完整 chain 逐段截断,先命中最长的)
|
||||
uuid, accountID, _, found := store.Find(1, "prefix", "u:a-m:b-u:c-m:d-u:e")
|
||||
require.True(t, found)
|
||||
assert.Equal(t, "uuid-3", uuid)
|
||||
assert.Equal(t, int64(3), accountID)
|
||||
|
||||
// 查找中等长度,应匹配到 "u:a-m:b"
|
||||
uuid, accountID, _, found = store.Find(1, "prefix", "u:a-m:b-u:x")
|
||||
require.True(t, found)
|
||||
assert.Equal(t, "uuid-2", uuid)
|
||||
assert.Equal(t, int64(2), accountID)
|
||||
}
|
||||
|
||||
func TestDigestSessionStore_SaveDeletesOldChain(t *testing.T) {
|
||||
store := NewDigestSessionStore()
|
||||
|
||||
// 第一轮:保存 "u:a-m:b"
|
||||
store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "")
|
||||
|
||||
// 第二轮:同一 uuid 保存更长的链,传入旧 chain
|
||||
store.Save(1, "prefix", "u:a-m:b-u:c-m:d", "uuid-1", 100, "u:a-m:b")
|
||||
|
||||
// 旧链 "u:a-m:b" 应已被删除
|
||||
_, _, _, found := store.Find(1, "prefix", "u:a-m:b")
|
||||
assert.False(t, found, "old chain should be deleted")
|
||||
|
||||
// 新链应能找到
|
||||
uuid, accountID, _, found := store.Find(1, "prefix", "u:a-m:b-u:c-m:d")
|
||||
require.True(t, found)
|
||||
assert.Equal(t, "uuid-1", uuid)
|
||||
assert.Equal(t, int64(100), accountID)
|
||||
}
|
||||
|
||||
func TestDigestSessionStore_DifferentSessionsNoInterference(t *testing.T) {
|
||||
store := NewDigestSessionStore()
|
||||
|
||||
// 相同系统提示词,不同用户提示词
|
||||
store.Save(1, "prefix", "s:sys-u:user1", "uuid-1", 100, "")
|
||||
store.Save(1, "prefix", "s:sys-u:user2", "uuid-2", 200, "")
|
||||
|
||||
uuid, accountID, _, found := store.Find(1, "prefix", "s:sys-u:user1-m:reply1")
|
||||
require.True(t, found)
|
||||
assert.Equal(t, "uuid-1", uuid)
|
||||
assert.Equal(t, int64(100), accountID)
|
||||
|
||||
uuid, accountID, _, found = store.Find(1, "prefix", "s:sys-u:user2-m:reply2")
|
||||
require.True(t, found)
|
||||
assert.Equal(t, "uuid-2", uuid)
|
||||
assert.Equal(t, int64(200), accountID)
|
||||
}
|
||||
|
||||
func TestDigestSessionStore_NoMatch(t *testing.T) {
|
||||
store := NewDigestSessionStore()
|
||||
|
||||
store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "")
|
||||
|
||||
// 完全不同的 chain
|
||||
_, _, _, found := store.Find(1, "prefix", "u:x-m:y")
|
||||
assert.False(t, found)
|
||||
}
|
||||
|
||||
func TestDigestSessionStore_DifferentPrefixHash(t *testing.T) {
|
||||
store := NewDigestSessionStore()
|
||||
|
||||
store.Save(1, "prefix1", "u:a-m:b", "uuid-1", 100, "")
|
||||
|
||||
// 不同 prefixHash 应隔离
|
||||
_, _, _, found := store.Find(1, "prefix2", "u:a-m:b")
|
||||
assert.False(t, found)
|
||||
}
|
||||
|
||||
func TestDigestSessionStore_DifferentGroupID(t *testing.T) {
|
||||
store := NewDigestSessionStore()
|
||||
|
||||
store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "")
|
||||
|
||||
// 不同 groupID 应隔离
|
||||
_, _, _, found := store.Find(2, "prefix", "u:a-m:b")
|
||||
assert.False(t, found)
|
||||
}
|
||||
|
||||
func TestDigestSessionStore_EmptyDigestChain(t *testing.T) {
|
||||
store := NewDigestSessionStore()
|
||||
|
||||
// 空链不应保存
|
||||
store.Save(1, "prefix", "", "uuid-1", 100, "")
|
||||
_, _, _, found := store.Find(1, "prefix", "")
|
||||
assert.False(t, found)
|
||||
}
|
||||
|
||||
func TestDigestSessionStore_TTLExpiration(t *testing.T) {
|
||||
store := &DigestSessionStore{
|
||||
cache: gocache.New(100*time.Millisecond, 50*time.Millisecond),
|
||||
}
|
||||
|
||||
store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "")
|
||||
|
||||
// 立即应该能找到
|
||||
_, _, _, found := store.Find(1, "prefix", "u:a-m:b")
|
||||
require.True(t, found)
|
||||
|
||||
// 等待过期 + 清理周期
|
||||
time.Sleep(300 * time.Millisecond)
|
||||
|
||||
// 过期后应找不到
|
||||
_, _, _, found = store.Find(1, "prefix", "u:a-m:b")
|
||||
assert.False(t, found)
|
||||
}
|
||||
|
||||
func TestDigestSessionStore_ConcurrentSafety(t *testing.T) {
|
||||
store := NewDigestSessionStore()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
const goroutines = 50
|
||||
const operations = 100
|
||||
|
||||
wg.Add(goroutines)
|
||||
for g := 0; g < goroutines; g++ {
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
prefix := fmt.Sprintf("prefix-%d", id%5)
|
||||
for i := 0; i < operations; i++ {
|
||||
chain := fmt.Sprintf("u:%d-m:%d", id, i)
|
||||
uuid := fmt.Sprintf("uuid-%d-%d", id, i)
|
||||
store.Save(1, prefix, chain, uuid, int64(id), "")
|
||||
store.Find(1, prefix, chain)
|
||||
}
|
||||
}(g)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestDigestSessionStore_MultipleSessions(t *testing.T) {
|
||||
store := NewDigestSessionStore()
|
||||
|
||||
sessions := []struct {
|
||||
chain string
|
||||
uuid string
|
||||
accountID int64
|
||||
}{
|
||||
{"u:session1", "uuid-1", 1},
|
||||
{"u:session2-m:reply2", "uuid-2", 2},
|
||||
{"u:session3-m:reply3-u:msg3", "uuid-3", 3},
|
||||
}
|
||||
|
||||
for _, sess := range sessions {
|
||||
store.Save(1, "prefix", sess.chain, sess.uuid, sess.accountID, "")
|
||||
}
|
||||
|
||||
// 验证每个会话都能正确查找
|
||||
for _, sess := range sessions {
|
||||
uuid, accountID, _, found := store.Find(1, "prefix", sess.chain)
|
||||
require.True(t, found, "should find session: %s", sess.chain)
|
||||
assert.Equal(t, sess.uuid, uuid)
|
||||
assert.Equal(t, sess.accountID, accountID)
|
||||
}
|
||||
|
||||
// 验证继续对话的场景
|
||||
uuid, accountID, _, found := store.Find(1, "prefix", "u:session2-m:reply2-u:newmsg")
|
||||
require.True(t, found)
|
||||
assert.Equal(t, "uuid-2", uuid)
|
||||
assert.Equal(t, int64(2), accountID)
|
||||
}
|
||||
|
||||
func TestDigestSessionStore_Performance1000Sessions(t *testing.T) {
|
||||
store := NewDigestSessionStore()
|
||||
|
||||
// 插入 1000 个会话
|
||||
for i := 0; i < 1000; i++ {
|
||||
chain := fmt.Sprintf("s:sys-u:user%d-m:reply%d", i, i)
|
||||
store.Save(1, "prefix", chain, fmt.Sprintf("uuid-%d", i), int64(i), "")
|
||||
}
|
||||
|
||||
// 查找性能测试
|
||||
start := time.Now()
|
||||
const lookups = 10000
|
||||
for i := 0; i < lookups; i++ {
|
||||
idx := i % 1000
|
||||
chain := fmt.Sprintf("s:sys-u:user%d-m:reply%d-u:newmsg", idx, idx)
|
||||
_, _, _, found := store.Find(1, "prefix", chain)
|
||||
assert.True(t, found)
|
||||
}
|
||||
elapsed := time.Since(start)
|
||||
t.Logf("%d lookups in %v (%.0f ns/op)", lookups, elapsed, float64(elapsed.Nanoseconds())/lookups)
|
||||
}
|
||||
|
||||
func TestDigestSessionStore_FindReturnsMatchedChain(t *testing.T) {
|
||||
store := NewDigestSessionStore()
|
||||
|
||||
store.Save(1, "prefix", "u:a-m:b-u:c", "uuid-1", 100, "")
|
||||
|
||||
// 精确匹配
|
||||
_, _, matchedChain, found := store.Find(1, "prefix", "u:a-m:b-u:c")
|
||||
require.True(t, found)
|
||||
assert.Equal(t, "u:a-m:b-u:c", matchedChain)
|
||||
|
||||
// 前缀匹配(截断后命中)
|
||||
_, _, matchedChain, found = store.Find(1, "prefix", "u:a-m:b-u:c-m:d-u:e")
|
||||
require.True(t, found)
|
||||
assert.Equal(t, "u:a-m:b-u:c", matchedChain)
|
||||
}
|
||||
|
||||
func TestDigestSessionStore_CacheItemCountStable(t *testing.T) {
|
||||
store := NewDigestSessionStore()
|
||||
|
||||
// 模拟 100 个独立会话,每个进行 10 轮对话
|
||||
// 正确传递 oldDigestChain 时,每个会话始终只保留 1 个 key
|
||||
for conv := 0; conv < 100; conv++ {
|
||||
var prevMatchedChain string
|
||||
for round := 0; round < 10; round++ {
|
||||
chain := fmt.Sprintf("s:sys-u:user%d", conv)
|
||||
for r := 0; r < round; r++ {
|
||||
chain += fmt.Sprintf("-m:a%d-u:q%d", r, r+1)
|
||||
}
|
||||
uuid := fmt.Sprintf("uuid-conv%d", conv)
|
||||
|
||||
_, _, matched, _ := store.Find(1, "prefix", chain)
|
||||
store.Save(1, "prefix", chain, uuid, int64(conv), matched)
|
||||
prevMatchedChain = matched
|
||||
_ = prevMatchedChain
|
||||
}
|
||||
}
|
||||
|
||||
// 100 个会话 × 1 key/会话 = 应该 ≤ 100 个 key
|
||||
// 允许少量并发残留,但绝不能接近 100×10=1000
|
||||
itemCount := store.cache.ItemCount()
|
||||
assert.LessOrEqual(t, itemCount, 100, "cache should have at most 100 items (1 per conversation), got %d", itemCount)
|
||||
t.Logf("Cache item count after 100 conversations × 10 rounds: %d", itemCount)
|
||||
}
|
||||
|
||||
func TestDigestSessionStore_TTLPreventsUnboundedGrowth(t *testing.T) {
|
||||
// 使用极短 TTL 验证大量写入后 cache 能被清理
|
||||
store := &DigestSessionStore{
|
||||
cache: gocache.New(100*time.Millisecond, 50*time.Millisecond),
|
||||
}
|
||||
|
||||
// 插入 500 个不同的 key(无 oldDigestChain,模拟最坏场景:全是新会话首轮)
|
||||
for i := 0; i < 500; i++ {
|
||||
chain := fmt.Sprintf("u:user%d", i)
|
||||
store.Save(1, "prefix", chain, fmt.Sprintf("uuid-%d", i), int64(i), "")
|
||||
}
|
||||
|
||||
assert.Equal(t, 500, store.cache.ItemCount())
|
||||
|
||||
// 等待 TTL + 清理周期
|
||||
time.Sleep(300 * time.Millisecond)
|
||||
|
||||
assert.Equal(t, 0, store.cache.ItemCount(), "all items should be expired and cleaned up")
|
||||
}
|
||||
|
||||
func TestDigestSessionStore_SaveSameChainNoDelete(t *testing.T) {
|
||||
store := NewDigestSessionStore()
|
||||
|
||||
// 保存 chain
|
||||
store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "")
|
||||
|
||||
// 用户重发相同消息:oldDigestChain == digestChain,不应删掉刚设置的 key
|
||||
store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "u:a-m:b")
|
||||
|
||||
// 仍然能找到
|
||||
uuid, accountID, _, found := store.Find(1, "prefix", "u:a-m:b")
|
||||
require.True(t, found)
|
||||
assert.Equal(t, "uuid-1", uuid)
|
||||
assert.Equal(t, int64(100), accountID)
|
||||
}
|
||||
@@ -61,6 +61,11 @@ func applyErrorPassthroughRule(
|
||||
errMsg = *rule.CustomMessage
|
||||
}
|
||||
|
||||
// 命中 skip_monitoring 时在 context 中标记,供 ops_error_logger 跳过记录。
|
||||
if rule.SkipMonitoring {
|
||||
c.Set(OpsSkipPassthroughKey, true)
|
||||
}
|
||||
|
||||
// 与现有 failover 场景保持一致:命中规则时统一返回 upstream_error。
|
||||
errType = "upstream_error"
|
||||
return status, errType, errMsg, true
|
||||
|
||||
@@ -194,6 +194,63 @@ func TestGeminiWriteGeminiMappedError_AppliesRuleFor422(t *testing.T) {
|
||||
assert.Equal(t, "Gemini上游失败", errField["message"])
|
||||
}
|
||||
|
||||
func TestApplyErrorPassthroughRule_SkipMonitoringSetsContextKey(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
|
||||
rule := newNonFailoverPassthroughRule(http.StatusBadRequest, "prompt is too long", http.StatusBadRequest, "上下文超限")
|
||||
rule.SkipMonitoring = true
|
||||
|
||||
ruleSvc := &ErrorPassthroughService{}
|
||||
ruleSvc.setLocalCache([]*model.ErrorPassthroughRule{rule})
|
||||
BindErrorPassthroughService(c, ruleSvc)
|
||||
|
||||
_, _, _, matched := applyErrorPassthroughRule(
|
||||
c,
|
||||
PlatformAnthropic,
|
||||
http.StatusBadRequest,
|
||||
[]byte(`{"error":{"message":"prompt is too long"}}`),
|
||||
http.StatusBadGateway,
|
||||
"upstream_error",
|
||||
"Upstream request failed",
|
||||
)
|
||||
|
||||
assert.True(t, matched)
|
||||
v, exists := c.Get(OpsSkipPassthroughKey)
|
||||
assert.True(t, exists, "OpsSkipPassthroughKey should be set when skip_monitoring=true")
|
||||
boolVal, ok := v.(bool)
|
||||
assert.True(t, ok, "value should be bool")
|
||||
assert.True(t, boolVal)
|
||||
}
|
||||
|
||||
func TestApplyErrorPassthroughRule_NoSkipMonitoringDoesNotSetContextKey(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
|
||||
rule := newNonFailoverPassthroughRule(http.StatusBadRequest, "prompt is too long", http.StatusBadRequest, "上下文超限")
|
||||
rule.SkipMonitoring = false
|
||||
|
||||
ruleSvc := &ErrorPassthroughService{}
|
||||
ruleSvc.setLocalCache([]*model.ErrorPassthroughRule{rule})
|
||||
BindErrorPassthroughService(c, ruleSvc)
|
||||
|
||||
_, _, _, matched := applyErrorPassthroughRule(
|
||||
c,
|
||||
PlatformAnthropic,
|
||||
http.StatusBadRequest,
|
||||
[]byte(`{"error":{"message":"prompt is too long"}}`),
|
||||
http.StatusBadGateway,
|
||||
"upstream_error",
|
||||
"Upstream request failed",
|
||||
)
|
||||
|
||||
assert.True(t, matched)
|
||||
_, exists := c.Get(OpsSkipPassthroughKey)
|
||||
assert.False(t, exists, "OpsSkipPassthroughKey should NOT be set when skip_monitoring=false")
|
||||
}
|
||||
|
||||
func newNonFailoverPassthroughRule(statusCode int, keyword string, respCode int, customMessage string) *model.ErrorPassthroughRule {
|
||||
return &model.ErrorPassthroughRule{
|
||||
ID: 1,
|
||||
|
||||
@@ -45,10 +45,20 @@ type ErrorPassthroughService struct {
|
||||
cache ErrorPassthroughCache
|
||||
|
||||
// 本地内存缓存,用于快速匹配
|
||||
localCache []*model.ErrorPassthroughRule
|
||||
localCache []*cachedPassthroughRule
|
||||
localCacheMu sync.RWMutex
|
||||
}
|
||||
|
||||
// cachedPassthroughRule 预计算的规则缓存,避免运行时重复 ToLower
|
||||
type cachedPassthroughRule struct {
|
||||
*model.ErrorPassthroughRule
|
||||
lowerKeywords []string // 预计算的小写关键词
|
||||
lowerPlatforms []string // 预计算的小写平台
|
||||
errorCodeSet map[int]struct{} // 预计算的 error code set
|
||||
}
|
||||
|
||||
const maxBodyMatchLen = 8 << 10 // 8KB,错误信息不会在 8KB 之后才出现
|
||||
|
||||
// NewErrorPassthroughService 创建错误透传规则服务
|
||||
func NewErrorPassthroughService(
|
||||
repo ErrorPassthroughRepository,
|
||||
@@ -150,17 +160,19 @@ func (s *ErrorPassthroughService) MatchRule(platform string, statusCode int, bod
|
||||
return nil
|
||||
}
|
||||
|
||||
bodyStr := strings.ToLower(string(body))
|
||||
lowerPlatform := strings.ToLower(platform)
|
||||
var bodyLower string // 延迟初始化,只在需要关键词匹配时计算
|
||||
var bodyLowerDone bool
|
||||
|
||||
for _, rule := range rules {
|
||||
if !rule.Enabled {
|
||||
continue
|
||||
}
|
||||
if !s.platformMatches(rule, platform) {
|
||||
if !s.platformMatchesCached(rule, lowerPlatform) {
|
||||
continue
|
||||
}
|
||||
if s.ruleMatches(rule, statusCode, bodyStr) {
|
||||
return rule
|
||||
if s.ruleMatchesOptimized(rule, statusCode, body, &bodyLower, &bodyLowerDone) {
|
||||
return rule.ErrorPassthroughRule
|
||||
}
|
||||
}
|
||||
|
||||
@@ -168,7 +180,7 @@ func (s *ErrorPassthroughService) MatchRule(platform string, statusCode int, bod
|
||||
}
|
||||
|
||||
// getCachedRules 获取缓存的规则列表(按优先级排序)
|
||||
func (s *ErrorPassthroughService) getCachedRules() []*model.ErrorPassthroughRule {
|
||||
func (s *ErrorPassthroughService) getCachedRules() []*cachedPassthroughRule {
|
||||
s.localCacheMu.RLock()
|
||||
rules := s.localCache
|
||||
s.localCacheMu.RUnlock()
|
||||
@@ -223,17 +235,39 @@ func (s *ErrorPassthroughService) reloadRulesFromDB(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// setLocalCache 设置本地缓存
|
||||
// setLocalCache 设置本地缓存,预计算小写值和 set 以避免运行时重复计算
|
||||
func (s *ErrorPassthroughService) setLocalCache(rules []*model.ErrorPassthroughRule) {
|
||||
cached := make([]*cachedPassthroughRule, len(rules))
|
||||
for i, r := range rules {
|
||||
cr := &cachedPassthroughRule{ErrorPassthroughRule: r}
|
||||
if len(r.Keywords) > 0 {
|
||||
cr.lowerKeywords = make([]string, len(r.Keywords))
|
||||
for j, kw := range r.Keywords {
|
||||
cr.lowerKeywords[j] = strings.ToLower(kw)
|
||||
}
|
||||
}
|
||||
if len(r.Platforms) > 0 {
|
||||
cr.lowerPlatforms = make([]string, len(r.Platforms))
|
||||
for j, p := range r.Platforms {
|
||||
cr.lowerPlatforms[j] = strings.ToLower(p)
|
||||
}
|
||||
}
|
||||
if len(r.ErrorCodes) > 0 {
|
||||
cr.errorCodeSet = make(map[int]struct{}, len(r.ErrorCodes))
|
||||
for _, code := range r.ErrorCodes {
|
||||
cr.errorCodeSet[code] = struct{}{}
|
||||
}
|
||||
}
|
||||
cached[i] = cr
|
||||
}
|
||||
|
||||
// 按优先级排序
|
||||
sorted := make([]*model.ErrorPassthroughRule, len(rules))
|
||||
copy(sorted, rules)
|
||||
sort.Slice(sorted, func(i, j int) bool {
|
||||
return sorted[i].Priority < sorted[j].Priority
|
||||
sort.Slice(cached, func(i, j int) bool {
|
||||
return cached[i].Priority < cached[j].Priority
|
||||
})
|
||||
|
||||
s.localCacheMu.Lock()
|
||||
s.localCache = sorted
|
||||
s.localCache = cached
|
||||
s.localCacheMu.Unlock()
|
||||
}
|
||||
|
||||
@@ -273,62 +307,79 @@ func (s *ErrorPassthroughService) invalidateAndNotify(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// platformMatches 检查平台是否匹配
|
||||
func (s *ErrorPassthroughService) platformMatches(rule *model.ErrorPassthroughRule, platform string) bool {
|
||||
// 如果没有配置平台限制,则匹配所有平台
|
||||
if len(rule.Platforms) == 0 {
|
||||
// ensureBodyLower 延迟初始化 body 的小写版本,只做一次转换,限制 8KB
|
||||
func ensureBodyLower(body []byte, bodyLower *string, done *bool) string {
|
||||
if *done {
|
||||
return *bodyLower
|
||||
}
|
||||
b := body
|
||||
if len(b) > maxBodyMatchLen {
|
||||
b = b[:maxBodyMatchLen]
|
||||
}
|
||||
*bodyLower = strings.ToLower(string(b))
|
||||
*done = true
|
||||
return *bodyLower
|
||||
}
|
||||
|
||||
// platformMatchesCached 使用预计算的小写平台检查是否匹配
|
||||
func (s *ErrorPassthroughService) platformMatchesCached(rule *cachedPassthroughRule, lowerPlatform string) bool {
|
||||
if len(rule.lowerPlatforms) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
platform = strings.ToLower(platform)
|
||||
for _, p := range rule.Platforms {
|
||||
if strings.ToLower(p) == platform {
|
||||
for _, p := range rule.lowerPlatforms {
|
||||
if p == lowerPlatform {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// ruleMatches 检查规则是否匹配
|
||||
func (s *ErrorPassthroughService) ruleMatches(rule *model.ErrorPassthroughRule, statusCode int, bodyLower string) bool {
|
||||
hasErrorCodes := len(rule.ErrorCodes) > 0
|
||||
hasKeywords := len(rule.Keywords) > 0
|
||||
// ruleMatchesOptimized 优化的规则匹配,支持短路和延迟 body 转换
|
||||
func (s *ErrorPassthroughService) ruleMatchesOptimized(rule *cachedPassthroughRule, statusCode int, body []byte, bodyLower *string, bodyLowerDone *bool) bool {
|
||||
hasErrorCodes := len(rule.errorCodeSet) > 0
|
||||
hasKeywords := len(rule.lowerKeywords) > 0
|
||||
|
||||
// 如果没有配置任何条件,不匹配
|
||||
if !hasErrorCodes && !hasKeywords {
|
||||
return false
|
||||
}
|
||||
|
||||
codeMatch := !hasErrorCodes || s.containsInt(rule.ErrorCodes, statusCode)
|
||||
keywordMatch := !hasKeywords || s.containsAnyKeyword(bodyLower, rule.Keywords)
|
||||
codeMatch := !hasErrorCodes || s.containsIntSet(rule.errorCodeSet, statusCode)
|
||||
|
||||
if rule.MatchMode == model.MatchModeAll {
|
||||
// "all" 模式:所有配置的条件都必须满足
|
||||
return codeMatch && keywordMatch
|
||||
// "all" 模式:所有配置的条件都必须满足,短路
|
||||
if hasErrorCodes && !codeMatch {
|
||||
return false
|
||||
}
|
||||
if hasKeywords {
|
||||
return s.containsAnyKeywordCached(ensureBodyLower(body, bodyLower, bodyLowerDone), rule.lowerKeywords)
|
||||
}
|
||||
return codeMatch
|
||||
}
|
||||
|
||||
// "any" 模式:任一条件满足即可
|
||||
// "any" 模式:任一条件满足即可,短路
|
||||
if hasErrorCodes && hasKeywords {
|
||||
return codeMatch || keywordMatch
|
||||
if codeMatch {
|
||||
return true
|
||||
}
|
||||
return s.containsAnyKeywordCached(ensureBodyLower(body, bodyLower, bodyLowerDone), rule.lowerKeywords)
|
||||
}
|
||||
return codeMatch && keywordMatch
|
||||
// 只配置了一种条件
|
||||
if hasKeywords {
|
||||
return s.containsAnyKeywordCached(ensureBodyLower(body, bodyLower, bodyLowerDone), rule.lowerKeywords)
|
||||
}
|
||||
return codeMatch
|
||||
}
|
||||
|
||||
// containsInt 检查切片是否包含指定整数
|
||||
func (s *ErrorPassthroughService) containsInt(slice []int, val int) bool {
|
||||
for _, v := range slice {
|
||||
if v == val {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// containsAnyKeyword 检查字符串是否包含任一关键词(不区分大小写)
|
||||
func (s *ErrorPassthroughService) containsAnyKeyword(bodyLower string, keywords []string) bool {
|
||||
for _, kw := range keywords {
|
||||
if strings.Contains(bodyLower, strings.ToLower(kw)) {
|
||||
// containsIntSet 使用 map 查找替代线性扫描
|
||||
func (s *ErrorPassthroughService) containsIntSet(set map[int]struct{}, val int) bool {
|
||||
_, ok := set[val]
|
||||
return ok
|
||||
}
|
||||
|
||||
// containsAnyKeywordCached 使用预计算的小写关键词检查匹配
|
||||
func (s *ErrorPassthroughService) containsAnyKeywordCached(bodyLower string, lowerKeywords []string) bool {
|
||||
for _, kw := range lowerKeywords {
|
||||
if strings.Contains(bodyLower, kw) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
@@ -145,32 +145,58 @@ func newTestService(rules []*model.ErrorPassthroughRule) *ErrorPassthroughServic
|
||||
return svc
|
||||
}
|
||||
|
||||
// newCachedRuleForTest 从 model.ErrorPassthroughRule 创建 cachedPassthroughRule(测试用)
|
||||
func newCachedRuleForTest(rule *model.ErrorPassthroughRule) *cachedPassthroughRule {
|
||||
cr := &cachedPassthroughRule{ErrorPassthroughRule: rule}
|
||||
if len(rule.Keywords) > 0 {
|
||||
cr.lowerKeywords = make([]string, len(rule.Keywords))
|
||||
for j, kw := range rule.Keywords {
|
||||
cr.lowerKeywords[j] = strings.ToLower(kw)
|
||||
}
|
||||
}
|
||||
if len(rule.Platforms) > 0 {
|
||||
cr.lowerPlatforms = make([]string, len(rule.Platforms))
|
||||
for j, p := range rule.Platforms {
|
||||
cr.lowerPlatforms[j] = strings.ToLower(p)
|
||||
}
|
||||
}
|
||||
if len(rule.ErrorCodes) > 0 {
|
||||
cr.errorCodeSet = make(map[int]struct{}, len(rule.ErrorCodes))
|
||||
for _, code := range rule.ErrorCodes {
|
||||
cr.errorCodeSet[code] = struct{}{}
|
||||
}
|
||||
}
|
||||
return cr
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// 测试 ruleMatches 核心匹配逻辑
|
||||
// 测试 ruleMatchesOptimized 核心匹配逻辑
|
||||
// =============================================================================
|
||||
|
||||
func TestRuleMatches_NoConditions(t *testing.T) {
|
||||
// 没有配置任何条件时,不应该匹配
|
||||
svc := newTestService(nil)
|
||||
rule := &model.ErrorPassthroughRule{
|
||||
rule := newCachedRuleForTest(&model.ErrorPassthroughRule{
|
||||
Enabled: true,
|
||||
ErrorCodes: []int{},
|
||||
Keywords: []string{},
|
||||
MatchMode: model.MatchModeAny,
|
||||
}
|
||||
})
|
||||
|
||||
assert.False(t, svc.ruleMatches(rule, 422, "some error message"),
|
||||
var bodyLower string
|
||||
var bodyLowerDone bool
|
||||
assert.False(t, svc.ruleMatchesOptimized(rule, 422, []byte("some error message"), &bodyLower, &bodyLowerDone),
|
||||
"没有配置条件时不应该匹配")
|
||||
}
|
||||
|
||||
func TestRuleMatches_OnlyErrorCodes_AnyMode(t *testing.T) {
|
||||
svc := newTestService(nil)
|
||||
rule := &model.ErrorPassthroughRule{
|
||||
rule := newCachedRuleForTest(&model.ErrorPassthroughRule{
|
||||
Enabled: true,
|
||||
ErrorCodes: []int{422, 400},
|
||||
Keywords: []string{},
|
||||
MatchMode: model.MatchModeAny,
|
||||
}
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -186,7 +212,9 @@ func TestRuleMatches_OnlyErrorCodes_AnyMode(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := svc.ruleMatches(rule, tt.statusCode, tt.body)
|
||||
var bodyLower string
|
||||
var bodyLowerDone bool
|
||||
result := svc.ruleMatchesOptimized(rule, tt.statusCode, []byte(tt.body), &bodyLower, &bodyLowerDone)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
@@ -194,12 +222,12 @@ func TestRuleMatches_OnlyErrorCodes_AnyMode(t *testing.T) {
|
||||
|
||||
func TestRuleMatches_OnlyKeywords_AnyMode(t *testing.T) {
|
||||
svc := newTestService(nil)
|
||||
rule := &model.ErrorPassthroughRule{
|
||||
rule := newCachedRuleForTest(&model.ErrorPassthroughRule{
|
||||
Enabled: true,
|
||||
ErrorCodes: []int{},
|
||||
Keywords: []string{"context limit", "model not supported"},
|
||||
MatchMode: model.MatchModeAny,
|
||||
}
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -210,16 +238,14 @@ func TestRuleMatches_OnlyKeywords_AnyMode(t *testing.T) {
|
||||
{"关键词匹配 context limit", 500, "error: context limit reached", true},
|
||||
{"关键词匹配 model not supported", 400, "the model not supported here", true},
|
||||
{"关键词不匹配", 422, "some other error", false},
|
||||
// 注意:ruleMatches 接收的 body 参数应该是已经转换为小写的
|
||||
// 实际使用时,MatchRule 会先将 body 转换为小写再传给 ruleMatches
|
||||
{"关键词大小写 - 输入已小写", 500, "context limit exceeded", true},
|
||||
{"关键词大小写 - 自动转换", 500, "Context Limit exceeded", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 模拟 MatchRule 的行为:先转换为小写
|
||||
bodyLower := strings.ToLower(tt.body)
|
||||
result := svc.ruleMatches(rule, tt.statusCode, bodyLower)
|
||||
var bodyLower string
|
||||
var bodyLowerDone bool
|
||||
result := svc.ruleMatchesOptimized(rule, tt.statusCode, []byte(tt.body), &bodyLower, &bodyLowerDone)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
@@ -228,12 +254,12 @@ func TestRuleMatches_OnlyKeywords_AnyMode(t *testing.T) {
|
||||
func TestRuleMatches_BothConditions_AnyMode(t *testing.T) {
|
||||
// any 模式:错误码 OR 关键词
|
||||
svc := newTestService(nil)
|
||||
rule := &model.ErrorPassthroughRule{
|
||||
rule := newCachedRuleForTest(&model.ErrorPassthroughRule{
|
||||
Enabled: true,
|
||||
ErrorCodes: []int{422, 400},
|
||||
Keywords: []string{"context limit"},
|
||||
MatchMode: model.MatchModeAny,
|
||||
}
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -274,7 +300,9 @@ func TestRuleMatches_BothConditions_AnyMode(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := svc.ruleMatches(rule, tt.statusCode, tt.body)
|
||||
var bodyLower string
|
||||
var bodyLowerDone bool
|
||||
result := svc.ruleMatchesOptimized(rule, tt.statusCode, []byte(tt.body), &bodyLower, &bodyLowerDone)
|
||||
assert.Equal(t, tt.expected, result, tt.reason)
|
||||
})
|
||||
}
|
||||
@@ -283,12 +311,12 @@ func TestRuleMatches_BothConditions_AnyMode(t *testing.T) {
|
||||
func TestRuleMatches_BothConditions_AllMode(t *testing.T) {
|
||||
// all 模式:错误码 AND 关键词
|
||||
svc := newTestService(nil)
|
||||
rule := &model.ErrorPassthroughRule{
|
||||
rule := newCachedRuleForTest(&model.ErrorPassthroughRule{
|
||||
Enabled: true,
|
||||
ErrorCodes: []int{422, 400},
|
||||
Keywords: []string{"context limit"},
|
||||
MatchMode: model.MatchModeAll,
|
||||
}
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -329,14 +357,16 @@ func TestRuleMatches_BothConditions_AllMode(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := svc.ruleMatches(rule, tt.statusCode, tt.body)
|
||||
var bodyLower string
|
||||
var bodyLowerDone bool
|
||||
result := svc.ruleMatchesOptimized(rule, tt.statusCode, []byte(tt.body), &bodyLower, &bodyLowerDone)
|
||||
assert.Equal(t, tt.expected, result, tt.reason)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// 测试 platformMatches 平台匹配逻辑
|
||||
// 测试 platformMatchesCached 平台匹配逻辑
|
||||
// =============================================================================
|
||||
|
||||
func TestPlatformMatches(t *testing.T) {
|
||||
@@ -394,10 +424,10 @@ func TestPlatformMatches(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rule := &model.ErrorPassthroughRule{
|
||||
rule := newCachedRuleForTest(&model.ErrorPassthroughRule{
|
||||
Platforms: tt.rulePlatforms,
|
||||
}
|
||||
result := svc.platformMatches(rule, tt.requestPlatform)
|
||||
})
|
||||
result := svc.platformMatchesCached(rule, strings.ToLower(tt.requestPlatform))
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
|
||||
472
backend/internal/service/error_policy_integration_test.go
Normal file
472
backend/internal/service/error_policy_integration_test.go
Normal file
@@ -0,0 +1,472 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Mocks (scoped to this file by naming convention)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// epFixedUpstream returns a fixed response for every request.
|
||||
type epFixedUpstream struct {
|
||||
statusCode int
|
||||
body string
|
||||
calls int
|
||||
}
|
||||
|
||||
func (u *epFixedUpstream) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
|
||||
u.calls++
|
||||
return &http.Response{
|
||||
StatusCode: u.statusCode,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(strings.NewReader(u.body)),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (u *epFixedUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) {
|
||||
return u.Do(req, proxyURL, accountID, accountConcurrency)
|
||||
}
|
||||
|
||||
// epAccountRepo records SetTempUnschedulable / SetError calls.
|
||||
type epAccountRepo struct {
|
||||
mockAccountRepoForGemini
|
||||
tempCalls int
|
||||
setErrCalls int
|
||||
}
|
||||
|
||||
func (r *epAccountRepo) SetTempUnschedulable(_ context.Context, _ int64, _ time.Time, _ string) error {
|
||||
r.tempCalls++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *epAccountRepo) SetError(_ context.Context, _ int64, _ string) error {
|
||||
r.setErrCalls++
|
||||
return nil
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func saveAndSetBaseURLs(t *testing.T) {
|
||||
t.Helper()
|
||||
oldBaseURLs := append([]string(nil), antigravity.BaseURLs...)
|
||||
oldAvail := antigravity.DefaultURLAvailability
|
||||
antigravity.BaseURLs = []string{"https://ep-test.example"}
|
||||
antigravity.DefaultURLAvailability = antigravity.NewURLAvailability(time.Minute)
|
||||
t.Cleanup(func() {
|
||||
antigravity.BaseURLs = oldBaseURLs
|
||||
antigravity.DefaultURLAvailability = oldAvail
|
||||
})
|
||||
}
|
||||
|
||||
func newRetryParams(account *Account, upstream HTTPUpstream, handleError func(context.Context, string, *Account, int, http.Header, []byte, string, int64, string, bool) *handleModelRateLimitResult) antigravityRetryLoopParams {
|
||||
return antigravityRetryLoopParams{
|
||||
ctx: context.Background(),
|
||||
prefix: "[ep-test]",
|
||||
account: account,
|
||||
accessToken: "token",
|
||||
action: "generateContent",
|
||||
body: []byte(`{"input":"test"}`),
|
||||
httpUpstream: upstream,
|
||||
requestedModel: "claude-sonnet-4-5",
|
||||
handleError: handleError,
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// TestRetryLoop_ErrorPolicy_CustomErrorCodes
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestRetryLoop_ErrorPolicy_CustomErrorCodes(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
upstreamStatus int
|
||||
upstreamBody string
|
||||
customCodes []any
|
||||
expectHandleError int
|
||||
expectUpstream int
|
||||
expectStatusCode int
|
||||
}{
|
||||
{
|
||||
name: "429_in_custom_codes_matched",
|
||||
upstreamStatus: 429,
|
||||
upstreamBody: `{"error":"rate limited"}`,
|
||||
customCodes: []any{float64(429)},
|
||||
expectHandleError: 1,
|
||||
expectUpstream: 1,
|
||||
expectStatusCode: 429,
|
||||
},
|
||||
{
|
||||
name: "429_not_in_custom_codes_skipped",
|
||||
upstreamStatus: 429,
|
||||
upstreamBody: `{"error":"rate limited"}`,
|
||||
customCodes: []any{float64(500)},
|
||||
expectHandleError: 0,
|
||||
expectUpstream: 1,
|
||||
expectStatusCode: 500,
|
||||
},
|
||||
{
|
||||
name: "500_in_custom_codes_matched",
|
||||
upstreamStatus: 500,
|
||||
upstreamBody: `{"error":"internal"}`,
|
||||
customCodes: []any{float64(500)},
|
||||
expectHandleError: 1,
|
||||
expectUpstream: 1,
|
||||
expectStatusCode: 500,
|
||||
},
|
||||
{
|
||||
name: "500_not_in_custom_codes_skipped",
|
||||
upstreamStatus: 500,
|
||||
upstreamBody: `{"error":"internal"}`,
|
||||
customCodes: []any{float64(429)},
|
||||
expectHandleError: 0,
|
||||
expectUpstream: 1,
|
||||
expectStatusCode: 500,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
saveAndSetBaseURLs(t)
|
||||
|
||||
upstream := &epFixedUpstream{statusCode: tt.upstreamStatus, body: tt.upstreamBody}
|
||||
repo := &epAccountRepo{}
|
||||
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
|
||||
account := &Account{
|
||||
ID: 100,
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformAntigravity,
|
||||
Schedulable: true,
|
||||
Status: StatusActive,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"custom_error_codes_enabled": true,
|
||||
"custom_error_codes": tt.customCodes,
|
||||
},
|
||||
}
|
||||
|
||||
svc := &AntigravityGatewayService{rateLimitService: rlSvc}
|
||||
|
||||
var handleErrorCount int
|
||||
p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult {
|
||||
handleErrorCount++
|
||||
return nil
|
||||
})
|
||||
|
||||
result, err := svc.antigravityRetryLoop(p)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.resp)
|
||||
defer func() { _ = result.resp.Body.Close() }()
|
||||
|
||||
require.Equal(t, tt.expectStatusCode, result.resp.StatusCode)
|
||||
require.Equal(t, tt.expectHandleError, handleErrorCount, "handleError call count")
|
||||
require.Equal(t, tt.expectUpstream, upstream.calls, "upstream call count")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// TestRetryLoop_ErrorPolicy_TempUnschedulable
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestRetryLoop_ErrorPolicy_TempUnschedulable(t *testing.T) {
|
||||
tempRulesAccount := func(rules []any) *Account {
|
||||
return &Account{
|
||||
ID: 200,
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
Schedulable: true,
|
||||
Status: StatusActive,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"temp_unschedulable_enabled": true,
|
||||
"temp_unschedulable_rules": rules,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
overloadedRule := map[string]any{
|
||||
"error_code": float64(503),
|
||||
"keywords": []any{"overloaded"},
|
||||
"duration_minutes": float64(10),
|
||||
}
|
||||
|
||||
rateLimitRule := map[string]any{
|
||||
"error_code": float64(429),
|
||||
"keywords": []any{"rate limited keyword"},
|
||||
"duration_minutes": float64(5),
|
||||
}
|
||||
|
||||
t.Run("503_overloaded_matches_rule", func(t *testing.T) {
|
||||
saveAndSetBaseURLs(t)
|
||||
|
||||
upstream := &epFixedUpstream{statusCode: 503, body: `overloaded`}
|
||||
repo := &epAccountRepo{}
|
||||
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
svc := &AntigravityGatewayService{rateLimitService: rlSvc}
|
||||
|
||||
account := tempRulesAccount([]any{overloadedRule})
|
||||
p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult {
|
||||
t.Error("handleError should not be called for temp unschedulable")
|
||||
return nil
|
||||
})
|
||||
|
||||
result, err := svc.antigravityRetryLoop(p)
|
||||
|
||||
require.Nil(t, result)
|
||||
var switchErr *AntigravityAccountSwitchError
|
||||
require.ErrorAs(t, err, &switchErr)
|
||||
require.Equal(t, account.ID, switchErr.OriginalAccountID)
|
||||
require.Equal(t, 1, upstream.calls, "should not retry")
|
||||
})
|
||||
|
||||
t.Run("429_rate_limited_keyword_matches_rule", func(t *testing.T) {
|
||||
saveAndSetBaseURLs(t)
|
||||
|
||||
upstream := &epFixedUpstream{statusCode: 429, body: `rate limited keyword`}
|
||||
repo := &epAccountRepo{}
|
||||
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
svc := &AntigravityGatewayService{rateLimitService: rlSvc}
|
||||
|
||||
account := tempRulesAccount([]any{rateLimitRule})
|
||||
p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult {
|
||||
t.Error("handleError should not be called for temp unschedulable")
|
||||
return nil
|
||||
})
|
||||
|
||||
result, err := svc.antigravityRetryLoop(p)
|
||||
|
||||
require.Nil(t, result)
|
||||
var switchErr *AntigravityAccountSwitchError
|
||||
require.ErrorAs(t, err, &switchErr)
|
||||
require.Equal(t, account.ID, switchErr.OriginalAccountID)
|
||||
require.Equal(t, 1, upstream.calls, "should not retry")
|
||||
})
|
||||
|
||||
t.Run("503_body_no_match_continues_default_retry", func(t *testing.T) {
|
||||
saveAndSetBaseURLs(t)
|
||||
|
||||
upstream := &epFixedUpstream{statusCode: 503, body: `random`}
|
||||
repo := &epAccountRepo{}
|
||||
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
svc := &AntigravityGatewayService{rateLimitService: rlSvc}
|
||||
|
||||
account := tempRulesAccount([]any{overloadedRule})
|
||||
|
||||
// Use a short-lived context: the backoff sleep (~1s) will be
|
||||
// interrupted, proving the code entered the default retry path
|
||||
// instead of breaking early via error policy.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
})
|
||||
p.ctx = ctx
|
||||
|
||||
result, err := svc.antigravityRetryLoop(p)
|
||||
|
||||
// Context cancellation during backoff proves default retry was entered
|
||||
require.Nil(t, result)
|
||||
require.ErrorIs(t, err, context.DeadlineExceeded)
|
||||
require.GreaterOrEqual(t, upstream.calls, 1, "should have called upstream at least once")
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// TestRetryLoop_ErrorPolicy_NilRateLimitService
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestRetryLoop_ErrorPolicy_NilRateLimitService(t *testing.T) {
|
||||
saveAndSetBaseURLs(t)
|
||||
|
||||
upstream := &epFixedUpstream{statusCode: 429, body: `{"error":"rate limited"}`}
|
||||
// rateLimitService is nil — must not panic
|
||||
svc := &AntigravityGatewayService{rateLimitService: nil}
|
||||
|
||||
account := &Account{
|
||||
ID: 300,
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
Schedulable: true,
|
||||
Status: StatusActive,
|
||||
Concurrency: 1,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
})
|
||||
p.ctx = ctx
|
||||
|
||||
// Should not panic; enters the default retry path (eventually times out)
|
||||
result, err := svc.antigravityRetryLoop(p)
|
||||
|
||||
require.Nil(t, result)
|
||||
require.ErrorIs(t, err, context.DeadlineExceeded)
|
||||
require.GreaterOrEqual(t, upstream.calls, 1)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// TestRetryLoop_ErrorPolicy_NoPolicy_OriginalBehavior
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestRetryLoop_ErrorPolicy_NoPolicy_OriginalBehavior(t *testing.T) {
|
||||
saveAndSetBaseURLs(t)
|
||||
|
||||
upstream := &epFixedUpstream{statusCode: 429, body: `{"error":"rate limited"}`}
|
||||
repo := &epAccountRepo{}
|
||||
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
svc := &AntigravityGatewayService{rateLimitService: rlSvc}
|
||||
|
||||
// Plain OAuth account with no error policy configured
|
||||
account := &Account{
|
||||
ID: 400,
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
Schedulable: true,
|
||||
Status: StatusActive,
|
||||
Concurrency: 1,
|
||||
}
|
||||
|
||||
var handleErrorCount int
|
||||
p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult {
|
||||
handleErrorCount++
|
||||
return nil
|
||||
})
|
||||
|
||||
result, err := svc.antigravityRetryLoop(p)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.resp)
|
||||
defer func() { _ = result.resp.Body.Close() }()
|
||||
|
||||
require.Equal(t, http.StatusTooManyRequests, result.resp.StatusCode)
|
||||
require.Equal(t, antigravityMaxRetries, upstream.calls, "should exhaust all retries")
|
||||
require.Equal(t, 1, handleErrorCount, "handleError should be called once after retries exhausted")
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// epTrackingRepo — records SetRateLimited / SetError calls for verification.
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type epTrackingRepo struct {
|
||||
mockAccountRepoForGemini
|
||||
rateLimitedCalls int
|
||||
rateLimitedID int64
|
||||
setErrCalls int
|
||||
setErrID int64
|
||||
tempCalls int
|
||||
}
|
||||
|
||||
func (r *epTrackingRepo) SetRateLimited(_ context.Context, id int64, _ time.Time) error {
|
||||
r.rateLimitedCalls++
|
||||
r.rateLimitedID = id
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *epTrackingRepo) SetError(_ context.Context, id int64, _ string) error {
|
||||
r.setErrCalls++
|
||||
r.setErrID = id
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *epTrackingRepo) SetTempUnschedulable(_ context.Context, _ int64, _ time.Time, _ string) error {
|
||||
r.tempCalls++
|
||||
return nil
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// TestCustomErrorCode599_SkippedErrors_Return500_NoRateLimit
|
||||
//
|
||||
// 核心场景:自定义错误码设为 [599](一个不会真正出现的错误码),
|
||||
// 当上游返回 429/500/503/401 时:
|
||||
// - 返回给客户端的状态码必须是 500(而不是透传原始状态码)
|
||||
// - 不调用 SetRateLimited(不进入限流状态)
|
||||
// - 不调用 SetError(不停止调度)
|
||||
// - 不调用 handleError
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestCustomErrorCode599_SkippedErrors_Return500_NoRateLimit(t *testing.T) {
|
||||
errorCodes := []int{429, 500, 503, 401, 403}
|
||||
|
||||
for _, upstreamStatus := range errorCodes {
|
||||
t.Run(http.StatusText(upstreamStatus), func(t *testing.T) {
|
||||
saveAndSetBaseURLs(t)
|
||||
|
||||
upstream := &epFixedUpstream{
|
||||
statusCode: upstreamStatus,
|
||||
body: `{"error":"some upstream error"}`,
|
||||
}
|
||||
repo := &epTrackingRepo{}
|
||||
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
svc := &AntigravityGatewayService{rateLimitService: rlSvc}
|
||||
|
||||
account := &Account{
|
||||
ID: 500,
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformAntigravity,
|
||||
Schedulable: true,
|
||||
Status: StatusActive,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"custom_error_codes_enabled": true,
|
||||
"custom_error_codes": []any{float64(599)},
|
||||
},
|
||||
}
|
||||
|
||||
var handleErrorCount int
|
||||
p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult {
|
||||
handleErrorCount++
|
||||
return nil
|
||||
})
|
||||
|
||||
result, err := svc.antigravityRetryLoop(p)
|
||||
|
||||
// 不应返回 error(Skipped 不触发账号切换)
|
||||
require.NoError(t, err, "should not return error")
|
||||
require.NotNil(t, result, "result should not be nil")
|
||||
require.NotNil(t, result.resp, "response should not be nil")
|
||||
defer func() { _ = result.resp.Body.Close() }()
|
||||
|
||||
// 状态码必须是 500(不透传原始状态码)
|
||||
require.Equal(t, http.StatusInternalServerError, result.resp.StatusCode,
|
||||
"skipped error should return 500, not %d", upstreamStatus)
|
||||
|
||||
// 不调用 handleError
|
||||
require.Equal(t, 0, handleErrorCount,
|
||||
"handleError should NOT be called for skipped errors")
|
||||
|
||||
// 不标记限流
|
||||
require.Equal(t, 0, repo.rateLimitedCalls,
|
||||
"SetRateLimited should NOT be called for skipped errors")
|
||||
|
||||
// 不停止调度
|
||||
require.Equal(t, 0, repo.setErrCalls,
|
||||
"SetError should NOT be called for skipped errors")
|
||||
|
||||
// 只调用一次上游(不重试)
|
||||
require.Equal(t, 1, upstream.calls,
|
||||
"should call upstream exactly once (no retry)")
|
||||
})
|
||||
}
|
||||
}
|
||||
295
backend/internal/service/error_policy_test.go
Normal file
295
backend/internal/service/error_policy_test.go
Normal file
@@ -0,0 +1,295 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// TestCheckErrorPolicy — 6 table-driven cases for the pure logic function
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestCheckErrorPolicy(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
account *Account
|
||||
statusCode int
|
||||
body []byte
|
||||
expected ErrorPolicyResult
|
||||
}{
|
||||
{
|
||||
name: "no_policy_oauth_returns_none",
|
||||
account: &Account{
|
||||
ID: 1,
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
// no custom error codes, no temp rules
|
||||
},
|
||||
statusCode: 500,
|
||||
body: []byte(`"error"`),
|
||||
expected: ErrorPolicyNone,
|
||||
},
|
||||
{
|
||||
name: "custom_error_codes_hit_returns_matched",
|
||||
account: &Account{
|
||||
ID: 2,
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformAntigravity,
|
||||
Credentials: map[string]any{
|
||||
"custom_error_codes_enabled": true,
|
||||
"custom_error_codes": []any{float64(429), float64(500)},
|
||||
},
|
||||
},
|
||||
statusCode: 500,
|
||||
body: []byte(`"error"`),
|
||||
expected: ErrorPolicyMatched,
|
||||
},
|
||||
{
|
||||
name: "custom_error_codes_miss_returns_skipped",
|
||||
account: &Account{
|
||||
ID: 3,
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformAntigravity,
|
||||
Credentials: map[string]any{
|
||||
"custom_error_codes_enabled": true,
|
||||
"custom_error_codes": []any{float64(429), float64(500)},
|
||||
},
|
||||
},
|
||||
statusCode: 503,
|
||||
body: []byte(`"error"`),
|
||||
expected: ErrorPolicySkipped,
|
||||
},
|
||||
{
|
||||
name: "temp_unschedulable_hit_returns_temp_unscheduled",
|
||||
account: &Account{
|
||||
ID: 4,
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
Credentials: map[string]any{
|
||||
"temp_unschedulable_enabled": true,
|
||||
"temp_unschedulable_rules": []any{
|
||||
map[string]any{
|
||||
"error_code": float64(503),
|
||||
"keywords": []any{"overloaded"},
|
||||
"duration_minutes": float64(10),
|
||||
"description": "overloaded rule",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
statusCode: 503,
|
||||
body: []byte(`overloaded service`),
|
||||
expected: ErrorPolicyTempUnscheduled,
|
||||
},
|
||||
{
|
||||
name: "temp_unschedulable_body_miss_returns_none",
|
||||
account: &Account{
|
||||
ID: 5,
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
Credentials: map[string]any{
|
||||
"temp_unschedulable_enabled": true,
|
||||
"temp_unschedulable_rules": []any{
|
||||
map[string]any{
|
||||
"error_code": float64(503),
|
||||
"keywords": []any{"overloaded"},
|
||||
"duration_minutes": float64(10),
|
||||
"description": "overloaded rule",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
statusCode: 503,
|
||||
body: []byte(`random msg`),
|
||||
expected: ErrorPolicyNone,
|
||||
},
|
||||
{
|
||||
name: "custom_error_codes_override_temp_unschedulable",
|
||||
account: &Account{
|
||||
ID: 6,
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformAntigravity,
|
||||
Credentials: map[string]any{
|
||||
"custom_error_codes_enabled": true,
|
||||
"custom_error_codes": []any{float64(503)},
|
||||
"temp_unschedulable_enabled": true,
|
||||
"temp_unschedulable_rules": []any{
|
||||
map[string]any{
|
||||
"error_code": float64(503),
|
||||
"keywords": []any{"overloaded"},
|
||||
"duration_minutes": float64(10),
|
||||
"description": "overloaded rule",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
statusCode: 503,
|
||||
body: []byte(`overloaded`),
|
||||
expected: ErrorPolicyMatched, // custom codes take precedence
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
repo := &errorPolicyRepoStub{}
|
||||
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
|
||||
result := svc.CheckErrorPolicy(context.Background(), tt.account, tt.statusCode, tt.body)
|
||||
require.Equal(t, tt.expected, result, "unexpected ErrorPolicyResult")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// TestApplyErrorPolicy — 4 table-driven cases for the wrapper method
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestApplyErrorPolicy(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
account *Account
|
||||
statusCode int
|
||||
body []byte
|
||||
expectedHandled bool
|
||||
expectedStatus int // expected outStatus
|
||||
expectedSwitchErr bool // expect *AntigravityAccountSwitchError
|
||||
handleErrorCalls int
|
||||
}{
|
||||
{
|
||||
name: "none_not_handled",
|
||||
account: &Account{
|
||||
ID: 10,
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
},
|
||||
statusCode: 500,
|
||||
body: []byte(`"error"`),
|
||||
expectedHandled: false,
|
||||
expectedStatus: 500, // passthrough
|
||||
handleErrorCalls: 0,
|
||||
},
|
||||
{
|
||||
name: "skipped_handled_no_handleError",
|
||||
account: &Account{
|
||||
ID: 11,
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformAntigravity,
|
||||
Credentials: map[string]any{
|
||||
"custom_error_codes_enabled": true,
|
||||
"custom_error_codes": []any{float64(429)},
|
||||
},
|
||||
},
|
||||
statusCode: 500, // not in custom codes
|
||||
body: []byte(`"error"`),
|
||||
expectedHandled: true,
|
||||
expectedStatus: http.StatusInternalServerError, // skipped → 500
|
||||
handleErrorCalls: 0,
|
||||
},
|
||||
{
|
||||
name: "matched_handled_calls_handleError",
|
||||
account: &Account{
|
||||
ID: 12,
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformAntigravity,
|
||||
Credentials: map[string]any{
|
||||
"custom_error_codes_enabled": true,
|
||||
"custom_error_codes": []any{float64(500)},
|
||||
},
|
||||
},
|
||||
statusCode: 500,
|
||||
body: []byte(`"error"`),
|
||||
expectedHandled: true,
|
||||
expectedStatus: 500, // matched → original status
|
||||
handleErrorCalls: 1,
|
||||
},
|
||||
{
|
||||
name: "temp_unscheduled_returns_switch_error",
|
||||
account: &Account{
|
||||
ID: 13,
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
Credentials: map[string]any{
|
||||
"temp_unschedulable_enabled": true,
|
||||
"temp_unschedulable_rules": []any{
|
||||
map[string]any{
|
||||
"error_code": float64(503),
|
||||
"keywords": []any{"overloaded"},
|
||||
"duration_minutes": float64(10),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
statusCode: 503,
|
||||
body: []byte(`overloaded`),
|
||||
expectedHandled: true,
|
||||
expectedStatus: 503, // temp_unscheduled → original status
|
||||
expectedSwitchErr: true,
|
||||
handleErrorCalls: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
repo := &errorPolicyRepoStub{}
|
||||
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
svc := &AntigravityGatewayService{
|
||||
rateLimitService: rlSvc,
|
||||
}
|
||||
|
||||
var handleErrorCount int
|
||||
p := antigravityRetryLoopParams{
|
||||
ctx: context.Background(),
|
||||
prefix: "[test]",
|
||||
account: tt.account,
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
handleErrorCount++
|
||||
return nil
|
||||
},
|
||||
isStickySession: true,
|
||||
}
|
||||
|
||||
handled, outStatus, retErr := svc.applyErrorPolicy(p, tt.statusCode, http.Header{}, tt.body)
|
||||
|
||||
require.Equal(t, tt.expectedHandled, handled, "handled mismatch")
|
||||
require.Equal(t, tt.expectedStatus, outStatus, "outStatus mismatch")
|
||||
require.Equal(t, tt.handleErrorCalls, handleErrorCount, "handleError call count mismatch")
|
||||
|
||||
if tt.expectedSwitchErr {
|
||||
var switchErr *AntigravityAccountSwitchError
|
||||
require.ErrorAs(t, retErr, &switchErr)
|
||||
require.Equal(t, tt.account.ID, switchErr.OriginalAccountID)
|
||||
} else {
|
||||
require.NoError(t, retErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// errorPolicyRepoStub — minimal AccountRepository stub for error policy tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type errorPolicyRepoStub struct {
|
||||
mockAccountRepoForGemini
|
||||
tempCalls int
|
||||
setErrCalls int
|
||||
lastErrorMsg string
|
||||
}
|
||||
|
||||
func (r *errorPolicyRepoStub) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
|
||||
r.tempCalls++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *errorPolicyRepoStub) SetError(ctx context.Context, id int64, errorMsg string) error {
|
||||
r.setErrCalls++
|
||||
r.lastErrorMsg = errorMsg
|
||||
return nil
|
||||
}
|
||||
@@ -77,6 +77,9 @@ func (m *mockAccountRepoForPlatform) Create(ctx context.Context, account *Accoun
|
||||
func (m *mockAccountRepoForPlatform) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockAccountRepoForPlatform) ListCRSAccountIDs(ctx context.Context) (map[string]int64, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockAccountRepoForPlatform) Update(ctx context.Context, account *Account) error {
|
||||
return nil
|
||||
}
|
||||
@@ -84,7 +87,7 @@ func (m *mockAccountRepoForPlatform) Delete(ctx context.Context, id int64) error
|
||||
func (m *mockAccountRepoForPlatform) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (m *mockAccountRepoForPlatform) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) {
|
||||
func (m *mockAccountRepoForPlatform) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (m *mockAccountRepoForPlatform) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) {
|
||||
@@ -142,9 +145,6 @@ func (m *mockAccountRepoForPlatform) ListSchedulableByGroupIDAndPlatforms(ctx co
|
||||
func (m *mockAccountRepoForPlatform) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForPlatform) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForPlatform) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
|
||||
return nil
|
||||
}
|
||||
@@ -216,22 +216,6 @@ func (m *mockGatewayCacheForPlatform) DeleteSessionAccountID(ctx context.Context
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockGatewayCacheForPlatform) IncrModelCallCount(ctx context.Context, accountID int64, model string) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *mockGatewayCacheForPlatform) GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*ModelLoadInfo, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockGatewayCacheForPlatform) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
|
||||
return "", 0, false
|
||||
}
|
||||
|
||||
func (m *mockGatewayCacheForPlatform) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type mockGroupRepoForGateway struct {
|
||||
groups map[int64]*Group
|
||||
getByIDCalls int
|
||||
@@ -290,6 +274,10 @@ func (m *mockGroupRepoForGateway) GetAccountIDsByGroupIDs(ctx context.Context, g
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockGroupRepoForGateway) UpdateSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func ptr[T any](v T) *T {
|
||||
return &v
|
||||
}
|
||||
|
||||
@@ -6,9 +6,19 @@ import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/domain"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
)
|
||||
|
||||
// SessionContext 粘性会话上下文,用于区分不同来源的请求。
|
||||
// 仅在 GenerateSessionHash 第 3 级 fallback(消息内容 hash)时混入,
|
||||
// 避免不同用户发送相同消息产生相同 hash 导致账号集中。
|
||||
type SessionContext struct {
|
||||
ClientIP string
|
||||
UserAgent string
|
||||
APIKeyID int64
|
||||
}
|
||||
|
||||
// ParsedRequest 保存网关请求的预解析结果
|
||||
//
|
||||
// 性能优化说明:
|
||||
@@ -22,20 +32,22 @@ import (
|
||||
// 2. 将解析结果 ParsedRequest 传递给 Service 层
|
||||
// 3. 避免重复 json.Unmarshal,减少 CPU 和内存开销
|
||||
type ParsedRequest struct {
|
||||
Body []byte // 原始请求体(保留用于转发)
|
||||
Model string // 请求的模型名称
|
||||
Stream bool // 是否为流式请求
|
||||
MetadataUserID string // metadata.user_id(用于会话亲和)
|
||||
System any // system 字段内容
|
||||
Messages []any // messages 数组
|
||||
HasSystem bool // 是否包含 system 字段(包含 null 也视为显式传入)
|
||||
ThinkingEnabled bool // 是否开启 thinking(部分平台会影响最终模型名)
|
||||
MaxTokens int // max_tokens 值(用于探测请求拦截)
|
||||
Body []byte // 原始请求体(保留用于转发)
|
||||
Model string // 请求的模型名称
|
||||
Stream bool // 是否为流式请求
|
||||
MetadataUserID string // metadata.user_id(用于会话亲和)
|
||||
System any // system 字段内容
|
||||
Messages []any // messages 数组
|
||||
HasSystem bool // 是否包含 system 字段(包含 null 也视为显式传入)
|
||||
ThinkingEnabled bool // 是否开启 thinking(部分平台会影响最终模型名)
|
||||
MaxTokens int // max_tokens 值(用于探测请求拦截)
|
||||
SessionContext *SessionContext // 可选:请求上下文区分因子(nil 时行为不变)
|
||||
}
|
||||
|
||||
// ParseGatewayRequest 解析网关请求体并返回结构化结果
|
||||
// 性能优化:一次解析提取所有需要的字段,避免重复 Unmarshal
|
||||
func ParseGatewayRequest(body []byte) (*ParsedRequest, error) {
|
||||
// ParseGatewayRequest 解析网关请求体并返回结构化结果。
|
||||
// protocol 指定请求协议格式(domain.PlatformAnthropic / domain.PlatformGemini),
|
||||
// 不同协议使用不同的 system/messages 字段名。
|
||||
func ParseGatewayRequest(body []byte, protocol string) (*ParsedRequest, error) {
|
||||
var req map[string]any
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
return nil, err
|
||||
@@ -64,19 +76,34 @@ func ParseGatewayRequest(body []byte) (*ParsedRequest, error) {
|
||||
parsed.MetadataUserID = userID
|
||||
}
|
||||
}
|
||||
// system 字段只要存在就视为显式提供(即使为 null),
|
||||
// 以避免客户端传 null 时被默认 system 误注入。
|
||||
if system, ok := req["system"]; ok {
|
||||
parsed.HasSystem = true
|
||||
parsed.System = system
|
||||
}
|
||||
if messages, ok := req["messages"].([]any); ok {
|
||||
parsed.Messages = messages
|
||||
|
||||
switch protocol {
|
||||
case domain.PlatformGemini:
|
||||
// Gemini 原生格式: systemInstruction.parts / contents
|
||||
if sysInst, ok := req["systemInstruction"].(map[string]any); ok {
|
||||
if parts, ok := sysInst["parts"].([]any); ok {
|
||||
parsed.System = parts
|
||||
}
|
||||
}
|
||||
if contents, ok := req["contents"].([]any); ok {
|
||||
parsed.Messages = contents
|
||||
}
|
||||
default:
|
||||
// Anthropic / OpenAI 格式: system / messages
|
||||
// system 字段只要存在就视为显式提供(即使为 null),
|
||||
// 以避免客户端传 null 时被默认 system 误注入。
|
||||
if system, ok := req["system"]; ok {
|
||||
parsed.HasSystem = true
|
||||
parsed.System = system
|
||||
}
|
||||
if messages, ok := req["messages"].([]any); ok {
|
||||
parsed.Messages = messages
|
||||
}
|
||||
}
|
||||
|
||||
// thinking: {type: "enabled"}
|
||||
// thinking: {type: "enabled" | "adaptive"}
|
||||
if rawThinking, ok := req["thinking"].(map[string]any); ok {
|
||||
if t, ok := rawThinking["type"].(string); ok && t == "enabled" {
|
||||
if t, ok := rawThinking["type"].(string); ok && (t == "enabled" || t == "adaptive") {
|
||||
parsed.ThinkingEnabled = true
|
||||
}
|
||||
}
|
||||
@@ -134,9 +161,9 @@ func parseIntegralNumber(raw any) (int, bool) {
|
||||
// Returns filtered body or original body if filtering fails (fail-safe)
|
||||
// This prevents 400 errors from invalid thinking block signatures
|
||||
//
|
||||
// Strategy:
|
||||
// - When thinking.type != "enabled": Remove all thinking blocks
|
||||
// - When thinking.type == "enabled": Only remove thinking blocks without valid signatures
|
||||
// 策略:
|
||||
// - 当 thinking.type 不是 "enabled"/"adaptive":移除所有 thinking 相关块
|
||||
// - 当 thinking.type 是 "enabled"/"adaptive":仅移除缺失/无效 signature 的 thinking 块(避免 400)
|
||||
// (blocks with missing/empty/dummy signatures that would cause 400 errors)
|
||||
func FilterThinkingBlocks(body []byte) []byte {
|
||||
return filterThinkingBlocksInternal(body, false)
|
||||
@@ -462,9 +489,9 @@ func FilterSignatureSensitiveBlocksForRetry(body []byte) []byte {
|
||||
}
|
||||
|
||||
// filterThinkingBlocksInternal removes invalid thinking blocks from request
|
||||
// Strategy:
|
||||
// - When thinking.type != "enabled": Remove all thinking blocks
|
||||
// - When thinking.type == "enabled": Only remove thinking blocks without valid signatures
|
||||
// 策略:
|
||||
// - 当 thinking.type 不是 "enabled"/"adaptive":移除所有 thinking 相关块
|
||||
// - 当 thinking.type 是 "enabled"/"adaptive":仅移除缺失/无效 signature 的 thinking 块
|
||||
func filterThinkingBlocksInternal(body []byte, _ bool) []byte {
|
||||
// Fast path: if body doesn't contain "thinking", skip parsing
|
||||
if !bytes.Contains(body, []byte(`"type":"thinking"`)) &&
|
||||
@@ -484,7 +511,7 @@ func filterThinkingBlocksInternal(body []byte, _ bool) []byte {
|
||||
// Check if thinking is enabled
|
||||
thinkingEnabled := false
|
||||
if thinking, ok := req["thinking"].(map[string]any); ok {
|
||||
if thinkType, ok := thinking["type"].(string); ok && thinkType == "enabled" {
|
||||
if thinkType, ok := thinking["type"].(string); ok && (thinkType == "enabled" || thinkType == "adaptive") {
|
||||
thinkingEnabled = true
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,12 +4,13 @@ import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/domain"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestParseGatewayRequest(t *testing.T) {
|
||||
body := []byte(`{"model":"claude-3-7-sonnet","stream":true,"metadata":{"user_id":"session_123e4567-e89b-12d3-a456-426614174000"},"system":[{"type":"text","text":"hello","cache_control":{"type":"ephemeral"}}],"messages":[{"content":"hi"}]}`)
|
||||
parsed, err := ParseGatewayRequest(body)
|
||||
parsed, err := ParseGatewayRequest(body, "")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "claude-3-7-sonnet", parsed.Model)
|
||||
require.True(t, parsed.Stream)
|
||||
@@ -22,7 +23,15 @@ func TestParseGatewayRequest(t *testing.T) {
|
||||
|
||||
func TestParseGatewayRequest_ThinkingEnabled(t *testing.T) {
|
||||
body := []byte(`{"model":"claude-sonnet-4-5","thinking":{"type":"enabled"},"messages":[{"content":"hi"}]}`)
|
||||
parsed, err := ParseGatewayRequest(body)
|
||||
parsed, err := ParseGatewayRequest(body, "")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "claude-sonnet-4-5", parsed.Model)
|
||||
require.True(t, parsed.ThinkingEnabled)
|
||||
}
|
||||
|
||||
func TestParseGatewayRequest_ThinkingAdaptiveEnabled(t *testing.T) {
|
||||
body := []byte(`{"model":"claude-sonnet-4-5","thinking":{"type":"adaptive"},"messages":[{"content":"hi"}]}`)
|
||||
parsed, err := ParseGatewayRequest(body, "")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "claude-sonnet-4-5", parsed.Model)
|
||||
require.True(t, parsed.ThinkingEnabled)
|
||||
@@ -30,21 +39,21 @@ func TestParseGatewayRequest_ThinkingEnabled(t *testing.T) {
|
||||
|
||||
func TestParseGatewayRequest_MaxTokens(t *testing.T) {
|
||||
body := []byte(`{"model":"claude-haiku-4-5","max_tokens":1}`)
|
||||
parsed, err := ParseGatewayRequest(body)
|
||||
parsed, err := ParseGatewayRequest(body, "")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, parsed.MaxTokens)
|
||||
}
|
||||
|
||||
func TestParseGatewayRequest_MaxTokensNonIntegralIgnored(t *testing.T) {
|
||||
body := []byte(`{"model":"claude-haiku-4-5","max_tokens":1.5}`)
|
||||
parsed, err := ParseGatewayRequest(body)
|
||||
parsed, err := ParseGatewayRequest(body, "")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 0, parsed.MaxTokens)
|
||||
}
|
||||
|
||||
func TestParseGatewayRequest_SystemNull(t *testing.T) {
|
||||
body := []byte(`{"model":"claude-3","system":null}`)
|
||||
parsed, err := ParseGatewayRequest(body)
|
||||
parsed, err := ParseGatewayRequest(body, "")
|
||||
require.NoError(t, err)
|
||||
// 显式传入 system:null 也应视为“字段已存在”,避免默认 system 被注入。
|
||||
require.True(t, parsed.HasSystem)
|
||||
@@ -53,16 +62,112 @@ func TestParseGatewayRequest_SystemNull(t *testing.T) {
|
||||
|
||||
func TestParseGatewayRequest_InvalidModelType(t *testing.T) {
|
||||
body := []byte(`{"model":123}`)
|
||||
_, err := ParseGatewayRequest(body)
|
||||
_, err := ParseGatewayRequest(body, "")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestParseGatewayRequest_InvalidStreamType(t *testing.T) {
|
||||
body := []byte(`{"stream":"true"}`)
|
||||
_, err := ParseGatewayRequest(body)
|
||||
_, err := ParseGatewayRequest(body, "")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
// ============ Gemini 原生格式解析测试 ============
|
||||
|
||||
func TestParseGatewayRequest_GeminiContents(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"contents": [
|
||||
{"role": "user", "parts": [{"text": "Hello"}]},
|
||||
{"role": "model", "parts": [{"text": "Hi there"}]},
|
||||
{"role": "user", "parts": [{"text": "How are you?"}]}
|
||||
]
|
||||
}`)
|
||||
parsed, err := ParseGatewayRequest(body, domain.PlatformGemini)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, parsed.Messages, 3, "should parse contents as Messages")
|
||||
require.False(t, parsed.HasSystem, "Gemini format should not set HasSystem")
|
||||
require.Nil(t, parsed.System, "no systemInstruction means nil System")
|
||||
}
|
||||
|
||||
func TestParseGatewayRequest_GeminiSystemInstruction(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"systemInstruction": {
|
||||
"parts": [{"text": "You are a helpful assistant."}]
|
||||
},
|
||||
"contents": [
|
||||
{"role": "user", "parts": [{"text": "Hello"}]}
|
||||
]
|
||||
}`)
|
||||
parsed, err := ParseGatewayRequest(body, domain.PlatformGemini)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, parsed.System, "should parse systemInstruction.parts as System")
|
||||
parts, ok := parsed.System.([]any)
|
||||
require.True(t, ok)
|
||||
require.Len(t, parts, 1)
|
||||
partMap, ok := parts[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "You are a helpful assistant.", partMap["text"])
|
||||
require.Len(t, parsed.Messages, 1)
|
||||
}
|
||||
|
||||
func TestParseGatewayRequest_GeminiWithModel(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"model": "gemini-2.5-pro",
|
||||
"contents": [{"role": "user", "parts": [{"text": "test"}]}]
|
||||
}`)
|
||||
parsed, err := ParseGatewayRequest(body, domain.PlatformGemini)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "gemini-2.5-pro", parsed.Model)
|
||||
require.Len(t, parsed.Messages, 1)
|
||||
}
|
||||
|
||||
func TestParseGatewayRequest_GeminiIgnoresAnthropicFields(t *testing.T) {
|
||||
// Gemini 格式下 system/messages 字段应被忽略
|
||||
body := []byte(`{
|
||||
"system": "should be ignored",
|
||||
"messages": [{"role": "user", "content": "ignored"}],
|
||||
"contents": [{"role": "user", "parts": [{"text": "real content"}]}]
|
||||
}`)
|
||||
parsed, err := ParseGatewayRequest(body, domain.PlatformGemini)
|
||||
require.NoError(t, err)
|
||||
require.False(t, parsed.HasSystem, "Gemini protocol should not parse Anthropic system field")
|
||||
require.Nil(t, parsed.System, "no systemInstruction = nil System")
|
||||
require.Len(t, parsed.Messages, 1, "should use contents, not messages")
|
||||
}
|
||||
|
||||
func TestParseGatewayRequest_GeminiEmptyContents(t *testing.T) {
|
||||
body := []byte(`{"contents": []}`)
|
||||
parsed, err := ParseGatewayRequest(body, domain.PlatformGemini)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, parsed.Messages)
|
||||
}
|
||||
|
||||
func TestParseGatewayRequest_GeminiNoContents(t *testing.T) {
|
||||
body := []byte(`{"model": "gemini-2.5-flash"}`)
|
||||
parsed, err := ParseGatewayRequest(body, domain.PlatformGemini)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, parsed.Messages)
|
||||
require.Equal(t, "gemini-2.5-flash", parsed.Model)
|
||||
}
|
||||
|
||||
func TestParseGatewayRequest_AnthropicIgnoresGeminiFields(t *testing.T) {
|
||||
// Anthropic 格式下 contents/systemInstruction 字段应被忽略
|
||||
body := []byte(`{
|
||||
"system": "real system",
|
||||
"messages": [{"role": "user", "content": "real content"}],
|
||||
"contents": [{"role": "user", "parts": [{"text": "ignored"}]}],
|
||||
"systemInstruction": {"parts": [{"text": "ignored"}]}
|
||||
}`)
|
||||
parsed, err := ParseGatewayRequest(body, domain.PlatformAnthropic)
|
||||
require.NoError(t, err)
|
||||
require.True(t, parsed.HasSystem)
|
||||
require.Equal(t, "real system", parsed.System)
|
||||
require.Len(t, parsed.Messages, 1)
|
||||
msg, ok := parsed.Messages[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "real content", msg["content"])
|
||||
}
|
||||
|
||||
func TestFilterThinkingBlocks(t *testing.T) {
|
||||
containsThinkingBlock := func(body []byte) bool {
|
||||
var req map[string]any
|
||||
@@ -112,6 +217,16 @@ func TestFilterThinkingBlocks(t *testing.T) {
|
||||
input: `{"model":"claude-3-5-sonnet-20241022","messages":[{"role":"user","content":[{"type":"text","text":"Hello"},{"type":"thinking","thinking":"internal","signature":"invalid"},{"type":"text","text":"World"}]}]}`,
|
||||
shouldFilter: true,
|
||||
},
|
||||
{
|
||||
name: "does not filter signed thinking blocks when thinking adaptive",
|
||||
input: `{"thinking":{"type":"adaptive"},"messages":[{"role":"assistant","content":[{"type":"thinking","thinking":"ok","signature":"sig_real_123"},{"type":"text","text":"B"}]}]}`,
|
||||
shouldFilter: false,
|
||||
},
|
||||
{
|
||||
name: "filters unsigned thinking blocks when thinking adaptive",
|
||||
input: `{"thinking":{"type":"adaptive"},"messages":[{"role":"assistant","content":[{"type":"thinking","thinking":"internal","signature":""},{"type":"text","text":"B"}]}]}`,
|
||||
shouldFilter: true,
|
||||
},
|
||||
{
|
||||
name: "handles no thinking blocks",
|
||||
input: `{"model":"claude-3-5-sonnet-20241022","messages":[{"role":"user","content":[{"type":"text","text":"Hello"}]}]}`,
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -17,6 +16,7 @@ import (
|
||||
"os"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@@ -26,6 +26,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
||||
"github.com/cespare/xxhash/v2"
|
||||
"github.com/google/uuid"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
@@ -242,12 +243,15 @@ var (
|
||||
}
|
||||
)
|
||||
|
||||
// systemBlockFilterPrefixes 需要从 system 中过滤的文本前缀列表
|
||||
// OAuth/SetupToken 账号转发时,匹配这些前缀的 system 元素会被移除
|
||||
var systemBlockFilterPrefixes = []string{
|
||||
"x-anthropic-billing-header",
|
||||
}
|
||||
|
||||
// ErrClaudeCodeOnly 表示分组仅允许 Claude Code 客户端访问
|
||||
var ErrClaudeCodeOnly = errors.New("this group only allows Claude Code clients")
|
||||
|
||||
// ErrModelScopeNotSupported 表示请求的模型系列不在分组支持的范围内
|
||||
var ErrModelScopeNotSupported = errors.New("model scope not supported by this group")
|
||||
|
||||
// allowedHeaders 白名单headers(参考CRS项目)
|
||||
var allowedHeaders = map[string]bool{
|
||||
"accept": true,
|
||||
@@ -273,13 +277,6 @@ var allowedHeaders = map[string]bool{
|
||||
// GatewayCache 定义网关服务的缓存操作接口。
|
||||
// 提供粘性会话(Sticky Session)的存储、查询、刷新和删除功能。
|
||||
//
|
||||
// ModelLoadInfo 模型负载信息(用于 Antigravity 调度)
|
||||
// Model load info for Antigravity scheduling
|
||||
type ModelLoadInfo struct {
|
||||
CallCount int64 // 当前分钟调用次数 / Call count in current minute
|
||||
LastUsedAt time.Time // 最后调度时间(零值表示未调度过)/ Last scheduling time (zero means never scheduled)
|
||||
}
|
||||
|
||||
// GatewayCache defines cache operations for gateway service.
|
||||
// Provides sticky session storage, retrieval, refresh and deletion capabilities.
|
||||
type GatewayCache interface {
|
||||
@@ -295,24 +292,6 @@ type GatewayCache interface {
|
||||
// DeleteSessionAccountID 删除粘性会话绑定,用于账号不可用时主动清理
|
||||
// Delete sticky session binding, used to proactively clean up when account becomes unavailable
|
||||
DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error
|
||||
|
||||
// IncrModelCallCount 增加模型调用次数并更新最后调度时间(Antigravity 专用)
|
||||
// Increment model call count and update last scheduling time (Antigravity only)
|
||||
// 返回更新后的调用次数
|
||||
IncrModelCallCount(ctx context.Context, accountID int64, model string) (int64, error)
|
||||
|
||||
// GetModelLoadBatch 批量获取账号的模型负载信息(Antigravity 专用)
|
||||
// Batch get model load info for accounts (Antigravity only)
|
||||
GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*ModelLoadInfo, error)
|
||||
|
||||
// FindGeminiSession 查找 Gemini 会话(MGET 倒序匹配)
|
||||
// Find Gemini session using MGET reverse order matching
|
||||
// 返回最长匹配的会话信息(uuid, accountID)
|
||||
FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool)
|
||||
|
||||
// SaveGeminiSession 保存 Gemini 会话
|
||||
// Save Gemini session binding
|
||||
SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error
|
||||
}
|
||||
|
||||
// derefGroupID safely dereferences *int64 to int64, returning 0 if nil
|
||||
@@ -323,21 +302,15 @@ func derefGroupID(groupID *int64) int64 {
|
||||
return *groupID
|
||||
}
|
||||
|
||||
// stickySessionRateLimitThreshold 定义清除粘性会话的限流时间阈值。
|
||||
// 当账号限流剩余时间超过此阈值时,清除粘性会话以便切换到其他账号。
|
||||
// 低于此阈值时保持粘性会话,等待短暂限流结束。
|
||||
const stickySessionRateLimitThreshold = 10 * time.Second
|
||||
|
||||
// shouldClearStickySession 检查账号是否处于不可调度状态,需要清理粘性会话绑定。
|
||||
// 当账号状态为错误、禁用、不可调度、处于临时不可调度期间,
|
||||
// 或模型限流剩余时间超过 stickySessionRateLimitThreshold 时,返回 true。
|
||||
// 或请求的模型处于限流状态时,返回 true。
|
||||
// 这确保后续请求不会继续使用不可用的账号。
|
||||
//
|
||||
// shouldClearStickySession checks if an account is in an unschedulable state
|
||||
// and the sticky session binding should be cleared.
|
||||
// Returns true when account status is error/disabled, schedulable is false,
|
||||
// within temporary unschedulable period, or model rate limit remaining time
|
||||
// exceeds stickySessionRateLimitThreshold.
|
||||
// within temporary unschedulable period, or the requested model is rate-limited.
|
||||
// This ensures subsequent requests won't continue using unavailable accounts.
|
||||
func shouldClearStickySession(account *Account, requestedModel string) bool {
|
||||
if account == nil {
|
||||
@@ -349,8 +322,8 @@ func shouldClearStickySession(account *Account, requestedModel string) bool {
|
||||
if account.TempUnschedulableUntil != nil && time.Now().Before(*account.TempUnschedulableUntil) {
|
||||
return true
|
||||
}
|
||||
// 检查模型限流和 scope 限流,只在超过阈值时清除粘性会话
|
||||
if remaining := account.GetRateLimitRemainingTimeWithContext(context.Background(), requestedModel); remaining > stickySessionRateLimitThreshold {
|
||||
// 检查模型限流和 scope 限流,有限流即清除粘性会话
|
||||
if remaining := account.GetRateLimitRemainingTimeWithContext(context.Background(), requestedModel); remaining > 0 {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
@@ -376,6 +349,8 @@ type ClaudeUsage struct {
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
CacheCreationInputTokens int `json:"cache_creation_input_tokens"`
|
||||
CacheReadInputTokens int `json:"cache_read_input_tokens"`
|
||||
CacheCreation5mTokens int // 5分钟缓存创建token(来自嵌套 cache_creation 对象)
|
||||
CacheCreation1hTokens int // 1小时缓存创建token(来自嵌套 cache_creation 对象)
|
||||
}
|
||||
|
||||
// ForwardResult 转发结果
|
||||
@@ -395,15 +370,31 @@ type ForwardResult struct {
|
||||
|
||||
// UpstreamFailoverError indicates an upstream error that should trigger account failover.
|
||||
type UpstreamFailoverError struct {
|
||||
StatusCode int
|
||||
ResponseBody []byte // 上游响应体,用于错误透传规则匹配
|
||||
ForceCacheBilling bool // Antigravity 粘性会话切换时设为 true
|
||||
StatusCode int
|
||||
ResponseBody []byte // 上游响应体,用于错误透传规则匹配
|
||||
ForceCacheBilling bool // Antigravity 粘性会话切换时设为 true
|
||||
RetryableOnSameAccount bool // 临时性错误(如 Google 间歇性 400、空响应),应在同一账号上重试 N 次再切换
|
||||
}
|
||||
|
||||
func (e *UpstreamFailoverError) Error() string {
|
||||
return fmt.Sprintf("upstream error: %d (failover)", e.StatusCode)
|
||||
}
|
||||
|
||||
// TempUnscheduleRetryableError 对 RetryableOnSameAccount 类型的 failover 错误触发临时封禁。
|
||||
// 由 handler 层在同账号重试全部用尽、切换账号时调用。
|
||||
func (s *GatewayService) TempUnscheduleRetryableError(ctx context.Context, accountID int64, failoverErr *UpstreamFailoverError) {
|
||||
if failoverErr == nil || !failoverErr.RetryableOnSameAccount {
|
||||
return
|
||||
}
|
||||
// 根据状态码选择封禁策略
|
||||
switch failoverErr.StatusCode {
|
||||
case http.StatusBadRequest:
|
||||
tempUnscheduleGoogleConfigError(ctx, s.accountRepo, accountID, "[handler]")
|
||||
case http.StatusBadGateway:
|
||||
tempUnscheduleEmptyResponse(ctx, s.accountRepo, accountID, "[handler]")
|
||||
}
|
||||
}
|
||||
|
||||
// GatewayService handles API gateway operations
|
||||
type GatewayService struct {
|
||||
accountRepo AccountRepository
|
||||
@@ -413,6 +404,7 @@ type GatewayService struct {
|
||||
userSubRepo UserSubscriptionRepository
|
||||
userGroupRateRepo UserGroupRateRepository
|
||||
cache GatewayCache
|
||||
digestStore *DigestSessionStore
|
||||
cfg *config.Config
|
||||
schedulerSnapshot *SchedulerSnapshotService
|
||||
billingService *BillingService
|
||||
@@ -446,6 +438,7 @@ func NewGatewayService(
|
||||
deferredService *DeferredService,
|
||||
claudeTokenProvider *ClaudeTokenProvider,
|
||||
sessionLimitCache SessionLimitCache,
|
||||
digestStore *DigestSessionStore,
|
||||
) *GatewayService {
|
||||
return &GatewayService{
|
||||
accountRepo: accountRepo,
|
||||
@@ -455,6 +448,7 @@ func NewGatewayService(
|
||||
userSubRepo: userSubRepo,
|
||||
userGroupRateRepo: userGroupRateRepo,
|
||||
cache: cache,
|
||||
digestStore: digestStore,
|
||||
cfg: cfg,
|
||||
schedulerSnapshot: schedulerSnapshot,
|
||||
concurrencyService: concurrencyService,
|
||||
@@ -488,23 +482,45 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string {
|
||||
return s.hashContent(cacheableContent)
|
||||
}
|
||||
|
||||
// 3. Fallback: 使用 system 内容
|
||||
// 3. 最后 fallback: 使用 session上下文 + system + 所有消息的完整摘要串
|
||||
var combined strings.Builder
|
||||
// 混入请求上下文区分因子,避免不同用户相同消息产生相同 hash
|
||||
if parsed.SessionContext != nil {
|
||||
_, _ = combined.WriteString(parsed.SessionContext.ClientIP)
|
||||
_, _ = combined.WriteString(":")
|
||||
_, _ = combined.WriteString(parsed.SessionContext.UserAgent)
|
||||
_, _ = combined.WriteString(":")
|
||||
_, _ = combined.WriteString(strconv.FormatInt(parsed.SessionContext.APIKeyID, 10))
|
||||
_, _ = combined.WriteString("|")
|
||||
}
|
||||
if parsed.System != nil {
|
||||
systemText := s.extractTextFromSystem(parsed.System)
|
||||
if systemText != "" {
|
||||
return s.hashContent(systemText)
|
||||
_, _ = combined.WriteString(systemText)
|
||||
}
|
||||
}
|
||||
|
||||
// 4. 最后 fallback: 使用第一条消息
|
||||
if len(parsed.Messages) > 0 {
|
||||
if firstMsg, ok := parsed.Messages[0].(map[string]any); ok {
|
||||
msgText := s.extractTextFromContent(firstMsg["content"])
|
||||
if msgText != "" {
|
||||
return s.hashContent(msgText)
|
||||
for _, msg := range parsed.Messages {
|
||||
if m, ok := msg.(map[string]any); ok {
|
||||
if content, exists := m["content"]; exists {
|
||||
// Anthropic: messages[].content
|
||||
if msgText := s.extractTextFromContent(content); msgText != "" {
|
||||
_, _ = combined.WriteString(msgText)
|
||||
}
|
||||
} else if parts, ok := m["parts"].([]any); ok {
|
||||
// Gemini: contents[].parts[].text
|
||||
for _, part := range parts {
|
||||
if partMap, ok := part.(map[string]any); ok {
|
||||
if text, ok := partMap["text"].(string); ok {
|
||||
_, _ = combined.WriteString(text)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if combined.Len() > 0 {
|
||||
return s.hashContent(combined.String())
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
@@ -532,19 +548,37 @@ func (s *GatewayService) GetCachedSessionAccountID(ctx context.Context, groupID
|
||||
|
||||
// FindGeminiSession 查找 Gemini 会话(基于内容摘要链的 Fallback 匹配)
|
||||
// 返回最长匹配的会话信息(uuid, accountID)
|
||||
func (s *GatewayService) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
|
||||
if digestChain == "" || s.cache == nil {
|
||||
return "", 0, false
|
||||
func (s *GatewayService) FindGeminiSession(_ context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, matchedChain string, found bool) {
|
||||
if digestChain == "" || s.digestStore == nil {
|
||||
return "", 0, "", false
|
||||
}
|
||||
return s.cache.FindGeminiSession(ctx, groupID, prefixHash, digestChain)
|
||||
return s.digestStore.Find(groupID, prefixHash, digestChain)
|
||||
}
|
||||
|
||||
// SaveGeminiSession 保存 Gemini 会话
|
||||
func (s *GatewayService) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
|
||||
if digestChain == "" || s.cache == nil {
|
||||
// SaveGeminiSession 保存 Gemini 会话。oldDigestChain 为 Find 返回的 matchedChain,用于删旧 key。
|
||||
func (s *GatewayService) SaveGeminiSession(_ context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64, oldDigestChain string) error {
|
||||
if digestChain == "" || s.digestStore == nil {
|
||||
return nil
|
||||
}
|
||||
return s.cache.SaveGeminiSession(ctx, groupID, prefixHash, digestChain, uuid, accountID)
|
||||
s.digestStore.Save(groupID, prefixHash, digestChain, uuid, accountID, oldDigestChain)
|
||||
return nil
|
||||
}
|
||||
|
||||
// FindAnthropicSession 查找 Anthropic 会话(基于内容摘要链的 Fallback 匹配)
|
||||
func (s *GatewayService) FindAnthropicSession(_ context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, matchedChain string, found bool) {
|
||||
if digestChain == "" || s.digestStore == nil {
|
||||
return "", 0, "", false
|
||||
}
|
||||
return s.digestStore.Find(groupID, prefixHash, digestChain)
|
||||
}
|
||||
|
||||
// SaveAnthropicSession 保存 Anthropic 会话
|
||||
func (s *GatewayService) SaveAnthropicSession(_ context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64, oldDigestChain string) error {
|
||||
if digestChain == "" || s.digestStore == nil {
|
||||
return nil
|
||||
}
|
||||
s.digestStore.Save(groupID, prefixHash, digestChain, uuid, accountID, oldDigestChain)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string {
|
||||
@@ -629,8 +663,8 @@ func (s *GatewayService) extractTextFromContent(content any) string {
|
||||
}
|
||||
|
||||
func (s *GatewayService) hashContent(content string) string {
|
||||
hash := sha256.Sum256([]byte(content))
|
||||
return hex.EncodeToString(hash[:16]) // 32字符
|
||||
h := xxhash.Sum64String(content)
|
||||
return strconv.FormatUint(h, 36)
|
||||
}
|
||||
|
||||
// replaceModelInBody 替换请求体中的model字段
|
||||
@@ -989,13 +1023,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
log.Printf("[ModelRoutingDebug] load-aware enabled: group_id=%v model=%s session=%s platform=%s", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), platform)
|
||||
}
|
||||
|
||||
// Antigravity 模型系列检查(在账号选择前检查,确保所有代码路径都经过此检查)
|
||||
if platform == PlatformAntigravity && groupID != nil && requestedModel != "" {
|
||||
if err := s.checkAntigravityModelScope(ctx, *groupID, requestedModel); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
accounts, useMixed, err := s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -1110,7 +1137,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
result.ReleaseFunc() // 释放槽位
|
||||
// 继续到负载感知选择
|
||||
} else {
|
||||
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
|
||||
if s.debugModelRoutingEnabled() {
|
||||
log.Printf("[ModelRoutingDebug] routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), stickyAccountID)
|
||||
}
|
||||
@@ -1190,6 +1216,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
return a.account.LastUsedAt.Before(*b.account.LastUsedAt)
|
||||
}
|
||||
})
|
||||
shuffleWithinSortGroups(routingAvailable)
|
||||
|
||||
// 4. 尝试获取槽位
|
||||
for _, item := range routingAvailable {
|
||||
@@ -1264,7 +1291,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
|
||||
result.ReleaseFunc() // 释放槽位,继续到 Layer 2
|
||||
} else {
|
||||
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
Acquired: true,
|
||||
@@ -1344,10 +1370,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
return result, nil
|
||||
}
|
||||
} else {
|
||||
// Antigravity 平台:获取模型负载信息
|
||||
var modelLoadMap map[int64]*ModelLoadInfo
|
||||
isAntigravity := platform == PlatformAntigravity
|
||||
|
||||
var available []accountWithLoad
|
||||
for _, acc := range candidates {
|
||||
loadInfo := loadMap[acc.ID]
|
||||
@@ -1362,109 +1384,44 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
}
|
||||
}
|
||||
|
||||
// Antigravity 平台:按账号实际映射后的模型名获取模型负载(与 Forward 的统计保持一致)
|
||||
if isAntigravity && requestedModel != "" && s.cache != nil && len(available) > 0 {
|
||||
modelLoadMap = make(map[int64]*ModelLoadInfo, len(available))
|
||||
modelToAccountIDs := make(map[string][]int64)
|
||||
for _, item := range available {
|
||||
mappedModel := mapAntigravityModel(item.account, requestedModel)
|
||||
if mappedModel == "" {
|
||||
continue
|
||||
}
|
||||
modelToAccountIDs[mappedModel] = append(modelToAccountIDs[mappedModel], item.account.ID)
|
||||
// 分层过滤选择:优先级 → 负载率 → LRU
|
||||
for len(available) > 0 {
|
||||
// 1. 取优先级最小的集合
|
||||
candidates := filterByMinPriority(available)
|
||||
// 2. 取负载率最低的集合
|
||||
candidates = filterByMinLoadRate(candidates)
|
||||
// 3. LRU 选择最久未用的账号
|
||||
selected := selectByLRU(candidates, preferOAuth)
|
||||
if selected == nil {
|
||||
break
|
||||
}
|
||||
for model, ids := range modelToAccountIDs {
|
||||
batch, err := s.cache.GetModelLoadBatch(ctx, ids, model)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
for id, info := range batch {
|
||||
modelLoadMap[id] = info
|
||||
}
|
||||
}
|
||||
if len(modelLoadMap) == 0 {
|
||||
modelLoadMap = nil
|
||||
}
|
||||
}
|
||||
|
||||
// Antigravity 平台:优先级硬过滤 →(同优先级内)按调用次数选择(最少优先,新账号用平均值)
|
||||
// 其他平台:分层过滤选择:优先级 → 负载率 → LRU
|
||||
if isAntigravity {
|
||||
for len(available) > 0 {
|
||||
// 1. 取优先级最小的集合(硬过滤)
|
||||
candidates := filterByMinPriority(available)
|
||||
// 2. 同优先级内按调用次数选择(调用次数最少优先,新账号使用平均值)
|
||||
selected := selectByCallCount(candidates, modelLoadMap, preferOAuth)
|
||||
if selected == nil {
|
||||
break
|
||||
}
|
||||
|
||||
result, err := s.tryAcquireAccountSlot(ctx, selected.account.ID, selected.account.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
// 会话数量限制检查
|
||||
if !s.checkAndRegisterSession(ctx, selected.account, sessionHash) {
|
||||
result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
|
||||
} else {
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.account.ID, stickySessionTTL)
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: selected.account,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, nil
|
||||
result, err := s.tryAcquireAccountSlot(ctx, selected.account.ID, selected.account.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
// 会话数量限制检查
|
||||
if !s.checkAndRegisterSession(ctx, selected.account, sessionHash) {
|
||||
result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
|
||||
} else {
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.account.ID, stickySessionTTL)
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: selected.account,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 移除已尝试的账号,重新选择
|
||||
selectedID := selected.account.ID
|
||||
newAvailable := make([]accountWithLoad, 0, len(available)-1)
|
||||
for _, acc := range available {
|
||||
if acc.account.ID != selectedID {
|
||||
newAvailable = append(newAvailable, acc)
|
||||
}
|
||||
}
|
||||
available = newAvailable
|
||||
}
|
||||
} else {
|
||||
for len(available) > 0 {
|
||||
// 1. 取优先级最小的集合
|
||||
candidates := filterByMinPriority(available)
|
||||
// 2. 取负载率最低的集合
|
||||
candidates = filterByMinLoadRate(candidates)
|
||||
// 3. LRU 选择最久未用的账号
|
||||
selected := selectByLRU(candidates, preferOAuth)
|
||||
if selected == nil {
|
||||
break
|
||||
}
|
||||
|
||||
result, err := s.tryAcquireAccountSlot(ctx, selected.account.ID, selected.account.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
// 会话数量限制检查
|
||||
if !s.checkAndRegisterSession(ctx, selected.account, sessionHash) {
|
||||
result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
|
||||
} else {
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.account.ID, stickySessionTTL)
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: selected.account,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, nil
|
||||
}
|
||||
// 移除已尝试的账号,重新进行分层过滤
|
||||
selectedID := selected.account.ID
|
||||
newAvailable := make([]accountWithLoad, 0, len(available)-1)
|
||||
for _, acc := range available {
|
||||
if acc.account.ID != selectedID {
|
||||
newAvailable = append(newAvailable, acc)
|
||||
}
|
||||
|
||||
// 移除已尝试的账号,重新进行分层过滤
|
||||
selectedID := selected.account.ID
|
||||
newAvailable := make([]accountWithLoad, 0, len(available)-1)
|
||||
for _, acc := range available {
|
||||
if acc.account.ID != selectedID {
|
||||
newAvailable = append(newAvailable, acc)
|
||||
}
|
||||
}
|
||||
available = newAvailable
|
||||
}
|
||||
available = newAvailable
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1750,6 +1707,17 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i
|
||||
return accounts, useMixed, nil
|
||||
}
|
||||
|
||||
// IsSingleAntigravityAccountGroup 检查指定分组是否只有一个 antigravity 平台的可调度账号。
|
||||
// 用于 Handler 层在首次请求时提前设置 SingleAccountRetry context,
|
||||
// 避免单账号分组收到 503 时错误地设置模型限流标记导致后续请求连续快速失败。
|
||||
func (s *GatewayService) IsSingleAntigravityAccountGroup(ctx context.Context, groupID *int64) bool {
|
||||
accounts, _, err := s.listSchedulableAccounts(ctx, groupID, PlatformAntigravity, true)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return len(accounts) == 1
|
||||
}
|
||||
|
||||
func (s *GatewayService) isAccountAllowedForPlatform(account *Account, platform string, useMixed bool) bool {
|
||||
if account == nil {
|
||||
return false
|
||||
@@ -2000,87 +1968,79 @@ func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) {
|
||||
return a.LastUsedAt.Before(*b.LastUsedAt)
|
||||
}
|
||||
})
|
||||
shuffleWithinPriorityAndLastUsed(accounts)
|
||||
}
|
||||
|
||||
// selectByCallCount 从候选账号中选择调用次数最少的账号(Antigravity 专用)
|
||||
// 新账号(CallCount=0)使用平均调用次数作为虚拟值,避免冷启动被猛调
|
||||
// 如果有多个账号具有相同的最小调用次数,则随机选择一个
|
||||
func selectByCallCount(accounts []accountWithLoad, modelLoadMap map[int64]*ModelLoadInfo, preferOAuth bool) *accountWithLoad {
|
||||
if len(accounts) == 0 {
|
||||
return nil
|
||||
// shuffleWithinSortGroups 对排序后的 accountWithLoad 切片,按 (Priority, LoadRate, LastUsedAt) 分组后组内随机打乱。
|
||||
// 防止并发请求读取同一快照时,确定性排序导致所有请求命中相同账号。
|
||||
func shuffleWithinSortGroups(accounts []accountWithLoad) {
|
||||
if len(accounts) <= 1 {
|
||||
return
|
||||
}
|
||||
if len(accounts) == 1 {
|
||||
return &accounts[0]
|
||||
}
|
||||
|
||||
// 如果没有负载信息,回退到 LRU
|
||||
if modelLoadMap == nil {
|
||||
return selectByLRU(accounts, preferOAuth)
|
||||
}
|
||||
|
||||
// 1. 计算平均调用次数(用于新账号冷启动)
|
||||
var totalCallCount int64
|
||||
var countWithCalls int
|
||||
for _, acc := range accounts {
|
||||
if info := modelLoadMap[acc.account.ID]; info != nil && info.CallCount > 0 {
|
||||
totalCallCount += info.CallCount
|
||||
countWithCalls++
|
||||
i := 0
|
||||
for i < len(accounts) {
|
||||
j := i + 1
|
||||
for j < len(accounts) && sameAccountWithLoadGroup(accounts[i], accounts[j]) {
|
||||
j++
|
||||
}
|
||||
}
|
||||
|
||||
var avgCallCount int64
|
||||
if countWithCalls > 0 {
|
||||
avgCallCount = totalCallCount / int64(countWithCalls)
|
||||
}
|
||||
|
||||
// 2. 获取每个账号的有效调用次数
|
||||
getEffectiveCallCount := func(acc accountWithLoad) int64 {
|
||||
if acc.account == nil {
|
||||
return 0
|
||||
if j-i > 1 {
|
||||
mathrand.Shuffle(j-i, func(a, b int) {
|
||||
accounts[i+a], accounts[i+b] = accounts[i+b], accounts[i+a]
|
||||
})
|
||||
}
|
||||
info := modelLoadMap[acc.account.ID]
|
||||
if info == nil || info.CallCount == 0 {
|
||||
return avgCallCount // 新账号使用平均值
|
||||
}
|
||||
return info.CallCount
|
||||
i = j
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 找到最小调用次数
|
||||
minCount := getEffectiveCallCount(accounts[0])
|
||||
for _, acc := range accounts[1:] {
|
||||
if c := getEffectiveCallCount(acc); c < minCount {
|
||||
minCount = c
|
||||
}
|
||||
// sameAccountWithLoadGroup 判断两个 accountWithLoad 是否属于同一排序组
|
||||
func sameAccountWithLoadGroup(a, b accountWithLoad) bool {
|
||||
if a.account.Priority != b.account.Priority {
|
||||
return false
|
||||
}
|
||||
|
||||
// 4. 收集所有具有最小调用次数的账号
|
||||
var candidateIdxs []int
|
||||
for i, acc := range accounts {
|
||||
if getEffectiveCallCount(acc) == minCount {
|
||||
candidateIdxs = append(candidateIdxs, i)
|
||||
}
|
||||
if a.loadInfo.LoadRate != b.loadInfo.LoadRate {
|
||||
return false
|
||||
}
|
||||
return sameLastUsedAt(a.account.LastUsedAt, b.account.LastUsedAt)
|
||||
}
|
||||
|
||||
// 5. 如果只有一个候选,直接返回
|
||||
if len(candidateIdxs) == 1 {
|
||||
return &accounts[candidateIdxs[0]]
|
||||
// shuffleWithinPriorityAndLastUsed 对排序后的 []*Account 切片,按 (Priority, LastUsedAt) 分组后组内随机打乱。
|
||||
func shuffleWithinPriorityAndLastUsed(accounts []*Account) {
|
||||
if len(accounts) <= 1 {
|
||||
return
|
||||
}
|
||||
|
||||
// 6. preferOAuth 处理
|
||||
if preferOAuth {
|
||||
var oauthIdxs []int
|
||||
for _, idx := range candidateIdxs {
|
||||
if accounts[idx].account.Type == AccountTypeOAuth {
|
||||
oauthIdxs = append(oauthIdxs, idx)
|
||||
}
|
||||
i := 0
|
||||
for i < len(accounts) {
|
||||
j := i + 1
|
||||
for j < len(accounts) && sameAccountGroup(accounts[i], accounts[j]) {
|
||||
j++
|
||||
}
|
||||
if len(oauthIdxs) > 0 {
|
||||
candidateIdxs = oauthIdxs
|
||||
if j-i > 1 {
|
||||
mathrand.Shuffle(j-i, func(a, b int) {
|
||||
accounts[i+a], accounts[i+b] = accounts[i+b], accounts[i+a]
|
||||
})
|
||||
}
|
||||
i = j
|
||||
}
|
||||
}
|
||||
|
||||
// 7. 随机选择
|
||||
return &accounts[candidateIdxs[mathrand.Intn(len(candidateIdxs))]]
|
||||
// sameAccountGroup 判断两个 Account 是否属于同一排序组(Priority + LastUsedAt)
|
||||
func sameAccountGroup(a, b *Account) bool {
|
||||
if a.Priority != b.Priority {
|
||||
return false
|
||||
}
|
||||
return sameLastUsedAt(a.LastUsedAt, b.LastUsedAt)
|
||||
}
|
||||
|
||||
// sameLastUsedAt 判断两个 LastUsedAt 是否相同(精度到秒)
|
||||
func sameLastUsedAt(a, b *time.Time) bool {
|
||||
switch {
|
||||
case a == nil && b == nil:
|
||||
return true
|
||||
case a == nil || b == nil:
|
||||
return false
|
||||
default:
|
||||
return a.Unix() == b.Unix()
|
||||
}
|
||||
}
|
||||
|
||||
// sortCandidatesForFallback 根据配置选择排序策略
|
||||
@@ -2135,13 +2095,6 @@ func shuffleWithinPriority(accounts []*Account) {
|
||||
|
||||
// selectAccountForModelWithPlatform 选择单平台账户(完全隔离)
|
||||
func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) {
|
||||
// 对 Antigravity 平台,检查请求的模型系列是否在分组支持范围内
|
||||
if platform == PlatformAntigravity && groupID != nil && requestedModel != "" {
|
||||
if err := s.checkAntigravityModelScope(ctx, *groupID, requestedModel); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
preferOAuth := platform == PlatformGemini
|
||||
routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, platform)
|
||||
|
||||
@@ -2169,9 +2122,6 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
}
|
||||
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) {
|
||||
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
|
||||
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
|
||||
}
|
||||
if s.debugModelRoutingEnabled() {
|
||||
log.Printf("[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
|
||||
}
|
||||
@@ -2272,9 +2222,6 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
}
|
||||
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) {
|
||||
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
|
||||
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
|
||||
}
|
||||
return account, nil
|
||||
}
|
||||
}
|
||||
@@ -2383,9 +2330,6 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
}
|
||||
if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) {
|
||||
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
|
||||
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
|
||||
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
|
||||
}
|
||||
if s.debugModelRoutingEnabled() {
|
||||
log.Printf("[ModelRoutingDebug] legacy mixed routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
|
||||
}
|
||||
@@ -2488,9 +2432,6 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
}
|
||||
if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) {
|
||||
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
|
||||
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
|
||||
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
|
||||
}
|
||||
return account, nil
|
||||
}
|
||||
}
|
||||
@@ -2767,6 +2708,60 @@ func hasClaudeCodePrefix(text string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// matchesFilterPrefix 检查文本是否匹配任一过滤前缀
|
||||
func matchesFilterPrefix(text string) bool {
|
||||
for _, prefix := range systemBlockFilterPrefixes {
|
||||
if strings.HasPrefix(text, prefix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// filterSystemBlocksByPrefix 从 body 的 system 中移除文本匹配 systemBlockFilterPrefixes 前缀的元素
|
||||
// 直接从 body 解析 system,不依赖外部传入的 parsed.System(因为前置步骤可能已修改 body 中的 system)
|
||||
func filterSystemBlocksByPrefix(body []byte) []byte {
|
||||
sys := gjson.GetBytes(body, "system")
|
||||
if !sys.Exists() {
|
||||
return body
|
||||
}
|
||||
|
||||
switch {
|
||||
case sys.Type == gjson.String:
|
||||
if matchesFilterPrefix(sys.Str) {
|
||||
result, err := sjson.DeleteBytes(body, "system")
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
return result
|
||||
}
|
||||
case sys.IsArray():
|
||||
var parsed []any
|
||||
if err := json.Unmarshal([]byte(sys.Raw), &parsed); err != nil {
|
||||
return body
|
||||
}
|
||||
filtered := make([]any, 0, len(parsed))
|
||||
changed := false
|
||||
for _, item := range parsed {
|
||||
if m, ok := item.(map[string]any); ok {
|
||||
if text, ok := m["text"].(string); ok && matchesFilterPrefix(text) {
|
||||
changed = true
|
||||
continue
|
||||
}
|
||||
}
|
||||
filtered = append(filtered, item)
|
||||
}
|
||||
if changed {
|
||||
result, err := sjson.SetBytes(body, "system", filtered)
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
return result
|
||||
}
|
||||
}
|
||||
return body
|
||||
}
|
||||
|
||||
// injectClaudeCodePrompt 在 system 开头注入 Claude Code 提示词
|
||||
// 处理 null、字符串、数组三种格式
|
||||
func injectClaudeCodePrompt(body []byte, system any) []byte {
|
||||
@@ -3046,6 +3041,12 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
body, reqModel = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts)
|
||||
}
|
||||
|
||||
// OAuth/SetupToken 账号:移除黑名单前缀匹配的 system 元素(如客户端注入的计费元数据)
|
||||
// 放在 inject/normalize 之后,确保不会被覆盖
|
||||
if account.IsOAuth() {
|
||||
body = filterSystemBlocksByPrefix(body)
|
||||
}
|
||||
|
||||
// 强制执行 cache_control 块数量限制(最多 4 个)
|
||||
body = enforceCacheControlLimit(body)
|
||||
|
||||
@@ -3632,7 +3633,8 @@ func requestNeedsBetaFeatures(body []byte) bool {
|
||||
if tools.Exists() && tools.IsArray() && len(tools.Array()) > 0 {
|
||||
return true
|
||||
}
|
||||
if strings.EqualFold(gjson.GetBytes(body, "thinking.type").String(), "enabled") {
|
||||
thinkingType := gjson.GetBytes(body, "thinking.type").String()
|
||||
if strings.EqualFold(thinkingType, "enabled") || strings.EqualFold(thinkingType, "adaptive") {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
@@ -4401,6 +4403,14 @@ func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
|
||||
usage.InputTokens = msgStart.Message.Usage.InputTokens
|
||||
usage.CacheCreationInputTokens = msgStart.Message.Usage.CacheCreationInputTokens
|
||||
usage.CacheReadInputTokens = msgStart.Message.Usage.CacheReadInputTokens
|
||||
|
||||
// 解析嵌套的 cache_creation 对象中的 5m/1h 明细
|
||||
cc5m := gjson.Get(data, "message.usage.cache_creation.ephemeral_5m_input_tokens")
|
||||
cc1h := gjson.Get(data, "message.usage.cache_creation.ephemeral_1h_input_tokens")
|
||||
if cc5m.Exists() || cc1h.Exists() {
|
||||
usage.CacheCreation5mTokens = int(cc5m.Int())
|
||||
usage.CacheCreation1hTokens = int(cc1h.Int())
|
||||
}
|
||||
}
|
||||
|
||||
// 解析message_delta获取tokens(兼容GLM等把所有usage放在delta中的API)
|
||||
@@ -4429,6 +4439,14 @@ func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
|
||||
if msgDelta.Usage.CacheReadInputTokens > 0 {
|
||||
usage.CacheReadInputTokens = msgDelta.Usage.CacheReadInputTokens
|
||||
}
|
||||
|
||||
// 解析嵌套的 cache_creation 对象中的 5m/1h 明细
|
||||
cc5m := gjson.Get(data, "usage.cache_creation.ephemeral_5m_input_tokens")
|
||||
cc1h := gjson.Get(data, "usage.cache_creation.ephemeral_1h_input_tokens")
|
||||
if cc5m.Exists() || cc1h.Exists() {
|
||||
usage.CacheCreation5mTokens = int(cc5m.Int())
|
||||
usage.CacheCreation1hTokens = int(cc1h.Int())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4449,6 +4467,14 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
|
||||
return nil, fmt.Errorf("parse response: %w", err)
|
||||
}
|
||||
|
||||
// 解析嵌套的 cache_creation 对象中的 5m/1h 明细
|
||||
cc5m := gjson.GetBytes(body, "usage.cache_creation.ephemeral_5m_input_tokens")
|
||||
cc1h := gjson.GetBytes(body, "usage.cache_creation.ephemeral_1h_input_tokens")
|
||||
if cc5m.Exists() || cc1h.Exists() {
|
||||
response.Usage.CacheCreation5mTokens = int(cc5m.Int())
|
||||
response.Usage.CacheCreation1hTokens = int(cc1h.Int())
|
||||
}
|
||||
|
||||
// 兼容 Kimi cached_tokens → cache_read_input_tokens
|
||||
if response.Usage.CacheReadInputTokens == 0 {
|
||||
cachedTokens := gjson.GetBytes(body, "usage.cached_tokens").Int()
|
||||
@@ -4566,10 +4592,12 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
} else {
|
||||
// Token 计费
|
||||
tokens := UsageTokens{
|
||||
InputTokens: result.Usage.InputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
||||
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
||||
InputTokens: result.Usage.InputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
||||
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
||||
CacheCreation5mTokens: result.Usage.CacheCreation5mTokens,
|
||||
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
|
||||
}
|
||||
var err error
|
||||
cost, err = s.billingService.CalculateCost(result.Model, tokens, multiplier)
|
||||
@@ -4603,6 +4631,8 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
||||
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
||||
CacheCreation5mTokens: result.Usage.CacheCreation5mTokens,
|
||||
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
|
||||
InputCost: cost.InputCost,
|
||||
OutputCost: cost.OutputCost,
|
||||
CacheCreationCost: cost.CacheCreationCost,
|
||||
@@ -4747,10 +4777,12 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
|
||||
} else {
|
||||
// Token 计费(使用长上下文计费方法)
|
||||
tokens := UsageTokens{
|
||||
InputTokens: result.Usage.InputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
||||
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
||||
InputTokens: result.Usage.InputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
||||
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
||||
CacheCreation5mTokens: result.Usage.CacheCreation5mTokens,
|
||||
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
|
||||
}
|
||||
var err error
|
||||
cost, err = s.billingService.CalculateCostWithLongContext(result.Model, tokens, multiplier, input.LongContextThreshold, input.LongContextMultiplier)
|
||||
@@ -4784,6 +4816,8 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
||||
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
||||
CacheCreation5mTokens: result.Usage.CacheCreation5mTokens,
|
||||
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
|
||||
InputCost: cost.InputCost,
|
||||
OutputCost: cost.OutputCost,
|
||||
CacheCreationCost: cost.CacheCreationCost,
|
||||
@@ -5165,27 +5199,6 @@ func (s *GatewayService) validateUpstreamBaseURL(raw string) (string, error) {
|
||||
return normalized, nil
|
||||
}
|
||||
|
||||
// checkAntigravityModelScope 检查 Antigravity 平台的模型系列是否在分组支持范围内
|
||||
func (s *GatewayService) checkAntigravityModelScope(ctx context.Context, groupID int64, requestedModel string) error {
|
||||
scope, ok := ResolveAntigravityQuotaScope(requestedModel)
|
||||
if !ok {
|
||||
return nil // 无法解析 scope,跳过检查
|
||||
}
|
||||
|
||||
group, err := s.resolveGroupByID(ctx, groupID)
|
||||
if err != nil {
|
||||
return nil // 查询失败时放行
|
||||
}
|
||||
if group == nil {
|
||||
return nil // 分组不存在时放行
|
||||
}
|
||||
|
||||
if !IsScopeSupported(group.SupportedModelScopes, scope) {
|
||||
return ErrModelScopeNotSupported
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAvailableModels returns the list of models available for a group
|
||||
// It aggregates model_mapping keys from all schedulable accounts in the group
|
||||
func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64, platform string) []string {
|
||||
|
||||
@@ -14,7 +14,7 @@ func BenchmarkGenerateSessionHash_Metadata(b *testing.B) {
|
||||
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
parsed, err := ParseGatewayRequest(body)
|
||||
parsed, err := ParseGatewayRequest(body, "")
|
||||
if err != nil {
|
||||
b.Fatalf("解析请求失败: %v", err)
|
||||
}
|
||||
|
||||
384
backend/internal/service/gemini_error_policy_test.go
Normal file
384
backend/internal/service/gemini_error_policy_test.go
Normal file
@@ -0,0 +1,384 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// TestShouldFailoverGeminiUpstreamError — verifies the failover decision
|
||||
// for the ErrorPolicyNone path (original logic preserved).
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestShouldFailoverGeminiUpstreamError(t *testing.T) {
|
||||
svc := &GeminiMessagesCompatService{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
statusCode int
|
||||
expected bool
|
||||
}{
|
||||
{"401_failover", 401, true},
|
||||
{"403_failover", 403, true},
|
||||
{"429_failover", 429, true},
|
||||
{"529_failover", 529, true},
|
||||
{"500_failover", 500, true},
|
||||
{"502_failover", 502, true},
|
||||
{"503_failover", 503, true},
|
||||
{"400_no_failover", 400, false},
|
||||
{"404_no_failover", 404, false},
|
||||
{"422_no_failover", 422, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := svc.shouldFailoverGeminiUpstreamError(tt.statusCode)
|
||||
require.Equal(t, tt.expected, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// TestCheckErrorPolicy_GeminiAccounts — verifies CheckErrorPolicy works
|
||||
// correctly for Gemini platform accounts (API Key type).
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestCheckErrorPolicy_GeminiAccounts(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
account *Account
|
||||
statusCode int
|
||||
body []byte
|
||||
expected ErrorPolicyResult
|
||||
}{
|
||||
{
|
||||
name: "gemini_apikey_custom_codes_hit",
|
||||
account: &Account{
|
||||
ID: 100,
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformGemini,
|
||||
Credentials: map[string]any{
|
||||
"custom_error_codes_enabled": true,
|
||||
"custom_error_codes": []any{float64(429), float64(500)},
|
||||
},
|
||||
},
|
||||
statusCode: 429,
|
||||
body: []byte(`{"error":"rate limited"}`),
|
||||
expected: ErrorPolicyMatched,
|
||||
},
|
||||
{
|
||||
name: "gemini_apikey_custom_codes_miss",
|
||||
account: &Account{
|
||||
ID: 101,
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformGemini,
|
||||
Credentials: map[string]any{
|
||||
"custom_error_codes_enabled": true,
|
||||
"custom_error_codes": []any{float64(429)},
|
||||
},
|
||||
},
|
||||
statusCode: 500,
|
||||
body: []byte(`{"error":"internal"}`),
|
||||
expected: ErrorPolicySkipped,
|
||||
},
|
||||
{
|
||||
name: "gemini_apikey_no_custom_codes_returns_none",
|
||||
account: &Account{
|
||||
ID: 102,
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformGemini,
|
||||
},
|
||||
statusCode: 500,
|
||||
body: []byte(`{"error":"internal"}`),
|
||||
expected: ErrorPolicyNone,
|
||||
},
|
||||
{
|
||||
name: "gemini_apikey_temp_unschedulable_hit",
|
||||
account: &Account{
|
||||
ID: 103,
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformGemini,
|
||||
Credentials: map[string]any{
|
||||
"temp_unschedulable_enabled": true,
|
||||
"temp_unschedulable_rules": []any{
|
||||
map[string]any{
|
||||
"error_code": float64(503),
|
||||
"keywords": []any{"overloaded"},
|
||||
"duration_minutes": float64(10),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
statusCode: 503,
|
||||
body: []byte(`overloaded service`),
|
||||
expected: ErrorPolicyTempUnscheduled,
|
||||
},
|
||||
{
|
||||
name: "gemini_custom_codes_override_temp_unschedulable",
|
||||
account: &Account{
|
||||
ID: 104,
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformGemini,
|
||||
Credentials: map[string]any{
|
||||
"custom_error_codes_enabled": true,
|
||||
"custom_error_codes": []any{float64(503)},
|
||||
"temp_unschedulable_enabled": true,
|
||||
"temp_unschedulable_rules": []any{
|
||||
map[string]any{
|
||||
"error_code": float64(503),
|
||||
"keywords": []any{"overloaded"},
|
||||
"duration_minutes": float64(10),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
statusCode: 503,
|
||||
body: []byte(`overloaded`),
|
||||
expected: ErrorPolicyMatched, // custom codes take precedence
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
repo := &errorPolicyRepoStub{}
|
||||
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
|
||||
result := svc.CheckErrorPolicy(context.Background(), tt.account, tt.statusCode, tt.body)
|
||||
require.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// TestGeminiErrorPolicyIntegration — verifies the Gemini error handling
|
||||
// paths produce the correct behavior for each ErrorPolicyResult.
|
||||
//
|
||||
// These tests simulate the inline error policy switch in handleClaudeCompat
|
||||
// and forwardNativeGemini by calling the same methods in the same order.
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestGeminiErrorPolicyIntegration(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
account *Account
|
||||
statusCode int
|
||||
respBody []byte
|
||||
expectFailover bool // expect UpstreamFailoverError
|
||||
expectHandleError bool // expect handleGeminiUpstreamError to be called
|
||||
expectShouldFailover bool // for None path, whether shouldFailover triggers
|
||||
}{
|
||||
{
|
||||
name: "custom_codes_matched_429_failover",
|
||||
account: &Account{
|
||||
ID: 200,
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformGemini,
|
||||
Credentials: map[string]any{
|
||||
"custom_error_codes_enabled": true,
|
||||
"custom_error_codes": []any{float64(429)},
|
||||
},
|
||||
},
|
||||
statusCode: 429,
|
||||
respBody: []byte(`{"error":"rate limited"}`),
|
||||
expectFailover: true,
|
||||
expectHandleError: true,
|
||||
},
|
||||
{
|
||||
name: "custom_codes_skipped_500_no_failover",
|
||||
account: &Account{
|
||||
ID: 201,
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformGemini,
|
||||
Credentials: map[string]any{
|
||||
"custom_error_codes_enabled": true,
|
||||
"custom_error_codes": []any{float64(429)},
|
||||
},
|
||||
},
|
||||
statusCode: 500,
|
||||
respBody: []byte(`{"error":"internal"}`),
|
||||
expectFailover: false,
|
||||
expectHandleError: false,
|
||||
},
|
||||
{
|
||||
name: "temp_unschedulable_matched_failover",
|
||||
account: &Account{
|
||||
ID: 202,
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformGemini,
|
||||
Credentials: map[string]any{
|
||||
"temp_unschedulable_enabled": true,
|
||||
"temp_unschedulable_rules": []any{
|
||||
map[string]any{
|
||||
"error_code": float64(503),
|
||||
"keywords": []any{"overloaded"},
|
||||
"duration_minutes": float64(10),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
statusCode: 503,
|
||||
respBody: []byte(`overloaded`),
|
||||
expectFailover: true,
|
||||
expectHandleError: true,
|
||||
},
|
||||
{
|
||||
name: "no_policy_429_failover_via_shouldFailover",
|
||||
account: &Account{
|
||||
ID: 203,
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformGemini,
|
||||
},
|
||||
statusCode: 429,
|
||||
respBody: []byte(`{"error":"rate limited"}`),
|
||||
expectFailover: true,
|
||||
expectHandleError: true,
|
||||
expectShouldFailover: true,
|
||||
},
|
||||
{
|
||||
name: "no_policy_400_no_failover",
|
||||
account: &Account{
|
||||
ID: 204,
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformGemini,
|
||||
},
|
||||
statusCode: 400,
|
||||
respBody: []byte(`{"error":"bad request"}`),
|
||||
expectFailover: false,
|
||||
expectHandleError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
repo := &geminiErrorPolicyRepo{}
|
||||
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
svc := &GeminiMessagesCompatService{
|
||||
accountRepo: repo,
|
||||
rateLimitService: rlSvc,
|
||||
}
|
||||
|
||||
writer := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(writer)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
|
||||
// Simulate the Claude compat error handling path (same logic as native).
|
||||
// This mirrors the inline switch in handleClaudeCompat.
|
||||
var handleErrorCalled bool
|
||||
var gotFailover bool
|
||||
|
||||
ctx := context.Background()
|
||||
statusCode := tt.statusCode
|
||||
respBody := tt.respBody
|
||||
account := tt.account
|
||||
headers := http.Header{}
|
||||
|
||||
if svc.rateLimitService != nil {
|
||||
switch svc.rateLimitService.CheckErrorPolicy(ctx, account, statusCode, respBody) {
|
||||
case ErrorPolicySkipped:
|
||||
// Skipped → return error directly (no handleGeminiUpstreamError, no failover)
|
||||
gotFailover = false
|
||||
handleErrorCalled = false
|
||||
goto verify
|
||||
case ErrorPolicyMatched, ErrorPolicyTempUnscheduled:
|
||||
svc.handleGeminiUpstreamError(ctx, account, statusCode, headers, respBody)
|
||||
handleErrorCalled = true
|
||||
gotFailover = true
|
||||
goto verify
|
||||
}
|
||||
}
|
||||
|
||||
// ErrorPolicyNone → original logic
|
||||
svc.handleGeminiUpstreamError(ctx, account, statusCode, headers, respBody)
|
||||
handleErrorCalled = true
|
||||
if svc.shouldFailoverGeminiUpstreamError(statusCode) {
|
||||
gotFailover = true
|
||||
}
|
||||
|
||||
verify:
|
||||
require.Equal(t, tt.expectFailover, gotFailover, "failover mismatch")
|
||||
require.Equal(t, tt.expectHandleError, handleErrorCalled, "handleGeminiUpstreamError call mismatch")
|
||||
|
||||
if tt.expectShouldFailover {
|
||||
require.True(t, svc.shouldFailoverGeminiUpstreamError(statusCode),
|
||||
"shouldFailoverGeminiUpstreamError should return true for status %d", statusCode)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// TestGeminiErrorPolicy_NilRateLimitService — verifies nil safety
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestGeminiErrorPolicy_NilRateLimitService(t *testing.T) {
|
||||
svc := &GeminiMessagesCompatService{
|
||||
rateLimitService: nil,
|
||||
}
|
||||
|
||||
// When rateLimitService is nil, error policy is skipped → falls through to
|
||||
// shouldFailoverGeminiUpstreamError (original logic).
|
||||
// Verify this doesn't panic and follows expected behavior.
|
||||
|
||||
ctx := context.Background()
|
||||
account := &Account{
|
||||
ID: 300,
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformGemini,
|
||||
Credentials: map[string]any{
|
||||
"custom_error_codes_enabled": true,
|
||||
"custom_error_codes": []any{float64(429)},
|
||||
},
|
||||
}
|
||||
|
||||
// The nil check should prevent CheckErrorPolicy from being called
|
||||
if svc.rateLimitService != nil {
|
||||
t.Fatal("rateLimitService should be nil for this test")
|
||||
}
|
||||
|
||||
// shouldFailoverGeminiUpstreamError still works
|
||||
require.True(t, svc.shouldFailoverGeminiUpstreamError(429))
|
||||
require.False(t, svc.shouldFailoverGeminiUpstreamError(400))
|
||||
|
||||
// handleGeminiUpstreamError should not panic with nil rateLimitService
|
||||
require.NotPanics(t, func() {
|
||||
svc.handleGeminiUpstreamError(ctx, account, 500, http.Header{}, []byte(`error`))
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// geminiErrorPolicyRepo — minimal AccountRepository stub for Gemini error
|
||||
// policy tests. Embeds mockAccountRepoForGemini and adds tracking.
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type geminiErrorPolicyRepo struct {
|
||||
mockAccountRepoForGemini
|
||||
setErrorCalls int
|
||||
setRateLimitedCalls int
|
||||
setTempCalls int
|
||||
}
|
||||
|
||||
func (r *geminiErrorPolicyRepo) SetError(_ context.Context, _ int64, _ string) error {
|
||||
r.setErrorCalls++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *geminiErrorPolicyRepo) SetRateLimited(_ context.Context, _ int64, _ time.Time) error {
|
||||
r.setRateLimitedCalls++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *geminiErrorPolicyRepo) SetTempUnschedulable(_ context.Context, _ int64, _ time.Time, _ string) error {
|
||||
r.setTempCalls++
|
||||
return nil
|
||||
}
|
||||
@@ -560,10 +560,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
return nil, "", errors.New("gemini api_key not configured")
|
||||
}
|
||||
|
||||
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
|
||||
if baseURL == "" {
|
||||
baseURL = geminicli.AIStudioBaseURL
|
||||
}
|
||||
baseURL := account.GetGeminiBaseURL(geminicli.AIStudioBaseURL)
|
||||
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
@@ -640,10 +637,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
return upstreamReq, "x-request-id", nil
|
||||
} else {
|
||||
// Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token)
|
||||
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
|
||||
if baseURL == "" {
|
||||
baseURL = geminicli.AIStudioBaseURL
|
||||
}
|
||||
baseURL := account.GetGeminiBaseURL(geminicli.AIStudioBaseURL)
|
||||
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
@@ -776,6 +770,14 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
break
|
||||
}
|
||||
|
||||
// 错误策略优先:匹配则跳过重试直接处理。
|
||||
if matched, rebuilt := s.checkErrorPolicyInLoop(ctx, account, resp); matched {
|
||||
resp = rebuilt
|
||||
break
|
||||
} else {
|
||||
resp = rebuilt
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 400 && s.shouldRetryGeminiUpstreamError(account, resp.StatusCode) {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
_ = resp.Body.Close()
|
||||
@@ -837,37 +839,77 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
tempMatched := false
|
||||
// 统一错误策略:自定义错误码 + 临时不可调度
|
||||
if s.rateLimitService != nil {
|
||||
tempMatched = s.rateLimitService.HandleTempUnschedulable(ctx, account, resp.StatusCode, respBody)
|
||||
}
|
||||
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
if tempMatched {
|
||||
upstreamReqID := resp.Header.Get(requestIDHeader)
|
||||
if upstreamReqID == "" {
|
||||
upstreamReqID = resp.Header.Get("x-goog-request-id")
|
||||
}
|
||||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
upstreamDetail := ""
|
||||
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||
if maxBytes <= 0 {
|
||||
maxBytes = 2048
|
||||
switch s.rateLimitService.CheckErrorPolicy(ctx, account, resp.StatusCode, respBody) {
|
||||
case ErrorPolicySkipped:
|
||||
upstreamReqID := resp.Header.Get(requestIDHeader)
|
||||
if upstreamReqID == "" {
|
||||
upstreamReqID = resp.Header.Get("x-goog-request-id")
|
||||
}
|
||||
upstreamDetail = truncateString(string(respBody), maxBytes)
|
||||
return nil, s.writeGeminiMappedError(c, account, http.StatusInternalServerError, upstreamReqID, respBody)
|
||||
case ErrorPolicyMatched, ErrorPolicyTempUnscheduled:
|
||||
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
upstreamReqID := resp.Header.Get(requestIDHeader)
|
||||
if upstreamReqID == "" {
|
||||
upstreamReqID = resp.Header.Get("x-goog-request-id")
|
||||
}
|
||||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
upstreamDetail := ""
|
||||
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||
if maxBytes <= 0 {
|
||||
maxBytes = 2048
|
||||
}
|
||||
upstreamDetail = truncateString(string(respBody), maxBytes)
|
||||
}
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: upstreamReqID,
|
||||
Kind: "failover",
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
|
||||
}
|
||||
}
|
||||
|
||||
// ErrorPolicyNone → 原有逻辑
|
||||
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
// 精确匹配服务端配置类 400 错误,触发 failover + 临时封禁
|
||||
if resp.StatusCode == http.StatusBadRequest {
|
||||
msg400 := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(respBody)))
|
||||
if isGoogleProjectConfigError(msg400) {
|
||||
upstreamReqID := resp.Header.Get(requestIDHeader)
|
||||
if upstreamReqID == "" {
|
||||
upstreamReqID = resp.Header.Get("x-goog-request-id")
|
||||
}
|
||||
upstreamMsg := sanitizeUpstreamErrorMessage(strings.TrimSpace(extractUpstreamErrorMessage(respBody)))
|
||||
upstreamDetail := ""
|
||||
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||
if maxBytes <= 0 {
|
||||
maxBytes = 2048
|
||||
}
|
||||
upstreamDetail = truncateString(string(respBody), maxBytes)
|
||||
}
|
||||
log.Printf("[Gemini] status=400 google_config_error failover=true upstream_message=%q account=%d", upstreamMsg, account.ID)
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: upstreamReqID,
|
||||
Kind: "failover",
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody, RetryableOnSameAccount: true}
|
||||
}
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: upstreamReqID,
|
||||
Kind: "failover",
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
|
||||
}
|
||||
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
|
||||
upstreamReqID := resp.Header.Get(requestIDHeader)
|
||||
@@ -1026,10 +1068,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
return nil, "", errors.New("gemini api_key not configured")
|
||||
}
|
||||
|
||||
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
|
||||
if baseURL == "" {
|
||||
baseURL = geminicli.AIStudioBaseURL
|
||||
}
|
||||
baseURL := account.GetGeminiBaseURL(geminicli.AIStudioBaseURL)
|
||||
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
@@ -1097,10 +1136,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
return upstreamReq, "x-request-id", nil
|
||||
} else {
|
||||
// Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token)
|
||||
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
|
||||
if baseURL == "" {
|
||||
baseURL = geminicli.AIStudioBaseURL
|
||||
}
|
||||
baseURL := account.GetGeminiBaseURL(geminicli.AIStudioBaseURL)
|
||||
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
@@ -1179,6 +1215,14 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries: "+safeErr)
|
||||
}
|
||||
|
||||
// 错误策略优先:匹配则跳过重试直接处理。
|
||||
if matched, rebuilt := s.checkErrorPolicyInLoop(ctx, account, resp); matched {
|
||||
resp = rebuilt
|
||||
break
|
||||
} else {
|
||||
resp = rebuilt
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 400 && s.shouldRetryGeminiUpstreamError(account, resp.StatusCode) {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
_ = resp.Body.Close()
|
||||
@@ -1261,14 +1305,9 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
tempMatched := false
|
||||
if s.rateLimitService != nil {
|
||||
tempMatched = s.rateLimitService.HandleTempUnschedulable(ctx, account, resp.StatusCode, respBody)
|
||||
}
|
||||
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
|
||||
// Best-effort fallback for OAuth tokens missing AI Studio scopes when calling countTokens.
|
||||
// This avoids Gemini SDKs failing hard during preflight token counting.
|
||||
// Checked before error policy so it always works regardless of custom error codes.
|
||||
if action == "countTokens" && isOAuth && isGeminiInsufficientScope(resp.Header, respBody) {
|
||||
estimated := estimateGeminiCountTokens(body)
|
||||
c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated})
|
||||
@@ -1282,29 +1321,73 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
}, nil
|
||||
}
|
||||
|
||||
if tempMatched {
|
||||
evBody := unwrapIfNeeded(isOAuth, respBody)
|
||||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(evBody))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
upstreamDetail := ""
|
||||
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||
if maxBytes <= 0 {
|
||||
maxBytes = 2048
|
||||
// 统一错误策略:自定义错误码 + 临时不可调度
|
||||
if s.rateLimitService != nil {
|
||||
switch s.rateLimitService.CheckErrorPolicy(ctx, account, resp.StatusCode, respBody) {
|
||||
case ErrorPolicySkipped:
|
||||
respBody = unwrapIfNeeded(isOAuth, respBody)
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
if contentType == "" {
|
||||
contentType = "application/json"
|
||||
}
|
||||
upstreamDetail = truncateString(string(evBody), maxBytes)
|
||||
c.Data(http.StatusInternalServerError, contentType, respBody)
|
||||
return nil, fmt.Errorf("gemini upstream error: %d (skipped by error policy)", resp.StatusCode)
|
||||
case ErrorPolicyMatched, ErrorPolicyTempUnscheduled:
|
||||
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
evBody := unwrapIfNeeded(isOAuth, respBody)
|
||||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(evBody))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
upstreamDetail := ""
|
||||
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||
if maxBytes <= 0 {
|
||||
maxBytes = 2048
|
||||
}
|
||||
upstreamDetail = truncateString(string(evBody), maxBytes)
|
||||
}
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: requestID,
|
||||
Kind: "failover",
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
|
||||
}
|
||||
}
|
||||
|
||||
// ErrorPolicyNone → 原有逻辑
|
||||
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
// 精确匹配服务端配置类 400 错误,触发 failover + 临时封禁
|
||||
if resp.StatusCode == http.StatusBadRequest {
|
||||
msg400 := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(respBody)))
|
||||
if isGoogleProjectConfigError(msg400) {
|
||||
evBody := unwrapIfNeeded(isOAuth, respBody)
|
||||
upstreamMsg := sanitizeUpstreamErrorMessage(strings.TrimSpace(extractUpstreamErrorMessage(evBody)))
|
||||
upstreamDetail := ""
|
||||
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||
if maxBytes <= 0 {
|
||||
maxBytes = 2048
|
||||
}
|
||||
upstreamDetail = truncateString(string(evBody), maxBytes)
|
||||
}
|
||||
log.Printf("[Gemini] status=400 google_config_error failover=true upstream_message=%q account=%d", upstreamMsg, account.ID)
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: requestID,
|
||||
Kind: "failover",
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: evBody, RetryableOnSameAccount: true}
|
||||
}
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: requestID,
|
||||
Kind: "failover",
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
|
||||
}
|
||||
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
|
||||
evBody := unwrapIfNeeded(isOAuth, respBody)
|
||||
@@ -1417,6 +1500,26 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
}, nil
|
||||
}
|
||||
|
||||
// checkErrorPolicyInLoop 在重试循环内预检查错误策略。
|
||||
// 返回 true 表示策略已匹配(调用者应 break),resp 已重建可直接使用。
|
||||
// 返回 false 表示 ErrorPolicyNone,resp 已重建,调用者继续走重试逻辑。
|
||||
func (s *GeminiMessagesCompatService) checkErrorPolicyInLoop(
|
||||
ctx context.Context, account *Account, resp *http.Response,
|
||||
) (matched bool, rebuilt *http.Response) {
|
||||
if resp.StatusCode < 400 || s.rateLimitService == nil {
|
||||
return false, resp
|
||||
}
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
_ = resp.Body.Close()
|
||||
rebuilt = &http.Response{
|
||||
StatusCode: resp.StatusCode,
|
||||
Header: resp.Header.Clone(),
|
||||
Body: io.NopCloser(bytes.NewReader(body)),
|
||||
}
|
||||
policy := s.rateLimitService.CheckErrorPolicy(ctx, account, resp.StatusCode, body)
|
||||
return policy != ErrorPolicyNone, rebuilt
|
||||
}
|
||||
|
||||
func (s *GeminiMessagesCompatService) shouldRetryGeminiUpstreamError(account *Account, statusCode int) bool {
|
||||
switch statusCode {
|
||||
case 429, 500, 502, 503, 504, 529:
|
||||
@@ -2420,10 +2523,7 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac
|
||||
return nil, errors.New("invalid path")
|
||||
}
|
||||
|
||||
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
|
||||
if baseURL == "" {
|
||||
baseURL = geminicli.AIStudioBaseURL
|
||||
}
|
||||
baseURL := account.GetGeminiBaseURL(geminicli.AIStudioBaseURL)
|
||||
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -2563,11 +2663,12 @@ func extractGeminiUsage(geminiResp map[string]any) *ClaudeUsage {
|
||||
prompt, _ := asInt(usageMeta["promptTokenCount"])
|
||||
cand, _ := asInt(usageMeta["candidatesTokenCount"])
|
||||
cached, _ := asInt(usageMeta["cachedContentTokenCount"])
|
||||
thoughts, _ := asInt(usageMeta["thoughtsTokenCount"])
|
||||
// 注意:Gemini 的 promptTokenCount 包含 cachedContentTokenCount,
|
||||
// 但 Claude 的 input_tokens 不包含 cache_read_input_tokens,需要减去
|
||||
return &ClaudeUsage{
|
||||
InputTokens: prompt - cached,
|
||||
OutputTokens: cand,
|
||||
OutputTokens: cand + thoughts,
|
||||
CacheReadInputTokens: cached,
|
||||
}
|
||||
}
|
||||
@@ -2592,6 +2693,10 @@ func asInt(v any) (int, bool) {
|
||||
}
|
||||
|
||||
func (s *GeminiMessagesCompatService) handleGeminiUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, body []byte) {
|
||||
// 遵守自定义错误码策略:未命中则跳过所有限流处理
|
||||
if !account.ShouldHandleErrorCode(statusCode) {
|
||||
return
|
||||
}
|
||||
if s.rateLimitService != nil && (statusCode == 401 || statusCode == 403 || statusCode == 529) {
|
||||
s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, headers, body)
|
||||
return
|
||||
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestConvertClaudeToolsToGeminiTools_CustomType 测试custom类型工具转换
|
||||
@@ -203,3 +205,70 @@ func TestEnsureGeminiFunctionCallThoughtSignatures_InsertsWhenMissing(t *testing
|
||||
t.Fatalf("expected injected thoughtSignature %q, got: %s", geminiDummyThoughtSignature, s)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractGeminiUsage_ThoughtsTokenCount(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
resp map[string]any
|
||||
wantInput int
|
||||
wantOutput int
|
||||
wantCacheRead int
|
||||
wantNil bool
|
||||
}{
|
||||
{
|
||||
name: "with thoughtsTokenCount",
|
||||
resp: map[string]any{
|
||||
"usageMetadata": map[string]any{
|
||||
"promptTokenCount": float64(100),
|
||||
"candidatesTokenCount": float64(20),
|
||||
"thoughtsTokenCount": float64(50),
|
||||
},
|
||||
},
|
||||
wantInput: 100,
|
||||
wantOutput: 70,
|
||||
},
|
||||
{
|
||||
name: "with thoughtsTokenCount and cache",
|
||||
resp: map[string]any{
|
||||
"usageMetadata": map[string]any{
|
||||
"promptTokenCount": float64(100),
|
||||
"candidatesTokenCount": float64(20),
|
||||
"cachedContentTokenCount": float64(30),
|
||||
"thoughtsTokenCount": float64(50),
|
||||
},
|
||||
},
|
||||
wantInput: 70,
|
||||
wantOutput: 70,
|
||||
wantCacheRead: 30,
|
||||
},
|
||||
{
|
||||
name: "without thoughtsTokenCount (old model)",
|
||||
resp: map[string]any{
|
||||
"usageMetadata": map[string]any{
|
||||
"promptTokenCount": float64(100),
|
||||
"candidatesTokenCount": float64(20),
|
||||
},
|
||||
},
|
||||
wantInput: 100,
|
||||
wantOutput: 20,
|
||||
},
|
||||
{
|
||||
name: "no usageMetadata",
|
||||
resp: map[string]any{},
|
||||
wantNil: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
usage := extractGeminiUsage(tt.resp)
|
||||
if tt.wantNil {
|
||||
require.Nil(t, usage)
|
||||
return
|
||||
}
|
||||
require.NotNil(t, usage)
|
||||
require.Equal(t, tt.wantInput, usage.InputTokens)
|
||||
require.Equal(t, tt.wantOutput, usage.OutputTokens)
|
||||
require.Equal(t, tt.wantCacheRead, usage.CacheReadInputTokens)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -66,12 +66,15 @@ func (m *mockAccountRepoForGemini) Create(ctx context.Context, account *Account)
|
||||
func (m *mockAccountRepoForGemini) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) ListCRSAccountIDs(ctx context.Context) (map[string]int64, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) Update(ctx context.Context, account *Account) error { return nil }
|
||||
func (m *mockAccountRepoForGemini) Delete(ctx context.Context, id int64) error { return nil }
|
||||
func (m *mockAccountRepoForGemini) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) {
|
||||
func (m *mockAccountRepoForGemini) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) {
|
||||
@@ -133,9 +136,6 @@ func (m *mockAccountRepoForGemini) ListSchedulableByGroupIDAndPlatforms(ctx cont
|
||||
func (m *mockAccountRepoForGemini) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
|
||||
return nil
|
||||
}
|
||||
@@ -226,6 +226,10 @@ func (m *mockGroupRepoForGemini) GetAccountIDsByGroupIDs(ctx context.Context, gr
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockGroupRepoForGemini) UpdateSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ GroupRepository = (*mockGroupRepoForGemini)(nil)
|
||||
|
||||
// mockGatewayCacheForGemini Gemini 测试用的 cache mock
|
||||
@@ -265,22 +269,6 @@ func (m *mockGatewayCacheForGemini) DeleteSessionAccountID(ctx context.Context,
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockGatewayCacheForGemini) IncrModelCallCount(ctx context.Context, accountID int64, model string) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *mockGatewayCacheForGemini) GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*ModelLoadInfo, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockGatewayCacheForGemini) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
|
||||
return "", 0, false
|
||||
}
|
||||
|
||||
func (m *mockGatewayCacheForGemini) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform 测试 Gemini 单平台选择
|
||||
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user