mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-04 23:42:13 +08:00
Compare commits
46 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
76d242e024 | ||
|
|
260c152166 | ||
|
|
9f4c1ef9f9 | ||
|
|
bd7fdb5e6c | ||
|
|
a381910e86 | ||
|
|
d182ef0391 | ||
|
|
7319122e92 | ||
|
|
4809fa4f19 | ||
|
|
ee01f80dc1 | ||
|
|
98671a73f4 | ||
|
|
f33a950103 | ||
|
|
132bf34b69 | ||
|
|
01b08e1e43 | ||
|
|
c6a456c7c7 | ||
|
|
cc2329d4fd | ||
|
|
84d0433cc3 | ||
|
|
a113dd4def | ||
|
|
98f793155f | ||
|
|
a38bd413ab | ||
|
|
9e1535e203 | ||
|
|
037a409919 | ||
|
|
571d1479a4 | ||
|
|
ae1934f7db | ||
|
|
39e05a2dad | ||
|
|
7b46bbb628 | ||
|
|
d2527e36eb | ||
|
|
029994a83b | ||
|
|
37047919ab | ||
|
|
0b45d48e85 | ||
|
|
0c660f8335 | ||
|
|
ce9a247a9d | ||
|
|
b4bd46d067 | ||
|
|
1d8b686446 | ||
|
|
2b192f7dca | ||
|
|
979114db45 | ||
|
|
6d0152c8e2 | ||
|
|
dabed96af4 | ||
|
|
36becd972a | ||
|
|
7498035d24 | ||
|
|
39a0359dd5 | ||
|
|
49a3c43741 | ||
|
|
fa3ea5ee4d | ||
|
|
05af95dade | ||
|
|
ae680d79ed | ||
|
|
97a5c1ac1d | ||
|
|
fecfaae8dc |
4
.github/workflows/backend-ci.yml
vendored
4
.github/workflows/backend-ci.yml
vendored
@@ -19,7 +19,7 @@ jobs:
|
||||
cache: true
|
||||
- name: Verify Go version
|
||||
run: |
|
||||
go version | grep -q 'go1.25.6'
|
||||
go version | grep -q 'go1.25.7'
|
||||
- name: Unit tests
|
||||
working-directory: backend
|
||||
run: make test-unit
|
||||
@@ -38,7 +38,7 @@ jobs:
|
||||
cache: true
|
||||
- name: Verify Go version
|
||||
run: |
|
||||
go version | grep -q 'go1.25.6'
|
||||
go version | grep -q 'go1.25.7'
|
||||
- name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@v9
|
||||
with:
|
||||
|
||||
2
.github/workflows/release.yml
vendored
2
.github/workflows/release.yml
vendored
@@ -115,7 +115,7 @@ jobs:
|
||||
|
||||
- name: Verify Go version
|
||||
run: |
|
||||
go version | grep -q 'go1.25.6'
|
||||
go version | grep -q 'go1.25.7'
|
||||
|
||||
# Docker setup for GoReleaser
|
||||
- name: Set up QEMU
|
||||
|
||||
2
.github/workflows/security-scan.yml
vendored
2
.github/workflows/security-scan.yml
vendored
@@ -22,7 +22,7 @@ jobs:
|
||||
cache-dependency-path: backend/go.sum
|
||||
- name: Verify Go version
|
||||
run: |
|
||||
go version | grep -q 'go1.25.6'
|
||||
go version | grep -q 'go1.25.7'
|
||||
- name: Run govulncheck
|
||||
working-directory: backend
|
||||
run: |
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
# =============================================================================
|
||||
|
||||
ARG NODE_IMAGE=node:24-alpine
|
||||
ARG GOLANG_IMAGE=golang:1.25.6-alpine
|
||||
ARG GOLANG_IMAGE=golang:1.25.7-alpine
|
||||
ARG ALPINE_IMAGE=alpine:3.20
|
||||
ARG GOPROXY=https://goproxy.cn,direct
|
||||
ARG GOSUMDB=sum.golang.google.cn
|
||||
|
||||
@@ -1,368 +0,0 @@
|
||||
# Linux DO Connect
|
||||
|
||||
OAuth(Open Authorization)是一个开放的网络授权标准,目前最新版本为 OAuth 2.0。我们日常使用的第三方登录(如 Google 账号登录)就采用了该标准。OAuth 允许用户授权第三方应用访问存储在其他服务提供商(如 Google)上的信息,无需在不同平台上重复填写注册信息。用户授权后,平台可以直接访问用户的账户信息进行身份验证,而用户无需向第三方应用提供密码。
|
||||
|
||||
目前系统已实现完整的 OAuth2 授权码(code)方式鉴权,但界面等配套功能还在持续完善中。让我们一起打造一个更完善的共享方案。
|
||||
|
||||
## 基本介绍
|
||||
|
||||
这是一套标准的 OAuth2 鉴权系统,可以让开发者共享论坛的用户基本信息。
|
||||
|
||||
- 可获取字段:
|
||||
|
||||
| 参数 | 说明 |
|
||||
| ----------------- | ------------------------------- |
|
||||
| `id` | 用户唯一标识(不可变) |
|
||||
| `username` | 论坛用户名 |
|
||||
| `name` | 论坛用户昵称(可变) |
|
||||
| `avatar_template` | 用户头像模板URL(支持多种尺寸) |
|
||||
| `active` | 账号活跃状态 |
|
||||
| `trust_level` | 信任等级(0-4) |
|
||||
| `silenced` | 禁言状态 |
|
||||
| `external_ids` | 外部ID关联信息 |
|
||||
| `api_key` | API访问密钥 |
|
||||
|
||||
通过这些信息,公益网站/接口可以实现:
|
||||
|
||||
1. 基于 `id` 的服务频率限制
|
||||
2. 基于 `trust_level` 的服务额度分配
|
||||
3. 基于用户信息的滥用举报机制
|
||||
|
||||
## 相关端点
|
||||
|
||||
- Authorize 端点: `https://connect.linux.do/oauth2/authorize`
|
||||
- Token 端点:`https://connect.linux.do/oauth2/token`
|
||||
- 用户信息 端点:`https://connect.linux.do/api/user`
|
||||
|
||||
## 申请使用
|
||||
|
||||
- 访问 [Connect.Linux.Do](https://connect.linux.do/) 申请接入你的应用。
|
||||
|
||||

|
||||
|
||||
- 点击 **`我的应用接入`** - **`申请新接入`**,填写相关信息。其中 **`回调地址`** 是你的应用接收用户信息的地址。
|
||||
|
||||

|
||||
|
||||
- 申请成功后,你将获得 **`Client Id`** 和 **`Client Secret`**,这是你应用的唯一身份凭证。
|
||||
|
||||

|
||||
|
||||
## 接入 Linux Do
|
||||
|
||||
JavaScript
|
||||
```JavaScript
|
||||
// 安装第三方请求库(或使用原生的 Fetch API),本例中使用 axios
|
||||
// npm install axios
|
||||
|
||||
// 通过 OAuth2 获取 Linux Do 用户信息的参考流程
|
||||
const axios = require('axios');
|
||||
const readline = require('readline');
|
||||
|
||||
// 配置信息(建议通过环境变量配置,避免使用硬编码)
|
||||
const CLIENT_ID = '你的 Client ID';
|
||||
const CLIENT_SECRET = '你的 Client Secret';
|
||||
const REDIRECT_URI = '你的回调地址';
|
||||
const AUTH_URL = 'https://connect.linux.do/oauth2/authorize';
|
||||
const TOKEN_URL = 'https://connect.linux.do/oauth2/token';
|
||||
const USER_INFO_URL = 'https://connect.linux.do/api/user';
|
||||
|
||||
// 第一步:生成授权 URL
|
||||
function getAuthUrl() {
|
||||
const params = new URLSearchParams({
|
||||
client_id: CLIENT_ID,
|
||||
redirect_uri: REDIRECT_URI,
|
||||
response_type: 'code',
|
||||
scope: 'user'
|
||||
});
|
||||
|
||||
return `${AUTH_URL}?${params.toString()}`;
|
||||
}
|
||||
|
||||
// 第二步:获取 code 参数
|
||||
function getCode() {
|
||||
return new Promise((resolve) => {
|
||||
// 本例中使用终端输入来模拟流程,仅供本地测试
|
||||
// 请在实际应用中替换为真实的处理逻辑
|
||||
const rl = readline.createInterface({ input: process.stdin, output: process.stdout });
|
||||
rl.question('从回调 URL 中提取出 code,粘贴到此处并按回车:', (answer) => {
|
||||
rl.close();
|
||||
resolve(answer.trim());
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// 第三步:使用 code 参数获取访问令牌
|
||||
async function getAccessToken(code) {
|
||||
try {
|
||||
const form = new URLSearchParams({
|
||||
client_id: CLIENT_ID,
|
||||
client_secret: CLIENT_SECRET,
|
||||
code: code,
|
||||
redirect_uri: REDIRECT_URI,
|
||||
grant_type: 'authorization_code'
|
||||
}).toString();
|
||||
|
||||
const response = await axios.post(TOKEN_URL, form, {
|
||||
// 提醒:需正确配置请求头,否则无法正常获取访问令牌
|
||||
headers: {
|
||||
'Content-Type': 'application/x-www-form-urlencoded',
|
||||
'Accept': 'application/json'
|
||||
}
|
||||
});
|
||||
|
||||
return response.data;
|
||||
} catch (error) {
|
||||
console.error(`获取访问令牌失败:${error.response ? JSON.stringify(error.response.data) : error.message}`);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
// 第四步:使用访问令牌获取用户信息
|
||||
async function getUserInfo(accessToken) {
|
||||
try {
|
||||
const response = await axios.get(USER_INFO_URL, {
|
||||
headers: {
|
||||
Authorization: `Bearer ${accessToken}`
|
||||
}
|
||||
});
|
||||
|
||||
return response.data;
|
||||
} catch (error) {
|
||||
console.error(`获取用户信息失败:${error.response ? JSON.stringify(error.response.data) : error.message}`);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
// 主流程
|
||||
async function main() {
|
||||
// 1. 生成授权 URL,前端引导用户访问授权页
|
||||
const authUrl = getAuthUrl();
|
||||
console.log(`请访问此 URL 授权:${authUrl}
|
||||
`);
|
||||
|
||||
// 2. 用户授权后,从回调 URL 获取 code 参数
|
||||
const code = await getCode();
|
||||
|
||||
try {
|
||||
// 3. 使用 code 参数获取访问令牌
|
||||
const tokenData = await getAccessToken(code);
|
||||
const accessToken = tokenData.access_token;
|
||||
|
||||
// 4. 使用访问令牌获取用户信息
|
||||
if (accessToken) {
|
||||
const userInfo = await getUserInfo(accessToken);
|
||||
console.log(`
|
||||
获取用户信息成功:${JSON.stringify(userInfo, null, 2)}`);
|
||||
} else {
|
||||
console.log(`
|
||||
获取访问令牌失败:${JSON.stringify(tokenData)}`);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('发生错误:', error);
|
||||
}
|
||||
}
|
||||
```
|
||||
Python
|
||||
```python
|
||||
# 安装第三方请求库,本例中使用 requests
|
||||
# pip install requests
|
||||
|
||||
# 通过 OAuth2 获取 Linux Do 用户信息的参考流程
|
||||
import requests
|
||||
import json
|
||||
|
||||
# 配置信息(建议通过环境变量配置,避免使用硬编码)
|
||||
CLIENT_ID = '你的 Client ID'
|
||||
CLIENT_SECRET = '你的 Client Secret'
|
||||
REDIRECT_URI = '你的回调地址'
|
||||
AUTH_URL = 'https://connect.linux.do/oauth2/authorize'
|
||||
TOKEN_URL = 'https://connect.linux.do/oauth2/token'
|
||||
USER_INFO_URL = 'https://connect.linux.do/api/user'
|
||||
|
||||
# 第一步:生成授权 URL
|
||||
def get_auth_url():
|
||||
params = {
|
||||
'client_id': CLIENT_ID,
|
||||
'redirect_uri': REDIRECT_URI,
|
||||
'response_type': 'code',
|
||||
'scope': 'user'
|
||||
}
|
||||
auth_url = f"{AUTH_URL}?{'&'.join(f'{k}={v}' for k, v in params.items())}"
|
||||
return auth_url
|
||||
|
||||
# 第二步:获取 code 参数
|
||||
def get_code():
|
||||
# 本例中使用终端输入来模拟流程,仅供本地测试
|
||||
# 请在实际应用中替换为真实的处理逻辑
|
||||
return input('从回调 URL 中提取出 code,粘贴到此处并按回车:').strip()
|
||||
|
||||
# 第三步:使用 code 参数获取访问令牌
|
||||
def get_access_token(code):
|
||||
try:
|
||||
data = {
|
||||
'client_id': CLIENT_ID,
|
||||
'client_secret': CLIENT_SECRET,
|
||||
'code': code,
|
||||
'redirect_uri': REDIRECT_URI,
|
||||
'grant_type': 'authorization_code'
|
||||
}
|
||||
# 提醒:需正确配置请求头,否则无法正常获取访问令牌
|
||||
headers = {
|
||||
'Content-Type': 'application/x-www-form-urlencoded',
|
||||
'Accept': 'application/json'
|
||||
}
|
||||
response = requests.post(TOKEN_URL, data=data, headers=headers)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except requests.exceptions.RequestException as e:
|
||||
print(f"获取访问令牌失败:{e}")
|
||||
return None
|
||||
|
||||
# 第四步:使用访问令牌获取用户信息
|
||||
def get_user_info(access_token):
|
||||
try:
|
||||
headers = {
|
||||
'Authorization': f'Bearer {access_token}'
|
||||
}
|
||||
response = requests.get(USER_INFO_URL, headers=headers)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except requests.exceptions.RequestException as e:
|
||||
print(f"获取用户信息失败:{e}")
|
||||
return None
|
||||
|
||||
# 主流程
|
||||
if __name__ == '__main__':
|
||||
# 1. 生成授权 URL,前端引导用户访问授权页
|
||||
auth_url = get_auth_url()
|
||||
print(f'请访问此 URL 授权:{auth_url}
|
||||
')
|
||||
|
||||
# 2. 用户授权后,从回调 URL 获取 code 参数
|
||||
code = get_code()
|
||||
|
||||
# 3. 使用 code 参数获取访问令牌
|
||||
token_data = get_access_token(code)
|
||||
if token_data:
|
||||
access_token = token_data.get('access_token')
|
||||
|
||||
# 4. 使用访问令牌获取用户信息
|
||||
if access_token:
|
||||
user_info = get_user_info(access_token)
|
||||
if user_info:
|
||||
print(f"
|
||||
获取用户信息成功:{json.dumps(user_info, indent=2)}")
|
||||
else:
|
||||
print("
|
||||
获取用户信息失败")
|
||||
else:
|
||||
print(f"
|
||||
获取访问令牌失败:{json.dumps(token_data, indent=2)}")
|
||||
else:
|
||||
print("
|
||||
获取访问令牌失败")
|
||||
```
|
||||
PHP
|
||||
```php
|
||||
// 通过 OAuth2 获取 Linux Do 用户信息的参考流程
|
||||
|
||||
// 配置信息
|
||||
$CLIENT_ID = '你的 Client ID';
|
||||
$CLIENT_SECRET = '你的 Client Secret';
|
||||
$REDIRECT_URI = '你的回调地址';
|
||||
$AUTH_URL = 'https://connect.linux.do/oauth2/authorize';
|
||||
$TOKEN_URL = 'https://connect.linux.do/oauth2/token';
|
||||
$USER_INFO_URL = 'https://connect.linux.do/api/user';
|
||||
|
||||
// 生成授权 URL
|
||||
function getAuthUrl($clientId, $redirectUri) {
|
||||
global $AUTH_URL;
|
||||
return $AUTH_URL . '?' . http_build_query([
|
||||
'client_id' => $clientId,
|
||||
'redirect_uri' => $redirectUri,
|
||||
'response_type' => 'code',
|
||||
'scope' => 'user'
|
||||
]);
|
||||
}
|
||||
|
||||
// 使用 code 参数获取用户信息(合并获取令牌和获取用户信息的步骤)
|
||||
function getUserInfoWithCode($code, $clientId, $clientSecret, $redirectUri) {
|
||||
global $TOKEN_URL, $USER_INFO_URL;
|
||||
|
||||
// 1. 获取访问令牌
|
||||
$ch = curl_init($TOKEN_URL);
|
||||
curl_setopt($ch, CURLOPT_RETURNTRANSFER, true);
|
||||
curl_setopt($ch, CURLOPT_POST, true);
|
||||
curl_setopt($ch, CURLOPT_POSTFIELDS, http_build_query([
|
||||
'client_id' => $clientId,
|
||||
'client_secret' => $clientSecret,
|
||||
'code' => $code,
|
||||
'redirect_uri' => $redirectUri,
|
||||
'grant_type' => 'authorization_code'
|
||||
]));
|
||||
curl_setopt($ch, CURLOPT_HTTPHEADER, [
|
||||
'Content-Type: application/x-www-form-urlencoded',
|
||||
'Accept: application/json'
|
||||
]);
|
||||
|
||||
$tokenResponse = curl_exec($ch);
|
||||
curl_close($ch);
|
||||
|
||||
$tokenData = json_decode($tokenResponse, true);
|
||||
if (!isset($tokenData['access_token'])) {
|
||||
return ['error' => '获取访问令牌失败', 'details' => $tokenData];
|
||||
}
|
||||
|
||||
// 2. 获取用户信息
|
||||
$ch = curl_init($USER_INFO_URL);
|
||||
curl_setopt($ch, CURLOPT_RETURNTRANSFER, true);
|
||||
curl_setopt($ch, CURLOPT_HTTPHEADER, [
|
||||
'Authorization: Bearer ' . $tokenData['access_token']
|
||||
]);
|
||||
|
||||
$userResponse = curl_exec($ch);
|
||||
curl_close($ch);
|
||||
|
||||
return json_decode($userResponse, true);
|
||||
}
|
||||
|
||||
// 主流程
|
||||
// 1. 生成授权 URL
|
||||
$authUrl = getAuthUrl($CLIENT_ID, $REDIRECT_URI);
|
||||
echo "<a href='$authUrl'>使用 Linux Do 登录</a>";
|
||||
|
||||
// 2. 处理回调并获取用户信息
|
||||
if (isset($_GET['code'])) {
|
||||
$userInfo = getUserInfoWithCode(
|
||||
$_GET['code'],
|
||||
$CLIENT_ID,
|
||||
$CLIENT_SECRET,
|
||||
$REDIRECT_URI
|
||||
);
|
||||
|
||||
if (isset($userInfo['error'])) {
|
||||
echo '错误: ' . $userInfo['error'];
|
||||
} else {
|
||||
echo '欢迎, ' . $userInfo['name'] . '!';
|
||||
// 处理用户登录逻辑...
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## 使用说明
|
||||
|
||||
### 授权流程
|
||||
|
||||
1. 用户点击应用中的’使用 Linux Do 登录’按钮
|
||||
2. 系统将用户重定向至 Linux Do 的授权页面
|
||||
3. 用户完成授权后,系统自动重定向回应用并携带授权码
|
||||
4. 应用使用授权码获取访问令牌
|
||||
5. 使用访问令牌获取用户信息
|
||||
|
||||
### 安全建议
|
||||
|
||||
- 切勿在前端代码中暴露 Client Secret
|
||||
- 对所有用户输入数据进行严格验证
|
||||
- 确保使用 HTTPS 协议传输数据
|
||||
- 定期更新并妥善保管 Client Secret
|
||||
@@ -1,164 +0,0 @@
|
||||
## 概述
|
||||
|
||||
全面增强运维监控系统(Ops)的错误日志管理和告警静默功能,优化前端 UI 组件代码质量和用户体验。本次更新重构了核心服务层和数据访问层,提升系统可维护性和运维效率。
|
||||
|
||||
## 主要改动
|
||||
|
||||
### 1. 错误日志查询优化
|
||||
|
||||
**功能特性:**
|
||||
- 新增 GetErrorLogByID 接口,支持按 ID 精确查询错误详情
|
||||
- 优化错误日志过滤逻辑,支持多维度筛选(平台、阶段、来源、所有者等)
|
||||
- 改进查询参数处理,简化代码结构
|
||||
- 增强错误分类和标准化处理
|
||||
- 支持错误解决状态追踪(resolved 字段)
|
||||
|
||||
**技术实现:**
|
||||
- `ops_handler.go` - 新增单条错误日志查询接口
|
||||
- `ops_repo.go` - 优化数据查询和过滤条件构建
|
||||
- `ops_models.go` - 扩展错误日志数据模型
|
||||
- 前端 API 接口同步更新
|
||||
|
||||
### 2. 告警静默功能
|
||||
|
||||
**功能特性:**
|
||||
- 支持按规则、平台、分组、区域等维度静默告警
|
||||
- 可设置静默时长和原因说明
|
||||
- 静默记录可追溯,记录创建人和创建时间
|
||||
- 自动过期机制,避免永久静默
|
||||
|
||||
**技术实现:**
|
||||
- `037_ops_alert_silences.sql` - 新增告警静默表
|
||||
- `ops_alerts.go` - 告警静默逻辑实现
|
||||
- `ops_alerts_handler.go` - 告警静默 API 接口
|
||||
- `OpsAlertEventsCard.vue` - 前端告警静默操作界面
|
||||
|
||||
**数据库结构:**
|
||||
|
||||
| 字段 | 类型 | 说明 |
|
||||
|------|------|------|
|
||||
| rule_id | BIGINT | 告警规则 ID |
|
||||
| platform | VARCHAR(64) | 平台标识 |
|
||||
| group_id | BIGINT | 分组 ID(可选) |
|
||||
| region | VARCHAR(64) | 区域(可选) |
|
||||
| until | TIMESTAMPTZ | 静默截止时间 |
|
||||
| reason | TEXT | 静默原因 |
|
||||
| created_by | BIGINT | 创建人 ID |
|
||||
|
||||
### 3. 错误分类标准化
|
||||
|
||||
**功能特性:**
|
||||
- 统一错误阶段分类(request|auth|routing|upstream|network|internal)
|
||||
- 规范错误归属分类(client|provider|platform)
|
||||
- 标准化错误来源分类(client_request|upstream_http|gateway)
|
||||
- 自动迁移历史数据到新分类体系
|
||||
|
||||
**技术实现:**
|
||||
- `038_ops_errors_resolution_retry_results_and_standardize_classification.sql` - 分类标准化迁移
|
||||
- 自动映射历史遗留分类到新标准
|
||||
- 自动解决已恢复的上游错误(客户端状态码 < 400)
|
||||
|
||||
### 4. Gateway 服务集成
|
||||
|
||||
**功能特性:**
|
||||
- 完善各 Gateway 服务的 Ops 集成
|
||||
- 统一错误日志记录接口
|
||||
- 增强上游错误追踪能力
|
||||
|
||||
**涉及服务:**
|
||||
- `antigravity_gateway_service.go` - Antigravity 网关集成
|
||||
- `gateway_service.go` - 通用网关集成
|
||||
- `gemini_messages_compat_service.go` - Gemini 兼容层集成
|
||||
- `openai_gateway_service.go` - OpenAI 网关集成
|
||||
|
||||
### 5. 前端 UI 优化
|
||||
|
||||
**代码重构:**
|
||||
- 大幅简化错误详情模态框代码(从 828 行优化到 450 行)
|
||||
- 优化错误日志表格组件,提升可读性
|
||||
- 清理未使用的 i18n 翻译,减少冗余
|
||||
- 统一组件代码风格和格式
|
||||
- 优化骨架屏组件,更好匹配实际看板布局
|
||||
|
||||
**布局改进:**
|
||||
- 修复模态框内容溢出和滚动问题
|
||||
- 优化表格布局,使用 flex 布局确保正确显示
|
||||
- 改进看板头部布局和交互
|
||||
- 提升响应式体验
|
||||
- 骨架屏支持全屏模式适配
|
||||
|
||||
**交互优化:**
|
||||
- 优化告警事件卡片功能和展示
|
||||
- 改进错误详情展示逻辑
|
||||
- 增强请求详情模态框
|
||||
- 完善运行时设置卡片
|
||||
- 改进加载动画效果
|
||||
|
||||
### 6. 国际化完善
|
||||
|
||||
**文案补充:**
|
||||
- 补充错误日志相关的英文翻译
|
||||
- 添加告警静默功能的中英文文案
|
||||
- 完善提示文本和错误信息
|
||||
- 统一术语翻译标准
|
||||
|
||||
## 文件变更
|
||||
|
||||
**后端(26 个文件):**
|
||||
- `backend/internal/handler/admin/ops_alerts_handler.go` - 告警接口增强
|
||||
- `backend/internal/handler/admin/ops_handler.go` - 错误日志接口优化
|
||||
- `backend/internal/handler/ops_error_logger.go` - 错误记录器增强
|
||||
- `backend/internal/repository/ops_repo.go` - 数据访问层重构
|
||||
- `backend/internal/repository/ops_repo_alerts.go` - 告警数据访问增强
|
||||
- `backend/internal/service/ops_*.go` - 核心服务层重构(10 个文件)
|
||||
- `backend/internal/service/*_gateway_service.go` - Gateway 集成(4 个文件)
|
||||
- `backend/internal/server/routes/admin.go` - 路由配置更新
|
||||
- `backend/migrations/*.sql` - 数据库迁移(2 个文件)
|
||||
- 测试文件更新(5 个文件)
|
||||
|
||||
**前端(13 个文件):**
|
||||
- `frontend/src/views/admin/ops/OpsDashboard.vue` - 看板主页优化
|
||||
- `frontend/src/views/admin/ops/components/*.vue` - 组件重构(10 个文件)
|
||||
- `frontend/src/api/admin/ops.ts` - API 接口扩展
|
||||
- `frontend/src/i18n/locales/*.ts` - 国际化文本(2 个文件)
|
||||
|
||||
## 代码统计
|
||||
|
||||
- 44 个文件修改
|
||||
- 3733 行新增
|
||||
- 995 行删除
|
||||
- 净增加 2738 行
|
||||
|
||||
## 核心改进
|
||||
|
||||
**可维护性提升:**
|
||||
- 重构核心服务层,职责更清晰
|
||||
- 简化前端组件代码,降低复杂度
|
||||
- 统一代码风格和命名规范
|
||||
- 清理冗余代码和未使用的翻译
|
||||
- 标准化错误分类体系
|
||||
|
||||
**功能完善:**
|
||||
- 告警静默功能,减少告警噪音
|
||||
- 错误日志查询优化,提升运维效率
|
||||
- Gateway 服务集成完善,统一监控能力
|
||||
- 错误解决状态追踪,便于问题管理
|
||||
|
||||
**用户体验优化:**
|
||||
- 修复多个 UI 布局问题
|
||||
- 优化交互流程
|
||||
- 完善国际化支持
|
||||
- 提升响应式体验
|
||||
- 改进加载状态展示
|
||||
|
||||
## 测试验证
|
||||
|
||||
- ✅ 错误日志查询和过滤功能
|
||||
- ✅ 告警静默创建和自动过期
|
||||
- ✅ 错误分类标准化迁移
|
||||
- ✅ Gateway 服务错误日志记录
|
||||
- ✅ 前端组件布局和交互
|
||||
- ✅ 骨架屏全屏模式适配
|
||||
- ✅ 国际化文本完整性
|
||||
- ✅ API 接口功能正确性
|
||||
- ✅ 数据库迁移执行成功
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
<div align="center">
|
||||
|
||||
[](https://golang.org/)
|
||||
[](https://golang.org/)
|
||||
[](https://vuejs.org/)
|
||||
[](https://www.postgresql.org/)
|
||||
[](https://redis.io/)
|
||||
@@ -44,7 +44,7 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot
|
||||
|
||||
| Component | Technology |
|
||||
|-----------|------------|
|
||||
| Backend | Go 1.25.5, Gin, Ent |
|
||||
| Backend | Go 1.25.7, Gin, Ent |
|
||||
| Frontend | Vue 3.4+, Vite 5+, TailwindCSS |
|
||||
| Database | PostgreSQL 15+ |
|
||||
| Cache/Queue | Redis 7+ |
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
<div align="center">
|
||||
|
||||
[](https://golang.org/)
|
||||
[](https://golang.org/)
|
||||
[](https://vuejs.org/)
|
||||
[](https://www.postgresql.org/)
|
||||
[](https://redis.io/)
|
||||
@@ -44,7 +44,7 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅(
|
||||
|
||||
| 组件 | 技术 |
|
||||
|------|------|
|
||||
| 后端 | Go 1.25.5, Gin, Ent |
|
||||
| 后端 | Go 1.25.7, Gin, Ent |
|
||||
| 前端 | Vue 3.4+, Vite 5+, TailwindCSS |
|
||||
| 数据库 | PostgreSQL 15+ |
|
||||
| 缓存/队列 | Redis 7+ |
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
FROM golang:1.25.6-alpine
|
||||
FROM golang:1.25.7-alpine
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ func main() {
|
||||
}()
|
||||
|
||||
userRepo := repository.NewUserRepository(client, sqlDB)
|
||||
authService := service.NewAuthService(userRepo, nil, cfg, nil, nil, nil, nil, nil)
|
||||
authService := service.NewAuthService(userRepo, nil, nil, cfg, nil, nil, nil, nil, nil)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
@@ -1 +1 @@
|
||||
0.1.61
|
||||
0.1.70
|
||||
|
||||
@@ -44,9 +44,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
}
|
||||
userRepository := repository.NewUserRepository(client, db)
|
||||
redeemCodeRepository := repository.NewRedeemCodeRepository(client)
|
||||
redisClient := repository.ProvideRedis(configConfig)
|
||||
refreshTokenCache := repository.NewRefreshTokenCache(redisClient)
|
||||
settingRepository := repository.NewSettingRepository(client)
|
||||
settingService := service.NewSettingService(settingRepository, configConfig)
|
||||
redisClient := repository.ProvideRedis(configConfig)
|
||||
emailCache := repository.NewEmailCache(redisClient)
|
||||
emailService := service.NewEmailService(settingRepository, emailCache)
|
||||
turnstileVerifier := repository.NewTurnstileVerifier()
|
||||
@@ -58,11 +59,12 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository, configConfig)
|
||||
apiKeyRepository := repository.NewAPIKeyRepository(client)
|
||||
groupRepository := repository.NewGroupRepository(client, db)
|
||||
userGroupRateRepository := repository.NewUserGroupRateRepository(db)
|
||||
apiKeyCache := repository.NewAPIKeyCache(redisClient)
|
||||
apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, apiKeyCache, configConfig)
|
||||
apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, userGroupRateRepository, apiKeyCache, configConfig)
|
||||
apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService)
|
||||
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator)
|
||||
authService := service.NewAuthService(userRepository, redeemCodeRepository, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService)
|
||||
authService := service.NewAuthService(userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService)
|
||||
userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator)
|
||||
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService)
|
||||
redeemCache := repository.NewRedeemCache(redisClient)
|
||||
@@ -99,7 +101,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
proxyRepository := repository.NewProxyRepository(client, db)
|
||||
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
|
||||
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
|
||||
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator)
|
||||
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator)
|
||||
adminUserHandler := admin.NewUserHandler(adminService)
|
||||
groupHandler := admin.NewGroupHandler(adminService)
|
||||
claudeOAuthClient := repository.NewClaudeOAuthClient()
|
||||
@@ -152,7 +154,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
identityService := service.NewIdentityService(identityCache)
|
||||
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
|
||||
claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService)
|
||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache)
|
||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache)
|
||||
openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService)
|
||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
|
||||
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
||||
@@ -172,9 +174,13 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
userAttributeValueRepository := repository.NewUserAttributeValueRepository(client)
|
||||
userAttributeService := service.NewUserAttributeService(userAttributeDefinitionRepository, userAttributeValueRepository)
|
||||
userAttributeHandler := admin.NewUserAttributeHandler(userAttributeService)
|
||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler)
|
||||
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, configConfig)
|
||||
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, configConfig)
|
||||
errorPassthroughRepository := repository.NewErrorPassthroughRepository(client)
|
||||
errorPassthroughCache := repository.NewErrorPassthroughCache(redisClient)
|
||||
errorPassthroughService := service.NewErrorPassthroughService(errorPassthroughRepository, errorPassthroughCache)
|
||||
errorPassthroughHandler := admin.NewErrorPassthroughHandler(errorPassthroughService)
|
||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler)
|
||||
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, errorPassthroughService, configConfig)
|
||||
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, errorPassthroughService, configConfig)
|
||||
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
|
||||
totpHandler := handler.NewTotpHandler(totpService)
|
||||
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler, totpHandler)
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/ent/announcement"
|
||||
"github.com/Wei-Shaw/sub2api/ent/announcementread"
|
||||
"github.com/Wei-Shaw/sub2api/ent/apikey"
|
||||
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
|
||||
"github.com/Wei-Shaw/sub2api/ent/group"
|
||||
"github.com/Wei-Shaw/sub2api/ent/promocode"
|
||||
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
|
||||
@@ -52,6 +53,8 @@ type Client struct {
|
||||
Announcement *AnnouncementClient
|
||||
// AnnouncementRead is the client for interacting with the AnnouncementRead builders.
|
||||
AnnouncementRead *AnnouncementReadClient
|
||||
// ErrorPassthroughRule is the client for interacting with the ErrorPassthroughRule builders.
|
||||
ErrorPassthroughRule *ErrorPassthroughRuleClient
|
||||
// Group is the client for interacting with the Group builders.
|
||||
Group *GroupClient
|
||||
// PromoCode is the client for interacting with the PromoCode builders.
|
||||
@@ -94,6 +97,7 @@ func (c *Client) init() {
|
||||
c.AccountGroup = NewAccountGroupClient(c.config)
|
||||
c.Announcement = NewAnnouncementClient(c.config)
|
||||
c.AnnouncementRead = NewAnnouncementReadClient(c.config)
|
||||
c.ErrorPassthroughRule = NewErrorPassthroughRuleClient(c.config)
|
||||
c.Group = NewGroupClient(c.config)
|
||||
c.PromoCode = NewPromoCodeClient(c.config)
|
||||
c.PromoCodeUsage = NewPromoCodeUsageClient(c.config)
|
||||
@@ -204,6 +208,7 @@ func (c *Client) Tx(ctx context.Context) (*Tx, error) {
|
||||
AccountGroup: NewAccountGroupClient(cfg),
|
||||
Announcement: NewAnnouncementClient(cfg),
|
||||
AnnouncementRead: NewAnnouncementReadClient(cfg),
|
||||
ErrorPassthroughRule: NewErrorPassthroughRuleClient(cfg),
|
||||
Group: NewGroupClient(cfg),
|
||||
PromoCode: NewPromoCodeClient(cfg),
|
||||
PromoCodeUsage: NewPromoCodeUsageClient(cfg),
|
||||
@@ -241,6 +246,7 @@ func (c *Client) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error)
|
||||
AccountGroup: NewAccountGroupClient(cfg),
|
||||
Announcement: NewAnnouncementClient(cfg),
|
||||
AnnouncementRead: NewAnnouncementReadClient(cfg),
|
||||
ErrorPassthroughRule: NewErrorPassthroughRuleClient(cfg),
|
||||
Group: NewGroupClient(cfg),
|
||||
PromoCode: NewPromoCodeClient(cfg),
|
||||
PromoCodeUsage: NewPromoCodeUsageClient(cfg),
|
||||
@@ -284,9 +290,10 @@ func (c *Client) Close() error {
|
||||
func (c *Client) Use(hooks ...Hook) {
|
||||
for _, n := range []interface{ Use(...Hook) }{
|
||||
c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead,
|
||||
c.Group, c.PromoCode, c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.Setting,
|
||||
c.UsageCleanupTask, c.UsageLog, c.User, c.UserAllowedGroup,
|
||||
c.UserAttributeDefinition, c.UserAttributeValue, c.UserSubscription,
|
||||
c.ErrorPassthroughRule, c.Group, c.PromoCode, c.PromoCodeUsage, c.Proxy,
|
||||
c.RedeemCode, c.Setting, c.UsageCleanupTask, c.UsageLog, c.User,
|
||||
c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue,
|
||||
c.UserSubscription,
|
||||
} {
|
||||
n.Use(hooks...)
|
||||
}
|
||||
@@ -297,9 +304,10 @@ func (c *Client) Use(hooks ...Hook) {
|
||||
func (c *Client) Intercept(interceptors ...Interceptor) {
|
||||
for _, n := range []interface{ Intercept(...Interceptor) }{
|
||||
c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead,
|
||||
c.Group, c.PromoCode, c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.Setting,
|
||||
c.UsageCleanupTask, c.UsageLog, c.User, c.UserAllowedGroup,
|
||||
c.UserAttributeDefinition, c.UserAttributeValue, c.UserSubscription,
|
||||
c.ErrorPassthroughRule, c.Group, c.PromoCode, c.PromoCodeUsage, c.Proxy,
|
||||
c.RedeemCode, c.Setting, c.UsageCleanupTask, c.UsageLog, c.User,
|
||||
c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue,
|
||||
c.UserSubscription,
|
||||
} {
|
||||
n.Intercept(interceptors...)
|
||||
}
|
||||
@@ -318,6 +326,8 @@ func (c *Client) Mutate(ctx context.Context, m Mutation) (Value, error) {
|
||||
return c.Announcement.mutate(ctx, m)
|
||||
case *AnnouncementReadMutation:
|
||||
return c.AnnouncementRead.mutate(ctx, m)
|
||||
case *ErrorPassthroughRuleMutation:
|
||||
return c.ErrorPassthroughRule.mutate(ctx, m)
|
||||
case *GroupMutation:
|
||||
return c.Group.mutate(ctx, m)
|
||||
case *PromoCodeMutation:
|
||||
@@ -1161,6 +1171,139 @@ func (c *AnnouncementReadClient) mutate(ctx context.Context, m *AnnouncementRead
|
||||
}
|
||||
}
|
||||
|
||||
// ErrorPassthroughRuleClient is a client for the ErrorPassthroughRule schema.
|
||||
type ErrorPassthroughRuleClient struct {
|
||||
config
|
||||
}
|
||||
|
||||
// NewErrorPassthroughRuleClient returns a client for the ErrorPassthroughRule from the given config.
|
||||
func NewErrorPassthroughRuleClient(c config) *ErrorPassthroughRuleClient {
|
||||
return &ErrorPassthroughRuleClient{config: c}
|
||||
}
|
||||
|
||||
// Use adds a list of mutation hooks to the hooks stack.
|
||||
// A call to `Use(f, g, h)` equals to `errorpassthroughrule.Hooks(f(g(h())))`.
|
||||
func (c *ErrorPassthroughRuleClient) Use(hooks ...Hook) {
|
||||
c.hooks.ErrorPassthroughRule = append(c.hooks.ErrorPassthroughRule, hooks...)
|
||||
}
|
||||
|
||||
// Intercept adds a list of query interceptors to the interceptors stack.
|
||||
// A call to `Intercept(f, g, h)` equals to `errorpassthroughrule.Intercept(f(g(h())))`.
|
||||
func (c *ErrorPassthroughRuleClient) Intercept(interceptors ...Interceptor) {
|
||||
c.inters.ErrorPassthroughRule = append(c.inters.ErrorPassthroughRule, interceptors...)
|
||||
}
|
||||
|
||||
// Create returns a builder for creating a ErrorPassthroughRule entity.
|
||||
func (c *ErrorPassthroughRuleClient) Create() *ErrorPassthroughRuleCreate {
|
||||
mutation := newErrorPassthroughRuleMutation(c.config, OpCreate)
|
||||
return &ErrorPassthroughRuleCreate{config: c.config, hooks: c.Hooks(), mutation: mutation}
|
||||
}
|
||||
|
||||
// CreateBulk returns a builder for creating a bulk of ErrorPassthroughRule entities.
|
||||
func (c *ErrorPassthroughRuleClient) CreateBulk(builders ...*ErrorPassthroughRuleCreate) *ErrorPassthroughRuleCreateBulk {
|
||||
return &ErrorPassthroughRuleCreateBulk{config: c.config, builders: builders}
|
||||
}
|
||||
|
||||
// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
|
||||
// a builder and applies setFunc on it.
|
||||
func (c *ErrorPassthroughRuleClient) MapCreateBulk(slice any, setFunc func(*ErrorPassthroughRuleCreate, int)) *ErrorPassthroughRuleCreateBulk {
|
||||
rv := reflect.ValueOf(slice)
|
||||
if rv.Kind() != reflect.Slice {
|
||||
return &ErrorPassthroughRuleCreateBulk{err: fmt.Errorf("calling to ErrorPassthroughRuleClient.MapCreateBulk with wrong type %T, need slice", slice)}
|
||||
}
|
||||
builders := make([]*ErrorPassthroughRuleCreate, rv.Len())
|
||||
for i := 0; i < rv.Len(); i++ {
|
||||
builders[i] = c.Create()
|
||||
setFunc(builders[i], i)
|
||||
}
|
||||
return &ErrorPassthroughRuleCreateBulk{config: c.config, builders: builders}
|
||||
}
|
||||
|
||||
// Update returns an update builder for ErrorPassthroughRule.
|
||||
func (c *ErrorPassthroughRuleClient) Update() *ErrorPassthroughRuleUpdate {
|
||||
mutation := newErrorPassthroughRuleMutation(c.config, OpUpdate)
|
||||
return &ErrorPassthroughRuleUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation}
|
||||
}
|
||||
|
||||
// UpdateOne returns an update builder for the given entity.
|
||||
func (c *ErrorPassthroughRuleClient) UpdateOne(_m *ErrorPassthroughRule) *ErrorPassthroughRuleUpdateOne {
|
||||
mutation := newErrorPassthroughRuleMutation(c.config, OpUpdateOne, withErrorPassthroughRule(_m))
|
||||
return &ErrorPassthroughRuleUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
|
||||
}
|
||||
|
||||
// UpdateOneID returns an update builder for the given id.
|
||||
func (c *ErrorPassthroughRuleClient) UpdateOneID(id int64) *ErrorPassthroughRuleUpdateOne {
|
||||
mutation := newErrorPassthroughRuleMutation(c.config, OpUpdateOne, withErrorPassthroughRuleID(id))
|
||||
return &ErrorPassthroughRuleUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
|
||||
}
|
||||
|
||||
// Delete returns a delete builder for ErrorPassthroughRule.
|
||||
func (c *ErrorPassthroughRuleClient) Delete() *ErrorPassthroughRuleDelete {
|
||||
mutation := newErrorPassthroughRuleMutation(c.config, OpDelete)
|
||||
return &ErrorPassthroughRuleDelete{config: c.config, hooks: c.Hooks(), mutation: mutation}
|
||||
}
|
||||
|
||||
// DeleteOne returns a builder for deleting the given entity.
|
||||
func (c *ErrorPassthroughRuleClient) DeleteOne(_m *ErrorPassthroughRule) *ErrorPassthroughRuleDeleteOne {
|
||||
return c.DeleteOneID(_m.ID)
|
||||
}
|
||||
|
||||
// DeleteOneID returns a builder for deleting the given entity by its id.
|
||||
func (c *ErrorPassthroughRuleClient) DeleteOneID(id int64) *ErrorPassthroughRuleDeleteOne {
|
||||
builder := c.Delete().Where(errorpassthroughrule.ID(id))
|
||||
builder.mutation.id = &id
|
||||
builder.mutation.op = OpDeleteOne
|
||||
return &ErrorPassthroughRuleDeleteOne{builder}
|
||||
}
|
||||
|
||||
// Query returns a query builder for ErrorPassthroughRule.
|
||||
func (c *ErrorPassthroughRuleClient) Query() *ErrorPassthroughRuleQuery {
|
||||
return &ErrorPassthroughRuleQuery{
|
||||
config: c.config,
|
||||
ctx: &QueryContext{Type: TypeErrorPassthroughRule},
|
||||
inters: c.Interceptors(),
|
||||
}
|
||||
}
|
||||
|
||||
// Get returns a ErrorPassthroughRule entity by its id.
|
||||
func (c *ErrorPassthroughRuleClient) Get(ctx context.Context, id int64) (*ErrorPassthroughRule, error) {
|
||||
return c.Query().Where(errorpassthroughrule.ID(id)).Only(ctx)
|
||||
}
|
||||
|
||||
// GetX is like Get, but panics if an error occurs.
|
||||
func (c *ErrorPassthroughRuleClient) GetX(ctx context.Context, id int64) *ErrorPassthroughRule {
|
||||
obj, err := c.Get(ctx, id)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return obj
|
||||
}
|
||||
|
||||
// Hooks returns the client hooks.
|
||||
func (c *ErrorPassthroughRuleClient) Hooks() []Hook {
|
||||
return c.hooks.ErrorPassthroughRule
|
||||
}
|
||||
|
||||
// Interceptors returns the client interceptors.
|
||||
func (c *ErrorPassthroughRuleClient) Interceptors() []Interceptor {
|
||||
return c.inters.ErrorPassthroughRule
|
||||
}
|
||||
|
||||
func (c *ErrorPassthroughRuleClient) mutate(ctx context.Context, m *ErrorPassthroughRuleMutation) (Value, error) {
|
||||
switch m.Op() {
|
||||
case OpCreate:
|
||||
return (&ErrorPassthroughRuleCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
|
||||
case OpUpdate:
|
||||
return (&ErrorPassthroughRuleUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
|
||||
case OpUpdateOne:
|
||||
return (&ErrorPassthroughRuleUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
|
||||
case OpDelete, OpDeleteOne:
|
||||
return (&ErrorPassthroughRuleDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
|
||||
default:
|
||||
return nil, fmt.Errorf("ent: unknown ErrorPassthroughRule mutation op: %q", m.Op())
|
||||
}
|
||||
}
|
||||
|
||||
// GroupClient is a client for the Group schema.
|
||||
type GroupClient struct {
|
||||
config
|
||||
@@ -3462,16 +3605,16 @@ func (c *UserSubscriptionClient) mutate(ctx context.Context, m *UserSubscription
|
||||
// hooks and interceptors per client, for fast access.
|
||||
type (
|
||||
hooks struct {
|
||||
APIKey, Account, AccountGroup, Announcement, AnnouncementRead, Group, PromoCode,
|
||||
PromoCodeUsage, Proxy, RedeemCode, Setting, UsageCleanupTask, UsageLog, User,
|
||||
UserAllowedGroup, UserAttributeDefinition, UserAttributeValue,
|
||||
UserSubscription []ent.Hook
|
||||
APIKey, Account, AccountGroup, Announcement, AnnouncementRead,
|
||||
ErrorPassthroughRule, Group, PromoCode, PromoCodeUsage, Proxy, RedeemCode,
|
||||
Setting, UsageCleanupTask, UsageLog, User, UserAllowedGroup,
|
||||
UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Hook
|
||||
}
|
||||
inters struct {
|
||||
APIKey, Account, AccountGroup, Announcement, AnnouncementRead, Group, PromoCode,
|
||||
PromoCodeUsage, Proxy, RedeemCode, Setting, UsageCleanupTask, UsageLog, User,
|
||||
UserAllowedGroup, UserAttributeDefinition, UserAttributeValue,
|
||||
UserSubscription []ent.Interceptor
|
||||
APIKey, Account, AccountGroup, Announcement, AnnouncementRead,
|
||||
ErrorPassthroughRule, Group, PromoCode, PromoCodeUsage, Proxy, RedeemCode,
|
||||
Setting, UsageCleanupTask, UsageLog, User, UserAllowedGroup,
|
||||
UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Interceptor
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/ent/announcement"
|
||||
"github.com/Wei-Shaw/sub2api/ent/announcementread"
|
||||
"github.com/Wei-Shaw/sub2api/ent/apikey"
|
||||
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
|
||||
"github.com/Wei-Shaw/sub2api/ent/group"
|
||||
"github.com/Wei-Shaw/sub2api/ent/promocode"
|
||||
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
|
||||
@@ -95,6 +96,7 @@ func checkColumn(t, c string) error {
|
||||
accountgroup.Table: accountgroup.ValidColumn,
|
||||
announcement.Table: announcement.ValidColumn,
|
||||
announcementread.Table: announcementread.ValidColumn,
|
||||
errorpassthroughrule.Table: errorpassthroughrule.ValidColumn,
|
||||
group.Table: group.ValidColumn,
|
||||
promocode.Table: promocode.ValidColumn,
|
||||
promocodeusage.Table: promocodeusage.ValidColumn,
|
||||
|
||||
269
backend/ent/errorpassthroughrule.go
Normal file
269
backend/ent/errorpassthroughrule.go
Normal file
@@ -0,0 +1,269 @@
|
||||
// Code generated by ent, DO NOT EDIT.
|
||||
|
||||
package ent
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"entgo.io/ent"
|
||||
"entgo.io/ent/dialect/sql"
|
||||
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
|
||||
)
|
||||
|
||||
// ErrorPassthroughRule is the model entity for the ErrorPassthroughRule schema.
|
||||
type ErrorPassthroughRule struct {
|
||||
config `json:"-"`
|
||||
// ID of the ent.
|
||||
ID int64 `json:"id,omitempty"`
|
||||
// CreatedAt holds the value of the "created_at" field.
|
||||
CreatedAt time.Time `json:"created_at,omitempty"`
|
||||
// UpdatedAt holds the value of the "updated_at" field.
|
||||
UpdatedAt time.Time `json:"updated_at,omitempty"`
|
||||
// Name holds the value of the "name" field.
|
||||
Name string `json:"name,omitempty"`
|
||||
// Enabled holds the value of the "enabled" field.
|
||||
Enabled bool `json:"enabled,omitempty"`
|
||||
// Priority holds the value of the "priority" field.
|
||||
Priority int `json:"priority,omitempty"`
|
||||
// ErrorCodes holds the value of the "error_codes" field.
|
||||
ErrorCodes []int `json:"error_codes,omitempty"`
|
||||
// Keywords holds the value of the "keywords" field.
|
||||
Keywords []string `json:"keywords,omitempty"`
|
||||
// MatchMode holds the value of the "match_mode" field.
|
||||
MatchMode string `json:"match_mode,omitempty"`
|
||||
// Platforms holds the value of the "platforms" field.
|
||||
Platforms []string `json:"platforms,omitempty"`
|
||||
// PassthroughCode holds the value of the "passthrough_code" field.
|
||||
PassthroughCode bool `json:"passthrough_code,omitempty"`
|
||||
// ResponseCode holds the value of the "response_code" field.
|
||||
ResponseCode *int `json:"response_code,omitempty"`
|
||||
// PassthroughBody holds the value of the "passthrough_body" field.
|
||||
PassthroughBody bool `json:"passthrough_body,omitempty"`
|
||||
// CustomMessage holds the value of the "custom_message" field.
|
||||
CustomMessage *string `json:"custom_message,omitempty"`
|
||||
// Description holds the value of the "description" field.
|
||||
Description *string `json:"description,omitempty"`
|
||||
selectValues sql.SelectValues
|
||||
}
|
||||
|
||||
// scanValues returns the types for scanning values from sql.Rows.
|
||||
func (*ErrorPassthroughRule) scanValues(columns []string) ([]any, error) {
|
||||
values := make([]any, len(columns))
|
||||
for i := range columns {
|
||||
switch columns[i] {
|
||||
case errorpassthroughrule.FieldErrorCodes, errorpassthroughrule.FieldKeywords, errorpassthroughrule.FieldPlatforms:
|
||||
values[i] = new([]byte)
|
||||
case errorpassthroughrule.FieldEnabled, errorpassthroughrule.FieldPassthroughCode, errorpassthroughrule.FieldPassthroughBody:
|
||||
values[i] = new(sql.NullBool)
|
||||
case errorpassthroughrule.FieldID, errorpassthroughrule.FieldPriority, errorpassthroughrule.FieldResponseCode:
|
||||
values[i] = new(sql.NullInt64)
|
||||
case errorpassthroughrule.FieldName, errorpassthroughrule.FieldMatchMode, errorpassthroughrule.FieldCustomMessage, errorpassthroughrule.FieldDescription:
|
||||
values[i] = new(sql.NullString)
|
||||
case errorpassthroughrule.FieldCreatedAt, errorpassthroughrule.FieldUpdatedAt:
|
||||
values[i] = new(sql.NullTime)
|
||||
default:
|
||||
values[i] = new(sql.UnknownType)
|
||||
}
|
||||
}
|
||||
return values, nil
|
||||
}
|
||||
|
||||
// assignValues assigns the values that were returned from sql.Rows (after scanning)
|
||||
// to the ErrorPassthroughRule fields.
|
||||
func (_m *ErrorPassthroughRule) assignValues(columns []string, values []any) error {
|
||||
if m, n := len(values), len(columns); m < n {
|
||||
return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
|
||||
}
|
||||
for i := range columns {
|
||||
switch columns[i] {
|
||||
case errorpassthroughrule.FieldID:
|
||||
value, ok := values[i].(*sql.NullInt64)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field id", value)
|
||||
}
|
||||
_m.ID = int64(value.Int64)
|
||||
case errorpassthroughrule.FieldCreatedAt:
|
||||
if value, ok := values[i].(*sql.NullTime); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field created_at", values[i])
|
||||
} else if value.Valid {
|
||||
_m.CreatedAt = value.Time
|
||||
}
|
||||
case errorpassthroughrule.FieldUpdatedAt:
|
||||
if value, ok := values[i].(*sql.NullTime); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field updated_at", values[i])
|
||||
} else if value.Valid {
|
||||
_m.UpdatedAt = value.Time
|
||||
}
|
||||
case errorpassthroughrule.FieldName:
|
||||
if value, ok := values[i].(*sql.NullString); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field name", values[i])
|
||||
} else if value.Valid {
|
||||
_m.Name = value.String
|
||||
}
|
||||
case errorpassthroughrule.FieldEnabled:
|
||||
if value, ok := values[i].(*sql.NullBool); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field enabled", values[i])
|
||||
} else if value.Valid {
|
||||
_m.Enabled = value.Bool
|
||||
}
|
||||
case errorpassthroughrule.FieldPriority:
|
||||
if value, ok := values[i].(*sql.NullInt64); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field priority", values[i])
|
||||
} else if value.Valid {
|
||||
_m.Priority = int(value.Int64)
|
||||
}
|
||||
case errorpassthroughrule.FieldErrorCodes:
|
||||
if value, ok := values[i].(*[]byte); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field error_codes", values[i])
|
||||
} else if value != nil && len(*value) > 0 {
|
||||
if err := json.Unmarshal(*value, &_m.ErrorCodes); err != nil {
|
||||
return fmt.Errorf("unmarshal field error_codes: %w", err)
|
||||
}
|
||||
}
|
||||
case errorpassthroughrule.FieldKeywords:
|
||||
if value, ok := values[i].(*[]byte); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field keywords", values[i])
|
||||
} else if value != nil && len(*value) > 0 {
|
||||
if err := json.Unmarshal(*value, &_m.Keywords); err != nil {
|
||||
return fmt.Errorf("unmarshal field keywords: %w", err)
|
||||
}
|
||||
}
|
||||
case errorpassthroughrule.FieldMatchMode:
|
||||
if value, ok := values[i].(*sql.NullString); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field match_mode", values[i])
|
||||
} else if value.Valid {
|
||||
_m.MatchMode = value.String
|
||||
}
|
||||
case errorpassthroughrule.FieldPlatforms:
|
||||
if value, ok := values[i].(*[]byte); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field platforms", values[i])
|
||||
} else if value != nil && len(*value) > 0 {
|
||||
if err := json.Unmarshal(*value, &_m.Platforms); err != nil {
|
||||
return fmt.Errorf("unmarshal field platforms: %w", err)
|
||||
}
|
||||
}
|
||||
case errorpassthroughrule.FieldPassthroughCode:
|
||||
if value, ok := values[i].(*sql.NullBool); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field passthrough_code", values[i])
|
||||
} else if value.Valid {
|
||||
_m.PassthroughCode = value.Bool
|
||||
}
|
||||
case errorpassthroughrule.FieldResponseCode:
|
||||
if value, ok := values[i].(*sql.NullInt64); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field response_code", values[i])
|
||||
} else if value.Valid {
|
||||
_m.ResponseCode = new(int)
|
||||
*_m.ResponseCode = int(value.Int64)
|
||||
}
|
||||
case errorpassthroughrule.FieldPassthroughBody:
|
||||
if value, ok := values[i].(*sql.NullBool); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field passthrough_body", values[i])
|
||||
} else if value.Valid {
|
||||
_m.PassthroughBody = value.Bool
|
||||
}
|
||||
case errorpassthroughrule.FieldCustomMessage:
|
||||
if value, ok := values[i].(*sql.NullString); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field custom_message", values[i])
|
||||
} else if value.Valid {
|
||||
_m.CustomMessage = new(string)
|
||||
*_m.CustomMessage = value.String
|
||||
}
|
||||
case errorpassthroughrule.FieldDescription:
|
||||
if value, ok := values[i].(*sql.NullString); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field description", values[i])
|
||||
} else if value.Valid {
|
||||
_m.Description = new(string)
|
||||
*_m.Description = value.String
|
||||
}
|
||||
default:
|
||||
_m.selectValues.Set(columns[i], values[i])
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value returns the ent.Value that was dynamically selected and assigned to the ErrorPassthroughRule.
|
||||
// This includes values selected through modifiers, order, etc.
|
||||
func (_m *ErrorPassthroughRule) Value(name string) (ent.Value, error) {
|
||||
return _m.selectValues.Get(name)
|
||||
}
|
||||
|
||||
// Update returns a builder for updating this ErrorPassthroughRule.
|
||||
// Note that you need to call ErrorPassthroughRule.Unwrap() before calling this method if this ErrorPassthroughRule
|
||||
// was returned from a transaction, and the transaction was committed or rolled back.
|
||||
func (_m *ErrorPassthroughRule) Update() *ErrorPassthroughRuleUpdateOne {
|
||||
return NewErrorPassthroughRuleClient(_m.config).UpdateOne(_m)
|
||||
}
|
||||
|
||||
// Unwrap unwraps the ErrorPassthroughRule entity that was returned from a transaction after it was closed,
|
||||
// so that all future queries will be executed through the driver which created the transaction.
|
||||
func (_m *ErrorPassthroughRule) Unwrap() *ErrorPassthroughRule {
|
||||
_tx, ok := _m.config.driver.(*txDriver)
|
||||
if !ok {
|
||||
panic("ent: ErrorPassthroughRule is not a transactional entity")
|
||||
}
|
||||
_m.config.driver = _tx.drv
|
||||
return _m
|
||||
}
|
||||
|
||||
// String implements the fmt.Stringer.
|
||||
func (_m *ErrorPassthroughRule) String() string {
|
||||
var builder strings.Builder
|
||||
builder.WriteString("ErrorPassthroughRule(")
|
||||
builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID))
|
||||
builder.WriteString("created_at=")
|
||||
builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("updated_at=")
|
||||
builder.WriteString(_m.UpdatedAt.Format(time.ANSIC))
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("name=")
|
||||
builder.WriteString(_m.Name)
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("enabled=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.Enabled))
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("priority=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.Priority))
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("error_codes=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.ErrorCodes))
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("keywords=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.Keywords))
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("match_mode=")
|
||||
builder.WriteString(_m.MatchMode)
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("platforms=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.Platforms))
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("passthrough_code=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.PassthroughCode))
|
||||
builder.WriteString(", ")
|
||||
if v := _m.ResponseCode; v != nil {
|
||||
builder.WriteString("response_code=")
|
||||
builder.WriteString(fmt.Sprintf("%v", *v))
|
||||
}
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("passthrough_body=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.PassthroughBody))
|
||||
builder.WriteString(", ")
|
||||
if v := _m.CustomMessage; v != nil {
|
||||
builder.WriteString("custom_message=")
|
||||
builder.WriteString(*v)
|
||||
}
|
||||
builder.WriteString(", ")
|
||||
if v := _m.Description; v != nil {
|
||||
builder.WriteString("description=")
|
||||
builder.WriteString(*v)
|
||||
}
|
||||
builder.WriteByte(')')
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
// ErrorPassthroughRules is a parsable slice of ErrorPassthroughRule.
|
||||
type ErrorPassthroughRules []*ErrorPassthroughRule
|
||||
161
backend/ent/errorpassthroughrule/errorpassthroughrule.go
Normal file
161
backend/ent/errorpassthroughrule/errorpassthroughrule.go
Normal file
@@ -0,0 +1,161 @@
|
||||
// Code generated by ent, DO NOT EDIT.
|
||||
|
||||
package errorpassthroughrule
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"entgo.io/ent/dialect/sql"
|
||||
)
|
||||
|
||||
const (
|
||||
// Label holds the string label denoting the errorpassthroughrule type in the database.
|
||||
Label = "error_passthrough_rule"
|
||||
// FieldID holds the string denoting the id field in the database.
|
||||
FieldID = "id"
|
||||
// FieldCreatedAt holds the string denoting the created_at field in the database.
|
||||
FieldCreatedAt = "created_at"
|
||||
// FieldUpdatedAt holds the string denoting the updated_at field in the database.
|
||||
FieldUpdatedAt = "updated_at"
|
||||
// FieldName holds the string denoting the name field in the database.
|
||||
FieldName = "name"
|
||||
// FieldEnabled holds the string denoting the enabled field in the database.
|
||||
FieldEnabled = "enabled"
|
||||
// FieldPriority holds the string denoting the priority field in the database.
|
||||
FieldPriority = "priority"
|
||||
// FieldErrorCodes holds the string denoting the error_codes field in the database.
|
||||
FieldErrorCodes = "error_codes"
|
||||
// FieldKeywords holds the string denoting the keywords field in the database.
|
||||
FieldKeywords = "keywords"
|
||||
// FieldMatchMode holds the string denoting the match_mode field in the database.
|
||||
FieldMatchMode = "match_mode"
|
||||
// FieldPlatforms holds the string denoting the platforms field in the database.
|
||||
FieldPlatforms = "platforms"
|
||||
// FieldPassthroughCode holds the string denoting the passthrough_code field in the database.
|
||||
FieldPassthroughCode = "passthrough_code"
|
||||
// FieldResponseCode holds the string denoting the response_code field in the database.
|
||||
FieldResponseCode = "response_code"
|
||||
// FieldPassthroughBody holds the string denoting the passthrough_body field in the database.
|
||||
FieldPassthroughBody = "passthrough_body"
|
||||
// FieldCustomMessage holds the string denoting the custom_message field in the database.
|
||||
FieldCustomMessage = "custom_message"
|
||||
// FieldDescription holds the string denoting the description field in the database.
|
||||
FieldDescription = "description"
|
||||
// Table holds the table name of the errorpassthroughrule in the database.
|
||||
Table = "error_passthrough_rules"
|
||||
)
|
||||
|
||||
// Columns holds all SQL columns for errorpassthroughrule fields.
|
||||
var Columns = []string{
|
||||
FieldID,
|
||||
FieldCreatedAt,
|
||||
FieldUpdatedAt,
|
||||
FieldName,
|
||||
FieldEnabled,
|
||||
FieldPriority,
|
||||
FieldErrorCodes,
|
||||
FieldKeywords,
|
||||
FieldMatchMode,
|
||||
FieldPlatforms,
|
||||
FieldPassthroughCode,
|
||||
FieldResponseCode,
|
||||
FieldPassthroughBody,
|
||||
FieldCustomMessage,
|
||||
FieldDescription,
|
||||
}
|
||||
|
||||
// ValidColumn reports if the column name is valid (part of the table columns).
|
||||
func ValidColumn(column string) bool {
|
||||
for i := range Columns {
|
||||
if column == Columns[i] {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
var (
|
||||
// DefaultCreatedAt holds the default value on creation for the "created_at" field.
|
||||
DefaultCreatedAt func() time.Time
|
||||
// DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
|
||||
DefaultUpdatedAt func() time.Time
|
||||
// UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field.
|
||||
UpdateDefaultUpdatedAt func() time.Time
|
||||
// NameValidator is a validator for the "name" field. It is called by the builders before save.
|
||||
NameValidator func(string) error
|
||||
// DefaultEnabled holds the default value on creation for the "enabled" field.
|
||||
DefaultEnabled bool
|
||||
// DefaultPriority holds the default value on creation for the "priority" field.
|
||||
DefaultPriority int
|
||||
// DefaultMatchMode holds the default value on creation for the "match_mode" field.
|
||||
DefaultMatchMode string
|
||||
// MatchModeValidator is a validator for the "match_mode" field. It is called by the builders before save.
|
||||
MatchModeValidator func(string) error
|
||||
// DefaultPassthroughCode holds the default value on creation for the "passthrough_code" field.
|
||||
DefaultPassthroughCode bool
|
||||
// DefaultPassthroughBody holds the default value on creation for the "passthrough_body" field.
|
||||
DefaultPassthroughBody bool
|
||||
)
|
||||
|
||||
// OrderOption defines the ordering options for the ErrorPassthroughRule queries.
|
||||
type OrderOption func(*sql.Selector)
|
||||
|
||||
// ByID orders the results by the id field.
|
||||
func ByID(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldID, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByCreatedAt orders the results by the created_at field.
|
||||
func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByUpdatedAt orders the results by the updated_at field.
|
||||
func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByName orders the results by the name field.
|
||||
func ByName(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldName, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByEnabled orders the results by the enabled field.
|
||||
func ByEnabled(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldEnabled, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByPriority orders the results by the priority field.
|
||||
func ByPriority(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldPriority, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByMatchMode orders the results by the match_mode field.
|
||||
func ByMatchMode(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldMatchMode, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByPassthroughCode orders the results by the passthrough_code field.
|
||||
func ByPassthroughCode(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldPassthroughCode, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByResponseCode orders the results by the response_code field.
|
||||
func ByResponseCode(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldResponseCode, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByPassthroughBody orders the results by the passthrough_body field.
|
||||
func ByPassthroughBody(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldPassthroughBody, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByCustomMessage orders the results by the custom_message field.
|
||||
func ByCustomMessage(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldCustomMessage, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByDescription orders the results by the description field.
|
||||
func ByDescription(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldDescription, opts...).ToFunc()
|
||||
}
|
||||
635
backend/ent/errorpassthroughrule/where.go
Normal file
635
backend/ent/errorpassthroughrule/where.go
Normal file
@@ -0,0 +1,635 @@
|
||||
// Code generated by ent, DO NOT EDIT.
|
||||
|
||||
package errorpassthroughrule
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"entgo.io/ent/dialect/sql"
|
||||
"github.com/Wei-Shaw/sub2api/ent/predicate"
|
||||
)
|
||||
|
||||
// ID filters vertices based on their ID field.
|
||||
func ID(id int64) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldID, id))
|
||||
}
|
||||
|
||||
// IDEQ applies the EQ predicate on the ID field.
|
||||
func IDEQ(id int64) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldID, id))
|
||||
}
|
||||
|
||||
// IDNEQ applies the NEQ predicate on the ID field.
|
||||
func IDNEQ(id int64) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldNEQ(FieldID, id))
|
||||
}
|
||||
|
||||
// IDIn applies the In predicate on the ID field.
|
||||
func IDIn(ids ...int64) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldIn(FieldID, ids...))
|
||||
}
|
||||
|
||||
// IDNotIn applies the NotIn predicate on the ID field.
|
||||
func IDNotIn(ids ...int64) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldNotIn(FieldID, ids...))
|
||||
}
|
||||
|
||||
// IDGT applies the GT predicate on the ID field.
|
||||
func IDGT(id int64) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldGT(FieldID, id))
|
||||
}
|
||||
|
||||
// IDGTE applies the GTE predicate on the ID field.
|
||||
func IDGTE(id int64) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldGTE(FieldID, id))
|
||||
}
|
||||
|
||||
// IDLT applies the LT predicate on the ID field.
|
||||
func IDLT(id int64) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldLT(FieldID, id))
|
||||
}
|
||||
|
||||
// IDLTE applies the LTE predicate on the ID field.
|
||||
func IDLTE(id int64) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldLTE(FieldID, id))
|
||||
}
|
||||
|
||||
// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
|
||||
func CreatedAt(v time.Time) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldCreatedAt, v))
|
||||
}
|
||||
|
||||
// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ.
|
||||
func UpdatedAt(v time.Time) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldUpdatedAt, v))
|
||||
}
|
||||
|
||||
// Name applies equality check predicate on the "name" field. It's identical to NameEQ.
|
||||
func Name(v string) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldName, v))
|
||||
}
|
||||
|
||||
// Enabled applies equality check predicate on the "enabled" field. It's identical to EnabledEQ.
|
||||
func Enabled(v bool) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldEnabled, v))
|
||||
}
|
||||
|
||||
// Priority applies equality check predicate on the "priority" field. It's identical to PriorityEQ.
|
||||
func Priority(v int) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldPriority, v))
|
||||
}
|
||||
|
||||
// MatchMode applies equality check predicate on the "match_mode" field. It's identical to MatchModeEQ.
|
||||
func MatchMode(v string) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldMatchMode, v))
|
||||
}
|
||||
|
||||
// PassthroughCode applies equality check predicate on the "passthrough_code" field. It's identical to PassthroughCodeEQ.
|
||||
func PassthroughCode(v bool) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldPassthroughCode, v))
|
||||
}
|
||||
|
||||
// ResponseCode applies equality check predicate on the "response_code" field. It's identical to ResponseCodeEQ.
|
||||
func ResponseCode(v int) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldResponseCode, v))
|
||||
}
|
||||
|
||||
// PassthroughBody applies equality check predicate on the "passthrough_body" field. It's identical to PassthroughBodyEQ.
|
||||
func PassthroughBody(v bool) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldPassthroughBody, v))
|
||||
}
|
||||
|
||||
// CustomMessage applies equality check predicate on the "custom_message" field. It's identical to CustomMessageEQ.
|
||||
func CustomMessage(v string) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldCustomMessage, v))
|
||||
}
|
||||
|
||||
// Description applies equality check predicate on the "description" field. It's identical to DescriptionEQ.
|
||||
func Description(v string) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldDescription, v))
|
||||
}
|
||||
|
||||
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
|
||||
func CreatedAtEQ(v time.Time) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldCreatedAt, v))
|
||||
}
|
||||
|
||||
// CreatedAtNEQ applies the NEQ predicate on the "created_at" field.
|
||||
func CreatedAtNEQ(v time.Time) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldNEQ(FieldCreatedAt, v))
|
||||
}
|
||||
|
||||
// CreatedAtIn applies the In predicate on the "created_at" field.
|
||||
func CreatedAtIn(vs ...time.Time) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldIn(FieldCreatedAt, vs...))
|
||||
}
|
||||
|
||||
// CreatedAtNotIn applies the NotIn predicate on the "created_at" field.
|
||||
func CreatedAtNotIn(vs ...time.Time) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldNotIn(FieldCreatedAt, vs...))
|
||||
}
|
||||
|
||||
// CreatedAtGT applies the GT predicate on the "created_at" field.
|
||||
func CreatedAtGT(v time.Time) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldGT(FieldCreatedAt, v))
|
||||
}
|
||||
|
||||
// CreatedAtGTE applies the GTE predicate on the "created_at" field.
|
||||
func CreatedAtGTE(v time.Time) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldGTE(FieldCreatedAt, v))
|
||||
}
|
||||
|
||||
// CreatedAtLT applies the LT predicate on the "created_at" field.
|
||||
func CreatedAtLT(v time.Time) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldLT(FieldCreatedAt, v))
|
||||
}
|
||||
|
||||
// CreatedAtLTE applies the LTE predicate on the "created_at" field.
|
||||
func CreatedAtLTE(v time.Time) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldLTE(FieldCreatedAt, v))
|
||||
}
|
||||
|
||||
// UpdatedAtEQ applies the EQ predicate on the "updated_at" field.
|
||||
func UpdatedAtEQ(v time.Time) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldUpdatedAt, v))
|
||||
}
|
||||
|
||||
// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field.
|
||||
func UpdatedAtNEQ(v time.Time) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldNEQ(FieldUpdatedAt, v))
|
||||
}
|
||||
|
||||
// UpdatedAtIn applies the In predicate on the "updated_at" field.
|
||||
func UpdatedAtIn(vs ...time.Time) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldIn(FieldUpdatedAt, vs...))
|
||||
}
|
||||
|
||||
// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field.
|
||||
func UpdatedAtNotIn(vs ...time.Time) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldNotIn(FieldUpdatedAt, vs...))
|
||||
}
|
||||
|
||||
// UpdatedAtGT applies the GT predicate on the "updated_at" field.
|
||||
func UpdatedAtGT(v time.Time) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldGT(FieldUpdatedAt, v))
|
||||
}
|
||||
|
||||
// UpdatedAtGTE applies the GTE predicate on the "updated_at" field.
|
||||
func UpdatedAtGTE(v time.Time) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldGTE(FieldUpdatedAt, v))
|
||||
}
|
||||
|
||||
// UpdatedAtLT applies the LT predicate on the "updated_at" field.
|
||||
func UpdatedAtLT(v time.Time) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldLT(FieldUpdatedAt, v))
|
||||
}
|
||||
|
||||
// UpdatedAtLTE applies the LTE predicate on the "updated_at" field.
|
||||
func UpdatedAtLTE(v time.Time) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldLTE(FieldUpdatedAt, v))
|
||||
}
|
||||
|
||||
// NameEQ applies the EQ predicate on the "name" field.
|
||||
func NameEQ(v string) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldName, v))
|
||||
}
|
||||
|
||||
// NameNEQ applies the NEQ predicate on the "name" field.
|
||||
func NameNEQ(v string) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldNEQ(FieldName, v))
|
||||
}
|
||||
|
||||
// NameIn applies the In predicate on the "name" field.
|
||||
func NameIn(vs ...string) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldIn(FieldName, vs...))
|
||||
}
|
||||
|
||||
// NameNotIn applies the NotIn predicate on the "name" field.
|
||||
func NameNotIn(vs ...string) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldNotIn(FieldName, vs...))
|
||||
}
|
||||
|
||||
// NameGT applies the GT predicate on the "name" field.
|
||||
func NameGT(v string) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldGT(FieldName, v))
|
||||
}
|
||||
|
||||
// NameGTE applies the GTE predicate on the "name" field.
|
||||
func NameGTE(v string) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldGTE(FieldName, v))
|
||||
}
|
||||
|
||||
// NameLT applies the LT predicate on the "name" field.
|
||||
func NameLT(v string) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldLT(FieldName, v))
|
||||
}
|
||||
|
||||
// NameLTE applies the LTE predicate on the "name" field.
|
||||
func NameLTE(v string) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldLTE(FieldName, v))
|
||||
}
|
||||
|
||||
// NameContains applies the Contains predicate on the "name" field.
|
||||
func NameContains(v string) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldContains(FieldName, v))
|
||||
}
|
||||
|
||||
// NameHasPrefix applies the HasPrefix predicate on the "name" field.
|
||||
func NameHasPrefix(v string) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldHasPrefix(FieldName, v))
|
||||
}
|
||||
|
||||
// NameHasSuffix applies the HasSuffix predicate on the "name" field.
|
||||
func NameHasSuffix(v string) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldHasSuffix(FieldName, v))
|
||||
}
|
||||
|
||||
// NameEqualFold applies the EqualFold predicate on the "name" field.
|
||||
func NameEqualFold(v string) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldEqualFold(FieldName, v))
|
||||
}
|
||||
|
||||
// NameContainsFold applies the ContainsFold predicate on the "name" field.
|
||||
func NameContainsFold(v string) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldContainsFold(FieldName, v))
|
||||
}
|
||||
|
||||
// EnabledEQ applies the EQ predicate on the "enabled" field.
|
||||
func EnabledEQ(v bool) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldEnabled, v))
|
||||
}
|
||||
|
||||
// EnabledNEQ applies the NEQ predicate on the "enabled" field.
|
||||
func EnabledNEQ(v bool) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldNEQ(FieldEnabled, v))
|
||||
}
|
||||
|
||||
// PriorityEQ applies the EQ predicate on the "priority" field.
|
||||
func PriorityEQ(v int) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldPriority, v))
|
||||
}
|
||||
|
||||
// PriorityNEQ applies the NEQ predicate on the "priority" field.
|
||||
func PriorityNEQ(v int) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldNEQ(FieldPriority, v))
|
||||
}
|
||||
|
||||
// PriorityIn applies the In predicate on the "priority" field.
|
||||
func PriorityIn(vs ...int) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldIn(FieldPriority, vs...))
|
||||
}
|
||||
|
||||
// PriorityNotIn applies the NotIn predicate on the "priority" field.
|
||||
func PriorityNotIn(vs ...int) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldNotIn(FieldPriority, vs...))
|
||||
}
|
||||
|
||||
// PriorityGT applies the GT predicate on the "priority" field.
|
||||
func PriorityGT(v int) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldGT(FieldPriority, v))
|
||||
}
|
||||
|
||||
// PriorityGTE applies the GTE predicate on the "priority" field.
|
||||
func PriorityGTE(v int) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldGTE(FieldPriority, v))
|
||||
}
|
||||
|
||||
// PriorityLT applies the LT predicate on the "priority" field.
|
||||
func PriorityLT(v int) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldLT(FieldPriority, v))
|
||||
}
|
||||
|
||||
// PriorityLTE applies the LTE predicate on the "priority" field.
|
||||
func PriorityLTE(v int) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldLTE(FieldPriority, v))
|
||||
}
|
||||
|
||||
// ErrorCodesIsNil applies the IsNil predicate on the "error_codes" field.
|
||||
func ErrorCodesIsNil() predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldIsNull(FieldErrorCodes))
|
||||
}
|
||||
|
||||
// ErrorCodesNotNil applies the NotNil predicate on the "error_codes" field.
|
||||
func ErrorCodesNotNil() predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldNotNull(FieldErrorCodes))
|
||||
}
|
||||
|
||||
// KeywordsIsNil applies the IsNil predicate on the "keywords" field.
|
||||
func KeywordsIsNil() predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldIsNull(FieldKeywords))
|
||||
}
|
||||
|
||||
// KeywordsNotNil applies the NotNil predicate on the "keywords" field.
|
||||
func KeywordsNotNil() predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldNotNull(FieldKeywords))
|
||||
}
|
||||
|
||||
// MatchModeEQ applies the EQ predicate on the "match_mode" field.
|
||||
func MatchModeEQ(v string) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldMatchMode, v))
|
||||
}
|
||||
|
||||
// MatchModeNEQ applies the NEQ predicate on the "match_mode" field.
|
||||
func MatchModeNEQ(v string) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldNEQ(FieldMatchMode, v))
|
||||
}
|
||||
|
||||
// MatchModeIn applies the In predicate on the "match_mode" field.
|
||||
func MatchModeIn(vs ...string) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldIn(FieldMatchMode, vs...))
|
||||
}
|
||||
|
||||
// MatchModeNotIn applies the NotIn predicate on the "match_mode" field.
|
||||
func MatchModeNotIn(vs ...string) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldNotIn(FieldMatchMode, vs...))
|
||||
}
|
||||
|
||||
// MatchModeGT applies the GT predicate on the "match_mode" field.
|
||||
func MatchModeGT(v string) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldGT(FieldMatchMode, v))
|
||||
}
|
||||
|
||||
// MatchModeGTE applies the GTE predicate on the "match_mode" field.
|
||||
func MatchModeGTE(v string) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldGTE(FieldMatchMode, v))
|
||||
}
|
||||
|
||||
// MatchModeLT applies the LT predicate on the "match_mode" field.
|
||||
func MatchModeLT(v string) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldLT(FieldMatchMode, v))
|
||||
}
|
||||
|
||||
// MatchModeLTE applies the LTE predicate on the "match_mode" field.
|
||||
func MatchModeLTE(v string) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldLTE(FieldMatchMode, v))
|
||||
}
|
||||
|
||||
// MatchModeContains applies the Contains predicate on the "match_mode" field.
|
||||
func MatchModeContains(v string) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldContains(FieldMatchMode, v))
|
||||
}
|
||||
|
||||
// MatchModeHasPrefix applies the HasPrefix predicate on the "match_mode" field.
|
||||
func MatchModeHasPrefix(v string) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldHasPrefix(FieldMatchMode, v))
|
||||
}
|
||||
|
||||
// MatchModeHasSuffix applies the HasSuffix predicate on the "match_mode" field.
|
||||
func MatchModeHasSuffix(v string) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldHasSuffix(FieldMatchMode, v))
|
||||
}
|
||||
|
||||
// MatchModeEqualFold applies the EqualFold predicate on the "match_mode" field.
|
||||
func MatchModeEqualFold(v string) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldEqualFold(FieldMatchMode, v))
|
||||
}
|
||||
|
||||
// MatchModeContainsFold applies the ContainsFold predicate on the "match_mode" field.
|
||||
func MatchModeContainsFold(v string) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldContainsFold(FieldMatchMode, v))
|
||||
}
|
||||
|
||||
// PlatformsIsNil applies the IsNil predicate on the "platforms" field.
|
||||
func PlatformsIsNil() predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldIsNull(FieldPlatforms))
|
||||
}
|
||||
|
||||
// PlatformsNotNil applies the NotNil predicate on the "platforms" field.
|
||||
func PlatformsNotNil() predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldNotNull(FieldPlatforms))
|
||||
}
|
||||
|
||||
// PassthroughCodeEQ applies the EQ predicate on the "passthrough_code" field.
|
||||
func PassthroughCodeEQ(v bool) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldPassthroughCode, v))
|
||||
}
|
||||
|
||||
// PassthroughCodeNEQ applies the NEQ predicate on the "passthrough_code" field.
|
||||
func PassthroughCodeNEQ(v bool) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldNEQ(FieldPassthroughCode, v))
|
||||
}
|
||||
|
||||
// ResponseCodeEQ applies the EQ predicate on the "response_code" field.
|
||||
func ResponseCodeEQ(v int) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldResponseCode, v))
|
||||
}
|
||||
|
||||
// ResponseCodeNEQ applies the NEQ predicate on the "response_code" field.
|
||||
func ResponseCodeNEQ(v int) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldNEQ(FieldResponseCode, v))
|
||||
}
|
||||
|
||||
// ResponseCodeIn applies the In predicate on the "response_code" field.
|
||||
func ResponseCodeIn(vs ...int) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldIn(FieldResponseCode, vs...))
|
||||
}
|
||||
|
||||
// ResponseCodeNotIn applies the NotIn predicate on the "response_code" field.
|
||||
func ResponseCodeNotIn(vs ...int) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldNotIn(FieldResponseCode, vs...))
|
||||
}
|
||||
|
||||
// ResponseCodeGT applies the GT predicate on the "response_code" field.
|
||||
func ResponseCodeGT(v int) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldGT(FieldResponseCode, v))
|
||||
}
|
||||
|
||||
// ResponseCodeGTE applies the GTE predicate on the "response_code" field.
|
||||
func ResponseCodeGTE(v int) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldGTE(FieldResponseCode, v))
|
||||
}
|
||||
|
||||
// ResponseCodeLT applies the LT predicate on the "response_code" field.
|
||||
func ResponseCodeLT(v int) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldLT(FieldResponseCode, v))
|
||||
}
|
||||
|
||||
// ResponseCodeLTE applies the LTE predicate on the "response_code" field.
|
||||
func ResponseCodeLTE(v int) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldLTE(FieldResponseCode, v))
|
||||
}
|
||||
|
||||
// ResponseCodeIsNil applies the IsNil predicate on the "response_code" field.
|
||||
func ResponseCodeIsNil() predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldIsNull(FieldResponseCode))
|
||||
}
|
||||
|
||||
// ResponseCodeNotNil applies the NotNil predicate on the "response_code" field.
|
||||
func ResponseCodeNotNil() predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldNotNull(FieldResponseCode))
|
||||
}
|
||||
|
||||
// PassthroughBodyEQ applies the EQ predicate on the "passthrough_body" field.
|
||||
func PassthroughBodyEQ(v bool) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldPassthroughBody, v))
|
||||
}
|
||||
|
||||
// PassthroughBodyNEQ applies the NEQ predicate on the "passthrough_body" field.
|
||||
func PassthroughBodyNEQ(v bool) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldNEQ(FieldPassthroughBody, v))
|
||||
}
|
||||
|
||||
// CustomMessageEQ applies the EQ predicate on the "custom_message" field.
|
||||
func CustomMessageEQ(v string) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldCustomMessage, v))
|
||||
}
|
||||
|
||||
// CustomMessageNEQ applies the NEQ predicate on the "custom_message" field.
|
||||
func CustomMessageNEQ(v string) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldNEQ(FieldCustomMessage, v))
|
||||
}
|
||||
|
||||
// CustomMessageIn applies the In predicate on the "custom_message" field.
|
||||
func CustomMessageIn(vs ...string) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldIn(FieldCustomMessage, vs...))
|
||||
}
|
||||
|
||||
// CustomMessageNotIn applies the NotIn predicate on the "custom_message" field.
|
||||
func CustomMessageNotIn(vs ...string) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldNotIn(FieldCustomMessage, vs...))
|
||||
}
|
||||
|
||||
// CustomMessageGT applies the GT predicate on the "custom_message" field.
|
||||
func CustomMessageGT(v string) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldGT(FieldCustomMessage, v))
|
||||
}
|
||||
|
||||
// CustomMessageGTE applies the GTE predicate on the "custom_message" field.
|
||||
func CustomMessageGTE(v string) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldGTE(FieldCustomMessage, v))
|
||||
}
|
||||
|
||||
// CustomMessageLT applies the LT predicate on the "custom_message" field.
|
||||
func CustomMessageLT(v string) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldLT(FieldCustomMessage, v))
|
||||
}
|
||||
|
||||
// CustomMessageLTE applies the LTE predicate on the "custom_message" field.
|
||||
func CustomMessageLTE(v string) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldLTE(FieldCustomMessage, v))
|
||||
}
|
||||
|
||||
// CustomMessageContains applies the Contains predicate on the "custom_message" field.
|
||||
func CustomMessageContains(v string) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldContains(FieldCustomMessage, v))
|
||||
}
|
||||
|
||||
// CustomMessageHasPrefix applies the HasPrefix predicate on the "custom_message" field.
|
||||
func CustomMessageHasPrefix(v string) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldHasPrefix(FieldCustomMessage, v))
|
||||
}
|
||||
|
||||
// CustomMessageHasSuffix applies the HasSuffix predicate on the "custom_message" field.
|
||||
func CustomMessageHasSuffix(v string) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldHasSuffix(FieldCustomMessage, v))
|
||||
}
|
||||
|
||||
// CustomMessageIsNil applies the IsNil predicate on the "custom_message" field.
|
||||
func CustomMessageIsNil() predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldIsNull(FieldCustomMessage))
|
||||
}
|
||||
|
||||
// CustomMessageNotNil applies the NotNil predicate on the "custom_message" field.
|
||||
func CustomMessageNotNil() predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldNotNull(FieldCustomMessage))
|
||||
}
|
||||
|
||||
// CustomMessageEqualFold applies the EqualFold predicate on the "custom_message" field.
|
||||
func CustomMessageEqualFold(v string) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldEqualFold(FieldCustomMessage, v))
|
||||
}
|
||||
|
||||
// CustomMessageContainsFold applies the ContainsFold predicate on the "custom_message" field.
|
||||
func CustomMessageContainsFold(v string) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldContainsFold(FieldCustomMessage, v))
|
||||
}
|
||||
|
||||
// DescriptionEQ applies the EQ predicate on the "description" field.
|
||||
func DescriptionEQ(v string) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldDescription, v))
|
||||
}
|
||||
|
||||
// DescriptionNEQ applies the NEQ predicate on the "description" field.
|
||||
func DescriptionNEQ(v string) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldNEQ(FieldDescription, v))
|
||||
}
|
||||
|
||||
// DescriptionIn applies the In predicate on the "description" field.
|
||||
func DescriptionIn(vs ...string) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldIn(FieldDescription, vs...))
|
||||
}
|
||||
|
||||
// DescriptionNotIn applies the NotIn predicate on the "description" field.
|
||||
func DescriptionNotIn(vs ...string) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldNotIn(FieldDescription, vs...))
|
||||
}
|
||||
|
||||
// DescriptionGT applies the GT predicate on the "description" field.
|
||||
func DescriptionGT(v string) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldGT(FieldDescription, v))
|
||||
}
|
||||
|
||||
// DescriptionGTE applies the GTE predicate on the "description" field.
|
||||
func DescriptionGTE(v string) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldGTE(FieldDescription, v))
|
||||
}
|
||||
|
||||
// DescriptionLT applies the LT predicate on the "description" field.
|
||||
func DescriptionLT(v string) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldLT(FieldDescription, v))
|
||||
}
|
||||
|
||||
// DescriptionLTE applies the LTE predicate on the "description" field.
|
||||
func DescriptionLTE(v string) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldLTE(FieldDescription, v))
|
||||
}
|
||||
|
||||
// DescriptionContains applies the Contains predicate on the "description" field.
|
||||
func DescriptionContains(v string) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldContains(FieldDescription, v))
|
||||
}
|
||||
|
||||
// DescriptionHasPrefix applies the HasPrefix predicate on the "description" field.
|
||||
func DescriptionHasPrefix(v string) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldHasPrefix(FieldDescription, v))
|
||||
}
|
||||
|
||||
// DescriptionHasSuffix applies the HasSuffix predicate on the "description" field.
|
||||
func DescriptionHasSuffix(v string) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldHasSuffix(FieldDescription, v))
|
||||
}
|
||||
|
||||
// DescriptionIsNil applies the IsNil predicate on the "description" field.
|
||||
func DescriptionIsNil() predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldIsNull(FieldDescription))
|
||||
}
|
||||
|
||||
// DescriptionNotNil applies the NotNil predicate on the "description" field.
|
||||
func DescriptionNotNil() predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldNotNull(FieldDescription))
|
||||
}
|
||||
|
||||
// DescriptionEqualFold applies the EqualFold predicate on the "description" field.
|
||||
func DescriptionEqualFold(v string) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldEqualFold(FieldDescription, v))
|
||||
}
|
||||
|
||||
// DescriptionContainsFold applies the ContainsFold predicate on the "description" field.
|
||||
func DescriptionContainsFold(v string) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.FieldContainsFold(FieldDescription, v))
|
||||
}
|
||||
|
||||
// And groups predicates with the AND operator between them.
|
||||
func And(predicates ...predicate.ErrorPassthroughRule) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.AndPredicates(predicates...))
|
||||
}
|
||||
|
||||
// Or groups predicates with the OR operator between them.
|
||||
func Or(predicates ...predicate.ErrorPassthroughRule) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.OrPredicates(predicates...))
|
||||
}
|
||||
|
||||
// Not applies the not operator on the given predicate.
|
||||
func Not(p predicate.ErrorPassthroughRule) predicate.ErrorPassthroughRule {
|
||||
return predicate.ErrorPassthroughRule(sql.NotPredicates(p))
|
||||
}
|
||||
1382
backend/ent/errorpassthroughrule_create.go
Normal file
1382
backend/ent/errorpassthroughrule_create.go
Normal file
File diff suppressed because it is too large
Load Diff
88
backend/ent/errorpassthroughrule_delete.go
Normal file
88
backend/ent/errorpassthroughrule_delete.go
Normal file
@@ -0,0 +1,88 @@
|
||||
// Code generated by ent, DO NOT EDIT.
|
||||
|
||||
package ent
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"entgo.io/ent/dialect/sql"
|
||||
"entgo.io/ent/dialect/sql/sqlgraph"
|
||||
"entgo.io/ent/schema/field"
|
||||
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
|
||||
"github.com/Wei-Shaw/sub2api/ent/predicate"
|
||||
)
|
||||
|
||||
// ErrorPassthroughRuleDelete is the builder for deleting a ErrorPassthroughRule entity.
|
||||
type ErrorPassthroughRuleDelete struct {
|
||||
config
|
||||
hooks []Hook
|
||||
mutation *ErrorPassthroughRuleMutation
|
||||
}
|
||||
|
||||
// Where appends a list predicates to the ErrorPassthroughRuleDelete builder.
|
||||
func (_d *ErrorPassthroughRuleDelete) Where(ps ...predicate.ErrorPassthroughRule) *ErrorPassthroughRuleDelete {
|
||||
_d.mutation.Where(ps...)
|
||||
return _d
|
||||
}
|
||||
|
||||
// Exec executes the deletion query and returns how many vertices were deleted.
|
||||
func (_d *ErrorPassthroughRuleDelete) Exec(ctx context.Context) (int, error) {
|
||||
return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks)
|
||||
}
|
||||
|
||||
// ExecX is like Exec, but panics if an error occurs.
|
||||
func (_d *ErrorPassthroughRuleDelete) ExecX(ctx context.Context) int {
|
||||
n, err := _d.Exec(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func (_d *ErrorPassthroughRuleDelete) sqlExec(ctx context.Context) (int, error) {
|
||||
_spec := sqlgraph.NewDeleteSpec(errorpassthroughrule.Table, sqlgraph.NewFieldSpec(errorpassthroughrule.FieldID, field.TypeInt64))
|
||||
if ps := _d.mutation.predicates; len(ps) > 0 {
|
||||
_spec.Predicate = func(selector *sql.Selector) {
|
||||
for i := range ps {
|
||||
ps[i](selector)
|
||||
}
|
||||
}
|
||||
}
|
||||
affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec)
|
||||
if err != nil && sqlgraph.IsConstraintError(err) {
|
||||
err = &ConstraintError{msg: err.Error(), wrap: err}
|
||||
}
|
||||
_d.mutation.done = true
|
||||
return affected, err
|
||||
}
|
||||
|
||||
// ErrorPassthroughRuleDeleteOne is the builder for deleting a single ErrorPassthroughRule entity.
|
||||
type ErrorPassthroughRuleDeleteOne struct {
|
||||
_d *ErrorPassthroughRuleDelete
|
||||
}
|
||||
|
||||
// Where appends a list predicates to the ErrorPassthroughRuleDelete builder.
|
||||
func (_d *ErrorPassthroughRuleDeleteOne) Where(ps ...predicate.ErrorPassthroughRule) *ErrorPassthroughRuleDeleteOne {
|
||||
_d._d.mutation.Where(ps...)
|
||||
return _d
|
||||
}
|
||||
|
||||
// Exec executes the deletion query.
|
||||
func (_d *ErrorPassthroughRuleDeleteOne) Exec(ctx context.Context) error {
|
||||
n, err := _d._d.Exec(ctx)
|
||||
switch {
|
||||
case err != nil:
|
||||
return err
|
||||
case n == 0:
|
||||
return &NotFoundError{errorpassthroughrule.Label}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// ExecX is like Exec, but panics if an error occurs.
|
||||
func (_d *ErrorPassthroughRuleDeleteOne) ExecX(ctx context.Context) {
|
||||
if err := _d.Exec(ctx); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
564
backend/ent/errorpassthroughrule_query.go
Normal file
564
backend/ent/errorpassthroughrule_query.go
Normal file
@@ -0,0 +1,564 @@
|
||||
// Code generated by ent, DO NOT EDIT.
|
||||
|
||||
package ent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"entgo.io/ent"
|
||||
"entgo.io/ent/dialect"
|
||||
"entgo.io/ent/dialect/sql"
|
||||
"entgo.io/ent/dialect/sql/sqlgraph"
|
||||
"entgo.io/ent/schema/field"
|
||||
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
|
||||
"github.com/Wei-Shaw/sub2api/ent/predicate"
|
||||
)
|
||||
|
||||
// ErrorPassthroughRuleQuery is the builder for querying ErrorPassthroughRule entities.
|
||||
type ErrorPassthroughRuleQuery struct {
|
||||
config
|
||||
ctx *QueryContext
|
||||
order []errorpassthroughrule.OrderOption
|
||||
inters []Interceptor
|
||||
predicates []predicate.ErrorPassthroughRule
|
||||
modifiers []func(*sql.Selector)
|
||||
// intermediate query (i.e. traversal path).
|
||||
sql *sql.Selector
|
||||
path func(context.Context) (*sql.Selector, error)
|
||||
}
|
||||
|
||||
// Where adds a new predicate for the ErrorPassthroughRuleQuery builder.
|
||||
func (_q *ErrorPassthroughRuleQuery) Where(ps ...predicate.ErrorPassthroughRule) *ErrorPassthroughRuleQuery {
|
||||
_q.predicates = append(_q.predicates, ps...)
|
||||
return _q
|
||||
}
|
||||
|
||||
// Limit the number of records to be returned by this query.
|
||||
func (_q *ErrorPassthroughRuleQuery) Limit(limit int) *ErrorPassthroughRuleQuery {
|
||||
_q.ctx.Limit = &limit
|
||||
return _q
|
||||
}
|
||||
|
||||
// Offset to start from.
|
||||
func (_q *ErrorPassthroughRuleQuery) Offset(offset int) *ErrorPassthroughRuleQuery {
|
||||
_q.ctx.Offset = &offset
|
||||
return _q
|
||||
}
|
||||
|
||||
// Unique configures the query builder to filter duplicate records on query.
|
||||
// By default, unique is set to true, and can be disabled using this method.
|
||||
func (_q *ErrorPassthroughRuleQuery) Unique(unique bool) *ErrorPassthroughRuleQuery {
|
||||
_q.ctx.Unique = &unique
|
||||
return _q
|
||||
}
|
||||
|
||||
// Order specifies how the records should be ordered.
|
||||
func (_q *ErrorPassthroughRuleQuery) Order(o ...errorpassthroughrule.OrderOption) *ErrorPassthroughRuleQuery {
|
||||
_q.order = append(_q.order, o...)
|
||||
return _q
|
||||
}
|
||||
|
||||
// First returns the first ErrorPassthroughRule entity from the query.
|
||||
// Returns a *NotFoundError when no ErrorPassthroughRule was found.
|
||||
func (_q *ErrorPassthroughRuleQuery) First(ctx context.Context) (*ErrorPassthroughRule, error) {
|
||||
nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(nodes) == 0 {
|
||||
return nil, &NotFoundError{errorpassthroughrule.Label}
|
||||
}
|
||||
return nodes[0], nil
|
||||
}
|
||||
|
||||
// FirstX is like First, but panics if an error occurs.
|
||||
func (_q *ErrorPassthroughRuleQuery) FirstX(ctx context.Context) *ErrorPassthroughRule {
|
||||
node, err := _q.First(ctx)
|
||||
if err != nil && !IsNotFound(err) {
|
||||
panic(err)
|
||||
}
|
||||
return node
|
||||
}
|
||||
|
||||
// FirstID returns the first ErrorPassthroughRule ID from the query.
|
||||
// Returns a *NotFoundError when no ErrorPassthroughRule ID was found.
|
||||
func (_q *ErrorPassthroughRuleQuery) FirstID(ctx context.Context) (id int64, err error) {
|
||||
var ids []int64
|
||||
if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil {
|
||||
return
|
||||
}
|
||||
if len(ids) == 0 {
|
||||
err = &NotFoundError{errorpassthroughrule.Label}
|
||||
return
|
||||
}
|
||||
return ids[0], nil
|
||||
}
|
||||
|
||||
// FirstIDX is like FirstID, but panics if an error occurs.
|
||||
func (_q *ErrorPassthroughRuleQuery) FirstIDX(ctx context.Context) int64 {
|
||||
id, err := _q.FirstID(ctx)
|
||||
if err != nil && !IsNotFound(err) {
|
||||
panic(err)
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
// Only returns a single ErrorPassthroughRule entity found by the query, ensuring it only returns one.
|
||||
// Returns a *NotSingularError when more than one ErrorPassthroughRule entity is found.
|
||||
// Returns a *NotFoundError when no ErrorPassthroughRule entities are found.
|
||||
func (_q *ErrorPassthroughRuleQuery) Only(ctx context.Context) (*ErrorPassthroughRule, error) {
|
||||
nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch len(nodes) {
|
||||
case 1:
|
||||
return nodes[0], nil
|
||||
case 0:
|
||||
return nil, &NotFoundError{errorpassthroughrule.Label}
|
||||
default:
|
||||
return nil, &NotSingularError{errorpassthroughrule.Label}
|
||||
}
|
||||
}
|
||||
|
||||
// OnlyX is like Only, but panics if an error occurs.
|
||||
func (_q *ErrorPassthroughRuleQuery) OnlyX(ctx context.Context) *ErrorPassthroughRule {
|
||||
node, err := _q.Only(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return node
|
||||
}
|
||||
|
||||
// OnlyID is like Only, but returns the only ErrorPassthroughRule ID in the query.
|
||||
// Returns a *NotSingularError when more than one ErrorPassthroughRule ID is found.
|
||||
// Returns a *NotFoundError when no entities are found.
|
||||
func (_q *ErrorPassthroughRuleQuery) OnlyID(ctx context.Context) (id int64, err error) {
|
||||
var ids []int64
|
||||
if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil {
|
||||
return
|
||||
}
|
||||
switch len(ids) {
|
||||
case 1:
|
||||
id = ids[0]
|
||||
case 0:
|
||||
err = &NotFoundError{errorpassthroughrule.Label}
|
||||
default:
|
||||
err = &NotSingularError{errorpassthroughrule.Label}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// OnlyIDX is like OnlyID, but panics if an error occurs.
|
||||
func (_q *ErrorPassthroughRuleQuery) OnlyIDX(ctx context.Context) int64 {
|
||||
id, err := _q.OnlyID(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
// All executes the query and returns a list of ErrorPassthroughRules.
|
||||
func (_q *ErrorPassthroughRuleQuery) All(ctx context.Context) ([]*ErrorPassthroughRule, error) {
|
||||
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll)
|
||||
if err := _q.prepareQuery(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
qr := querierAll[[]*ErrorPassthroughRule, *ErrorPassthroughRuleQuery]()
|
||||
return withInterceptors[[]*ErrorPassthroughRule](ctx, _q, qr, _q.inters)
|
||||
}
|
||||
|
||||
// AllX is like All, but panics if an error occurs.
|
||||
func (_q *ErrorPassthroughRuleQuery) AllX(ctx context.Context) []*ErrorPassthroughRule {
|
||||
nodes, err := _q.All(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return nodes
|
||||
}
|
||||
|
||||
// IDs executes the query and returns a list of ErrorPassthroughRule IDs.
|
||||
func (_q *ErrorPassthroughRuleQuery) IDs(ctx context.Context) (ids []int64, err error) {
|
||||
if _q.ctx.Unique == nil && _q.path != nil {
|
||||
_q.Unique(true)
|
||||
}
|
||||
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs)
|
||||
if err = _q.Select(errorpassthroughrule.FieldID).Scan(ctx, &ids); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
// IDsX is like IDs, but panics if an error occurs.
|
||||
func (_q *ErrorPassthroughRuleQuery) IDsX(ctx context.Context) []int64 {
|
||||
ids, err := _q.IDs(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
// Count returns the count of the given query.
|
||||
func (_q *ErrorPassthroughRuleQuery) Count(ctx context.Context) (int, error) {
|
||||
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount)
|
||||
if err := _q.prepareQuery(ctx); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return withInterceptors[int](ctx, _q, querierCount[*ErrorPassthroughRuleQuery](), _q.inters)
|
||||
}
|
||||
|
||||
// CountX is like Count, but panics if an error occurs.
|
||||
func (_q *ErrorPassthroughRuleQuery) CountX(ctx context.Context) int {
|
||||
count, err := _q.Count(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
// Exist returns true if the query has elements in the graph.
|
||||
func (_q *ErrorPassthroughRuleQuery) Exist(ctx context.Context) (bool, error) {
|
||||
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist)
|
||||
switch _, err := _q.FirstID(ctx); {
|
||||
case IsNotFound(err):
|
||||
return false, nil
|
||||
case err != nil:
|
||||
return false, fmt.Errorf("ent: check existence: %w", err)
|
||||
default:
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
// ExistX is like Exist, but panics if an error occurs.
|
||||
func (_q *ErrorPassthroughRuleQuery) ExistX(ctx context.Context) bool {
|
||||
exist, err := _q.Exist(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return exist
|
||||
}
|
||||
|
||||
// Clone returns a duplicate of the ErrorPassthroughRuleQuery builder, including all associated steps. It can be
|
||||
// used to prepare common query builders and use them differently after the clone is made.
|
||||
func (_q *ErrorPassthroughRuleQuery) Clone() *ErrorPassthroughRuleQuery {
|
||||
if _q == nil {
|
||||
return nil
|
||||
}
|
||||
return &ErrorPassthroughRuleQuery{
|
||||
config: _q.config,
|
||||
ctx: _q.ctx.Clone(),
|
||||
order: append([]errorpassthroughrule.OrderOption{}, _q.order...),
|
||||
inters: append([]Interceptor{}, _q.inters...),
|
||||
predicates: append([]predicate.ErrorPassthroughRule{}, _q.predicates...),
|
||||
// clone intermediate query.
|
||||
sql: _q.sql.Clone(),
|
||||
path: _q.path,
|
||||
}
|
||||
}
|
||||
|
||||
// GroupBy is used to group vertices by one or more fields/columns.
|
||||
// It is often used with aggregate functions, like: count, max, mean, min, sum.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// var v []struct {
|
||||
// CreatedAt time.Time `json:"created_at,omitempty"`
|
||||
// Count int `json:"count,omitempty"`
|
||||
// }
|
||||
//
|
||||
// client.ErrorPassthroughRule.Query().
|
||||
// GroupBy(errorpassthroughrule.FieldCreatedAt).
|
||||
// Aggregate(ent.Count()).
|
||||
// Scan(ctx, &v)
|
||||
func (_q *ErrorPassthroughRuleQuery) GroupBy(field string, fields ...string) *ErrorPassthroughRuleGroupBy {
|
||||
_q.ctx.Fields = append([]string{field}, fields...)
|
||||
grbuild := &ErrorPassthroughRuleGroupBy{build: _q}
|
||||
grbuild.flds = &_q.ctx.Fields
|
||||
grbuild.label = errorpassthroughrule.Label
|
||||
grbuild.scan = grbuild.Scan
|
||||
return grbuild
|
||||
}
|
||||
|
||||
// Select allows the selection one or more fields/columns for the given query,
|
||||
// instead of selecting all fields in the entity.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// var v []struct {
|
||||
// CreatedAt time.Time `json:"created_at,omitempty"`
|
||||
// }
|
||||
//
|
||||
// client.ErrorPassthroughRule.Query().
|
||||
// Select(errorpassthroughrule.FieldCreatedAt).
|
||||
// Scan(ctx, &v)
|
||||
func (_q *ErrorPassthroughRuleQuery) Select(fields ...string) *ErrorPassthroughRuleSelect {
|
||||
_q.ctx.Fields = append(_q.ctx.Fields, fields...)
|
||||
sbuild := &ErrorPassthroughRuleSelect{ErrorPassthroughRuleQuery: _q}
|
||||
sbuild.label = errorpassthroughrule.Label
|
||||
sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan
|
||||
return sbuild
|
||||
}
|
||||
|
||||
// Aggregate returns a ErrorPassthroughRuleSelect configured with the given aggregations.
|
||||
func (_q *ErrorPassthroughRuleQuery) Aggregate(fns ...AggregateFunc) *ErrorPassthroughRuleSelect {
|
||||
return _q.Select().Aggregate(fns...)
|
||||
}
|
||||
|
||||
func (_q *ErrorPassthroughRuleQuery) prepareQuery(ctx context.Context) error {
|
||||
for _, inter := range _q.inters {
|
||||
if inter == nil {
|
||||
return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
|
||||
}
|
||||
if trv, ok := inter.(Traverser); ok {
|
||||
if err := trv.Traverse(ctx, _q); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, f := range _q.ctx.Fields {
|
||||
if !errorpassthroughrule.ValidColumn(f) {
|
||||
return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
|
||||
}
|
||||
}
|
||||
if _q.path != nil {
|
||||
prev, err := _q.path(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_q.sql = prev
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (_q *ErrorPassthroughRuleQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*ErrorPassthroughRule, error) {
|
||||
var (
|
||||
nodes = []*ErrorPassthroughRule{}
|
||||
_spec = _q.querySpec()
|
||||
)
|
||||
_spec.ScanValues = func(columns []string) ([]any, error) {
|
||||
return (*ErrorPassthroughRule).scanValues(nil, columns)
|
||||
}
|
||||
_spec.Assign = func(columns []string, values []any) error {
|
||||
node := &ErrorPassthroughRule{config: _q.config}
|
||||
nodes = append(nodes, node)
|
||||
return node.assignValues(columns, values)
|
||||
}
|
||||
if len(_q.modifiers) > 0 {
|
||||
_spec.Modifiers = _q.modifiers
|
||||
}
|
||||
for i := range hooks {
|
||||
hooks[i](ctx, _spec)
|
||||
}
|
||||
if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(nodes) == 0 {
|
||||
return nodes, nil
|
||||
}
|
||||
return nodes, nil
|
||||
}
|
||||
|
||||
func (_q *ErrorPassthroughRuleQuery) sqlCount(ctx context.Context) (int, error) {
|
||||
_spec := _q.querySpec()
|
||||
if len(_q.modifiers) > 0 {
|
||||
_spec.Modifiers = _q.modifiers
|
||||
}
|
||||
_spec.Node.Columns = _q.ctx.Fields
|
||||
if len(_q.ctx.Fields) > 0 {
|
||||
_spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
|
||||
}
|
||||
return sqlgraph.CountNodes(ctx, _q.driver, _spec)
|
||||
}
|
||||
|
||||
func (_q *ErrorPassthroughRuleQuery) querySpec() *sqlgraph.QuerySpec {
|
||||
_spec := sqlgraph.NewQuerySpec(errorpassthroughrule.Table, errorpassthroughrule.Columns, sqlgraph.NewFieldSpec(errorpassthroughrule.FieldID, field.TypeInt64))
|
||||
_spec.From = _q.sql
|
||||
if unique := _q.ctx.Unique; unique != nil {
|
||||
_spec.Unique = *unique
|
||||
} else if _q.path != nil {
|
||||
_spec.Unique = true
|
||||
}
|
||||
if fields := _q.ctx.Fields; len(fields) > 0 {
|
||||
_spec.Node.Columns = make([]string, 0, len(fields))
|
||||
_spec.Node.Columns = append(_spec.Node.Columns, errorpassthroughrule.FieldID)
|
||||
for i := range fields {
|
||||
if fields[i] != errorpassthroughrule.FieldID {
|
||||
_spec.Node.Columns = append(_spec.Node.Columns, fields[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
if ps := _q.predicates; len(ps) > 0 {
|
||||
_spec.Predicate = func(selector *sql.Selector) {
|
||||
for i := range ps {
|
||||
ps[i](selector)
|
||||
}
|
||||
}
|
||||
}
|
||||
if limit := _q.ctx.Limit; limit != nil {
|
||||
_spec.Limit = *limit
|
||||
}
|
||||
if offset := _q.ctx.Offset; offset != nil {
|
||||
_spec.Offset = *offset
|
||||
}
|
||||
if ps := _q.order; len(ps) > 0 {
|
||||
_spec.Order = func(selector *sql.Selector) {
|
||||
for i := range ps {
|
||||
ps[i](selector)
|
||||
}
|
||||
}
|
||||
}
|
||||
return _spec
|
||||
}
|
||||
|
||||
func (_q *ErrorPassthroughRuleQuery) sqlQuery(ctx context.Context) *sql.Selector {
|
||||
builder := sql.Dialect(_q.driver.Dialect())
|
||||
t1 := builder.Table(errorpassthroughrule.Table)
|
||||
columns := _q.ctx.Fields
|
||||
if len(columns) == 0 {
|
||||
columns = errorpassthroughrule.Columns
|
||||
}
|
||||
selector := builder.Select(t1.Columns(columns...)...).From(t1)
|
||||
if _q.sql != nil {
|
||||
selector = _q.sql
|
||||
selector.Select(selector.Columns(columns...)...)
|
||||
}
|
||||
if _q.ctx.Unique != nil && *_q.ctx.Unique {
|
||||
selector.Distinct()
|
||||
}
|
||||
for _, m := range _q.modifiers {
|
||||
m(selector)
|
||||
}
|
||||
for _, p := range _q.predicates {
|
||||
p(selector)
|
||||
}
|
||||
for _, p := range _q.order {
|
||||
p(selector)
|
||||
}
|
||||
if offset := _q.ctx.Offset; offset != nil {
|
||||
// limit is mandatory for offset clause. We start
|
||||
// with default value, and override it below if needed.
|
||||
selector.Offset(*offset).Limit(math.MaxInt32)
|
||||
}
|
||||
if limit := _q.ctx.Limit; limit != nil {
|
||||
selector.Limit(*limit)
|
||||
}
|
||||
return selector
|
||||
}
|
||||
|
||||
// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
|
||||
// updated, deleted or "selected ... for update" by other sessions, until the transaction is
|
||||
// either committed or rolled-back.
|
||||
func (_q *ErrorPassthroughRuleQuery) ForUpdate(opts ...sql.LockOption) *ErrorPassthroughRuleQuery {
|
||||
if _q.driver.Dialect() == dialect.Postgres {
|
||||
_q.Unique(false)
|
||||
}
|
||||
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
|
||||
s.ForUpdate(opts...)
|
||||
})
|
||||
return _q
|
||||
}
|
||||
|
||||
// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
|
||||
// on any rows that are read. Other sessions can read the rows, but cannot modify them
|
||||
// until your transaction commits.
|
||||
func (_q *ErrorPassthroughRuleQuery) ForShare(opts ...sql.LockOption) *ErrorPassthroughRuleQuery {
|
||||
if _q.driver.Dialect() == dialect.Postgres {
|
||||
_q.Unique(false)
|
||||
}
|
||||
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
|
||||
s.ForShare(opts...)
|
||||
})
|
||||
return _q
|
||||
}
|
||||
|
||||
// ErrorPassthroughRuleGroupBy is the group-by builder for ErrorPassthroughRule entities.
|
||||
type ErrorPassthroughRuleGroupBy struct {
|
||||
selector
|
||||
build *ErrorPassthroughRuleQuery
|
||||
}
|
||||
|
||||
// Aggregate adds the given aggregation functions to the group-by query.
|
||||
func (_g *ErrorPassthroughRuleGroupBy) Aggregate(fns ...AggregateFunc) *ErrorPassthroughRuleGroupBy {
|
||||
_g.fns = append(_g.fns, fns...)
|
||||
return _g
|
||||
}
|
||||
|
||||
// Scan applies the selector query and scans the result into the given value.
|
||||
func (_g *ErrorPassthroughRuleGroupBy) Scan(ctx context.Context, v any) error {
|
||||
ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy)
|
||||
if err := _g.build.prepareQuery(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
return scanWithInterceptors[*ErrorPassthroughRuleQuery, *ErrorPassthroughRuleGroupBy](ctx, _g.build, _g, _g.build.inters, v)
|
||||
}
|
||||
|
||||
func (_g *ErrorPassthroughRuleGroupBy) sqlScan(ctx context.Context, root *ErrorPassthroughRuleQuery, v any) error {
|
||||
selector := root.sqlQuery(ctx).Select()
|
||||
aggregation := make([]string, 0, len(_g.fns))
|
||||
for _, fn := range _g.fns {
|
||||
aggregation = append(aggregation, fn(selector))
|
||||
}
|
||||
if len(selector.SelectedColumns()) == 0 {
|
||||
columns := make([]string, 0, len(*_g.flds)+len(_g.fns))
|
||||
for _, f := range *_g.flds {
|
||||
columns = append(columns, selector.C(f))
|
||||
}
|
||||
columns = append(columns, aggregation...)
|
||||
selector.Select(columns...)
|
||||
}
|
||||
selector.GroupBy(selector.Columns(*_g.flds...)...)
|
||||
if err := selector.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
rows := &sql.Rows{}
|
||||
query, args := selector.Query()
|
||||
if err := _g.build.driver.Query(ctx, query, args, rows); err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
return sql.ScanSlice(rows, v)
|
||||
}
|
||||
|
||||
// ErrorPassthroughRuleSelect is the builder for selecting fields of ErrorPassthroughRule entities.
|
||||
type ErrorPassthroughRuleSelect struct {
|
||||
*ErrorPassthroughRuleQuery
|
||||
selector
|
||||
}
|
||||
|
||||
// Aggregate adds the given aggregation functions to the selector query.
|
||||
func (_s *ErrorPassthroughRuleSelect) Aggregate(fns ...AggregateFunc) *ErrorPassthroughRuleSelect {
|
||||
_s.fns = append(_s.fns, fns...)
|
||||
return _s
|
||||
}
|
||||
|
||||
// Scan applies the selector query and scans the result into the given value.
|
||||
func (_s *ErrorPassthroughRuleSelect) Scan(ctx context.Context, v any) error {
|
||||
ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect)
|
||||
if err := _s.prepareQuery(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
return scanWithInterceptors[*ErrorPassthroughRuleQuery, *ErrorPassthroughRuleSelect](ctx, _s.ErrorPassthroughRuleQuery, _s, _s.inters, v)
|
||||
}
|
||||
|
||||
func (_s *ErrorPassthroughRuleSelect) sqlScan(ctx context.Context, root *ErrorPassthroughRuleQuery, v any) error {
|
||||
selector := root.sqlQuery(ctx)
|
||||
aggregation := make([]string, 0, len(_s.fns))
|
||||
for _, fn := range _s.fns {
|
||||
aggregation = append(aggregation, fn(selector))
|
||||
}
|
||||
switch n := len(*_s.selector.flds); {
|
||||
case n == 0 && len(aggregation) > 0:
|
||||
selector.Select(aggregation...)
|
||||
case n != 0 && len(aggregation) > 0:
|
||||
selector.AppendSelect(aggregation...)
|
||||
}
|
||||
rows := &sql.Rows{}
|
||||
query, args := selector.Query()
|
||||
if err := _s.driver.Query(ctx, query, args, rows); err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
return sql.ScanSlice(rows, v)
|
||||
}
|
||||
823
backend/ent/errorpassthroughrule_update.go
Normal file
823
backend/ent/errorpassthroughrule_update.go
Normal file
@@ -0,0 +1,823 @@
|
||||
// Code generated by ent, DO NOT EDIT.
|
||||
|
||||
package ent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"entgo.io/ent/dialect/sql"
|
||||
"entgo.io/ent/dialect/sql/sqlgraph"
|
||||
"entgo.io/ent/dialect/sql/sqljson"
|
||||
"entgo.io/ent/schema/field"
|
||||
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
|
||||
"github.com/Wei-Shaw/sub2api/ent/predicate"
|
||||
)
|
||||
|
||||
// ErrorPassthroughRuleUpdate is the builder for updating ErrorPassthroughRule entities.
|
||||
type ErrorPassthroughRuleUpdate struct {
|
||||
config
|
||||
hooks []Hook
|
||||
mutation *ErrorPassthroughRuleMutation
|
||||
}
|
||||
|
||||
// Where appends a list predicates to the ErrorPassthroughRuleUpdate builder.
|
||||
func (_u *ErrorPassthroughRuleUpdate) Where(ps ...predicate.ErrorPassthroughRule) *ErrorPassthroughRuleUpdate {
|
||||
_u.mutation.Where(ps...)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetUpdatedAt sets the "updated_at" field.
|
||||
func (_u *ErrorPassthroughRuleUpdate) SetUpdatedAt(v time.Time) *ErrorPassthroughRuleUpdate {
|
||||
_u.mutation.SetUpdatedAt(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetName sets the "name" field.
|
||||
func (_u *ErrorPassthroughRuleUpdate) SetName(v string) *ErrorPassthroughRuleUpdate {
|
||||
_u.mutation.SetName(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableName sets the "name" field if the given value is not nil.
|
||||
func (_u *ErrorPassthroughRuleUpdate) SetNillableName(v *string) *ErrorPassthroughRuleUpdate {
|
||||
if v != nil {
|
||||
_u.SetName(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetEnabled sets the "enabled" field.
|
||||
func (_u *ErrorPassthroughRuleUpdate) SetEnabled(v bool) *ErrorPassthroughRuleUpdate {
|
||||
_u.mutation.SetEnabled(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableEnabled sets the "enabled" field if the given value is not nil.
|
||||
func (_u *ErrorPassthroughRuleUpdate) SetNillableEnabled(v *bool) *ErrorPassthroughRuleUpdate {
|
||||
if v != nil {
|
||||
_u.SetEnabled(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetPriority sets the "priority" field.
|
||||
func (_u *ErrorPassthroughRuleUpdate) SetPriority(v int) *ErrorPassthroughRuleUpdate {
|
||||
_u.mutation.ResetPriority()
|
||||
_u.mutation.SetPriority(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillablePriority sets the "priority" field if the given value is not nil.
|
||||
func (_u *ErrorPassthroughRuleUpdate) SetNillablePriority(v *int) *ErrorPassthroughRuleUpdate {
|
||||
if v != nil {
|
||||
_u.SetPriority(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddPriority adds value to the "priority" field.
|
||||
func (_u *ErrorPassthroughRuleUpdate) AddPriority(v int) *ErrorPassthroughRuleUpdate {
|
||||
_u.mutation.AddPriority(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetErrorCodes sets the "error_codes" field.
|
||||
func (_u *ErrorPassthroughRuleUpdate) SetErrorCodes(v []int) *ErrorPassthroughRuleUpdate {
|
||||
_u.mutation.SetErrorCodes(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// AppendErrorCodes appends value to the "error_codes" field.
|
||||
func (_u *ErrorPassthroughRuleUpdate) AppendErrorCodes(v []int) *ErrorPassthroughRuleUpdate {
|
||||
_u.mutation.AppendErrorCodes(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearErrorCodes clears the value of the "error_codes" field.
|
||||
func (_u *ErrorPassthroughRuleUpdate) ClearErrorCodes() *ErrorPassthroughRuleUpdate {
|
||||
_u.mutation.ClearErrorCodes()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetKeywords sets the "keywords" field.
|
||||
func (_u *ErrorPassthroughRuleUpdate) SetKeywords(v []string) *ErrorPassthroughRuleUpdate {
|
||||
_u.mutation.SetKeywords(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// AppendKeywords appends value to the "keywords" field.
|
||||
func (_u *ErrorPassthroughRuleUpdate) AppendKeywords(v []string) *ErrorPassthroughRuleUpdate {
|
||||
_u.mutation.AppendKeywords(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearKeywords clears the value of the "keywords" field.
|
||||
func (_u *ErrorPassthroughRuleUpdate) ClearKeywords() *ErrorPassthroughRuleUpdate {
|
||||
_u.mutation.ClearKeywords()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetMatchMode sets the "match_mode" field.
|
||||
func (_u *ErrorPassthroughRuleUpdate) SetMatchMode(v string) *ErrorPassthroughRuleUpdate {
|
||||
_u.mutation.SetMatchMode(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableMatchMode sets the "match_mode" field if the given value is not nil.
|
||||
func (_u *ErrorPassthroughRuleUpdate) SetNillableMatchMode(v *string) *ErrorPassthroughRuleUpdate {
|
||||
if v != nil {
|
||||
_u.SetMatchMode(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetPlatforms sets the "platforms" field.
|
||||
func (_u *ErrorPassthroughRuleUpdate) SetPlatforms(v []string) *ErrorPassthroughRuleUpdate {
|
||||
_u.mutation.SetPlatforms(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// AppendPlatforms appends value to the "platforms" field.
|
||||
func (_u *ErrorPassthroughRuleUpdate) AppendPlatforms(v []string) *ErrorPassthroughRuleUpdate {
|
||||
_u.mutation.AppendPlatforms(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearPlatforms clears the value of the "platforms" field.
|
||||
func (_u *ErrorPassthroughRuleUpdate) ClearPlatforms() *ErrorPassthroughRuleUpdate {
|
||||
_u.mutation.ClearPlatforms()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetPassthroughCode sets the "passthrough_code" field.
|
||||
func (_u *ErrorPassthroughRuleUpdate) SetPassthroughCode(v bool) *ErrorPassthroughRuleUpdate {
|
||||
_u.mutation.SetPassthroughCode(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillablePassthroughCode sets the "passthrough_code" field if the given value is not nil.
|
||||
func (_u *ErrorPassthroughRuleUpdate) SetNillablePassthroughCode(v *bool) *ErrorPassthroughRuleUpdate {
|
||||
if v != nil {
|
||||
_u.SetPassthroughCode(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetResponseCode sets the "response_code" field.
|
||||
func (_u *ErrorPassthroughRuleUpdate) SetResponseCode(v int) *ErrorPassthroughRuleUpdate {
|
||||
_u.mutation.ResetResponseCode()
|
||||
_u.mutation.SetResponseCode(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableResponseCode sets the "response_code" field if the given value is not nil.
|
||||
func (_u *ErrorPassthroughRuleUpdate) SetNillableResponseCode(v *int) *ErrorPassthroughRuleUpdate {
|
||||
if v != nil {
|
||||
_u.SetResponseCode(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddResponseCode adds value to the "response_code" field.
|
||||
func (_u *ErrorPassthroughRuleUpdate) AddResponseCode(v int) *ErrorPassthroughRuleUpdate {
|
||||
_u.mutation.AddResponseCode(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearResponseCode clears the value of the "response_code" field.
|
||||
func (_u *ErrorPassthroughRuleUpdate) ClearResponseCode() *ErrorPassthroughRuleUpdate {
|
||||
_u.mutation.ClearResponseCode()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetPassthroughBody sets the "passthrough_body" field.
|
||||
func (_u *ErrorPassthroughRuleUpdate) SetPassthroughBody(v bool) *ErrorPassthroughRuleUpdate {
|
||||
_u.mutation.SetPassthroughBody(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillablePassthroughBody sets the "passthrough_body" field if the given value is not nil.
|
||||
func (_u *ErrorPassthroughRuleUpdate) SetNillablePassthroughBody(v *bool) *ErrorPassthroughRuleUpdate {
|
||||
if v != nil {
|
||||
_u.SetPassthroughBody(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetCustomMessage sets the "custom_message" field.
|
||||
func (_u *ErrorPassthroughRuleUpdate) SetCustomMessage(v string) *ErrorPassthroughRuleUpdate {
|
||||
_u.mutation.SetCustomMessage(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableCustomMessage sets the "custom_message" field if the given value is not nil.
|
||||
func (_u *ErrorPassthroughRuleUpdate) SetNillableCustomMessage(v *string) *ErrorPassthroughRuleUpdate {
|
||||
if v != nil {
|
||||
_u.SetCustomMessage(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearCustomMessage clears the value of the "custom_message" field.
|
||||
func (_u *ErrorPassthroughRuleUpdate) ClearCustomMessage() *ErrorPassthroughRuleUpdate {
|
||||
_u.mutation.ClearCustomMessage()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetDescription sets the "description" field.
|
||||
func (_u *ErrorPassthroughRuleUpdate) SetDescription(v string) *ErrorPassthroughRuleUpdate {
|
||||
_u.mutation.SetDescription(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableDescription sets the "description" field if the given value is not nil.
|
||||
func (_u *ErrorPassthroughRuleUpdate) SetNillableDescription(v *string) *ErrorPassthroughRuleUpdate {
|
||||
if v != nil {
|
||||
_u.SetDescription(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearDescription clears the value of the "description" field.
|
||||
func (_u *ErrorPassthroughRuleUpdate) ClearDescription() *ErrorPassthroughRuleUpdate {
|
||||
_u.mutation.ClearDescription()
|
||||
return _u
|
||||
}
|
||||
|
||||
// Mutation returns the ErrorPassthroughRuleMutation object of the builder.
|
||||
func (_u *ErrorPassthroughRuleUpdate) Mutation() *ErrorPassthroughRuleMutation {
|
||||
return _u.mutation
|
||||
}
|
||||
|
||||
// Save executes the query and returns the number of nodes affected by the update operation.
|
||||
func (_u *ErrorPassthroughRuleUpdate) Save(ctx context.Context) (int, error) {
|
||||
_u.defaults()
|
||||
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
|
||||
}
|
||||
|
||||
// SaveX is like Save, but panics if an error occurs.
|
||||
func (_u *ErrorPassthroughRuleUpdate) SaveX(ctx context.Context) int {
|
||||
affected, err := _u.Save(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return affected
|
||||
}
|
||||
|
||||
// Exec executes the query.
|
||||
func (_u *ErrorPassthroughRuleUpdate) Exec(ctx context.Context) error {
|
||||
_, err := _u.Save(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
// ExecX is like Exec, but panics if an error occurs.
|
||||
func (_u *ErrorPassthroughRuleUpdate) ExecX(ctx context.Context) {
|
||||
if err := _u.Exec(ctx); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
// defaults sets the default values of the builder before save.
|
||||
func (_u *ErrorPassthroughRuleUpdate) defaults() {
|
||||
if _, ok := _u.mutation.UpdatedAt(); !ok {
|
||||
v := errorpassthroughrule.UpdateDefaultUpdatedAt()
|
||||
_u.mutation.SetUpdatedAt(v)
|
||||
}
|
||||
}
|
||||
|
||||
// check runs all checks and user-defined validators on the builder.
|
||||
func (_u *ErrorPassthroughRuleUpdate) check() error {
|
||||
if v, ok := _u.mutation.Name(); ok {
|
||||
if err := errorpassthroughrule.NameValidator(v); err != nil {
|
||||
return &ValidationError{Name: "name", err: fmt.Errorf(`ent: validator failed for field "ErrorPassthroughRule.name": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.MatchMode(); ok {
|
||||
if err := errorpassthroughrule.MatchModeValidator(v); err != nil {
|
||||
return &ValidationError{Name: "match_mode", err: fmt.Errorf(`ent: validator failed for field "ErrorPassthroughRule.match_mode": %w`, err)}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (_u *ErrorPassthroughRuleUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
||||
if err := _u.check(); err != nil {
|
||||
return _node, err
|
||||
}
|
||||
_spec := sqlgraph.NewUpdateSpec(errorpassthroughrule.Table, errorpassthroughrule.Columns, sqlgraph.NewFieldSpec(errorpassthroughrule.FieldID, field.TypeInt64))
|
||||
if ps := _u.mutation.predicates; len(ps) > 0 {
|
||||
_spec.Predicate = func(selector *sql.Selector) {
|
||||
for i := range ps {
|
||||
ps[i](selector)
|
||||
}
|
||||
}
|
||||
}
|
||||
if value, ok := _u.mutation.UpdatedAt(); ok {
|
||||
_spec.SetField(errorpassthroughrule.FieldUpdatedAt, field.TypeTime, value)
|
||||
}
|
||||
if value, ok := _u.mutation.Name(); ok {
|
||||
_spec.SetField(errorpassthroughrule.FieldName, field.TypeString, value)
|
||||
}
|
||||
if value, ok := _u.mutation.Enabled(); ok {
|
||||
_spec.SetField(errorpassthroughrule.FieldEnabled, field.TypeBool, value)
|
||||
}
|
||||
if value, ok := _u.mutation.Priority(); ok {
|
||||
_spec.SetField(errorpassthroughrule.FieldPriority, field.TypeInt, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedPriority(); ok {
|
||||
_spec.AddField(errorpassthroughrule.FieldPriority, field.TypeInt, value)
|
||||
}
|
||||
if value, ok := _u.mutation.ErrorCodes(); ok {
|
||||
_spec.SetField(errorpassthroughrule.FieldErrorCodes, field.TypeJSON, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AppendedErrorCodes(); ok {
|
||||
_spec.AddModifier(func(u *sql.UpdateBuilder) {
|
||||
sqljson.Append(u, errorpassthroughrule.FieldErrorCodes, value)
|
||||
})
|
||||
}
|
||||
if _u.mutation.ErrorCodesCleared() {
|
||||
_spec.ClearField(errorpassthroughrule.FieldErrorCodes, field.TypeJSON)
|
||||
}
|
||||
if value, ok := _u.mutation.Keywords(); ok {
|
||||
_spec.SetField(errorpassthroughrule.FieldKeywords, field.TypeJSON, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AppendedKeywords(); ok {
|
||||
_spec.AddModifier(func(u *sql.UpdateBuilder) {
|
||||
sqljson.Append(u, errorpassthroughrule.FieldKeywords, value)
|
||||
})
|
||||
}
|
||||
if _u.mutation.KeywordsCleared() {
|
||||
_spec.ClearField(errorpassthroughrule.FieldKeywords, field.TypeJSON)
|
||||
}
|
||||
if value, ok := _u.mutation.MatchMode(); ok {
|
||||
_spec.SetField(errorpassthroughrule.FieldMatchMode, field.TypeString, value)
|
||||
}
|
||||
if value, ok := _u.mutation.Platforms(); ok {
|
||||
_spec.SetField(errorpassthroughrule.FieldPlatforms, field.TypeJSON, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AppendedPlatforms(); ok {
|
||||
_spec.AddModifier(func(u *sql.UpdateBuilder) {
|
||||
sqljson.Append(u, errorpassthroughrule.FieldPlatforms, value)
|
||||
})
|
||||
}
|
||||
if _u.mutation.PlatformsCleared() {
|
||||
_spec.ClearField(errorpassthroughrule.FieldPlatforms, field.TypeJSON)
|
||||
}
|
||||
if value, ok := _u.mutation.PassthroughCode(); ok {
|
||||
_spec.SetField(errorpassthroughrule.FieldPassthroughCode, field.TypeBool, value)
|
||||
}
|
||||
if value, ok := _u.mutation.ResponseCode(); ok {
|
||||
_spec.SetField(errorpassthroughrule.FieldResponseCode, field.TypeInt, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedResponseCode(); ok {
|
||||
_spec.AddField(errorpassthroughrule.FieldResponseCode, field.TypeInt, value)
|
||||
}
|
||||
if _u.mutation.ResponseCodeCleared() {
|
||||
_spec.ClearField(errorpassthroughrule.FieldResponseCode, field.TypeInt)
|
||||
}
|
||||
if value, ok := _u.mutation.PassthroughBody(); ok {
|
||||
_spec.SetField(errorpassthroughrule.FieldPassthroughBody, field.TypeBool, value)
|
||||
}
|
||||
if value, ok := _u.mutation.CustomMessage(); ok {
|
||||
_spec.SetField(errorpassthroughrule.FieldCustomMessage, field.TypeString, value)
|
||||
}
|
||||
if _u.mutation.CustomMessageCleared() {
|
||||
_spec.ClearField(errorpassthroughrule.FieldCustomMessage, field.TypeString)
|
||||
}
|
||||
if value, ok := _u.mutation.Description(); ok {
|
||||
_spec.SetField(errorpassthroughrule.FieldDescription, field.TypeString, value)
|
||||
}
|
||||
if _u.mutation.DescriptionCleared() {
|
||||
_spec.ClearField(errorpassthroughrule.FieldDescription, field.TypeString)
|
||||
}
|
||||
if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
|
||||
if _, ok := err.(*sqlgraph.NotFoundError); ok {
|
||||
err = &NotFoundError{errorpassthroughrule.Label}
|
||||
} else if sqlgraph.IsConstraintError(err) {
|
||||
err = &ConstraintError{msg: err.Error(), wrap: err}
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
_u.mutation.done = true
|
||||
return _node, nil
|
||||
}
|
||||
|
||||
// ErrorPassthroughRuleUpdateOne is the builder for updating a single ErrorPassthroughRule entity.
|
||||
type ErrorPassthroughRuleUpdateOne struct {
|
||||
config
|
||||
fields []string
|
||||
hooks []Hook
|
||||
mutation *ErrorPassthroughRuleMutation
|
||||
}
|
||||
|
||||
// SetUpdatedAt sets the "updated_at" field.
|
||||
func (_u *ErrorPassthroughRuleUpdateOne) SetUpdatedAt(v time.Time) *ErrorPassthroughRuleUpdateOne {
|
||||
_u.mutation.SetUpdatedAt(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetName sets the "name" field.
|
||||
func (_u *ErrorPassthroughRuleUpdateOne) SetName(v string) *ErrorPassthroughRuleUpdateOne {
|
||||
_u.mutation.SetName(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableName sets the "name" field if the given value is not nil.
|
||||
func (_u *ErrorPassthroughRuleUpdateOne) SetNillableName(v *string) *ErrorPassthroughRuleUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetName(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetEnabled sets the "enabled" field.
|
||||
func (_u *ErrorPassthroughRuleUpdateOne) SetEnabled(v bool) *ErrorPassthroughRuleUpdateOne {
|
||||
_u.mutation.SetEnabled(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableEnabled sets the "enabled" field if the given value is not nil.
|
||||
func (_u *ErrorPassthroughRuleUpdateOne) SetNillableEnabled(v *bool) *ErrorPassthroughRuleUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetEnabled(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetPriority sets the "priority" field.
|
||||
func (_u *ErrorPassthroughRuleUpdateOne) SetPriority(v int) *ErrorPassthroughRuleUpdateOne {
|
||||
_u.mutation.ResetPriority()
|
||||
_u.mutation.SetPriority(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillablePriority sets the "priority" field if the given value is not nil.
|
||||
func (_u *ErrorPassthroughRuleUpdateOne) SetNillablePriority(v *int) *ErrorPassthroughRuleUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetPriority(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddPriority adds value to the "priority" field.
|
||||
func (_u *ErrorPassthroughRuleUpdateOne) AddPriority(v int) *ErrorPassthroughRuleUpdateOne {
|
||||
_u.mutation.AddPriority(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetErrorCodes sets the "error_codes" field.
|
||||
func (_u *ErrorPassthroughRuleUpdateOne) SetErrorCodes(v []int) *ErrorPassthroughRuleUpdateOne {
|
||||
_u.mutation.SetErrorCodes(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// AppendErrorCodes appends value to the "error_codes" field.
|
||||
func (_u *ErrorPassthroughRuleUpdateOne) AppendErrorCodes(v []int) *ErrorPassthroughRuleUpdateOne {
|
||||
_u.mutation.AppendErrorCodes(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearErrorCodes clears the value of the "error_codes" field.
|
||||
func (_u *ErrorPassthroughRuleUpdateOne) ClearErrorCodes() *ErrorPassthroughRuleUpdateOne {
|
||||
_u.mutation.ClearErrorCodes()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetKeywords sets the "keywords" field.
|
||||
func (_u *ErrorPassthroughRuleUpdateOne) SetKeywords(v []string) *ErrorPassthroughRuleUpdateOne {
|
||||
_u.mutation.SetKeywords(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// AppendKeywords appends value to the "keywords" field.
|
||||
func (_u *ErrorPassthroughRuleUpdateOne) AppendKeywords(v []string) *ErrorPassthroughRuleUpdateOne {
|
||||
_u.mutation.AppendKeywords(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearKeywords clears the value of the "keywords" field.
|
||||
func (_u *ErrorPassthroughRuleUpdateOne) ClearKeywords() *ErrorPassthroughRuleUpdateOne {
|
||||
_u.mutation.ClearKeywords()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetMatchMode sets the "match_mode" field.
|
||||
func (_u *ErrorPassthroughRuleUpdateOne) SetMatchMode(v string) *ErrorPassthroughRuleUpdateOne {
|
||||
_u.mutation.SetMatchMode(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableMatchMode sets the "match_mode" field if the given value is not nil.
|
||||
func (_u *ErrorPassthroughRuleUpdateOne) SetNillableMatchMode(v *string) *ErrorPassthroughRuleUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetMatchMode(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetPlatforms sets the "platforms" field.
|
||||
func (_u *ErrorPassthroughRuleUpdateOne) SetPlatforms(v []string) *ErrorPassthroughRuleUpdateOne {
|
||||
_u.mutation.SetPlatforms(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// AppendPlatforms appends value to the "platforms" field.
|
||||
func (_u *ErrorPassthroughRuleUpdateOne) AppendPlatforms(v []string) *ErrorPassthroughRuleUpdateOne {
|
||||
_u.mutation.AppendPlatforms(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearPlatforms clears the value of the "platforms" field.
|
||||
func (_u *ErrorPassthroughRuleUpdateOne) ClearPlatforms() *ErrorPassthroughRuleUpdateOne {
|
||||
_u.mutation.ClearPlatforms()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetPassthroughCode sets the "passthrough_code" field.
|
||||
func (_u *ErrorPassthroughRuleUpdateOne) SetPassthroughCode(v bool) *ErrorPassthroughRuleUpdateOne {
|
||||
_u.mutation.SetPassthroughCode(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillablePassthroughCode sets the "passthrough_code" field if the given value is not nil.
|
||||
func (_u *ErrorPassthroughRuleUpdateOne) SetNillablePassthroughCode(v *bool) *ErrorPassthroughRuleUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetPassthroughCode(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetResponseCode sets the "response_code" field.
|
||||
func (_u *ErrorPassthroughRuleUpdateOne) SetResponseCode(v int) *ErrorPassthroughRuleUpdateOne {
|
||||
_u.mutation.ResetResponseCode()
|
||||
_u.mutation.SetResponseCode(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableResponseCode sets the "response_code" field if the given value is not nil.
|
||||
func (_u *ErrorPassthroughRuleUpdateOne) SetNillableResponseCode(v *int) *ErrorPassthroughRuleUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetResponseCode(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddResponseCode adds value to the "response_code" field.
|
||||
func (_u *ErrorPassthroughRuleUpdateOne) AddResponseCode(v int) *ErrorPassthroughRuleUpdateOne {
|
||||
_u.mutation.AddResponseCode(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearResponseCode clears the value of the "response_code" field.
|
||||
func (_u *ErrorPassthroughRuleUpdateOne) ClearResponseCode() *ErrorPassthroughRuleUpdateOne {
|
||||
_u.mutation.ClearResponseCode()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetPassthroughBody sets the "passthrough_body" field.
|
||||
func (_u *ErrorPassthroughRuleUpdateOne) SetPassthroughBody(v bool) *ErrorPassthroughRuleUpdateOne {
|
||||
_u.mutation.SetPassthroughBody(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillablePassthroughBody sets the "passthrough_body" field if the given value is not nil.
|
||||
func (_u *ErrorPassthroughRuleUpdateOne) SetNillablePassthroughBody(v *bool) *ErrorPassthroughRuleUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetPassthroughBody(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetCustomMessage sets the "custom_message" field.
|
||||
func (_u *ErrorPassthroughRuleUpdateOne) SetCustomMessage(v string) *ErrorPassthroughRuleUpdateOne {
|
||||
_u.mutation.SetCustomMessage(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableCustomMessage sets the "custom_message" field if the given value is not nil.
|
||||
func (_u *ErrorPassthroughRuleUpdateOne) SetNillableCustomMessage(v *string) *ErrorPassthroughRuleUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetCustomMessage(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearCustomMessage clears the value of the "custom_message" field.
|
||||
func (_u *ErrorPassthroughRuleUpdateOne) ClearCustomMessage() *ErrorPassthroughRuleUpdateOne {
|
||||
_u.mutation.ClearCustomMessage()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetDescription sets the "description" field.
|
||||
func (_u *ErrorPassthroughRuleUpdateOne) SetDescription(v string) *ErrorPassthroughRuleUpdateOne {
|
||||
_u.mutation.SetDescription(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableDescription sets the "description" field if the given value is not nil.
|
||||
func (_u *ErrorPassthroughRuleUpdateOne) SetNillableDescription(v *string) *ErrorPassthroughRuleUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetDescription(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearDescription clears the value of the "description" field.
|
||||
func (_u *ErrorPassthroughRuleUpdateOne) ClearDescription() *ErrorPassthroughRuleUpdateOne {
|
||||
_u.mutation.ClearDescription()
|
||||
return _u
|
||||
}
|
||||
|
||||
// Mutation returns the ErrorPassthroughRuleMutation object of the builder.
|
||||
func (_u *ErrorPassthroughRuleUpdateOne) Mutation() *ErrorPassthroughRuleMutation {
|
||||
return _u.mutation
|
||||
}
|
||||
|
||||
// Where appends a list predicates to the ErrorPassthroughRuleUpdate builder.
|
||||
func (_u *ErrorPassthroughRuleUpdateOne) Where(ps ...predicate.ErrorPassthroughRule) *ErrorPassthroughRuleUpdateOne {
|
||||
_u.mutation.Where(ps...)
|
||||
return _u
|
||||
}
|
||||
|
||||
// Select allows selecting one or more fields (columns) of the returned entity.
|
||||
// The default is selecting all fields defined in the entity schema.
|
||||
func (_u *ErrorPassthroughRuleUpdateOne) Select(field string, fields ...string) *ErrorPassthroughRuleUpdateOne {
|
||||
_u.fields = append([]string{field}, fields...)
|
||||
return _u
|
||||
}
|
||||
|
||||
// Save executes the query and returns the updated ErrorPassthroughRule entity.
|
||||
func (_u *ErrorPassthroughRuleUpdateOne) Save(ctx context.Context) (*ErrorPassthroughRule, error) {
|
||||
_u.defaults()
|
||||
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
|
||||
}
|
||||
|
||||
// SaveX is like Save, but panics if an error occurs.
|
||||
func (_u *ErrorPassthroughRuleUpdateOne) SaveX(ctx context.Context) *ErrorPassthroughRule {
|
||||
node, err := _u.Save(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return node
|
||||
}
|
||||
|
||||
// Exec executes the query on the entity.
|
||||
func (_u *ErrorPassthroughRuleUpdateOne) Exec(ctx context.Context) error {
|
||||
_, err := _u.Save(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
// ExecX is like Exec, but panics if an error occurs.
|
||||
func (_u *ErrorPassthroughRuleUpdateOne) ExecX(ctx context.Context) {
|
||||
if err := _u.Exec(ctx); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
// defaults sets the default values of the builder before save.
|
||||
func (_u *ErrorPassthroughRuleUpdateOne) defaults() {
|
||||
if _, ok := _u.mutation.UpdatedAt(); !ok {
|
||||
v := errorpassthroughrule.UpdateDefaultUpdatedAt()
|
||||
_u.mutation.SetUpdatedAt(v)
|
||||
}
|
||||
}
|
||||
|
||||
// check runs all checks and user-defined validators on the builder.
|
||||
func (_u *ErrorPassthroughRuleUpdateOne) check() error {
|
||||
if v, ok := _u.mutation.Name(); ok {
|
||||
if err := errorpassthroughrule.NameValidator(v); err != nil {
|
||||
return &ValidationError{Name: "name", err: fmt.Errorf(`ent: validator failed for field "ErrorPassthroughRule.name": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.MatchMode(); ok {
|
||||
if err := errorpassthroughrule.MatchModeValidator(v); err != nil {
|
||||
return &ValidationError{Name: "match_mode", err: fmt.Errorf(`ent: validator failed for field "ErrorPassthroughRule.match_mode": %w`, err)}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (_u *ErrorPassthroughRuleUpdateOne) sqlSave(ctx context.Context) (_node *ErrorPassthroughRule, err error) {
|
||||
if err := _u.check(); err != nil {
|
||||
return _node, err
|
||||
}
|
||||
_spec := sqlgraph.NewUpdateSpec(errorpassthroughrule.Table, errorpassthroughrule.Columns, sqlgraph.NewFieldSpec(errorpassthroughrule.FieldID, field.TypeInt64))
|
||||
id, ok := _u.mutation.ID()
|
||||
if !ok {
|
||||
return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "ErrorPassthroughRule.id" for update`)}
|
||||
}
|
||||
_spec.Node.ID.Value = id
|
||||
if fields := _u.fields; len(fields) > 0 {
|
||||
_spec.Node.Columns = make([]string, 0, len(fields))
|
||||
_spec.Node.Columns = append(_spec.Node.Columns, errorpassthroughrule.FieldID)
|
||||
for _, f := range fields {
|
||||
if !errorpassthroughrule.ValidColumn(f) {
|
||||
return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
|
||||
}
|
||||
if f != errorpassthroughrule.FieldID {
|
||||
_spec.Node.Columns = append(_spec.Node.Columns, f)
|
||||
}
|
||||
}
|
||||
}
|
||||
if ps := _u.mutation.predicates; len(ps) > 0 {
|
||||
_spec.Predicate = func(selector *sql.Selector) {
|
||||
for i := range ps {
|
||||
ps[i](selector)
|
||||
}
|
||||
}
|
||||
}
|
||||
if value, ok := _u.mutation.UpdatedAt(); ok {
|
||||
_spec.SetField(errorpassthroughrule.FieldUpdatedAt, field.TypeTime, value)
|
||||
}
|
||||
if value, ok := _u.mutation.Name(); ok {
|
||||
_spec.SetField(errorpassthroughrule.FieldName, field.TypeString, value)
|
||||
}
|
||||
if value, ok := _u.mutation.Enabled(); ok {
|
||||
_spec.SetField(errorpassthroughrule.FieldEnabled, field.TypeBool, value)
|
||||
}
|
||||
if value, ok := _u.mutation.Priority(); ok {
|
||||
_spec.SetField(errorpassthroughrule.FieldPriority, field.TypeInt, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedPriority(); ok {
|
||||
_spec.AddField(errorpassthroughrule.FieldPriority, field.TypeInt, value)
|
||||
}
|
||||
if value, ok := _u.mutation.ErrorCodes(); ok {
|
||||
_spec.SetField(errorpassthroughrule.FieldErrorCodes, field.TypeJSON, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AppendedErrorCodes(); ok {
|
||||
_spec.AddModifier(func(u *sql.UpdateBuilder) {
|
||||
sqljson.Append(u, errorpassthroughrule.FieldErrorCodes, value)
|
||||
})
|
||||
}
|
||||
if _u.mutation.ErrorCodesCleared() {
|
||||
_spec.ClearField(errorpassthroughrule.FieldErrorCodes, field.TypeJSON)
|
||||
}
|
||||
if value, ok := _u.mutation.Keywords(); ok {
|
||||
_spec.SetField(errorpassthroughrule.FieldKeywords, field.TypeJSON, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AppendedKeywords(); ok {
|
||||
_spec.AddModifier(func(u *sql.UpdateBuilder) {
|
||||
sqljson.Append(u, errorpassthroughrule.FieldKeywords, value)
|
||||
})
|
||||
}
|
||||
if _u.mutation.KeywordsCleared() {
|
||||
_spec.ClearField(errorpassthroughrule.FieldKeywords, field.TypeJSON)
|
||||
}
|
||||
if value, ok := _u.mutation.MatchMode(); ok {
|
||||
_spec.SetField(errorpassthroughrule.FieldMatchMode, field.TypeString, value)
|
||||
}
|
||||
if value, ok := _u.mutation.Platforms(); ok {
|
||||
_spec.SetField(errorpassthroughrule.FieldPlatforms, field.TypeJSON, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AppendedPlatforms(); ok {
|
||||
_spec.AddModifier(func(u *sql.UpdateBuilder) {
|
||||
sqljson.Append(u, errorpassthroughrule.FieldPlatforms, value)
|
||||
})
|
||||
}
|
||||
if _u.mutation.PlatformsCleared() {
|
||||
_spec.ClearField(errorpassthroughrule.FieldPlatforms, field.TypeJSON)
|
||||
}
|
||||
if value, ok := _u.mutation.PassthroughCode(); ok {
|
||||
_spec.SetField(errorpassthroughrule.FieldPassthroughCode, field.TypeBool, value)
|
||||
}
|
||||
if value, ok := _u.mutation.ResponseCode(); ok {
|
||||
_spec.SetField(errorpassthroughrule.FieldResponseCode, field.TypeInt, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedResponseCode(); ok {
|
||||
_spec.AddField(errorpassthroughrule.FieldResponseCode, field.TypeInt, value)
|
||||
}
|
||||
if _u.mutation.ResponseCodeCleared() {
|
||||
_spec.ClearField(errorpassthroughrule.FieldResponseCode, field.TypeInt)
|
||||
}
|
||||
if value, ok := _u.mutation.PassthroughBody(); ok {
|
||||
_spec.SetField(errorpassthroughrule.FieldPassthroughBody, field.TypeBool, value)
|
||||
}
|
||||
if value, ok := _u.mutation.CustomMessage(); ok {
|
||||
_spec.SetField(errorpassthroughrule.FieldCustomMessage, field.TypeString, value)
|
||||
}
|
||||
if _u.mutation.CustomMessageCleared() {
|
||||
_spec.ClearField(errorpassthroughrule.FieldCustomMessage, field.TypeString)
|
||||
}
|
||||
if value, ok := _u.mutation.Description(); ok {
|
||||
_spec.SetField(errorpassthroughrule.FieldDescription, field.TypeString, value)
|
||||
}
|
||||
if _u.mutation.DescriptionCleared() {
|
||||
_spec.ClearField(errorpassthroughrule.FieldDescription, field.TypeString)
|
||||
}
|
||||
_node = &ErrorPassthroughRule{config: _u.config}
|
||||
_spec.Assign = _node.assignValues
|
||||
_spec.ScanValues = _node.scanValues
|
||||
if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil {
|
||||
if _, ok := err.(*sqlgraph.NotFoundError); ok {
|
||||
err = &NotFoundError{errorpassthroughrule.Label}
|
||||
} else if sqlgraph.IsConstraintError(err) {
|
||||
err = &ConstraintError{msg: err.Error(), wrap: err}
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
_u.mutation.done = true
|
||||
return _node, nil
|
||||
}
|
||||
@@ -69,6 +69,18 @@ func (f AnnouncementReadFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.V
|
||||
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.AnnouncementReadMutation", m)
|
||||
}
|
||||
|
||||
// The ErrorPassthroughRuleFunc type is an adapter to allow the use of ordinary
|
||||
// function as ErrorPassthroughRule mutator.
|
||||
type ErrorPassthroughRuleFunc func(context.Context, *ent.ErrorPassthroughRuleMutation) (ent.Value, error)
|
||||
|
||||
// Mutate calls f(ctx, m).
|
||||
func (f ErrorPassthroughRuleFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
|
||||
if mv, ok := m.(*ent.ErrorPassthroughRuleMutation); ok {
|
||||
return f(ctx, mv)
|
||||
}
|
||||
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.ErrorPassthroughRuleMutation", m)
|
||||
}
|
||||
|
||||
// The GroupFunc type is an adapter to allow the use of ordinary
|
||||
// function as Group mutator.
|
||||
type GroupFunc func(context.Context, *ent.GroupMutation) (ent.Value, error)
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/ent/announcement"
|
||||
"github.com/Wei-Shaw/sub2api/ent/announcementread"
|
||||
"github.com/Wei-Shaw/sub2api/ent/apikey"
|
||||
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
|
||||
"github.com/Wei-Shaw/sub2api/ent/group"
|
||||
"github.com/Wei-Shaw/sub2api/ent/predicate"
|
||||
"github.com/Wei-Shaw/sub2api/ent/promocode"
|
||||
@@ -220,6 +221,33 @@ func (f TraverseAnnouncementRead) Traverse(ctx context.Context, q ent.Query) err
|
||||
return fmt.Errorf("unexpected query type %T. expect *ent.AnnouncementReadQuery", q)
|
||||
}
|
||||
|
||||
// The ErrorPassthroughRuleFunc type is an adapter to allow the use of ordinary function as a Querier.
|
||||
type ErrorPassthroughRuleFunc func(context.Context, *ent.ErrorPassthroughRuleQuery) (ent.Value, error)
|
||||
|
||||
// Query calls f(ctx, q).
|
||||
func (f ErrorPassthroughRuleFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
|
||||
if q, ok := q.(*ent.ErrorPassthroughRuleQuery); ok {
|
||||
return f(ctx, q)
|
||||
}
|
||||
return nil, fmt.Errorf("unexpected query type %T. expect *ent.ErrorPassthroughRuleQuery", q)
|
||||
}
|
||||
|
||||
// The TraverseErrorPassthroughRule type is an adapter to allow the use of ordinary function as Traverser.
|
||||
type TraverseErrorPassthroughRule func(context.Context, *ent.ErrorPassthroughRuleQuery) error
|
||||
|
||||
// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline.
|
||||
func (f TraverseErrorPassthroughRule) Intercept(next ent.Querier) ent.Querier {
|
||||
return next
|
||||
}
|
||||
|
||||
// Traverse calls f(ctx, q).
|
||||
func (f TraverseErrorPassthroughRule) Traverse(ctx context.Context, q ent.Query) error {
|
||||
if q, ok := q.(*ent.ErrorPassthroughRuleQuery); ok {
|
||||
return f(ctx, q)
|
||||
}
|
||||
return fmt.Errorf("unexpected query type %T. expect *ent.ErrorPassthroughRuleQuery", q)
|
||||
}
|
||||
|
||||
// The GroupFunc type is an adapter to allow the use of ordinary function as a Querier.
|
||||
type GroupFunc func(context.Context, *ent.GroupQuery) (ent.Value, error)
|
||||
|
||||
@@ -584,6 +612,8 @@ func NewQuery(q ent.Query) (Query, error) {
|
||||
return &query[*ent.AnnouncementQuery, predicate.Announcement, announcement.OrderOption]{typ: ent.TypeAnnouncement, tq: q}, nil
|
||||
case *ent.AnnouncementReadQuery:
|
||||
return &query[*ent.AnnouncementReadQuery, predicate.AnnouncementRead, announcementread.OrderOption]{typ: ent.TypeAnnouncementRead, tq: q}, nil
|
||||
case *ent.ErrorPassthroughRuleQuery:
|
||||
return &query[*ent.ErrorPassthroughRuleQuery, predicate.ErrorPassthroughRule, errorpassthroughrule.OrderOption]{typ: ent.TypeErrorPassthroughRule, tq: q}, nil
|
||||
case *ent.GroupQuery:
|
||||
return &query[*ent.GroupQuery, predicate.Group, group.OrderOption]{typ: ent.TypeGroup, tq: q}, nil
|
||||
case *ent.PromoCodeQuery:
|
||||
|
||||
@@ -309,6 +309,42 @@ var (
|
||||
},
|
||||
},
|
||||
}
|
||||
// ErrorPassthroughRulesColumns holds the columns for the "error_passthrough_rules" table.
|
||||
ErrorPassthroughRulesColumns = []*schema.Column{
|
||||
{Name: "id", Type: field.TypeInt64, Increment: true},
|
||||
{Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
||||
{Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
||||
{Name: "name", Type: field.TypeString, Size: 100},
|
||||
{Name: "enabled", Type: field.TypeBool, Default: true},
|
||||
{Name: "priority", Type: field.TypeInt, Default: 0},
|
||||
{Name: "error_codes", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}},
|
||||
{Name: "keywords", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}},
|
||||
{Name: "match_mode", Type: field.TypeString, Size: 10, Default: "any"},
|
||||
{Name: "platforms", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}},
|
||||
{Name: "passthrough_code", Type: field.TypeBool, Default: true},
|
||||
{Name: "response_code", Type: field.TypeInt, Nullable: true},
|
||||
{Name: "passthrough_body", Type: field.TypeBool, Default: true},
|
||||
{Name: "custom_message", Type: field.TypeString, Nullable: true, Size: 2147483647},
|
||||
{Name: "description", Type: field.TypeString, Nullable: true, Size: 2147483647},
|
||||
}
|
||||
// ErrorPassthroughRulesTable holds the schema information for the "error_passthrough_rules" table.
|
||||
ErrorPassthroughRulesTable = &schema.Table{
|
||||
Name: "error_passthrough_rules",
|
||||
Columns: ErrorPassthroughRulesColumns,
|
||||
PrimaryKey: []*schema.Column{ErrorPassthroughRulesColumns[0]},
|
||||
Indexes: []*schema.Index{
|
||||
{
|
||||
Name: "errorpassthroughrule_enabled",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{ErrorPassthroughRulesColumns[4]},
|
||||
},
|
||||
{
|
||||
Name: "errorpassthroughrule_priority",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{ErrorPassthroughRulesColumns[5]},
|
||||
},
|
||||
},
|
||||
}
|
||||
// GroupsColumns holds the columns for the "groups" table.
|
||||
GroupsColumns = []*schema.Column{
|
||||
{Name: "id", Type: field.TypeInt64, Increment: true},
|
||||
@@ -950,6 +986,7 @@ var (
|
||||
AccountGroupsTable,
|
||||
AnnouncementsTable,
|
||||
AnnouncementReadsTable,
|
||||
ErrorPassthroughRulesTable,
|
||||
GroupsTable,
|
||||
PromoCodesTable,
|
||||
PromoCodeUsagesTable,
|
||||
@@ -989,6 +1026,9 @@ func init() {
|
||||
AnnouncementReadsTable.Annotation = &entsql.Annotation{
|
||||
Table: "announcement_reads",
|
||||
}
|
||||
ErrorPassthroughRulesTable.Annotation = &entsql.Annotation{
|
||||
Table: "error_passthrough_rules",
|
||||
}
|
||||
GroupsTable.Annotation = &entsql.Annotation{
|
||||
Table: "groups",
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -21,6 +21,9 @@ type Announcement func(*sql.Selector)
|
||||
// AnnouncementRead is the predicate function for announcementread builders.
|
||||
type AnnouncementRead func(*sql.Selector)
|
||||
|
||||
// ErrorPassthroughRule is the predicate function for errorpassthroughrule builders.
|
||||
type ErrorPassthroughRule func(*sql.Selector)
|
||||
|
||||
// Group is the predicate function for group builders.
|
||||
type Group func(*sql.Selector)
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/ent/announcement"
|
||||
"github.com/Wei-Shaw/sub2api/ent/announcementread"
|
||||
"github.com/Wei-Shaw/sub2api/ent/apikey"
|
||||
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
|
||||
"github.com/Wei-Shaw/sub2api/ent/group"
|
||||
"github.com/Wei-Shaw/sub2api/ent/promocode"
|
||||
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
|
||||
@@ -270,6 +271,61 @@ func init() {
|
||||
announcementreadDescCreatedAt := announcementreadFields[3].Descriptor()
|
||||
// announcementread.DefaultCreatedAt holds the default value on creation for the created_at field.
|
||||
announcementread.DefaultCreatedAt = announcementreadDescCreatedAt.Default.(func() time.Time)
|
||||
errorpassthroughruleMixin := schema.ErrorPassthroughRule{}.Mixin()
|
||||
errorpassthroughruleMixinFields0 := errorpassthroughruleMixin[0].Fields()
|
||||
_ = errorpassthroughruleMixinFields0
|
||||
errorpassthroughruleFields := schema.ErrorPassthroughRule{}.Fields()
|
||||
_ = errorpassthroughruleFields
|
||||
// errorpassthroughruleDescCreatedAt is the schema descriptor for created_at field.
|
||||
errorpassthroughruleDescCreatedAt := errorpassthroughruleMixinFields0[0].Descriptor()
|
||||
// errorpassthroughrule.DefaultCreatedAt holds the default value on creation for the created_at field.
|
||||
errorpassthroughrule.DefaultCreatedAt = errorpassthroughruleDescCreatedAt.Default.(func() time.Time)
|
||||
// errorpassthroughruleDescUpdatedAt is the schema descriptor for updated_at field.
|
||||
errorpassthroughruleDescUpdatedAt := errorpassthroughruleMixinFields0[1].Descriptor()
|
||||
// errorpassthroughrule.DefaultUpdatedAt holds the default value on creation for the updated_at field.
|
||||
errorpassthroughrule.DefaultUpdatedAt = errorpassthroughruleDescUpdatedAt.Default.(func() time.Time)
|
||||
// errorpassthroughrule.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
|
||||
errorpassthroughrule.UpdateDefaultUpdatedAt = errorpassthroughruleDescUpdatedAt.UpdateDefault.(func() time.Time)
|
||||
// errorpassthroughruleDescName is the schema descriptor for name field.
|
||||
errorpassthroughruleDescName := errorpassthroughruleFields[0].Descriptor()
|
||||
// errorpassthroughrule.NameValidator is a validator for the "name" field. It is called by the builders before save.
|
||||
errorpassthroughrule.NameValidator = func() func(string) error {
|
||||
validators := errorpassthroughruleDescName.Validators
|
||||
fns := [...]func(string) error{
|
||||
validators[0].(func(string) error),
|
||||
validators[1].(func(string) error),
|
||||
}
|
||||
return func(name string) error {
|
||||
for _, fn := range fns {
|
||||
if err := fn(name); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}()
|
||||
// errorpassthroughruleDescEnabled is the schema descriptor for enabled field.
|
||||
errorpassthroughruleDescEnabled := errorpassthroughruleFields[1].Descriptor()
|
||||
// errorpassthroughrule.DefaultEnabled holds the default value on creation for the enabled field.
|
||||
errorpassthroughrule.DefaultEnabled = errorpassthroughruleDescEnabled.Default.(bool)
|
||||
// errorpassthroughruleDescPriority is the schema descriptor for priority field.
|
||||
errorpassthroughruleDescPriority := errorpassthroughruleFields[2].Descriptor()
|
||||
// errorpassthroughrule.DefaultPriority holds the default value on creation for the priority field.
|
||||
errorpassthroughrule.DefaultPriority = errorpassthroughruleDescPriority.Default.(int)
|
||||
// errorpassthroughruleDescMatchMode is the schema descriptor for match_mode field.
|
||||
errorpassthroughruleDescMatchMode := errorpassthroughruleFields[5].Descriptor()
|
||||
// errorpassthroughrule.DefaultMatchMode holds the default value on creation for the match_mode field.
|
||||
errorpassthroughrule.DefaultMatchMode = errorpassthroughruleDescMatchMode.Default.(string)
|
||||
// errorpassthroughrule.MatchModeValidator is a validator for the "match_mode" field. It is called by the builders before save.
|
||||
errorpassthroughrule.MatchModeValidator = errorpassthroughruleDescMatchMode.Validators[0].(func(string) error)
|
||||
// errorpassthroughruleDescPassthroughCode is the schema descriptor for passthrough_code field.
|
||||
errorpassthroughruleDescPassthroughCode := errorpassthroughruleFields[7].Descriptor()
|
||||
// errorpassthroughrule.DefaultPassthroughCode holds the default value on creation for the passthrough_code field.
|
||||
errorpassthroughrule.DefaultPassthroughCode = errorpassthroughruleDescPassthroughCode.Default.(bool)
|
||||
// errorpassthroughruleDescPassthroughBody is the schema descriptor for passthrough_body field.
|
||||
errorpassthroughruleDescPassthroughBody := errorpassthroughruleFields[9].Descriptor()
|
||||
// errorpassthroughrule.DefaultPassthroughBody holds the default value on creation for the passthrough_body field.
|
||||
errorpassthroughrule.DefaultPassthroughBody = errorpassthroughruleDescPassthroughBody.Default.(bool)
|
||||
groupMixin := schema.Group{}.Mixin()
|
||||
groupMixinHooks1 := groupMixin[1].Hooks()
|
||||
group.Hooks[0] = groupMixinHooks1[0]
|
||||
|
||||
121
backend/ent/schema/error_passthrough_rule.go
Normal file
121
backend/ent/schema/error_passthrough_rule.go
Normal file
@@ -0,0 +1,121 @@
|
||||
// Package schema 定义 Ent ORM 的数据库 schema。
|
||||
package schema
|
||||
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
|
||||
|
||||
"entgo.io/ent"
|
||||
"entgo.io/ent/dialect"
|
||||
"entgo.io/ent/dialect/entsql"
|
||||
"entgo.io/ent/schema"
|
||||
"entgo.io/ent/schema/field"
|
||||
"entgo.io/ent/schema/index"
|
||||
)
|
||||
|
||||
// ErrorPassthroughRule 定义全局错误透传规则的 schema。
|
||||
//
|
||||
// 错误透传规则用于控制上游错误如何返回给客户端:
|
||||
// - 匹配条件:错误码 + 关键词组合
|
||||
// - 响应行为:透传原始信息 或 自定义错误信息
|
||||
// - 响应状态码:可指定返回给客户端的状态码
|
||||
// - 平台范围:规则适用的平台(Anthropic、OpenAI、Gemini、Antigravity)
|
||||
type ErrorPassthroughRule struct {
|
||||
ent.Schema
|
||||
}
|
||||
|
||||
// Annotations 返回 schema 的注解配置。
|
||||
func (ErrorPassthroughRule) Annotations() []schema.Annotation {
|
||||
return []schema.Annotation{
|
||||
entsql.Annotation{Table: "error_passthrough_rules"},
|
||||
}
|
||||
}
|
||||
|
||||
// Mixin 返回该 schema 使用的混入组件。
|
||||
func (ErrorPassthroughRule) Mixin() []ent.Mixin {
|
||||
return []ent.Mixin{
|
||||
mixins.TimeMixin{},
|
||||
}
|
||||
}
|
||||
|
||||
// Fields 定义错误透传规则实体的所有字段。
|
||||
func (ErrorPassthroughRule) Fields() []ent.Field {
|
||||
return []ent.Field{
|
||||
// name: 规则名称,用于在界面中标识规则
|
||||
field.String("name").
|
||||
MaxLen(100).
|
||||
NotEmpty(),
|
||||
|
||||
// enabled: 是否启用该规则
|
||||
field.Bool("enabled").
|
||||
Default(true),
|
||||
|
||||
// priority: 规则优先级,数值越小优先级越高
|
||||
// 匹配时按优先级顺序检查,命中第一个匹配的规则
|
||||
field.Int("priority").
|
||||
Default(0),
|
||||
|
||||
// error_codes: 匹配的错误码列表(OR关系)
|
||||
// 例如:[422, 400] 表示匹配 422 或 400 错误码
|
||||
field.JSON("error_codes", []int{}).
|
||||
Optional().
|
||||
SchemaType(map[string]string{dialect.Postgres: "jsonb"}),
|
||||
|
||||
// keywords: 匹配的关键词列表(OR关系)
|
||||
// 例如:["context limit", "model not supported"]
|
||||
// 关键词匹配不区分大小写
|
||||
field.JSON("keywords", []string{}).
|
||||
Optional().
|
||||
SchemaType(map[string]string{dialect.Postgres: "jsonb"}),
|
||||
|
||||
// match_mode: 匹配模式
|
||||
// - "any": 错误码匹配 OR 关键词匹配(任一条件满足即可)
|
||||
// - "all": 错误码匹配 AND 关键词匹配(所有条件都必须满足)
|
||||
field.String("match_mode").
|
||||
MaxLen(10).
|
||||
Default("any"),
|
||||
|
||||
// platforms: 适用平台列表
|
||||
// 例如:["anthropic", "openai", "gemini", "antigravity"]
|
||||
// 空列表表示适用于所有平台
|
||||
field.JSON("platforms", []string{}).
|
||||
Optional().
|
||||
SchemaType(map[string]string{dialect.Postgres: "jsonb"}),
|
||||
|
||||
// passthrough_code: 是否透传上游原始状态码
|
||||
// true: 使用上游返回的状态码
|
||||
// false: 使用 response_code 指定的状态码
|
||||
field.Bool("passthrough_code").
|
||||
Default(true),
|
||||
|
||||
// response_code: 自定义响应状态码
|
||||
// 当 passthrough_code=false 时使用此状态码
|
||||
field.Int("response_code").
|
||||
Optional().
|
||||
Nillable(),
|
||||
|
||||
// passthrough_body: 是否透传上游原始错误信息
|
||||
// true: 使用上游返回的错误信息
|
||||
// false: 使用 custom_message 指定的错误信息
|
||||
field.Bool("passthrough_body").
|
||||
Default(true),
|
||||
|
||||
// custom_message: 自定义错误信息
|
||||
// 当 passthrough_body=false 时使用此错误信息
|
||||
field.Text("custom_message").
|
||||
Optional().
|
||||
Nillable(),
|
||||
|
||||
// description: 规则描述,用于说明规则的用途
|
||||
field.Text("description").
|
||||
Optional().
|
||||
Nillable(),
|
||||
}
|
||||
}
|
||||
|
||||
// Indexes 定义数据库索引,优化查询性能。
|
||||
func (ErrorPassthroughRule) Indexes() []ent.Index {
|
||||
return []ent.Index{
|
||||
index.Fields("enabled"), // 筛选启用的规则
|
||||
index.Fields("priority"), // 按优先级排序
|
||||
}
|
||||
}
|
||||
@@ -24,6 +24,8 @@ type Tx struct {
|
||||
Announcement *AnnouncementClient
|
||||
// AnnouncementRead is the client for interacting with the AnnouncementRead builders.
|
||||
AnnouncementRead *AnnouncementReadClient
|
||||
// ErrorPassthroughRule is the client for interacting with the ErrorPassthroughRule builders.
|
||||
ErrorPassthroughRule *ErrorPassthroughRuleClient
|
||||
// Group is the client for interacting with the Group builders.
|
||||
Group *GroupClient
|
||||
// PromoCode is the client for interacting with the PromoCode builders.
|
||||
@@ -186,6 +188,7 @@ func (tx *Tx) init() {
|
||||
tx.AccountGroup = NewAccountGroupClient(tx.config)
|
||||
tx.Announcement = NewAnnouncementClient(tx.config)
|
||||
tx.AnnouncementRead = NewAnnouncementReadClient(tx.config)
|
||||
tx.ErrorPassthroughRule = NewErrorPassthroughRuleClient(tx.config)
|
||||
tx.Group = NewGroupClient(tx.config)
|
||||
tx.PromoCode = NewPromoCodeClient(tx.config)
|
||||
tx.PromoCodeUsage = NewPromoCodeUsageClient(tx.config)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
module github.com/Wei-Shaw/sub2api
|
||||
|
||||
go 1.25.6
|
||||
go 1.25.7
|
||||
|
||||
require (
|
||||
entgo.io/ent v0.14.5
|
||||
@@ -25,10 +25,10 @@ require (
|
||||
github.com/tidwall/gjson v1.18.0
|
||||
github.com/tidwall/sjson v1.2.5
|
||||
github.com/zeromicro/go-zero v1.9.4
|
||||
golang.org/x/crypto v0.46.0
|
||||
golang.org/x/net v0.48.0
|
||||
golang.org/x/crypto v0.47.0
|
||||
golang.org/x/net v0.49.0
|
||||
golang.org/x/sync v0.19.0
|
||||
golang.org/x/term v0.38.0
|
||||
golang.org/x/term v0.39.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
modernc.org/sqlite v1.44.3
|
||||
)
|
||||
@@ -75,12 +75,10 @@ require (
|
||||
github.com/goccy/go-json v0.10.2 // indirect
|
||||
github.com/google/go-cmp v0.7.0 // indirect
|
||||
github.com/google/go-querystring v1.1.0 // indirect
|
||||
github.com/google/subcommands v1.2.0 // indirect
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 // indirect
|
||||
github.com/hashicorp/hcl v1.0.0 // indirect
|
||||
github.com/hashicorp/hcl/v2 v2.18.1 // indirect
|
||||
github.com/icholy/digest v1.1.0 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/klauspost/compress v1.18.2 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
|
||||
@@ -89,7 +87,6 @@ require (
|
||||
github.com/magiconair/properties v1.8.10 // indirect
|
||||
github.com/mattn/go-colorable v0.1.13 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/mattn/go-runewidth v0.0.15 // indirect
|
||||
github.com/mdelapenya/tlscert v0.2.0 // indirect
|
||||
github.com/mitchellh/go-wordwrap v1.0.1 // indirect
|
||||
github.com/mitchellh/mapstructure v1.5.0 // indirect
|
||||
@@ -104,7 +101,6 @@ require (
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/morikuni/aec v1.0.0 // indirect
|
||||
github.com/ncruces/go-strftime v1.0.0 // indirect
|
||||
github.com/olekukonko/tablewriter v0.0.5 // indirect
|
||||
github.com/opencontainers/go-digest v1.0.0 // indirect
|
||||
github.com/opencontainers/image-spec v1.1.1 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
|
||||
@@ -114,7 +110,6 @@ require (
|
||||
github.com/quic-go/qpack v0.6.0 // indirect
|
||||
github.com/quic-go/quic-go v0.57.1 // indirect
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||
github.com/rivo/uniseg v0.2.0 // indirect
|
||||
github.com/sagikazarmark/locafero v0.4.0 // indirect
|
||||
github.com/sagikazarmark/slog-shim v0.1.0 // indirect
|
||||
github.com/sirupsen/logrus v1.9.3 // indirect
|
||||
@@ -122,7 +117,6 @@ require (
|
||||
github.com/spaolacci/murmur3 v1.1.0 // indirect
|
||||
github.com/spf13/afero v1.11.0 // indirect
|
||||
github.com/spf13/cast v1.6.0 // indirect
|
||||
github.com/spf13/cobra v1.7.0 // indirect
|
||||
github.com/spf13/pflag v1.0.5 // indirect
|
||||
github.com/subosito/gotenv v1.6.0 // indirect
|
||||
github.com/testcontainers/testcontainers-go v0.40.0 // indirect
|
||||
@@ -146,10 +140,9 @@ require (
|
||||
go.uber.org/multierr v1.9.0 // indirect
|
||||
golang.org/x/arch v0.3.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
|
||||
golang.org/x/mod v0.30.0 // indirect
|
||||
golang.org/x/sys v0.39.0 // indirect
|
||||
golang.org/x/text v0.32.0 // indirect
|
||||
golang.org/x/tools v0.39.0 // indirect
|
||||
golang.org/x/mod v0.31.0 // indirect
|
||||
golang.org/x/sys v0.40.0 // indirect
|
||||
golang.org/x/text v0.33.0 // indirect
|
||||
google.golang.org/grpc v1.75.1 // indirect
|
||||
google.golang.org/protobuf v1.36.10 // indirect
|
||||
gopkg.in/ini.v1 v1.67.0 // indirect
|
||||
|
||||
@@ -46,7 +46,6 @@ github.com/containerd/platforms v0.2.1 h1:zvwtM3rz2YHPQsF2CHYM8+KtB5dvhISiXh5ZpS
|
||||
github.com/containerd/platforms v0.2.1/go.mod h1:XHCb+2/hzowdiut9rkudds9bE5yJ7npe7dG/wG+uFPw=
|
||||
github.com/cpuguy83/dockercfg v0.3.2 h1:DlJTyZGBDlXqUZ2Dk2Q3xHs/FtnooJJVaad2S9GKorA=
|
||||
github.com/cpuguy83/dockercfg v0.3.2/go.mod h1:sugsbF4//dDlL/i+S+rtpIWp+5h0BHJHfjj5/jFyUJc=
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
|
||||
github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY=
|
||||
github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
@@ -117,8 +116,6 @@ github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
|
||||
github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE=
|
||||
github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4=
|
||||
@@ -138,8 +135,6 @@ github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4=
|
||||
github.com/icholy/digest v1.1.0/go.mod h1:QNrsSGQ5v7v9cReDI0+eyjsXGUoRSUZQHeQ5C4XLa0Y=
|
||||
github.com/imroc/req/v3 v3.57.0 h1:LMTUjNRUybUkTPn8oJDq8Kg3JRBOBTcnDhKu7mzupKI=
|
||||
github.com/imroc/req/v3 v3.57.0/go.mod h1:JL62ey1nvSLq81HORNcosvlf7SxZStONNqOprg0Pz00=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
||||
@@ -175,9 +170,6 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk
|
||||
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
|
||||
github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U=
|
||||
github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
|
||||
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
|
||||
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
|
||||
github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI=
|
||||
@@ -211,8 +203,6 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
|
||||
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
|
||||
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
|
||||
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
||||
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
|
||||
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
|
||||
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
|
||||
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
|
||||
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
|
||||
@@ -240,13 +230,10 @@ github.com/refraction-networking/utls v1.8.1 h1:yNY1kapmQU8JeM1sSw2H2asfTIwWxIkr
|
||||
github.com/refraction-networking/utls v1.8.1/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
|
||||
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs=
|
||||
github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
|
||||
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
|
||||
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
|
||||
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||
github.com/sagikazarmark/locafero v0.4.0 h1:HApY1R9zGo4DBgr7dqsTH/JJxLTTsOt7u6keLGt6kNQ=
|
||||
github.com/sagikazarmark/locafero v0.4.0/go.mod h1:Pe1W6UlPYUk/+wc/6KFhbORCfqzgYEpgQ3O5fPuL3H4=
|
||||
github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6gto+ugjYE=
|
||||
@@ -265,8 +252,6 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
|
||||
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
|
||||
github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0=
|
||||
github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
|
||||
github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I=
|
||||
github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0=
|
||||
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
|
||||
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||
github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ=
|
||||
@@ -350,14 +335,14 @@ go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTV
|
||||
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
|
||||
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||
golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU=
|
||||
golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0=
|
||||
golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8=
|
||||
golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A=
|
||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY=
|
||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
|
||||
golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk=
|
||||
golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc=
|
||||
golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU=
|
||||
golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY=
|
||||
golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI=
|
||||
golang.org/x/mod v0.31.0/go.mod h1:43JraMp9cGx1Rx3AqioxrbrhNsLl2l/iNAvuBkrezpg=
|
||||
golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o=
|
||||
golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8=
|
||||
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
|
||||
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
@@ -369,20 +354,16 @@ golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBc
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk=
|
||||
golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/term v0.38.0 h1:PQ5pkm/rLO6HnxFR7N2lJHOZX6Kez5Y1gDSJla6jo7Q=
|
||||
golang.org/x/term v0.38.0/go.mod h1:bSEAKrOT1W+VSu9TSCMtoGEOUcKxOKgl3LE5QEF/xVg=
|
||||
golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU=
|
||||
golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY=
|
||||
golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ=
|
||||
golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/term v0.39.0 h1:RclSuaJf32jOqZz74CkPA9qFuVTX7vhLlpfj/IGWlqY=
|
||||
golang.org/x/term v0.39.0/go.mod h1:yxzUCTP/U+FzoxfdKmLaA0RV1WgE0VY7hXBwKtY/4ww=
|
||||
golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE=
|
||||
golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8=
|
||||
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
|
||||
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
|
||||
golang.org/x/tools v0.39.0 h1:ik4ho21kwuQln40uelmciQPp9SipgNDdrafrYA4TmQQ=
|
||||
golang.org/x/tools v0.39.0/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ=
|
||||
golang.org/x/tools/go/expect v0.1.1-deprecated h1:jpBZDwmgPhXsKZC6WhL20P4b/wmnpsEAGHaNy0n/rJM=
|
||||
golang.org/x/tools/go/expect v0.1.1-deprecated/go.mod h1:eihoPOH+FgIqa3FpoTwguz/bVUSGBlGQU67vpBeOrBY=
|
||||
golang.org/x/tools/go/packages/packagestest v0.1.1-deprecated h1:1h2MnaIAIXISqTFKdENegdpAgUXz6NrPEsbIeWaBRvM=
|
||||
golang.org/x/tools/go/packages/packagestest v0.1.1-deprecated/go.mod h1:RVAQXBGNv1ib0J382/DPCRS/BPnsGebyM1Gj5VSDpG8=
|
||||
golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA=
|
||||
golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/genproto v0.0.0-20231106174013-bbf56f31fb17 h1:wpZ8pe2x1Q3f2KyT5f8oP/fa9rHAKgFPr/HZdNuS+PQ=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20250929231259-57b25ae835d4 h1:8XJ4pajGwOlasW+L13MnEGA8W4115jJySQtVfS2/IBU=
|
||||
|
||||
@@ -144,12 +144,24 @@ type PricingConfig struct {
|
||||
}
|
||||
|
||||
type ServerConfig struct {
|
||||
Host string `mapstructure:"host"`
|
||||
Port int `mapstructure:"port"`
|
||||
Mode string `mapstructure:"mode"` // debug/release
|
||||
ReadHeaderTimeout int `mapstructure:"read_header_timeout"` // 读取请求头超时(秒)
|
||||
IdleTimeout int `mapstructure:"idle_timeout"` // 空闲连接超时(秒)
|
||||
TrustedProxies []string `mapstructure:"trusted_proxies"` // 可信代理列表(CIDR/IP)
|
||||
Host string `mapstructure:"host"`
|
||||
Port int `mapstructure:"port"`
|
||||
Mode string `mapstructure:"mode"` // debug/release
|
||||
ReadHeaderTimeout int `mapstructure:"read_header_timeout"` // 读取请求头超时(秒)
|
||||
IdleTimeout int `mapstructure:"idle_timeout"` // 空闲连接超时(秒)
|
||||
TrustedProxies []string `mapstructure:"trusted_proxies"` // 可信代理列表(CIDR/IP)
|
||||
MaxRequestBodySize int64 `mapstructure:"max_request_body_size"` // 全局最大请求体限制
|
||||
H2C H2CConfig `mapstructure:"h2c"` // HTTP/2 Cleartext 配置
|
||||
}
|
||||
|
||||
// H2CConfig HTTP/2 Cleartext 配置
|
||||
type H2CConfig struct {
|
||||
Enabled bool `mapstructure:"enabled"` // 是否启用 H2C
|
||||
MaxConcurrentStreams uint32 `mapstructure:"max_concurrent_streams"` // 最大并发流数量
|
||||
IdleTimeout int `mapstructure:"idle_timeout"` // 空闲超时(秒)
|
||||
MaxReadFrameSize int `mapstructure:"max_read_frame_size"` // 最大帧大小(字节)
|
||||
MaxUploadBufferPerConnection int `mapstructure:"max_upload_buffer_per_connection"` // 每个连接的上传缓冲区(字节)
|
||||
MaxUploadBufferPerStream int `mapstructure:"max_upload_buffer_per_stream"` // 每个流的上传缓冲区(字节)
|
||||
}
|
||||
|
||||
type CORSConfig struct {
|
||||
@@ -467,6 +479,13 @@ type OpsMetricsCollectorCacheConfig struct {
|
||||
type JWTConfig struct {
|
||||
Secret string `mapstructure:"secret"`
|
||||
ExpireHour int `mapstructure:"expire_hour"`
|
||||
// AccessTokenExpireMinutes: Access Token有效期(分钟),默认15分钟
|
||||
// 短有效期减少被盗用风险,配合Refresh Token实现无感续期
|
||||
AccessTokenExpireMinutes int `mapstructure:"access_token_expire_minutes"`
|
||||
// RefreshTokenExpireDays: Refresh Token有效期(天),默认30天
|
||||
RefreshTokenExpireDays int `mapstructure:"refresh_token_expire_days"`
|
||||
// RefreshWindowMinutes: 刷新窗口(分钟),在Access Token过期前多久开始允许刷新
|
||||
RefreshWindowMinutes int `mapstructure:"refresh_window_minutes"`
|
||||
}
|
||||
|
||||
// TotpConfig TOTP 双因素认证配置
|
||||
@@ -687,6 +706,14 @@ func setDefaults() {
|
||||
viper.SetDefault("server.read_header_timeout", 30) // 30秒读取请求头
|
||||
viper.SetDefault("server.idle_timeout", 120) // 120秒空闲超时
|
||||
viper.SetDefault("server.trusted_proxies", []string{})
|
||||
viper.SetDefault("server.max_request_body_size", int64(100*1024*1024))
|
||||
// H2C 默认配置
|
||||
viper.SetDefault("server.h2c.enabled", false)
|
||||
viper.SetDefault("server.h2c.max_concurrent_streams", uint32(50)) // 50 个并发流
|
||||
viper.SetDefault("server.h2c.idle_timeout", 75) // 75 秒
|
||||
viper.SetDefault("server.h2c.max_read_frame_size", 1<<20) // 1MB(够用)
|
||||
viper.SetDefault("server.h2c.max_upload_buffer_per_connection", 2<<20) // 2MB
|
||||
viper.SetDefault("server.h2c.max_upload_buffer_per_stream", 512<<10) // 512KB
|
||||
|
||||
// CORS
|
||||
viper.SetDefault("cors.allowed_origins", []string{})
|
||||
@@ -783,6 +810,9 @@ func setDefaults() {
|
||||
// JWT
|
||||
viper.SetDefault("jwt.secret", "")
|
||||
viper.SetDefault("jwt.expire_hour", 24)
|
||||
viper.SetDefault("jwt.access_token_expire_minutes", 360) // 6小时Access Token有效期
|
||||
viper.SetDefault("jwt.refresh_token_expire_days", 30) // 30天Refresh Token有效期
|
||||
viper.SetDefault("jwt.refresh_window_minutes", 2) // 过期前2分钟开始允许刷新
|
||||
|
||||
// TOTP
|
||||
viper.SetDefault("totp.encryption_key", "")
|
||||
@@ -912,6 +942,22 @@ func (c *Config) Validate() error {
|
||||
if c.JWT.ExpireHour > 24 {
|
||||
log.Printf("Warning: jwt.expire_hour is %d hours (> 24). Consider shorter expiration for security.", c.JWT.ExpireHour)
|
||||
}
|
||||
// JWT Refresh Token配置验证
|
||||
if c.JWT.AccessTokenExpireMinutes <= 0 {
|
||||
return fmt.Errorf("jwt.access_token_expire_minutes must be positive")
|
||||
}
|
||||
if c.JWT.AccessTokenExpireMinutes > 720 {
|
||||
log.Printf("Warning: jwt.access_token_expire_minutes is %d (> 720). Consider shorter expiration for security.", c.JWT.AccessTokenExpireMinutes)
|
||||
}
|
||||
if c.JWT.RefreshTokenExpireDays <= 0 {
|
||||
return fmt.Errorf("jwt.refresh_token_expire_days must be positive")
|
||||
}
|
||||
if c.JWT.RefreshTokenExpireDays > 90 {
|
||||
log.Printf("Warning: jwt.refresh_token_expire_days is %d (> 90). Consider shorter expiration for security.", c.JWT.RefreshTokenExpireDays)
|
||||
}
|
||||
if c.JWT.RefreshWindowMinutes < 0 {
|
||||
return fmt.Errorf("jwt.refresh_window_minutes must be non-negative")
|
||||
}
|
||||
if c.Security.CSP.Enabled && strings.TrimSpace(c.Security.CSP.Policy) == "" {
|
||||
return fmt.Errorf("security.csp.policy is required when CSP is enabled")
|
||||
}
|
||||
|
||||
544
backend/internal/handler/admin/account_data.go
Normal file
544
backend/internal/handler/admin/account_data.go
Normal file
@@ -0,0 +1,544 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const (
|
||||
dataType = "sub2api-data"
|
||||
legacyDataType = "sub2api-bundle"
|
||||
dataVersion = 1
|
||||
dataPageCap = 1000
|
||||
)
|
||||
|
||||
type DataPayload struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
Version int `json:"version,omitempty"`
|
||||
ExportedAt string `json:"exported_at"`
|
||||
Proxies []DataProxy `json:"proxies"`
|
||||
Accounts []DataAccount `json:"accounts"`
|
||||
}
|
||||
|
||||
type DataProxy struct {
|
||||
ProxyKey string `json:"proxy_key"`
|
||||
Name string `json:"name"`
|
||||
Protocol string `json:"protocol"`
|
||||
Host string `json:"host"`
|
||||
Port int `json:"port"`
|
||||
Username string `json:"username,omitempty"`
|
||||
Password string `json:"password,omitempty"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
type DataAccount struct {
|
||||
Name string `json:"name"`
|
||||
Notes *string `json:"notes,omitempty"`
|
||||
Platform string `json:"platform"`
|
||||
Type string `json:"type"`
|
||||
Credentials map[string]any `json:"credentials"`
|
||||
Extra map[string]any `json:"extra,omitempty"`
|
||||
ProxyKey *string `json:"proxy_key,omitempty"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
Priority int `json:"priority"`
|
||||
RateMultiplier *float64 `json:"rate_multiplier,omitempty"`
|
||||
ExpiresAt *int64 `json:"expires_at,omitempty"`
|
||||
AutoPauseOnExpired *bool `json:"auto_pause_on_expired,omitempty"`
|
||||
}
|
||||
|
||||
type DataImportRequest struct {
|
||||
Data DataPayload `json:"data"`
|
||||
SkipDefaultGroupBind *bool `json:"skip_default_group_bind"`
|
||||
}
|
||||
|
||||
type DataImportResult struct {
|
||||
ProxyCreated int `json:"proxy_created"`
|
||||
ProxyReused int `json:"proxy_reused"`
|
||||
ProxyFailed int `json:"proxy_failed"`
|
||||
AccountCreated int `json:"account_created"`
|
||||
AccountFailed int `json:"account_failed"`
|
||||
Errors []DataImportError `json:"errors,omitempty"`
|
||||
}
|
||||
|
||||
type DataImportError struct {
|
||||
Kind string `json:"kind"`
|
||||
Name string `json:"name,omitempty"`
|
||||
ProxyKey string `json:"proxy_key,omitempty"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
func buildProxyKey(protocol, host string, port int, username, password string) string {
|
||||
return fmt.Sprintf("%s|%s|%d|%s|%s", strings.TrimSpace(protocol), strings.TrimSpace(host), port, strings.TrimSpace(username), strings.TrimSpace(password))
|
||||
}
|
||||
|
||||
func (h *AccountHandler) ExportData(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
selectedIDs, err := parseAccountIDs(c)
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
accounts, err := h.resolveExportAccounts(ctx, selectedIDs, c)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
includeProxies, err := parseIncludeProxies(c)
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
var proxies []service.Proxy
|
||||
if includeProxies {
|
||||
proxies, err = h.resolveExportProxies(ctx, accounts)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
proxies = []service.Proxy{}
|
||||
}
|
||||
|
||||
proxyKeyByID := make(map[int64]string, len(proxies))
|
||||
dataProxies := make([]DataProxy, 0, len(proxies))
|
||||
for i := range proxies {
|
||||
p := proxies[i]
|
||||
key := buildProxyKey(p.Protocol, p.Host, p.Port, p.Username, p.Password)
|
||||
proxyKeyByID[p.ID] = key
|
||||
dataProxies = append(dataProxies, DataProxy{
|
||||
ProxyKey: key,
|
||||
Name: p.Name,
|
||||
Protocol: p.Protocol,
|
||||
Host: p.Host,
|
||||
Port: p.Port,
|
||||
Username: p.Username,
|
||||
Password: p.Password,
|
||||
Status: p.Status,
|
||||
})
|
||||
}
|
||||
|
||||
dataAccounts := make([]DataAccount, 0, len(accounts))
|
||||
for i := range accounts {
|
||||
acc := accounts[i]
|
||||
var proxyKey *string
|
||||
if acc.ProxyID != nil {
|
||||
if key, ok := proxyKeyByID[*acc.ProxyID]; ok {
|
||||
proxyKey = &key
|
||||
}
|
||||
}
|
||||
var expiresAt *int64
|
||||
if acc.ExpiresAt != nil {
|
||||
v := acc.ExpiresAt.Unix()
|
||||
expiresAt = &v
|
||||
}
|
||||
dataAccounts = append(dataAccounts, DataAccount{
|
||||
Name: acc.Name,
|
||||
Notes: acc.Notes,
|
||||
Platform: acc.Platform,
|
||||
Type: acc.Type,
|
||||
Credentials: acc.Credentials,
|
||||
Extra: acc.Extra,
|
||||
ProxyKey: proxyKey,
|
||||
Concurrency: acc.Concurrency,
|
||||
Priority: acc.Priority,
|
||||
RateMultiplier: acc.RateMultiplier,
|
||||
ExpiresAt: expiresAt,
|
||||
AutoPauseOnExpired: &acc.AutoPauseOnExpired,
|
||||
})
|
||||
}
|
||||
|
||||
payload := DataPayload{
|
||||
ExportedAt: time.Now().UTC().Format(time.RFC3339),
|
||||
Proxies: dataProxies,
|
||||
Accounts: dataAccounts,
|
||||
}
|
||||
|
||||
response.Success(c, payload)
|
||||
}
|
||||
|
||||
func (h *AccountHandler) ImportData(c *gin.Context) {
|
||||
var req DataImportRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
dataPayload := req.Data
|
||||
if err := validateDataHeader(dataPayload); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
skipDefaultGroupBind := true
|
||||
if req.SkipDefaultGroupBind != nil {
|
||||
skipDefaultGroupBind = *req.SkipDefaultGroupBind
|
||||
}
|
||||
|
||||
result := DataImportResult{}
|
||||
existingProxies, err := h.listAllProxies(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
proxyKeyToID := make(map[string]int64, len(existingProxies))
|
||||
for i := range existingProxies {
|
||||
p := existingProxies[i]
|
||||
key := buildProxyKey(p.Protocol, p.Host, p.Port, p.Username, p.Password)
|
||||
proxyKeyToID[key] = p.ID
|
||||
}
|
||||
|
||||
for i := range dataPayload.Proxies {
|
||||
item := dataPayload.Proxies[i]
|
||||
key := item.ProxyKey
|
||||
if key == "" {
|
||||
key = buildProxyKey(item.Protocol, item.Host, item.Port, item.Username, item.Password)
|
||||
}
|
||||
if err := validateDataProxy(item); err != nil {
|
||||
result.ProxyFailed++
|
||||
result.Errors = append(result.Errors, DataImportError{
|
||||
Kind: "proxy",
|
||||
Name: item.Name,
|
||||
ProxyKey: key,
|
||||
Message: err.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
normalizedStatus := normalizeProxyStatus(item.Status)
|
||||
if existingID, ok := proxyKeyToID[key]; ok {
|
||||
proxyKeyToID[key] = existingID
|
||||
result.ProxyReused++
|
||||
if normalizedStatus != "" {
|
||||
if proxy, err := h.adminService.GetProxy(c.Request.Context(), existingID); err == nil && proxy != nil && proxy.Status != normalizedStatus {
|
||||
_, _ = h.adminService.UpdateProxy(c.Request.Context(), existingID, &service.UpdateProxyInput{
|
||||
Status: normalizedStatus,
|
||||
})
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
created, err := h.adminService.CreateProxy(c.Request.Context(), &service.CreateProxyInput{
|
||||
Name: defaultProxyName(item.Name),
|
||||
Protocol: item.Protocol,
|
||||
Host: item.Host,
|
||||
Port: item.Port,
|
||||
Username: item.Username,
|
||||
Password: item.Password,
|
||||
})
|
||||
if err != nil {
|
||||
result.ProxyFailed++
|
||||
result.Errors = append(result.Errors, DataImportError{
|
||||
Kind: "proxy",
|
||||
Name: item.Name,
|
||||
ProxyKey: key,
|
||||
Message: err.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
proxyKeyToID[key] = created.ID
|
||||
result.ProxyCreated++
|
||||
|
||||
if normalizedStatus != "" && normalizedStatus != created.Status {
|
||||
_, _ = h.adminService.UpdateProxy(c.Request.Context(), created.ID, &service.UpdateProxyInput{
|
||||
Status: normalizedStatus,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
for i := range dataPayload.Accounts {
|
||||
item := dataPayload.Accounts[i]
|
||||
if err := validateDataAccount(item); err != nil {
|
||||
result.AccountFailed++
|
||||
result.Errors = append(result.Errors, DataImportError{
|
||||
Kind: "account",
|
||||
Name: item.Name,
|
||||
Message: err.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
var proxyID *int64
|
||||
if item.ProxyKey != nil && *item.ProxyKey != "" {
|
||||
if id, ok := proxyKeyToID[*item.ProxyKey]; ok {
|
||||
proxyID = &id
|
||||
} else {
|
||||
result.AccountFailed++
|
||||
result.Errors = append(result.Errors, DataImportError{
|
||||
Kind: "account",
|
||||
Name: item.Name,
|
||||
ProxyKey: *item.ProxyKey,
|
||||
Message: "proxy_key not found",
|
||||
})
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
accountInput := &service.CreateAccountInput{
|
||||
Name: item.Name,
|
||||
Notes: item.Notes,
|
||||
Platform: item.Platform,
|
||||
Type: item.Type,
|
||||
Credentials: item.Credentials,
|
||||
Extra: item.Extra,
|
||||
ProxyID: proxyID,
|
||||
Concurrency: item.Concurrency,
|
||||
Priority: item.Priority,
|
||||
RateMultiplier: item.RateMultiplier,
|
||||
GroupIDs: nil,
|
||||
ExpiresAt: item.ExpiresAt,
|
||||
AutoPauseOnExpired: item.AutoPauseOnExpired,
|
||||
SkipDefaultGroupBind: skipDefaultGroupBind,
|
||||
}
|
||||
|
||||
if _, err := h.adminService.CreateAccount(c.Request.Context(), accountInput); err != nil {
|
||||
result.AccountFailed++
|
||||
result.Errors = append(result.Errors, DataImportError{
|
||||
Kind: "account",
|
||||
Name: item.Name,
|
||||
Message: err.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
result.AccountCreated++
|
||||
}
|
||||
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
func (h *AccountHandler) listAllProxies(ctx context.Context) ([]service.Proxy, error) {
|
||||
page := 1
|
||||
pageSize := dataPageCap
|
||||
var out []service.Proxy
|
||||
for {
|
||||
items, total, err := h.adminService.ListProxies(ctx, page, pageSize, "", "", "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = append(out, items...)
|
||||
if len(out) >= int(total) || len(items) == 0 {
|
||||
break
|
||||
}
|
||||
page++
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (h *AccountHandler) listAccountsFiltered(ctx context.Context, platform, accountType, status, search string) ([]service.Account, error) {
|
||||
page := 1
|
||||
pageSize := dataPageCap
|
||||
var out []service.Account
|
||||
for {
|
||||
items, total, err := h.adminService.ListAccounts(ctx, page, pageSize, platform, accountType, status, search)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = append(out, items...)
|
||||
if len(out) >= int(total) || len(items) == 0 {
|
||||
break
|
||||
}
|
||||
page++
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (h *AccountHandler) resolveExportAccounts(ctx context.Context, ids []int64, c *gin.Context) ([]service.Account, error) {
|
||||
if len(ids) > 0 {
|
||||
accounts, err := h.adminService.GetAccountsByIDs(ctx, ids)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := make([]service.Account, 0, len(accounts))
|
||||
for _, acc := range accounts {
|
||||
if acc == nil {
|
||||
continue
|
||||
}
|
||||
out = append(out, *acc)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
platform := c.Query("platform")
|
||||
accountType := c.Query("type")
|
||||
status := c.Query("status")
|
||||
search := strings.TrimSpace(c.Query("search"))
|
||||
if len(search) > 100 {
|
||||
search = search[:100]
|
||||
}
|
||||
return h.listAccountsFiltered(ctx, platform, accountType, status, search)
|
||||
}
|
||||
|
||||
func (h *AccountHandler) resolveExportProxies(ctx context.Context, accounts []service.Account) ([]service.Proxy, error) {
|
||||
if len(accounts) == 0 {
|
||||
return []service.Proxy{}, nil
|
||||
}
|
||||
|
||||
seen := make(map[int64]struct{})
|
||||
ids := make([]int64, 0)
|
||||
for i := range accounts {
|
||||
if accounts[i].ProxyID == nil {
|
||||
continue
|
||||
}
|
||||
id := *accounts[i].ProxyID
|
||||
if id <= 0 {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[id]; ok {
|
||||
continue
|
||||
}
|
||||
seen[id] = struct{}{}
|
||||
ids = append(ids, id)
|
||||
}
|
||||
if len(ids) == 0 {
|
||||
return []service.Proxy{}, nil
|
||||
}
|
||||
|
||||
return h.adminService.GetProxiesByIDs(ctx, ids)
|
||||
}
|
||||
|
||||
func parseAccountIDs(c *gin.Context) ([]int64, error) {
|
||||
values := c.QueryArray("ids")
|
||||
if len(values) == 0 {
|
||||
raw := strings.TrimSpace(c.Query("ids"))
|
||||
if raw != "" {
|
||||
values = []string{raw}
|
||||
}
|
||||
}
|
||||
if len(values) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
ids := make([]int64, 0, len(values))
|
||||
for _, item := range values {
|
||||
for _, part := range strings.Split(item, ",") {
|
||||
part = strings.TrimSpace(part)
|
||||
if part == "" {
|
||||
continue
|
||||
}
|
||||
id, err := strconv.ParseInt(part, 10, 64)
|
||||
if err != nil || id <= 0 {
|
||||
return nil, fmt.Errorf("invalid account id: %s", part)
|
||||
}
|
||||
ids = append(ids, id)
|
||||
}
|
||||
}
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
func parseIncludeProxies(c *gin.Context) (bool, error) {
|
||||
raw := strings.TrimSpace(strings.ToLower(c.Query("include_proxies")))
|
||||
if raw == "" {
|
||||
return true, nil
|
||||
}
|
||||
switch raw {
|
||||
case "1", "true", "yes", "on":
|
||||
return true, nil
|
||||
case "0", "false", "no", "off":
|
||||
return false, nil
|
||||
default:
|
||||
return true, fmt.Errorf("invalid include_proxies value: %s", raw)
|
||||
}
|
||||
}
|
||||
|
||||
func validateDataHeader(payload DataPayload) error {
|
||||
if payload.Type != "" && payload.Type != dataType && payload.Type != legacyDataType {
|
||||
return fmt.Errorf("unsupported data type: %s", payload.Type)
|
||||
}
|
||||
if payload.Version != 0 && payload.Version != dataVersion {
|
||||
return fmt.Errorf("unsupported data version: %d", payload.Version)
|
||||
}
|
||||
if payload.Proxies == nil {
|
||||
return errors.New("proxies is required")
|
||||
}
|
||||
if payload.Accounts == nil {
|
||||
return errors.New("accounts is required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateDataProxy(item DataProxy) error {
|
||||
if strings.TrimSpace(item.Protocol) == "" {
|
||||
return errors.New("proxy protocol is required")
|
||||
}
|
||||
if strings.TrimSpace(item.Host) == "" {
|
||||
return errors.New("proxy host is required")
|
||||
}
|
||||
if item.Port <= 0 || item.Port > 65535 {
|
||||
return errors.New("proxy port is invalid")
|
||||
}
|
||||
switch item.Protocol {
|
||||
case "http", "https", "socks5", "socks5h":
|
||||
default:
|
||||
return fmt.Errorf("proxy protocol is invalid: %s", item.Protocol)
|
||||
}
|
||||
if item.Status != "" {
|
||||
normalizedStatus := normalizeProxyStatus(item.Status)
|
||||
if normalizedStatus != service.StatusActive && normalizedStatus != "inactive" {
|
||||
return fmt.Errorf("proxy status is invalid: %s", item.Status)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateDataAccount(item DataAccount) error {
|
||||
if strings.TrimSpace(item.Name) == "" {
|
||||
return errors.New("account name is required")
|
||||
}
|
||||
if strings.TrimSpace(item.Platform) == "" {
|
||||
return errors.New("account platform is required")
|
||||
}
|
||||
if strings.TrimSpace(item.Type) == "" {
|
||||
return errors.New("account type is required")
|
||||
}
|
||||
if len(item.Credentials) == 0 {
|
||||
return errors.New("account credentials is required")
|
||||
}
|
||||
switch item.Type {
|
||||
case service.AccountTypeOAuth, service.AccountTypeSetupToken, service.AccountTypeAPIKey, service.AccountTypeUpstream:
|
||||
default:
|
||||
return fmt.Errorf("account type is invalid: %s", item.Type)
|
||||
}
|
||||
if item.RateMultiplier != nil && *item.RateMultiplier < 0 {
|
||||
return errors.New("rate_multiplier must be >= 0")
|
||||
}
|
||||
if item.Concurrency < 0 {
|
||||
return errors.New("concurrency must be >= 0")
|
||||
}
|
||||
if item.Priority < 0 {
|
||||
return errors.New("priority must be >= 0")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func defaultProxyName(name string) string {
|
||||
if strings.TrimSpace(name) == "" {
|
||||
return "imported-proxy"
|
||||
}
|
||||
return name
|
||||
}
|
||||
|
||||
func normalizeProxyStatus(status string) string {
|
||||
normalized := strings.TrimSpace(strings.ToLower(status))
|
||||
switch normalized {
|
||||
case "":
|
||||
return ""
|
||||
case service.StatusActive:
|
||||
return service.StatusActive
|
||||
case "inactive", service.StatusDisabled:
|
||||
return "inactive"
|
||||
default:
|
||||
return normalized
|
||||
}
|
||||
}
|
||||
231
backend/internal/handler/admin/account_data_handler_test.go
Normal file
231
backend/internal/handler/admin/account_data_handler_test.go
Normal file
@@ -0,0 +1,231 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type dataResponse struct {
|
||||
Code int `json:"code"`
|
||||
Data dataPayload `json:"data"`
|
||||
}
|
||||
|
||||
type dataPayload struct {
|
||||
Type string `json:"type"`
|
||||
Version int `json:"version"`
|
||||
Proxies []dataProxy `json:"proxies"`
|
||||
Accounts []dataAccount `json:"accounts"`
|
||||
}
|
||||
|
||||
type dataProxy struct {
|
||||
ProxyKey string `json:"proxy_key"`
|
||||
Name string `json:"name"`
|
||||
Protocol string `json:"protocol"`
|
||||
Host string `json:"host"`
|
||||
Port int `json:"port"`
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
type dataAccount struct {
|
||||
Name string `json:"name"`
|
||||
Platform string `json:"platform"`
|
||||
Type string `json:"type"`
|
||||
Credentials map[string]any `json:"credentials"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
ProxyKey *string `json:"proxy_key"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
Priority int `json:"priority"`
|
||||
}
|
||||
|
||||
func setupAccountDataRouter() (*gin.Engine, *stubAdminService) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
adminSvc := newStubAdminService()
|
||||
|
||||
h := NewAccountHandler(
|
||||
adminSvc,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
|
||||
router.GET("/api/v1/admin/accounts/data", h.ExportData)
|
||||
router.POST("/api/v1/admin/accounts/data", h.ImportData)
|
||||
return router, adminSvc
|
||||
}
|
||||
|
||||
func TestExportDataIncludesSecrets(t *testing.T) {
|
||||
router, adminSvc := setupAccountDataRouter()
|
||||
|
||||
proxyID := int64(11)
|
||||
adminSvc.proxies = []service.Proxy{
|
||||
{
|
||||
ID: proxyID,
|
||||
Name: "proxy",
|
||||
Protocol: "http",
|
||||
Host: "127.0.0.1",
|
||||
Port: 8080,
|
||||
Username: "user",
|
||||
Password: "pass",
|
||||
Status: service.StatusActive,
|
||||
},
|
||||
{
|
||||
ID: 12,
|
||||
Name: "orphan",
|
||||
Protocol: "https",
|
||||
Host: "10.0.0.1",
|
||||
Port: 443,
|
||||
Username: "o",
|
||||
Password: "p",
|
||||
Status: service.StatusActive,
|
||||
},
|
||||
}
|
||||
adminSvc.accounts = []service.Account{
|
||||
{
|
||||
ID: 21,
|
||||
Name: "account",
|
||||
Platform: service.PlatformOpenAI,
|
||||
Type: service.AccountTypeOAuth,
|
||||
Credentials: map[string]any{"token": "secret"},
|
||||
Extra: map[string]any{"note": "x"},
|
||||
ProxyID: &proxyID,
|
||||
Concurrency: 3,
|
||||
Priority: 50,
|
||||
Status: service.StatusDisabled,
|
||||
},
|
||||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/data", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var resp dataResponse
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Equal(t, 0, resp.Code)
|
||||
require.Empty(t, resp.Data.Type)
|
||||
require.Equal(t, 0, resp.Data.Version)
|
||||
require.Len(t, resp.Data.Proxies, 1)
|
||||
require.Equal(t, "pass", resp.Data.Proxies[0].Password)
|
||||
require.Len(t, resp.Data.Accounts, 1)
|
||||
require.Equal(t, "secret", resp.Data.Accounts[0].Credentials["token"])
|
||||
}
|
||||
|
||||
func TestExportDataWithoutProxies(t *testing.T) {
|
||||
router, adminSvc := setupAccountDataRouter()
|
||||
|
||||
proxyID := int64(11)
|
||||
adminSvc.proxies = []service.Proxy{
|
||||
{
|
||||
ID: proxyID,
|
||||
Name: "proxy",
|
||||
Protocol: "http",
|
||||
Host: "127.0.0.1",
|
||||
Port: 8080,
|
||||
Username: "user",
|
||||
Password: "pass",
|
||||
Status: service.StatusActive,
|
||||
},
|
||||
}
|
||||
adminSvc.accounts = []service.Account{
|
||||
{
|
||||
ID: 21,
|
||||
Name: "account",
|
||||
Platform: service.PlatformOpenAI,
|
||||
Type: service.AccountTypeOAuth,
|
||||
Credentials: map[string]any{"token": "secret"},
|
||||
ProxyID: &proxyID,
|
||||
Concurrency: 3,
|
||||
Priority: 50,
|
||||
Status: service.StatusDisabled,
|
||||
},
|
||||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/data?include_proxies=false", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var resp dataResponse
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Equal(t, 0, resp.Code)
|
||||
require.Len(t, resp.Data.Proxies, 0)
|
||||
require.Len(t, resp.Data.Accounts, 1)
|
||||
require.Nil(t, resp.Data.Accounts[0].ProxyKey)
|
||||
}
|
||||
|
||||
func TestImportDataReusesProxyAndSkipsDefaultGroup(t *testing.T) {
|
||||
router, adminSvc := setupAccountDataRouter()
|
||||
|
||||
adminSvc.proxies = []service.Proxy{
|
||||
{
|
||||
ID: 1,
|
||||
Name: "proxy",
|
||||
Protocol: "socks5",
|
||||
Host: "1.2.3.4",
|
||||
Port: 1080,
|
||||
Username: "u",
|
||||
Password: "p",
|
||||
Status: service.StatusActive,
|
||||
},
|
||||
}
|
||||
|
||||
dataPayload := map[string]any{
|
||||
"data": map[string]any{
|
||||
"type": dataType,
|
||||
"version": dataVersion,
|
||||
"proxies": []map[string]any{
|
||||
{
|
||||
"proxy_key": "socks5|1.2.3.4|1080|u|p",
|
||||
"name": "proxy",
|
||||
"protocol": "socks5",
|
||||
"host": "1.2.3.4",
|
||||
"port": 1080,
|
||||
"username": "u",
|
||||
"password": "p",
|
||||
"status": "active",
|
||||
},
|
||||
},
|
||||
"accounts": []map[string]any{
|
||||
{
|
||||
"name": "acc",
|
||||
"platform": service.PlatformOpenAI,
|
||||
"type": service.AccountTypeOAuth,
|
||||
"credentials": map[string]any{"token": "x"},
|
||||
"proxy_key": "socks5|1.2.3.4|1080|u|p",
|
||||
"concurrency": 3,
|
||||
"priority": 50,
|
||||
},
|
||||
},
|
||||
},
|
||||
"skip_default_group_bind": true,
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(dataPayload)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/data", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
require.Len(t, adminSvc.createdProxies, 0)
|
||||
require.Len(t, adminSvc.createdAccounts, 1)
|
||||
require.True(t, adminSvc.createdAccounts[0].SkipDefaultGroupBind)
|
||||
}
|
||||
@@ -696,11 +696,61 @@ func (h *AccountHandler) BatchCreate(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Return mock data for now
|
||||
ctx := c.Request.Context()
|
||||
success := 0
|
||||
failed := 0
|
||||
results := make([]gin.H, 0, len(req.Accounts))
|
||||
|
||||
for _, item := range req.Accounts {
|
||||
if item.RateMultiplier != nil && *item.RateMultiplier < 0 {
|
||||
failed++
|
||||
results = append(results, gin.H{
|
||||
"name": item.Name,
|
||||
"success": false,
|
||||
"error": "rate_multiplier must be >= 0",
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
skipCheck := item.ConfirmMixedChannelRisk != nil && *item.ConfirmMixedChannelRisk
|
||||
|
||||
account, err := h.adminService.CreateAccount(ctx, &service.CreateAccountInput{
|
||||
Name: item.Name,
|
||||
Notes: item.Notes,
|
||||
Platform: item.Platform,
|
||||
Type: item.Type,
|
||||
Credentials: item.Credentials,
|
||||
Extra: item.Extra,
|
||||
ProxyID: item.ProxyID,
|
||||
Concurrency: item.Concurrency,
|
||||
Priority: item.Priority,
|
||||
RateMultiplier: item.RateMultiplier,
|
||||
GroupIDs: item.GroupIDs,
|
||||
ExpiresAt: item.ExpiresAt,
|
||||
AutoPauseOnExpired: item.AutoPauseOnExpired,
|
||||
SkipMixedChannelCheck: skipCheck,
|
||||
})
|
||||
if err != nil {
|
||||
failed++
|
||||
results = append(results, gin.H{
|
||||
"name": item.Name,
|
||||
"success": false,
|
||||
"error": err.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
success++
|
||||
results = append(results, gin.H{
|
||||
"name": item.Name,
|
||||
"id": account.ID,
|
||||
"success": true,
|
||||
})
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"success": len(req.Accounts),
|
||||
"failed": 0,
|
||||
"results": []gin.H{},
|
||||
"success": success,
|
||||
"failed": failed,
|
||||
"results": results,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -2,19 +2,27 @@ package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
type stubAdminService struct {
|
||||
users []service.User
|
||||
apiKeys []service.APIKey
|
||||
groups []service.Group
|
||||
accounts []service.Account
|
||||
proxies []service.Proxy
|
||||
proxyCounts []service.ProxyWithAccountCount
|
||||
redeems []service.RedeemCode
|
||||
users []service.User
|
||||
apiKeys []service.APIKey
|
||||
groups []service.Group
|
||||
accounts []service.Account
|
||||
proxies []service.Proxy
|
||||
proxyCounts []service.ProxyWithAccountCount
|
||||
redeems []service.RedeemCode
|
||||
createdAccounts []*service.CreateAccountInput
|
||||
createdProxies []*service.CreateProxyInput
|
||||
updatedProxyIDs []int64
|
||||
updatedProxies []*service.UpdateProxyInput
|
||||
testedProxyIDs []int64
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func newStubAdminService() *stubAdminService {
|
||||
@@ -177,6 +185,9 @@ func (s *stubAdminService) GetAccountsByIDs(ctx context.Context, ids []int64) ([
|
||||
}
|
||||
|
||||
func (s *stubAdminService) CreateAccount(ctx context.Context, input *service.CreateAccountInput) (*service.Account, error) {
|
||||
s.mu.Lock()
|
||||
s.createdAccounts = append(s.createdAccounts, input)
|
||||
s.mu.Unlock()
|
||||
account := service.Account{ID: 300, Name: input.Name, Status: service.StatusActive}
|
||||
return &account, nil
|
||||
}
|
||||
@@ -214,7 +225,25 @@ func (s *stubAdminService) BulkUpdateAccounts(ctx context.Context, input *servic
|
||||
}
|
||||
|
||||
func (s *stubAdminService) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]service.Proxy, int64, error) {
|
||||
return s.proxies, int64(len(s.proxies)), nil
|
||||
search = strings.TrimSpace(strings.ToLower(search))
|
||||
filtered := make([]service.Proxy, 0, len(s.proxies))
|
||||
for _, proxy := range s.proxies {
|
||||
if protocol != "" && proxy.Protocol != protocol {
|
||||
continue
|
||||
}
|
||||
if status != "" && proxy.Status != status {
|
||||
continue
|
||||
}
|
||||
if search != "" {
|
||||
name := strings.ToLower(proxy.Name)
|
||||
host := strings.ToLower(proxy.Host)
|
||||
if !strings.Contains(name, search) && !strings.Contains(host, search) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
filtered = append(filtered, proxy)
|
||||
}
|
||||
return filtered, int64(len(filtered)), nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) ListProxiesWithAccountCount(ctx context.Context, page, pageSize int, protocol, status, search string) ([]service.ProxyWithAccountCount, int64, error) {
|
||||
@@ -230,16 +259,47 @@ func (s *stubAdminService) GetAllProxiesWithAccountCount(ctx context.Context) ([
|
||||
}
|
||||
|
||||
func (s *stubAdminService) GetProxy(ctx context.Context, id int64) (*service.Proxy, error) {
|
||||
for i := range s.proxies {
|
||||
proxy := s.proxies[i]
|
||||
if proxy.ID == id {
|
||||
return &proxy, nil
|
||||
}
|
||||
}
|
||||
proxy := service.Proxy{ID: id, Name: "proxy", Status: service.StatusActive}
|
||||
return &proxy, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) GetProxiesByIDs(ctx context.Context, ids []int64) ([]service.Proxy, error) {
|
||||
if len(ids) == 0 {
|
||||
return []service.Proxy{}, nil
|
||||
}
|
||||
out := make([]service.Proxy, 0, len(ids))
|
||||
seen := make(map[int64]struct{}, len(ids))
|
||||
for _, id := range ids {
|
||||
seen[id] = struct{}{}
|
||||
}
|
||||
for i := range s.proxies {
|
||||
proxy := s.proxies[i]
|
||||
if _, ok := seen[proxy.ID]; ok {
|
||||
out = append(out, proxy)
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) CreateProxy(ctx context.Context, input *service.CreateProxyInput) (*service.Proxy, error) {
|
||||
s.mu.Lock()
|
||||
s.createdProxies = append(s.createdProxies, input)
|
||||
s.mu.Unlock()
|
||||
proxy := service.Proxy{ID: 400, Name: input.Name, Status: service.StatusActive}
|
||||
return &proxy, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) UpdateProxy(ctx context.Context, id int64, input *service.UpdateProxyInput) (*service.Proxy, error) {
|
||||
s.mu.Lock()
|
||||
s.updatedProxyIDs = append(s.updatedProxyIDs, id)
|
||||
s.updatedProxies = append(s.updatedProxies, input)
|
||||
s.mu.Unlock()
|
||||
proxy := service.Proxy{ID: id, Name: input.Name, Status: service.StatusActive}
|
||||
return &proxy, nil
|
||||
}
|
||||
@@ -261,6 +321,9 @@ func (s *stubAdminService) CheckProxyExists(ctx context.Context, host string, po
|
||||
}
|
||||
|
||||
func (s *stubAdminService) TestProxy(ctx context.Context, id int64) (*service.ProxyTestResult, error) {
|
||||
s.mu.Lock()
|
||||
s.testedProxyIDs = append(s.testedProxyIDs, id)
|
||||
s.mu.Unlock()
|
||||
return &service.ProxyTestResult{Success: true, Message: "ok"}, nil
|
||||
}
|
||||
|
||||
|
||||
273
backend/internal/handler/admin/error_passthrough_handler.go
Normal file
273
backend/internal/handler/admin/error_passthrough_handler.go
Normal file
@@ -0,0 +1,273 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// ErrorPassthroughHandler 处理错误透传规则的 HTTP 请求
|
||||
type ErrorPassthroughHandler struct {
|
||||
service *service.ErrorPassthroughService
|
||||
}
|
||||
|
||||
// NewErrorPassthroughHandler 创建错误透传规则处理器
|
||||
func NewErrorPassthroughHandler(service *service.ErrorPassthroughService) *ErrorPassthroughHandler {
|
||||
return &ErrorPassthroughHandler{service: service}
|
||||
}
|
||||
|
||||
// CreateErrorPassthroughRuleRequest 创建规则请求
|
||||
type CreateErrorPassthroughRuleRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
Enabled *bool `json:"enabled"`
|
||||
Priority int `json:"priority"`
|
||||
ErrorCodes []int `json:"error_codes"`
|
||||
Keywords []string `json:"keywords"`
|
||||
MatchMode string `json:"match_mode"`
|
||||
Platforms []string `json:"platforms"`
|
||||
PassthroughCode *bool `json:"passthrough_code"`
|
||||
ResponseCode *int `json:"response_code"`
|
||||
PassthroughBody *bool `json:"passthrough_body"`
|
||||
CustomMessage *string `json:"custom_message"`
|
||||
Description *string `json:"description"`
|
||||
}
|
||||
|
||||
// UpdateErrorPassthroughRuleRequest 更新规则请求(部分更新,所有字段可选)
|
||||
type UpdateErrorPassthroughRuleRequest struct {
|
||||
Name *string `json:"name"`
|
||||
Enabled *bool `json:"enabled"`
|
||||
Priority *int `json:"priority"`
|
||||
ErrorCodes []int `json:"error_codes"`
|
||||
Keywords []string `json:"keywords"`
|
||||
MatchMode *string `json:"match_mode"`
|
||||
Platforms []string `json:"platforms"`
|
||||
PassthroughCode *bool `json:"passthrough_code"`
|
||||
ResponseCode *int `json:"response_code"`
|
||||
PassthroughBody *bool `json:"passthrough_body"`
|
||||
CustomMessage *string `json:"custom_message"`
|
||||
Description *string `json:"description"`
|
||||
}
|
||||
|
||||
// List 获取所有规则
|
||||
// GET /api/v1/admin/error-passthrough-rules
|
||||
func (h *ErrorPassthroughHandler) List(c *gin.Context) {
|
||||
rules, err := h.service.List(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, rules)
|
||||
}
|
||||
|
||||
// GetByID 根据 ID 获取规则
|
||||
// GET /api/v1/admin/error-passthrough-rules/:id
|
||||
func (h *ErrorPassthroughHandler) GetByID(c *gin.Context) {
|
||||
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid rule ID")
|
||||
return
|
||||
}
|
||||
|
||||
rule, err := h.service.GetByID(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
if rule == nil {
|
||||
response.NotFound(c, "Rule not found")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, rule)
|
||||
}
|
||||
|
||||
// Create 创建规则
|
||||
// POST /api/v1/admin/error-passthrough-rules
|
||||
func (h *ErrorPassthroughHandler) Create(c *gin.Context) {
|
||||
var req CreateErrorPassthroughRuleRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
rule := &model.ErrorPassthroughRule{
|
||||
Name: req.Name,
|
||||
Priority: req.Priority,
|
||||
ErrorCodes: req.ErrorCodes,
|
||||
Keywords: req.Keywords,
|
||||
Platforms: req.Platforms,
|
||||
}
|
||||
|
||||
// 设置默认值
|
||||
if req.Enabled != nil {
|
||||
rule.Enabled = *req.Enabled
|
||||
} else {
|
||||
rule.Enabled = true
|
||||
}
|
||||
if req.MatchMode != "" {
|
||||
rule.MatchMode = req.MatchMode
|
||||
} else {
|
||||
rule.MatchMode = model.MatchModeAny
|
||||
}
|
||||
if req.PassthroughCode != nil {
|
||||
rule.PassthroughCode = *req.PassthroughCode
|
||||
} else {
|
||||
rule.PassthroughCode = true
|
||||
}
|
||||
if req.PassthroughBody != nil {
|
||||
rule.PassthroughBody = *req.PassthroughBody
|
||||
} else {
|
||||
rule.PassthroughBody = true
|
||||
}
|
||||
rule.ResponseCode = req.ResponseCode
|
||||
rule.CustomMessage = req.CustomMessage
|
||||
rule.Description = req.Description
|
||||
|
||||
// 确保切片不为 nil
|
||||
if rule.ErrorCodes == nil {
|
||||
rule.ErrorCodes = []int{}
|
||||
}
|
||||
if rule.Keywords == nil {
|
||||
rule.Keywords = []string{}
|
||||
}
|
||||
if rule.Platforms == nil {
|
||||
rule.Platforms = []string{}
|
||||
}
|
||||
|
||||
created, err := h.service.Create(c.Request.Context(), rule)
|
||||
if err != nil {
|
||||
if _, ok := err.(*model.ValidationError); ok {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, created)
|
||||
}
|
||||
|
||||
// Update 更新规则(支持部分更新)
|
||||
// PUT /api/v1/admin/error-passthrough-rules/:id
|
||||
func (h *ErrorPassthroughHandler) Update(c *gin.Context) {
|
||||
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid rule ID")
|
||||
return
|
||||
}
|
||||
|
||||
var req UpdateErrorPassthroughRuleRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 先获取现有规则
|
||||
existing, err := h.service.GetByID(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
if existing == nil {
|
||||
response.NotFound(c, "Rule not found")
|
||||
return
|
||||
}
|
||||
|
||||
// 部分更新:只更新请求中提供的字段
|
||||
rule := &model.ErrorPassthroughRule{
|
||||
ID: id,
|
||||
Name: existing.Name,
|
||||
Enabled: existing.Enabled,
|
||||
Priority: existing.Priority,
|
||||
ErrorCodes: existing.ErrorCodes,
|
||||
Keywords: existing.Keywords,
|
||||
MatchMode: existing.MatchMode,
|
||||
Platforms: existing.Platforms,
|
||||
PassthroughCode: existing.PassthroughCode,
|
||||
ResponseCode: existing.ResponseCode,
|
||||
PassthroughBody: existing.PassthroughBody,
|
||||
CustomMessage: existing.CustomMessage,
|
||||
Description: existing.Description,
|
||||
}
|
||||
|
||||
// 应用请求中提供的更新
|
||||
if req.Name != nil {
|
||||
rule.Name = *req.Name
|
||||
}
|
||||
if req.Enabled != nil {
|
||||
rule.Enabled = *req.Enabled
|
||||
}
|
||||
if req.Priority != nil {
|
||||
rule.Priority = *req.Priority
|
||||
}
|
||||
if req.ErrorCodes != nil {
|
||||
rule.ErrorCodes = req.ErrorCodes
|
||||
}
|
||||
if req.Keywords != nil {
|
||||
rule.Keywords = req.Keywords
|
||||
}
|
||||
if req.MatchMode != nil {
|
||||
rule.MatchMode = *req.MatchMode
|
||||
}
|
||||
if req.Platforms != nil {
|
||||
rule.Platforms = req.Platforms
|
||||
}
|
||||
if req.PassthroughCode != nil {
|
||||
rule.PassthroughCode = *req.PassthroughCode
|
||||
}
|
||||
if req.ResponseCode != nil {
|
||||
rule.ResponseCode = req.ResponseCode
|
||||
}
|
||||
if req.PassthroughBody != nil {
|
||||
rule.PassthroughBody = *req.PassthroughBody
|
||||
}
|
||||
if req.CustomMessage != nil {
|
||||
rule.CustomMessage = req.CustomMessage
|
||||
}
|
||||
if req.Description != nil {
|
||||
rule.Description = req.Description
|
||||
}
|
||||
|
||||
// 确保切片不为 nil
|
||||
if rule.ErrorCodes == nil {
|
||||
rule.ErrorCodes = []int{}
|
||||
}
|
||||
if rule.Keywords == nil {
|
||||
rule.Keywords = []string{}
|
||||
}
|
||||
if rule.Platforms == nil {
|
||||
rule.Platforms = []string{}
|
||||
}
|
||||
|
||||
updated, err := h.service.Update(c.Request.Context(), rule)
|
||||
if err != nil {
|
||||
if _, ok := err.(*model.ValidationError); ok {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, updated)
|
||||
}
|
||||
|
||||
// Delete 删除规则
|
||||
// DELETE /api/v1/admin/error-passthrough-rules/:id
|
||||
func (h *ErrorPassthroughHandler) Delete(c *gin.Context) {
|
||||
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid rule ID")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.service.Delete(c.Request.Context(), id); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "Rule deleted successfully"})
|
||||
}
|
||||
239
backend/internal/handler/admin/proxy_data.go
Normal file
239
backend/internal/handler/admin/proxy_data.go
Normal file
@@ -0,0 +1,239 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// ExportData exports proxy-only data for migration.
|
||||
func (h *ProxyHandler) ExportData(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
selectedIDs, err := parseProxyIDs(c)
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
var proxies []service.Proxy
|
||||
if len(selectedIDs) > 0 {
|
||||
proxies, err = h.getProxiesByIDs(ctx, selectedIDs)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
protocol := c.Query("protocol")
|
||||
status := c.Query("status")
|
||||
search := strings.TrimSpace(c.Query("search"))
|
||||
if len(search) > 100 {
|
||||
search = search[:100]
|
||||
}
|
||||
|
||||
proxies, err = h.listProxiesFiltered(ctx, protocol, status, search)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
dataProxies := make([]DataProxy, 0, len(proxies))
|
||||
for i := range proxies {
|
||||
p := proxies[i]
|
||||
key := buildProxyKey(p.Protocol, p.Host, p.Port, p.Username, p.Password)
|
||||
dataProxies = append(dataProxies, DataProxy{
|
||||
ProxyKey: key,
|
||||
Name: p.Name,
|
||||
Protocol: p.Protocol,
|
||||
Host: p.Host,
|
||||
Port: p.Port,
|
||||
Username: p.Username,
|
||||
Password: p.Password,
|
||||
Status: p.Status,
|
||||
})
|
||||
}
|
||||
|
||||
payload := DataPayload{
|
||||
ExportedAt: time.Now().UTC().Format(time.RFC3339),
|
||||
Proxies: dataProxies,
|
||||
Accounts: []DataAccount{},
|
||||
}
|
||||
|
||||
response.Success(c, payload)
|
||||
}
|
||||
|
||||
// ImportData imports proxy-only data for migration.
|
||||
func (h *ProxyHandler) ImportData(c *gin.Context) {
|
||||
type ProxyImportRequest struct {
|
||||
Data DataPayload `json:"data"`
|
||||
}
|
||||
|
||||
var req ProxyImportRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if err := validateDataHeader(req.Data); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
ctx := c.Request.Context()
|
||||
result := DataImportResult{}
|
||||
|
||||
existingProxies, err := h.listProxiesFiltered(ctx, "", "", "")
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
proxyByKey := make(map[string]service.Proxy, len(existingProxies))
|
||||
for i := range existingProxies {
|
||||
p := existingProxies[i]
|
||||
key := buildProxyKey(p.Protocol, p.Host, p.Port, p.Username, p.Password)
|
||||
proxyByKey[key] = p
|
||||
}
|
||||
|
||||
latencyProbeIDs := make([]int64, 0, len(req.Data.Proxies))
|
||||
for i := range req.Data.Proxies {
|
||||
item := req.Data.Proxies[i]
|
||||
key := item.ProxyKey
|
||||
if key == "" {
|
||||
key = buildProxyKey(item.Protocol, item.Host, item.Port, item.Username, item.Password)
|
||||
}
|
||||
|
||||
if err := validateDataProxy(item); err != nil {
|
||||
result.ProxyFailed++
|
||||
result.Errors = append(result.Errors, DataImportError{
|
||||
Kind: "proxy",
|
||||
Name: item.Name,
|
||||
ProxyKey: key,
|
||||
Message: err.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
normalizedStatus := normalizeProxyStatus(item.Status)
|
||||
if existing, ok := proxyByKey[key]; ok {
|
||||
result.ProxyReused++
|
||||
if normalizedStatus != "" && normalizedStatus != existing.Status {
|
||||
if _, err := h.adminService.UpdateProxy(ctx, existing.ID, &service.UpdateProxyInput{Status: normalizedStatus}); err != nil {
|
||||
result.Errors = append(result.Errors, DataImportError{
|
||||
Kind: "proxy",
|
||||
Name: item.Name,
|
||||
ProxyKey: key,
|
||||
Message: "update status failed: " + err.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
latencyProbeIDs = append(latencyProbeIDs, existing.ID)
|
||||
continue
|
||||
}
|
||||
|
||||
created, err := h.adminService.CreateProxy(ctx, &service.CreateProxyInput{
|
||||
Name: defaultProxyName(item.Name),
|
||||
Protocol: item.Protocol,
|
||||
Host: item.Host,
|
||||
Port: item.Port,
|
||||
Username: item.Username,
|
||||
Password: item.Password,
|
||||
})
|
||||
if err != nil {
|
||||
result.ProxyFailed++
|
||||
result.Errors = append(result.Errors, DataImportError{
|
||||
Kind: "proxy",
|
||||
Name: item.Name,
|
||||
ProxyKey: key,
|
||||
Message: err.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
result.ProxyCreated++
|
||||
proxyByKey[key] = *created
|
||||
|
||||
if normalizedStatus != "" && normalizedStatus != created.Status {
|
||||
if _, err := h.adminService.UpdateProxy(ctx, created.ID, &service.UpdateProxyInput{Status: normalizedStatus}); err != nil {
|
||||
result.Errors = append(result.Errors, DataImportError{
|
||||
Kind: "proxy",
|
||||
Name: item.Name,
|
||||
ProxyKey: key,
|
||||
Message: "update status failed: " + err.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
// CreateProxy already triggers a latency probe, avoid double probing here.
|
||||
}
|
||||
|
||||
if len(latencyProbeIDs) > 0 {
|
||||
ids := append([]int64(nil), latencyProbeIDs...)
|
||||
go func() {
|
||||
for _, id := range ids {
|
||||
_, _ = h.adminService.TestProxy(context.Background(), id)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) getProxiesByIDs(ctx context.Context, ids []int64) ([]service.Proxy, error) {
|
||||
if len(ids) == 0 {
|
||||
return []service.Proxy{}, nil
|
||||
}
|
||||
return h.adminService.GetProxiesByIDs(ctx, ids)
|
||||
}
|
||||
|
||||
func parseProxyIDs(c *gin.Context) ([]int64, error) {
|
||||
values := c.QueryArray("ids")
|
||||
if len(values) == 0 {
|
||||
raw := strings.TrimSpace(c.Query("ids"))
|
||||
if raw != "" {
|
||||
values = []string{raw}
|
||||
}
|
||||
}
|
||||
if len(values) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
ids := make([]int64, 0, len(values))
|
||||
for _, item := range values {
|
||||
for _, part := range strings.Split(item, ",") {
|
||||
part = strings.TrimSpace(part)
|
||||
if part == "" {
|
||||
continue
|
||||
}
|
||||
id, err := strconv.ParseInt(part, 10, 64)
|
||||
if err != nil || id <= 0 {
|
||||
return nil, fmt.Errorf("invalid proxy id: %s", part)
|
||||
}
|
||||
ids = append(ids, id)
|
||||
}
|
||||
}
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) listProxiesFiltered(ctx context.Context, protocol, status, search string) ([]service.Proxy, error) {
|
||||
page := 1
|
||||
pageSize := dataPageCap
|
||||
var out []service.Proxy
|
||||
for {
|
||||
items, total, err := h.adminService.ListProxies(ctx, page, pageSize, protocol, status, search)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = append(out, items...)
|
||||
if len(out) >= int(total) || len(items) == 0 {
|
||||
break
|
||||
}
|
||||
page++
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
188
backend/internal/handler/admin/proxy_data_handler_test.go
Normal file
188
backend/internal/handler/admin/proxy_data_handler_test.go
Normal file
@@ -0,0 +1,188 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type proxyDataResponse struct {
|
||||
Code int `json:"code"`
|
||||
Data DataPayload `json:"data"`
|
||||
}
|
||||
|
||||
type proxyImportResponse struct {
|
||||
Code int `json:"code"`
|
||||
Data DataImportResult `json:"data"`
|
||||
}
|
||||
|
||||
func setupProxyDataRouter() (*gin.Engine, *stubAdminService) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
adminSvc := newStubAdminService()
|
||||
|
||||
h := NewProxyHandler(adminSvc)
|
||||
router.GET("/api/v1/admin/proxies/data", h.ExportData)
|
||||
router.POST("/api/v1/admin/proxies/data", h.ImportData)
|
||||
|
||||
return router, adminSvc
|
||||
}
|
||||
|
||||
func TestProxyExportDataRespectsFilters(t *testing.T) {
|
||||
router, adminSvc := setupProxyDataRouter()
|
||||
|
||||
adminSvc.proxies = []service.Proxy{
|
||||
{
|
||||
ID: 1,
|
||||
Name: "proxy-a",
|
||||
Protocol: "http",
|
||||
Host: "127.0.0.1",
|
||||
Port: 8080,
|
||||
Username: "user",
|
||||
Password: "pass",
|
||||
Status: service.StatusActive,
|
||||
},
|
||||
{
|
||||
ID: 2,
|
||||
Name: "proxy-b",
|
||||
Protocol: "https",
|
||||
Host: "10.0.0.2",
|
||||
Port: 443,
|
||||
Username: "u",
|
||||
Password: "p",
|
||||
Status: service.StatusDisabled,
|
||||
},
|
||||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/data?protocol=https", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var resp proxyDataResponse
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Equal(t, 0, resp.Code)
|
||||
require.Empty(t, resp.Data.Type)
|
||||
require.Equal(t, 0, resp.Data.Version)
|
||||
require.Len(t, resp.Data.Proxies, 1)
|
||||
require.Len(t, resp.Data.Accounts, 0)
|
||||
require.Equal(t, "https", resp.Data.Proxies[0].Protocol)
|
||||
}
|
||||
|
||||
func TestProxyExportDataWithSelectedIDs(t *testing.T) {
|
||||
router, adminSvc := setupProxyDataRouter()
|
||||
|
||||
adminSvc.proxies = []service.Proxy{
|
||||
{
|
||||
ID: 1,
|
||||
Name: "proxy-a",
|
||||
Protocol: "http",
|
||||
Host: "127.0.0.1",
|
||||
Port: 8080,
|
||||
Username: "user",
|
||||
Password: "pass",
|
||||
Status: service.StatusActive,
|
||||
},
|
||||
{
|
||||
ID: 2,
|
||||
Name: "proxy-b",
|
||||
Protocol: "https",
|
||||
Host: "10.0.0.2",
|
||||
Port: 443,
|
||||
Username: "u",
|
||||
Password: "p",
|
||||
Status: service.StatusDisabled,
|
||||
},
|
||||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/data?ids=2", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var resp proxyDataResponse
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Equal(t, 0, resp.Code)
|
||||
require.Len(t, resp.Data.Proxies, 1)
|
||||
require.Equal(t, "https", resp.Data.Proxies[0].Protocol)
|
||||
require.Equal(t, "10.0.0.2", resp.Data.Proxies[0].Host)
|
||||
}
|
||||
|
||||
func TestProxyImportDataReusesAndTriggersLatencyProbe(t *testing.T) {
|
||||
router, adminSvc := setupProxyDataRouter()
|
||||
|
||||
adminSvc.proxies = []service.Proxy{
|
||||
{
|
||||
ID: 1,
|
||||
Name: "proxy-a",
|
||||
Protocol: "http",
|
||||
Host: "127.0.0.1",
|
||||
Port: 8080,
|
||||
Username: "user",
|
||||
Password: "pass",
|
||||
Status: service.StatusActive,
|
||||
},
|
||||
}
|
||||
|
||||
payload := map[string]any{
|
||||
"data": map[string]any{
|
||||
"type": dataType,
|
||||
"version": dataVersion,
|
||||
"proxies": []map[string]any{
|
||||
{
|
||||
"proxy_key": "http|127.0.0.1|8080|user|pass",
|
||||
"name": "proxy-a",
|
||||
"protocol": "http",
|
||||
"host": "127.0.0.1",
|
||||
"port": 8080,
|
||||
"username": "user",
|
||||
"password": "pass",
|
||||
"status": "inactive",
|
||||
},
|
||||
{
|
||||
"proxy_key": "https|10.0.0.2|443|u|p",
|
||||
"name": "proxy-b",
|
||||
"protocol": "https",
|
||||
"host": "10.0.0.2",
|
||||
"port": 443,
|
||||
"username": "u",
|
||||
"password": "p",
|
||||
"status": "active",
|
||||
},
|
||||
},
|
||||
"accounts": []map[string]any{},
|
||||
},
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(payload)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/proxies/data", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var resp proxyImportResponse
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Equal(t, 0, resp.Code)
|
||||
require.Equal(t, 1, resp.Data.ProxyCreated)
|
||||
require.Equal(t, 1, resp.Data.ProxyReused)
|
||||
require.Equal(t, 0, resp.Data.ProxyFailed)
|
||||
|
||||
adminSvc.mu.Lock()
|
||||
updatedIDs := append([]int64(nil), adminSvc.updatedProxyIDs...)
|
||||
adminSvc.mu.Unlock()
|
||||
require.Contains(t, updatedIDs, int64(1))
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
adminSvc.mu.Lock()
|
||||
defer adminSvc.mu.Unlock()
|
||||
return len(adminSvc.testedProxyIDs) == 1
|
||||
}, time.Second, 10*time.Millisecond)
|
||||
}
|
||||
@@ -45,6 +45,9 @@ type UpdateUserRequest struct {
|
||||
Concurrency *int `json:"concurrency"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active disabled"`
|
||||
AllowedGroups *[]int64 `json:"allowed_groups"`
|
||||
// GroupRates 用户专属分组倍率配置
|
||||
// map[groupID]*rate,nil 表示删除该分组的专属倍率
|
||||
GroupRates map[int64]*float64 `json:"group_rates"`
|
||||
}
|
||||
|
||||
// UpdateBalanceRequest represents balance update request
|
||||
@@ -183,6 +186,7 @@ func (h *UserHandler) Update(c *gin.Context) {
|
||||
Concurrency: req.Concurrency,
|
||||
Status: req.Status,
|
||||
AllowedGroups: req.AllowedGroups,
|
||||
GroupRates: req.GroupRates,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
|
||||
@@ -243,3 +243,21 @@ func (h *APIKeyHandler) GetAvailableGroups(c *gin.Context) {
|
||||
}
|
||||
response.Success(c, out)
|
||||
}
|
||||
|
||||
// GetUserGroupRates 获取当前用户的专属分组倍率配置
|
||||
// GET /api/v1/groups/rates
|
||||
func (h *APIKeyHandler) GetUserGroupRates(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
rates, err := h.apiKeyService.GetUserGroupRates(c.Request.Context(), subject.UserID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, rates)
|
||||
}
|
||||
|
||||
@@ -68,9 +68,39 @@ type LoginRequest struct {
|
||||
|
||||
// AuthResponse 认证响应格式(匹配前端期望)
|
||||
type AuthResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
User *dto.User `json:"user"`
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"` // 新增:Refresh Token
|
||||
ExpiresIn int `json:"expires_in,omitempty"` // 新增:Access Token有效期(秒)
|
||||
TokenType string `json:"token_type"`
|
||||
User *dto.User `json:"user"`
|
||||
}
|
||||
|
||||
// respondWithTokenPair 生成 Token 对并返回认证响应
|
||||
// 如果 Token 对生成失败,回退到只返回 Access Token(向后兼容)
|
||||
func (h *AuthHandler) respondWithTokenPair(c *gin.Context, user *service.User) {
|
||||
tokenPair, err := h.authService.GenerateTokenPair(c.Request.Context(), user, "")
|
||||
if err != nil {
|
||||
slog.Error("failed to generate token pair", "error", err, "user_id", user.ID)
|
||||
// 回退到只返回Access Token
|
||||
token, tokenErr := h.authService.GenerateToken(user)
|
||||
if tokenErr != nil {
|
||||
response.InternalError(c, "Failed to generate token")
|
||||
return
|
||||
}
|
||||
response.Success(c, AuthResponse{
|
||||
AccessToken: token,
|
||||
TokenType: "Bearer",
|
||||
User: dto.UserFromService(user),
|
||||
})
|
||||
return
|
||||
}
|
||||
response.Success(c, AuthResponse{
|
||||
AccessToken: tokenPair.AccessToken,
|
||||
RefreshToken: tokenPair.RefreshToken,
|
||||
ExpiresIn: tokenPair.ExpiresIn,
|
||||
TokenType: "Bearer",
|
||||
User: dto.UserFromService(user),
|
||||
})
|
||||
}
|
||||
|
||||
// Register handles user registration
|
||||
@@ -90,17 +120,13 @@ func (h *AuthHandler) Register(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
token, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode, req.PromoCode, req.InvitationCode)
|
||||
_, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode, req.PromoCode, req.InvitationCode)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, AuthResponse{
|
||||
AccessToken: token,
|
||||
TokenType: "Bearer",
|
||||
User: dto.UserFromService(user),
|
||||
})
|
||||
h.respondWithTokenPair(c, user)
|
||||
}
|
||||
|
||||
// SendVerifyCode 发送邮箱验证码
|
||||
@@ -150,6 +176,7 @@ func (h *AuthHandler) Login(c *gin.Context) {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
_ = token // token 由 authService.Login 返回但此处由 respondWithTokenPair 重新生成
|
||||
|
||||
// Check if TOTP 2FA is enabled for this user
|
||||
if h.totpService != nil && h.settingSvc.IsTotpEnabled(c.Request.Context()) && user.TotpEnabled {
|
||||
@@ -168,11 +195,7 @@ func (h *AuthHandler) Login(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, AuthResponse{
|
||||
AccessToken: token,
|
||||
TokenType: "Bearer",
|
||||
User: dto.UserFromService(user),
|
||||
})
|
||||
h.respondWithTokenPair(c, user)
|
||||
}
|
||||
|
||||
// TotpLoginResponse represents the response when 2FA is required
|
||||
@@ -238,18 +261,7 @@ func (h *AuthHandler) Login2FA(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Generate the JWT token
|
||||
token, err := h.authService.GenerateToken(user)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to generate token")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, AuthResponse{
|
||||
AccessToken: token,
|
||||
TokenType: "Bearer",
|
||||
User: dto.UserFromService(user),
|
||||
})
|
||||
h.respondWithTokenPair(c, user)
|
||||
}
|
||||
|
||||
// GetCurrentUser handles getting current authenticated user
|
||||
@@ -491,3 +503,96 @@ func (h *AuthHandler) ResetPassword(c *gin.Context) {
|
||||
Message: "Your password has been reset successfully. You can now log in with your new password.",
|
||||
})
|
||||
}
|
||||
|
||||
// ==================== Token Refresh Endpoints ====================
|
||||
|
||||
// RefreshTokenRequest 刷新Token请求
|
||||
type RefreshTokenRequest struct {
|
||||
RefreshToken string `json:"refresh_token" binding:"required"`
|
||||
}
|
||||
|
||||
// RefreshTokenResponse 刷新Token响应
|
||||
type RefreshTokenResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresIn int `json:"expires_in"` // Access Token有效期(秒)
|
||||
TokenType string `json:"token_type"`
|
||||
}
|
||||
|
||||
// RefreshToken 刷新Token
|
||||
// POST /api/v1/auth/refresh
|
||||
func (h *AuthHandler) RefreshToken(c *gin.Context) {
|
||||
var req RefreshTokenRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
tokenPair, err := h.authService.RefreshTokenPair(c.Request.Context(), req.RefreshToken)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, RefreshTokenResponse{
|
||||
AccessToken: tokenPair.AccessToken,
|
||||
RefreshToken: tokenPair.RefreshToken,
|
||||
ExpiresIn: tokenPair.ExpiresIn,
|
||||
TokenType: "Bearer",
|
||||
})
|
||||
}
|
||||
|
||||
// LogoutRequest 登出请求
|
||||
type LogoutRequest struct {
|
||||
RefreshToken string `json:"refresh_token,omitempty"` // 可选:撤销指定的Refresh Token
|
||||
}
|
||||
|
||||
// LogoutResponse 登出响应
|
||||
type LogoutResponse struct {
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// Logout 用户登出
|
||||
// POST /api/v1/auth/logout
|
||||
func (h *AuthHandler) Logout(c *gin.Context) {
|
||||
var req LogoutRequest
|
||||
// 允许空请求体(向后兼容)
|
||||
_ = c.ShouldBindJSON(&req)
|
||||
|
||||
// 如果提供了Refresh Token,撤销它
|
||||
if req.RefreshToken != "" {
|
||||
if err := h.authService.RevokeRefreshToken(c.Request.Context(), req.RefreshToken); err != nil {
|
||||
slog.Debug("failed to revoke refresh token", "error", err)
|
||||
// 不影响登出流程
|
||||
}
|
||||
}
|
||||
|
||||
response.Success(c, LogoutResponse{
|
||||
Message: "Logged out successfully",
|
||||
})
|
||||
}
|
||||
|
||||
// RevokeAllSessionsResponse 撤销所有会话响应
|
||||
type RevokeAllSessionsResponse struct {
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// RevokeAllSessions 撤销当前用户的所有会话
|
||||
// POST /api/v1/auth/revoke-all-sessions
|
||||
func (h *AuthHandler) RevokeAllSessions(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.authService.RevokeAllUserSessions(c.Request.Context(), subject.UserID); err != nil {
|
||||
slog.Error("failed to revoke all sessions", "user_id", subject.UserID, "error", err)
|
||||
response.InternalError(c, "Failed to revoke sessions")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, RevokeAllSessionsResponse{
|
||||
Message: "All sessions have been revoked. Please log in again.",
|
||||
})
|
||||
}
|
||||
|
||||
@@ -211,7 +211,7 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
|
||||
email = linuxDoSyntheticEmail(subject)
|
||||
}
|
||||
|
||||
jwtToken, _, err := h.authService.LoginOrRegisterOAuth(c.Request.Context(), email, username)
|
||||
tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username)
|
||||
if err != nil {
|
||||
// 避免把内部细节泄露给客户端;给前端保留结构化原因与提示信息即可。
|
||||
redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err))
|
||||
@@ -219,7 +219,9 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
|
||||
}
|
||||
|
||||
fragment := url.Values{}
|
||||
fragment.Set("access_token", jwtToken)
|
||||
fragment.Set("access_token", tokenPair.AccessToken)
|
||||
fragment.Set("refresh_token", tokenPair.RefreshToken)
|
||||
fragment.Set("expires_in", fmt.Sprintf("%d", tokenPair.ExpiresIn))
|
||||
fragment.Set("token_type", "Bearer")
|
||||
fragment.Set("redirect", redirectTo)
|
||||
redirectWithFragment(c, frontendCallback, fragment)
|
||||
|
||||
@@ -58,8 +58,9 @@ func UserFromServiceAdmin(u *service.User) *AdminUser {
|
||||
return nil
|
||||
}
|
||||
return &AdminUser{
|
||||
User: *base,
|
||||
Notes: u.Notes,
|
||||
User: *base,
|
||||
Notes: u.Notes,
|
||||
GroupRates: u.GroupRates,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -29,6 +29,9 @@ type AdminUser struct {
|
||||
User
|
||||
|
||||
Notes string `json:"notes"`
|
||||
// GroupRates 用户专属分组倍率配置
|
||||
// map[groupID]rateMultiplier
|
||||
GroupRates map[int64]float64 `json:"group_rates,omitempty"`
|
||||
}
|
||||
|
||||
type APIKey struct {
|
||||
|
||||
@@ -33,6 +33,7 @@ type GatewayHandler struct {
|
||||
billingCacheService *service.BillingCacheService
|
||||
usageService *service.UsageService
|
||||
apiKeyService *service.APIKeyService
|
||||
errorPassthroughService *service.ErrorPassthroughService
|
||||
concurrencyHelper *ConcurrencyHelper
|
||||
maxAccountSwitches int
|
||||
maxAccountSwitchesGemini int
|
||||
@@ -48,6 +49,7 @@ func NewGatewayHandler(
|
||||
billingCacheService *service.BillingCacheService,
|
||||
usageService *service.UsageService,
|
||||
apiKeyService *service.APIKeyService,
|
||||
errorPassthroughService *service.ErrorPassthroughService,
|
||||
cfg *config.Config,
|
||||
) *GatewayHandler {
|
||||
pingInterval := time.Duration(0)
|
||||
@@ -70,6 +72,7 @@ func NewGatewayHandler(
|
||||
billingCacheService: billingCacheService,
|
||||
usageService: usageService,
|
||||
apiKeyService: apiKeyService,
|
||||
errorPassthroughService: errorPassthroughService,
|
||||
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval),
|
||||
maxAccountSwitches: maxAccountSwitches,
|
||||
maxAccountSwitchesGemini: maxAccountSwitchesGemini,
|
||||
@@ -201,7 +204,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
maxAccountSwitches := h.maxAccountSwitchesGemini
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
lastFailoverStatus := 0
|
||||
var lastFailoverErr *service.UpstreamFailoverError
|
||||
|
||||
for {
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, "") // Gemini 不使用会话限制
|
||||
@@ -210,7 +213,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||
return
|
||||
}
|
||||
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
||||
if lastFailoverErr != nil {
|
||||
h.handleFailoverExhausted(c, lastFailoverErr, service.PlatformGemini, streamStarted)
|
||||
} else {
|
||||
h.handleFailoverExhaustedSimple(c, 502, streamStarted)
|
||||
}
|
||||
return
|
||||
}
|
||||
account := selection.Account
|
||||
@@ -301,9 +308,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
lastFailoverStatus = failoverErr.StatusCode
|
||||
lastFailoverErr = failoverErr
|
||||
if switchCount >= maxAccountSwitches {
|
||||
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
||||
h.handleFailoverExhausted(c, failoverErr, service.PlatformGemini, streamStarted)
|
||||
return
|
||||
}
|
||||
switchCount++
|
||||
@@ -352,7 +359,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
maxAccountSwitches := h.maxAccountSwitches
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
lastFailoverStatus := 0
|
||||
var lastFailoverErr *service.UpstreamFailoverError
|
||||
retryWithFallback := false
|
||||
|
||||
for {
|
||||
@@ -363,7 +370,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||
return
|
||||
}
|
||||
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
||||
if lastFailoverErr != nil {
|
||||
h.handleFailoverExhausted(c, lastFailoverErr, platform, streamStarted)
|
||||
} else {
|
||||
h.handleFailoverExhaustedSimple(c, 502, streamStarted)
|
||||
}
|
||||
return
|
||||
}
|
||||
account := selection.Account
|
||||
@@ -487,9 +498,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
lastFailoverStatus = failoverErr.StatusCode
|
||||
lastFailoverErr = failoverErr
|
||||
if switchCount >= maxAccountSwitches {
|
||||
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
||||
h.handleFailoverExhausted(c, failoverErr, account.Platform, streamStarted)
|
||||
return
|
||||
}
|
||||
switchCount++
|
||||
@@ -616,10 +627,10 @@ func (h *GatewayHandler) Usage(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Best-effort: 获取用量统计,失败不影响基础响应
|
||||
// Best-effort: 获取用量统计(按当前 API Key 过滤),失败不影响基础响应
|
||||
var usageData gin.H
|
||||
if h.usageService != nil {
|
||||
dashStats, err := h.usageService.GetUserDashboardStats(c.Request.Context(), subject.UserID)
|
||||
dashStats, err := h.usageService.GetAPIKeyDashboardStats(c.Request.Context(), apiKey.ID)
|
||||
if err == nil && dashStats != nil {
|
||||
usageData = gin.H{
|
||||
"today": gin.H{
|
||||
@@ -755,7 +766,37 @@ func (h *GatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotT
|
||||
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
|
||||
}
|
||||
|
||||
func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, streamStarted bool) {
|
||||
func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, platform string, streamStarted bool) {
|
||||
statusCode := failoverErr.StatusCode
|
||||
responseBody := failoverErr.ResponseBody
|
||||
|
||||
// 先检查透传规则
|
||||
if h.errorPassthroughService != nil && len(responseBody) > 0 {
|
||||
if rule := h.errorPassthroughService.MatchRule(platform, statusCode, responseBody); rule != nil {
|
||||
// 确定响应状态码
|
||||
respCode := statusCode
|
||||
if !rule.PassthroughCode && rule.ResponseCode != nil {
|
||||
respCode = *rule.ResponseCode
|
||||
}
|
||||
|
||||
// 确定响应消息
|
||||
msg := service.ExtractUpstreamErrorMessage(responseBody)
|
||||
if !rule.PassthroughBody && rule.CustomMessage != nil {
|
||||
msg = *rule.CustomMessage
|
||||
}
|
||||
|
||||
h.handleStreamingAwareError(c, respCode, "upstream_error", msg, streamStarted)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 使用默认的错误映射
|
||||
status, errType, errMsg := h.mapUpstreamError(statusCode)
|
||||
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
||||
}
|
||||
|
||||
// handleFailoverExhaustedSimple 简化版本,用于没有响应体的情况
|
||||
func (h *GatewayHandler) handleFailoverExhaustedSimple(c *gin.Context, statusCode int, streamStarted bool) {
|
||||
status, errType, errMsg := h.mapUpstreamError(statusCode)
|
||||
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
||||
}
|
||||
|
||||
@@ -253,7 +253,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
maxAccountSwitches := h.maxAccountSwitchesGemini
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
lastFailoverStatus := 0
|
||||
var lastFailoverErr *service.UpstreamFailoverError
|
||||
|
||||
for {
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, failedAccountIDs, "") // Gemini 不使用会话限制
|
||||
@@ -262,7 +262,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
|
||||
return
|
||||
}
|
||||
handleGeminiFailoverExhausted(c, lastFailoverStatus)
|
||||
h.handleGeminiFailoverExhausted(c, lastFailoverErr)
|
||||
return
|
||||
}
|
||||
account := selection.Account
|
||||
@@ -353,11 +353,11 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
if errors.As(err, &failoverErr) {
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
if switchCount >= maxAccountSwitches {
|
||||
lastFailoverStatus = failoverErr.StatusCode
|
||||
handleGeminiFailoverExhausted(c, lastFailoverStatus)
|
||||
lastFailoverErr = failoverErr
|
||||
h.handleGeminiFailoverExhausted(c, lastFailoverErr)
|
||||
return
|
||||
}
|
||||
lastFailoverStatus = failoverErr.StatusCode
|
||||
lastFailoverErr = failoverErr
|
||||
switchCount++
|
||||
log.Printf("Gemini account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
|
||||
continue
|
||||
@@ -414,7 +414,36 @@ func parseGeminiModelAction(rest string) (model string, action string, err error
|
||||
return "", "", &pathParseError{"invalid model action path"}
|
||||
}
|
||||
|
||||
func handleGeminiFailoverExhausted(c *gin.Context, statusCode int) {
|
||||
func (h *GatewayHandler) handleGeminiFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError) {
|
||||
if failoverErr == nil {
|
||||
googleError(c, http.StatusBadGateway, "Upstream request failed")
|
||||
return
|
||||
}
|
||||
|
||||
statusCode := failoverErr.StatusCode
|
||||
responseBody := failoverErr.ResponseBody
|
||||
|
||||
// 先检查透传规则
|
||||
if h.errorPassthroughService != nil && len(responseBody) > 0 {
|
||||
if rule := h.errorPassthroughService.MatchRule(service.PlatformGemini, statusCode, responseBody); rule != nil {
|
||||
// 确定响应状态码
|
||||
respCode := statusCode
|
||||
if !rule.PassthroughCode && rule.ResponseCode != nil {
|
||||
respCode = *rule.ResponseCode
|
||||
}
|
||||
|
||||
// 确定响应消息
|
||||
msg := service.ExtractUpstreamErrorMessage(responseBody)
|
||||
if !rule.PassthroughBody && rule.CustomMessage != nil {
|
||||
msg = *rule.CustomMessage
|
||||
}
|
||||
|
||||
googleError(c, respCode, msg)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 使用默认的错误映射
|
||||
status, message := mapGeminiUpstreamError(statusCode)
|
||||
googleError(c, status, message)
|
||||
}
|
||||
|
||||
@@ -24,6 +24,7 @@ type AdminHandlers struct {
|
||||
Subscription *admin.SubscriptionHandler
|
||||
Usage *admin.UsageHandler
|
||||
UserAttribute *admin.UserAttributeHandler
|
||||
ErrorPassthrough *admin.ErrorPassthroughHandler
|
||||
}
|
||||
|
||||
// Handlers contains all HTTP handlers
|
||||
|
||||
@@ -22,11 +22,12 @@ import (
|
||||
|
||||
// OpenAIGatewayHandler handles OpenAI API gateway requests
|
||||
type OpenAIGatewayHandler struct {
|
||||
gatewayService *service.OpenAIGatewayService
|
||||
billingCacheService *service.BillingCacheService
|
||||
apiKeyService *service.APIKeyService
|
||||
concurrencyHelper *ConcurrencyHelper
|
||||
maxAccountSwitches int
|
||||
gatewayService *service.OpenAIGatewayService
|
||||
billingCacheService *service.BillingCacheService
|
||||
apiKeyService *service.APIKeyService
|
||||
errorPassthroughService *service.ErrorPassthroughService
|
||||
concurrencyHelper *ConcurrencyHelper
|
||||
maxAccountSwitches int
|
||||
}
|
||||
|
||||
// NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler
|
||||
@@ -35,6 +36,7 @@ func NewOpenAIGatewayHandler(
|
||||
concurrencyService *service.ConcurrencyService,
|
||||
billingCacheService *service.BillingCacheService,
|
||||
apiKeyService *service.APIKeyService,
|
||||
errorPassthroughService *service.ErrorPassthroughService,
|
||||
cfg *config.Config,
|
||||
) *OpenAIGatewayHandler {
|
||||
pingInterval := time.Duration(0)
|
||||
@@ -46,11 +48,12 @@ func NewOpenAIGatewayHandler(
|
||||
}
|
||||
}
|
||||
return &OpenAIGatewayHandler{
|
||||
gatewayService: gatewayService,
|
||||
billingCacheService: billingCacheService,
|
||||
apiKeyService: apiKeyService,
|
||||
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
|
||||
maxAccountSwitches: maxAccountSwitches,
|
||||
gatewayService: gatewayService,
|
||||
billingCacheService: billingCacheService,
|
||||
apiKeyService: apiKeyService,
|
||||
errorPassthroughService: errorPassthroughService,
|
||||
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
|
||||
maxAccountSwitches: maxAccountSwitches,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -201,7 +204,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
maxAccountSwitches := h.maxAccountSwitches
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
lastFailoverStatus := 0
|
||||
var lastFailoverErr *service.UpstreamFailoverError
|
||||
|
||||
for {
|
||||
// Select account supporting the requested model
|
||||
@@ -213,7 +216,11 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||
return
|
||||
}
|
||||
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
||||
if lastFailoverErr != nil {
|
||||
h.handleFailoverExhausted(c, lastFailoverErr, streamStarted)
|
||||
} else {
|
||||
h.handleFailoverExhaustedSimple(c, 502, streamStarted)
|
||||
}
|
||||
return
|
||||
}
|
||||
account := selection.Account
|
||||
@@ -278,12 +285,11 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
lastFailoverErr = failoverErr
|
||||
if switchCount >= maxAccountSwitches {
|
||||
lastFailoverStatus = failoverErr.StatusCode
|
||||
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
||||
h.handleFailoverExhausted(c, failoverErr, streamStarted)
|
||||
return
|
||||
}
|
||||
lastFailoverStatus = failoverErr.StatusCode
|
||||
switchCount++
|
||||
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
|
||||
continue
|
||||
@@ -324,7 +330,37 @@ func (h *OpenAIGatewayHandler) handleConcurrencyError(c *gin.Context, err error,
|
||||
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
|
||||
}
|
||||
|
||||
func (h *OpenAIGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, streamStarted bool) {
|
||||
func (h *OpenAIGatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, streamStarted bool) {
|
||||
statusCode := failoverErr.StatusCode
|
||||
responseBody := failoverErr.ResponseBody
|
||||
|
||||
// 先检查透传规则
|
||||
if h.errorPassthroughService != nil && len(responseBody) > 0 {
|
||||
if rule := h.errorPassthroughService.MatchRule("openai", statusCode, responseBody); rule != nil {
|
||||
// 确定响应状态码
|
||||
respCode := statusCode
|
||||
if !rule.PassthroughCode && rule.ResponseCode != nil {
|
||||
respCode = *rule.ResponseCode
|
||||
}
|
||||
|
||||
// 确定响应消息
|
||||
msg := service.ExtractUpstreamErrorMessage(responseBody)
|
||||
if !rule.PassthroughBody && rule.CustomMessage != nil {
|
||||
msg = *rule.CustomMessage
|
||||
}
|
||||
|
||||
h.handleStreamingAwareError(c, respCode, "upstream_error", msg, streamStarted)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 使用默认的错误映射
|
||||
status, errType, errMsg := h.mapUpstreamError(statusCode)
|
||||
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
||||
}
|
||||
|
||||
// handleFailoverExhaustedSimple 简化版本,用于没有响应体的情况
|
||||
func (h *OpenAIGatewayHandler) handleFailoverExhaustedSimple(c *gin.Context, statusCode int, streamStarted bool) {
|
||||
status, errType, errMsg := h.mapUpstreamError(statusCode)
|
||||
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
||||
}
|
||||
|
||||
@@ -27,6 +27,7 @@ func ProvideAdminHandlers(
|
||||
subscriptionHandler *admin.SubscriptionHandler,
|
||||
usageHandler *admin.UsageHandler,
|
||||
userAttributeHandler *admin.UserAttributeHandler,
|
||||
errorPassthroughHandler *admin.ErrorPassthroughHandler,
|
||||
) *AdminHandlers {
|
||||
return &AdminHandlers{
|
||||
Dashboard: dashboardHandler,
|
||||
@@ -47,6 +48,7 @@ func ProvideAdminHandlers(
|
||||
Subscription: subscriptionHandler,
|
||||
Usage: usageHandler,
|
||||
UserAttribute: userAttributeHandler,
|
||||
ErrorPassthrough: errorPassthroughHandler,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -125,6 +127,7 @@ var ProviderSet = wire.NewSet(
|
||||
admin.NewSubscriptionHandler,
|
||||
admin.NewUsageHandler,
|
||||
admin.NewUserAttributeHandler,
|
||||
admin.NewErrorPassthroughHandler,
|
||||
|
||||
// AdminHandlers and Handlers constructors
|
||||
ProvideAdminHandlers,
|
||||
|
||||
74
backend/internal/model/error_passthrough_rule.go
Normal file
74
backend/internal/model/error_passthrough_rule.go
Normal file
@@ -0,0 +1,74 @@
|
||||
// Package model 定义服务层使用的数据模型。
|
||||
package model
|
||||
|
||||
import "time"
|
||||
|
||||
// ErrorPassthroughRule 全局错误透传规则
|
||||
// 用于控制上游错误如何返回给客户端
|
||||
type ErrorPassthroughRule struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"` // 规则名称
|
||||
Enabled bool `json:"enabled"` // 是否启用
|
||||
Priority int `json:"priority"` // 优先级(数字越小优先级越高)
|
||||
ErrorCodes []int `json:"error_codes"` // 匹配的错误码列表(OR关系)
|
||||
Keywords []string `json:"keywords"` // 匹配的关键词列表(OR关系)
|
||||
MatchMode string `json:"match_mode"` // "any"(任一条件) 或 "all"(所有条件)
|
||||
Platforms []string `json:"platforms"` // 适用平台列表
|
||||
PassthroughCode bool `json:"passthrough_code"` // 是否透传原始状态码
|
||||
ResponseCode *int `json:"response_code"` // 自定义状态码(passthrough_code=false 时使用)
|
||||
PassthroughBody bool `json:"passthrough_body"` // 是否透传原始错误信息
|
||||
CustomMessage *string `json:"custom_message"` // 自定义错误信息(passthrough_body=false 时使用)
|
||||
Description *string `json:"description"` // 规则描述
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// MatchModeAny 表示任一条件匹配即可
|
||||
const MatchModeAny = "any"
|
||||
|
||||
// MatchModeAll 表示所有条件都必须匹配
|
||||
const MatchModeAll = "all"
|
||||
|
||||
// 支持的平台常量
|
||||
const (
|
||||
PlatformAnthropic = "anthropic"
|
||||
PlatformOpenAI = "openai"
|
||||
PlatformGemini = "gemini"
|
||||
PlatformAntigravity = "antigravity"
|
||||
)
|
||||
|
||||
// AllPlatforms 返回所有支持的平台列表
|
||||
func AllPlatforms() []string {
|
||||
return []string{PlatformAnthropic, PlatformOpenAI, PlatformGemini, PlatformAntigravity}
|
||||
}
|
||||
|
||||
// Validate 验证规则配置的有效性
|
||||
func (r *ErrorPassthroughRule) Validate() error {
|
||||
if r.Name == "" {
|
||||
return &ValidationError{Field: "name", Message: "name is required"}
|
||||
}
|
||||
if r.MatchMode != MatchModeAny && r.MatchMode != MatchModeAll {
|
||||
return &ValidationError{Field: "match_mode", Message: "match_mode must be 'any' or 'all'"}
|
||||
}
|
||||
// 至少需要配置一个匹配条件(错误码或关键词)
|
||||
if len(r.ErrorCodes) == 0 && len(r.Keywords) == 0 {
|
||||
return &ValidationError{Field: "conditions", Message: "at least one error_code or keyword is required"}
|
||||
}
|
||||
if !r.PassthroughCode && (r.ResponseCode == nil || *r.ResponseCode <= 0) {
|
||||
return &ValidationError{Field: "response_code", Message: "response_code is required when passthrough_code is false"}
|
||||
}
|
||||
if !r.PassthroughBody && (r.CustomMessage == nil || *r.CustomMessage == "") {
|
||||
return &ValidationError{Field: "custom_message", Message: "custom_message is required when passthrough_body is false"}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidationError 表示验证错误
|
||||
type ValidationError struct {
|
||||
Field string
|
||||
Message string
|
||||
}
|
||||
|
||||
func (e *ValidationError) Error() string {
|
||||
return e.Field + ": " + e.Message
|
||||
}
|
||||
@@ -71,6 +71,12 @@ var DefaultModels = []Model{
|
||||
DisplayName: "Claude Opus 4.5",
|
||||
CreatedAt: "2025-11-01T00:00:00Z",
|
||||
},
|
||||
{
|
||||
ID: "claude-opus-4-6",
|
||||
Type: "model",
|
||||
DisplayName: "Claude Opus 4.6",
|
||||
CreatedAt: "2026-02-06T00:00:00Z",
|
||||
},
|
||||
{
|
||||
ID: "claude-sonnet-4-5-20250929",
|
||||
Type: "model",
|
||||
|
||||
109
backend/internal/pkg/googleapi/error.go
Normal file
109
backend/internal/pkg/googleapi/error.go
Normal file
@@ -0,0 +1,109 @@
|
||||
// Package googleapi provides helpers for Google-style API responses.
|
||||
package googleapi
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ErrorResponse represents a Google API error response
|
||||
type ErrorResponse struct {
|
||||
Error ErrorDetail `json:"error"`
|
||||
}
|
||||
|
||||
// ErrorDetail contains the error details from Google API
|
||||
type ErrorDetail struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Status string `json:"status"`
|
||||
Details []json.RawMessage `json:"details,omitempty"`
|
||||
}
|
||||
|
||||
// ErrorDetailInfo contains additional error information
|
||||
type ErrorDetailInfo struct {
|
||||
Type string `json:"@type"`
|
||||
Reason string `json:"reason,omitempty"`
|
||||
Domain string `json:"domain,omitempty"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// ErrorHelp contains help links
|
||||
type ErrorHelp struct {
|
||||
Type string `json:"@type"`
|
||||
Links []HelpLink `json:"links,omitempty"`
|
||||
}
|
||||
|
||||
// HelpLink represents a help link
|
||||
type HelpLink struct {
|
||||
Description string `json:"description"`
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
// ParseError parses a Google API error response and extracts key information
|
||||
func ParseError(body string) (*ErrorResponse, error) {
|
||||
var errResp ErrorResponse
|
||||
if err := json.Unmarshal([]byte(body), &errResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse error response: %w", err)
|
||||
}
|
||||
return &errResp, nil
|
||||
}
|
||||
|
||||
// ExtractActivationURL extracts the API activation URL from error details
|
||||
func ExtractActivationURL(body string) string {
|
||||
var errResp ErrorResponse
|
||||
if err := json.Unmarshal([]byte(body), &errResp); err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Check error details for activation URL
|
||||
for _, detailRaw := range errResp.Error.Details {
|
||||
// Parse as ErrorDetailInfo
|
||||
var info ErrorDetailInfo
|
||||
if err := json.Unmarshal(detailRaw, &info); err == nil {
|
||||
if info.Metadata != nil {
|
||||
if activationURL, ok := info.Metadata["activationUrl"]; ok && activationURL != "" {
|
||||
return activationURL
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Parse as ErrorHelp
|
||||
var help ErrorHelp
|
||||
if err := json.Unmarshal(detailRaw, &help); err == nil {
|
||||
for _, link := range help.Links {
|
||||
if strings.Contains(link.Description, "activation") ||
|
||||
strings.Contains(link.Description, "API activation") ||
|
||||
strings.Contains(link.URL, "/apis/api/") {
|
||||
return link.URL
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// IsServiceDisabledError checks if the error is a SERVICE_DISABLED error
|
||||
func IsServiceDisabledError(body string) bool {
|
||||
var errResp ErrorResponse
|
||||
if err := json.Unmarshal([]byte(body), &errResp); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if it's a 403 PERMISSION_DENIED with SERVICE_DISABLED reason
|
||||
if errResp.Error.Code != 403 || errResp.Error.Status != "PERMISSION_DENIED" {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, detailRaw := range errResp.Error.Details {
|
||||
var info ErrorDetailInfo
|
||||
if err := json.Unmarshal(detailRaw, &info); err == nil {
|
||||
if info.Reason == "SERVICE_DISABLED" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
143
backend/internal/pkg/googleapi/error_test.go
Normal file
143
backend/internal/pkg/googleapi/error_test.go
Normal file
@@ -0,0 +1,143 @@
|
||||
package googleapi
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestExtractActivationURL(t *testing.T) {
|
||||
// Test case from the user's error message
|
||||
errorBody := `{
|
||||
"error": {
|
||||
"code": 403,
|
||||
"message": "Gemini for Google Cloud API has not been used in project project-6eca5881-ab73-4736-843 before or it is disabled. Enable it by visiting https://console.developers.google.com/apis/api/cloudaicompanion.googleapis.com/overview?project=project-6eca5881-ab73-4736-843 then retry. If you enabled this API recently, wait a few minutes for the action to propagate to our systems and retry.",
|
||||
"status": "PERMISSION_DENIED",
|
||||
"details": [
|
||||
{
|
||||
"@type": "type.googleapis.com/google.rpc.ErrorInfo",
|
||||
"reason": "SERVICE_DISABLED",
|
||||
"domain": "googleapis.com",
|
||||
"metadata": {
|
||||
"service": "cloudaicompanion.googleapis.com",
|
||||
"activationUrl": "https://console.developers.google.com/apis/api/cloudaicompanion.googleapis.com/overview?project=project-6eca5881-ab73-4736-843",
|
||||
"consumer": "projects/project-6eca5881-ab73-4736-843",
|
||||
"serviceTitle": "Gemini for Google Cloud API",
|
||||
"containerInfo": "project-6eca5881-ab73-4736-843"
|
||||
}
|
||||
},
|
||||
{
|
||||
"@type": "type.googleapis.com/google.rpc.LocalizedMessage",
|
||||
"locale": "en-US",
|
||||
"message": "Gemini for Google Cloud API has not been used in project project-6eca5881-ab73-4736-843 before or it is disabled. Enable it by visiting https://console.developers.google.com/apis/api/cloudaicompanion.googleapis.com/overview?project=project-6eca5881-ab73-4736-843 then retry. If you enabled this API recently, wait a few minutes for the action to propagate to our systems and retry."
|
||||
},
|
||||
{
|
||||
"@type": "type.googleapis.com/google.rpc.Help",
|
||||
"links": [
|
||||
{
|
||||
"description": "Google developers console API activation",
|
||||
"url": "https://console.developers.google.com/apis/api/cloudaicompanion.googleapis.com/overview?project=project-6eca5881-ab73-4736-843"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
}`
|
||||
|
||||
activationURL := ExtractActivationURL(errorBody)
|
||||
expectedURL := "https://console.developers.google.com/apis/api/cloudaicompanion.googleapis.com/overview?project=project-6eca5881-ab73-4736-843"
|
||||
|
||||
if activationURL != expectedURL {
|
||||
t.Errorf("Expected activation URL %s, got %s", expectedURL, activationURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsServiceDisabledError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "SERVICE_DISABLED error",
|
||||
body: `{
|
||||
"error": {
|
||||
"code": 403,
|
||||
"status": "PERMISSION_DENIED",
|
||||
"details": [
|
||||
{
|
||||
"@type": "type.googleapis.com/google.rpc.ErrorInfo",
|
||||
"reason": "SERVICE_DISABLED"
|
||||
}
|
||||
]
|
||||
}
|
||||
}`,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Other 403 error",
|
||||
body: `{
|
||||
"error": {
|
||||
"code": 403,
|
||||
"status": "PERMISSION_DENIED",
|
||||
"details": [
|
||||
{
|
||||
"@type": "type.googleapis.com/google.rpc.ErrorInfo",
|
||||
"reason": "OTHER_REASON"
|
||||
}
|
||||
]
|
||||
}
|
||||
}`,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "404 error",
|
||||
body: `{
|
||||
"error": {
|
||||
"code": 404,
|
||||
"status": "NOT_FOUND"
|
||||
}
|
||||
}`,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid JSON",
|
||||
body: `invalid json`,
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := IsServiceDisabledError(tt.body)
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected %v, got %v", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseError(t *testing.T) {
|
||||
errorBody := `{
|
||||
"error": {
|
||||
"code": 403,
|
||||
"message": "API not enabled",
|
||||
"status": "PERMISSION_DENIED"
|
||||
}
|
||||
}`
|
||||
|
||||
errResp, err := ParseError(errorBody)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse error: %v", err)
|
||||
}
|
||||
|
||||
if errResp.Error.Code != 403 {
|
||||
t.Errorf("Expected code 403, got %d", errResp.Error.Code)
|
||||
}
|
||||
|
||||
if errResp.Error.Status != "PERMISSION_DENIED" {
|
||||
t.Errorf("Expected status PERMISSION_DENIED, got %s", errResp.Error.Status)
|
||||
}
|
||||
|
||||
if errResp.Error.Message != "API not enabled" {
|
||||
t.Errorf("Expected message 'API not enabled', got %s", errResp.Error.Message)
|
||||
}
|
||||
}
|
||||
@@ -15,6 +15,8 @@ type Model struct {
|
||||
|
||||
// DefaultModels OpenAI models list
|
||||
var DefaultModels = []Model{
|
||||
{ID: "gpt-5.3", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3"},
|
||||
{ID: "gpt-5.3-codex", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3 Codex"},
|
||||
{ID: "gpt-5.2", Object: "model", Created: 1733875200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2"},
|
||||
{ID: "gpt-5.2-codex", Object: "model", Created: 1733011200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2 Codex"},
|
||||
{ID: "gpt-5.1-codex-max", Object: "model", Created: 1730419200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1 Codex Max"},
|
||||
|
||||
@@ -1089,8 +1089,9 @@ func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates m
|
||||
result, err := client.ExecContext(
|
||||
ctx,
|
||||
"UPDATE accounts SET extra = COALESCE(extra, '{}'::jsonb) || $1::jsonb, updated_at = NOW() WHERE id = $2 AND deleted_at IS NULL",
|
||||
payload, id,
|
||||
string(payload), id,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
128
backend/internal/repository/error_passthrough_cache.go
Normal file
128
backend/internal/repository/error_passthrough_cache.go
Normal file
@@ -0,0 +1,128 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const (
|
||||
errorPassthroughCacheKey = "error_passthrough_rules"
|
||||
errorPassthroughPubSubKey = "error_passthrough_rules_updated"
|
||||
errorPassthroughCacheTTL = 24 * time.Hour
|
||||
)
|
||||
|
||||
type errorPassthroughCache struct {
|
||||
rdb *redis.Client
|
||||
localCache []*model.ErrorPassthroughRule
|
||||
localMu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewErrorPassthroughCache 创建错误透传规则缓存
|
||||
func NewErrorPassthroughCache(rdb *redis.Client) service.ErrorPassthroughCache {
|
||||
return &errorPassthroughCache{
|
||||
rdb: rdb,
|
||||
}
|
||||
}
|
||||
|
||||
// Get 从缓存获取规则列表
|
||||
func (c *errorPassthroughCache) Get(ctx context.Context) ([]*model.ErrorPassthroughRule, bool) {
|
||||
// 先检查本地缓存
|
||||
c.localMu.RLock()
|
||||
if c.localCache != nil {
|
||||
rules := c.localCache
|
||||
c.localMu.RUnlock()
|
||||
return rules, true
|
||||
}
|
||||
c.localMu.RUnlock()
|
||||
|
||||
// 从 Redis 获取
|
||||
data, err := c.rdb.Get(ctx, errorPassthroughCacheKey).Bytes()
|
||||
if err != nil {
|
||||
if err != redis.Nil {
|
||||
log.Printf("[ErrorPassthroughCache] Failed to get from Redis: %v", err)
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
var rules []*model.ErrorPassthroughRule
|
||||
if err := json.Unmarshal(data, &rules); err != nil {
|
||||
log.Printf("[ErrorPassthroughCache] Failed to unmarshal rules: %v", err)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// 更新本地缓存
|
||||
c.localMu.Lock()
|
||||
c.localCache = rules
|
||||
c.localMu.Unlock()
|
||||
|
||||
return rules, true
|
||||
}
|
||||
|
||||
// Set 设置缓存
|
||||
func (c *errorPassthroughCache) Set(ctx context.Context, rules []*model.ErrorPassthroughRule) error {
|
||||
data, err := json.Marshal(rules)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := c.rdb.Set(ctx, errorPassthroughCacheKey, data, errorPassthroughCacheTTL).Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 更新本地缓存
|
||||
c.localMu.Lock()
|
||||
c.localCache = rules
|
||||
c.localMu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Invalidate 使缓存失效
|
||||
func (c *errorPassthroughCache) Invalidate(ctx context.Context) error {
|
||||
// 清除本地缓存
|
||||
c.localMu.Lock()
|
||||
c.localCache = nil
|
||||
c.localMu.Unlock()
|
||||
|
||||
// 清除 Redis 缓存
|
||||
return c.rdb.Del(ctx, errorPassthroughCacheKey).Err()
|
||||
}
|
||||
|
||||
// NotifyUpdate 通知其他实例刷新缓存
|
||||
func (c *errorPassthroughCache) NotifyUpdate(ctx context.Context) error {
|
||||
return c.rdb.Publish(ctx, errorPassthroughPubSubKey, "refresh").Err()
|
||||
}
|
||||
|
||||
// SubscribeUpdates 订阅缓存更新通知
|
||||
func (c *errorPassthroughCache) SubscribeUpdates(ctx context.Context, handler func()) {
|
||||
go func() {
|
||||
sub := c.rdb.Subscribe(ctx, errorPassthroughPubSubKey)
|
||||
defer func() { _ = sub.Close() }()
|
||||
|
||||
ch := sub.Channel()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case msg := <-ch:
|
||||
if msg == nil {
|
||||
return
|
||||
}
|
||||
// 清除本地缓存,下次访问时会从 Redis 或数据库重新加载
|
||||
c.localMu.Lock()
|
||||
c.localCache = nil
|
||||
c.localMu.Unlock()
|
||||
|
||||
// 调用处理函数
|
||||
handler()
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
178
backend/internal/repository/error_passthrough_repo.go
Normal file
178
backend/internal/repository/error_passthrough_repo.go
Normal file
@@ -0,0 +1,178 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
type errorPassthroughRepository struct {
|
||||
client *ent.Client
|
||||
}
|
||||
|
||||
// NewErrorPassthroughRepository 创建错误透传规则仓库
|
||||
func NewErrorPassthroughRepository(client *ent.Client) service.ErrorPassthroughRepository {
|
||||
return &errorPassthroughRepository{client: client}
|
||||
}
|
||||
|
||||
// List 获取所有规则
|
||||
func (r *errorPassthroughRepository) List(ctx context.Context) ([]*model.ErrorPassthroughRule, error) {
|
||||
rules, err := r.client.ErrorPassthroughRule.Query().
|
||||
Order(ent.Asc(errorpassthroughrule.FieldPriority)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := make([]*model.ErrorPassthroughRule, len(rules))
|
||||
for i, rule := range rules {
|
||||
result[i] = r.toModel(rule)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetByID 根据 ID 获取规则
|
||||
func (r *errorPassthroughRepository) GetByID(ctx context.Context, id int64) (*model.ErrorPassthroughRule, error) {
|
||||
rule, err := r.client.ErrorPassthroughRule.Get(ctx, id)
|
||||
if err != nil {
|
||||
if ent.IsNotFound(err) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return r.toModel(rule), nil
|
||||
}
|
||||
|
||||
// Create 创建规则
|
||||
func (r *errorPassthroughRepository) Create(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) {
|
||||
builder := r.client.ErrorPassthroughRule.Create().
|
||||
SetName(rule.Name).
|
||||
SetEnabled(rule.Enabled).
|
||||
SetPriority(rule.Priority).
|
||||
SetMatchMode(rule.MatchMode).
|
||||
SetPassthroughCode(rule.PassthroughCode).
|
||||
SetPassthroughBody(rule.PassthroughBody)
|
||||
|
||||
if len(rule.ErrorCodes) > 0 {
|
||||
builder.SetErrorCodes(rule.ErrorCodes)
|
||||
}
|
||||
if len(rule.Keywords) > 0 {
|
||||
builder.SetKeywords(rule.Keywords)
|
||||
}
|
||||
if len(rule.Platforms) > 0 {
|
||||
builder.SetPlatforms(rule.Platforms)
|
||||
}
|
||||
if rule.ResponseCode != nil {
|
||||
builder.SetResponseCode(*rule.ResponseCode)
|
||||
}
|
||||
if rule.CustomMessage != nil {
|
||||
builder.SetCustomMessage(*rule.CustomMessage)
|
||||
}
|
||||
if rule.Description != nil {
|
||||
builder.SetDescription(*rule.Description)
|
||||
}
|
||||
|
||||
created, err := builder.Save(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return r.toModel(created), nil
|
||||
}
|
||||
|
||||
// Update 更新规则
|
||||
func (r *errorPassthroughRepository) Update(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) {
|
||||
builder := r.client.ErrorPassthroughRule.UpdateOneID(rule.ID).
|
||||
SetName(rule.Name).
|
||||
SetEnabled(rule.Enabled).
|
||||
SetPriority(rule.Priority).
|
||||
SetMatchMode(rule.MatchMode).
|
||||
SetPassthroughCode(rule.PassthroughCode).
|
||||
SetPassthroughBody(rule.PassthroughBody)
|
||||
|
||||
// 处理可选字段
|
||||
if len(rule.ErrorCodes) > 0 {
|
||||
builder.SetErrorCodes(rule.ErrorCodes)
|
||||
} else {
|
||||
builder.ClearErrorCodes()
|
||||
}
|
||||
if len(rule.Keywords) > 0 {
|
||||
builder.SetKeywords(rule.Keywords)
|
||||
} else {
|
||||
builder.ClearKeywords()
|
||||
}
|
||||
if len(rule.Platforms) > 0 {
|
||||
builder.SetPlatforms(rule.Platforms)
|
||||
} else {
|
||||
builder.ClearPlatforms()
|
||||
}
|
||||
if rule.ResponseCode != nil {
|
||||
builder.SetResponseCode(*rule.ResponseCode)
|
||||
} else {
|
||||
builder.ClearResponseCode()
|
||||
}
|
||||
if rule.CustomMessage != nil {
|
||||
builder.SetCustomMessage(*rule.CustomMessage)
|
||||
} else {
|
||||
builder.ClearCustomMessage()
|
||||
}
|
||||
if rule.Description != nil {
|
||||
builder.SetDescription(*rule.Description)
|
||||
} else {
|
||||
builder.ClearDescription()
|
||||
}
|
||||
|
||||
updated, err := builder.Save(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return r.toModel(updated), nil
|
||||
}
|
||||
|
||||
// Delete 删除规则
|
||||
func (r *errorPassthroughRepository) Delete(ctx context.Context, id int64) error {
|
||||
return r.client.ErrorPassthroughRule.DeleteOneID(id).Exec(ctx)
|
||||
}
|
||||
|
||||
// toModel 将 Ent 实体转换为服务模型
|
||||
func (r *errorPassthroughRepository) toModel(e *ent.ErrorPassthroughRule) *model.ErrorPassthroughRule {
|
||||
rule := &model.ErrorPassthroughRule{
|
||||
ID: int64(e.ID),
|
||||
Name: e.Name,
|
||||
Enabled: e.Enabled,
|
||||
Priority: e.Priority,
|
||||
ErrorCodes: e.ErrorCodes,
|
||||
Keywords: e.Keywords,
|
||||
MatchMode: e.MatchMode,
|
||||
Platforms: e.Platforms,
|
||||
PassthroughCode: e.PassthroughCode,
|
||||
PassthroughBody: e.PassthroughBody,
|
||||
CreatedAt: e.CreatedAt,
|
||||
UpdatedAt: e.UpdatedAt,
|
||||
}
|
||||
|
||||
if e.ResponseCode != nil {
|
||||
rule.ResponseCode = e.ResponseCode
|
||||
}
|
||||
if e.CustomMessage != nil {
|
||||
rule.CustomMessage = e.CustomMessage
|
||||
}
|
||||
if e.Description != nil {
|
||||
rule.Description = e.Description
|
||||
}
|
||||
|
||||
// 确保切片不为 nil
|
||||
if rule.ErrorCodes == nil {
|
||||
rule.ErrorCodes = []int{}
|
||||
}
|
||||
if rule.Keywords == nil {
|
||||
rule.Keywords = []string{}
|
||||
}
|
||||
if rule.Platforms == nil {
|
||||
rule.Platforms = []string{}
|
||||
}
|
||||
|
||||
return rule
|
||||
}
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/imroc/req/v3"
|
||||
@@ -38,9 +39,20 @@ func (c *geminiCliCodeAssistClient) LoadCodeAssist(ctx context.Context, accessTo
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
if !resp.IsSuccessState() {
|
||||
body := geminicli.SanitizeBodyForLogs(resp.String())
|
||||
fmt.Printf("[CodeAssist] LoadCodeAssist failed: status %d, body: %s\n", resp.StatusCode, body)
|
||||
return nil, fmt.Errorf("loadCodeAssist failed: status %d, body: %s", resp.StatusCode, body)
|
||||
body := resp.String()
|
||||
sanitizedBody := geminicli.SanitizeBodyForLogs(body)
|
||||
fmt.Printf("[CodeAssist] LoadCodeAssist failed: status %d, body: %s\n", resp.StatusCode, sanitizedBody)
|
||||
|
||||
// Check if this is a SERVICE_DISABLED error and extract activation URL
|
||||
if googleapi.IsServiceDisabledError(body) {
|
||||
activationURL := googleapi.ExtractActivationURL(body)
|
||||
if activationURL != "" {
|
||||
return nil, fmt.Errorf("gemini API not enabled for this project, please enable it by visiting: %s\n\nAfter enabling the API, wait a few minutes for the changes to propagate, then try again", activationURL)
|
||||
}
|
||||
return nil, fmt.Errorf("gemini API not enabled for this project, please enable it in the Google Cloud Console at: https://console.cloud.google.com/apis/library/cloudaicompanion.googleapis.com")
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("loadCodeAssist failed: status %d, body: %s", resp.StatusCode, sanitizedBody)
|
||||
}
|
||||
fmt.Printf("[CodeAssist] LoadCodeAssist success: status %d, response: %+v\n", resp.StatusCode, out)
|
||||
return &out, nil
|
||||
@@ -67,9 +79,20 @@ func (c *geminiCliCodeAssistClient) OnboardUser(ctx context.Context, accessToken
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
if !resp.IsSuccessState() {
|
||||
body := geminicli.SanitizeBodyForLogs(resp.String())
|
||||
fmt.Printf("[CodeAssist] OnboardUser failed: status %d, body: %s\n", resp.StatusCode, body)
|
||||
return nil, fmt.Errorf("onboardUser failed: status %d, body: %s", resp.StatusCode, body)
|
||||
body := resp.String()
|
||||
sanitizedBody := geminicli.SanitizeBodyForLogs(body)
|
||||
fmt.Printf("[CodeAssist] OnboardUser failed: status %d, body: %s\n", resp.StatusCode, sanitizedBody)
|
||||
|
||||
// Check if this is a SERVICE_DISABLED error and extract activation URL
|
||||
if googleapi.IsServiceDisabledError(body) {
|
||||
activationURL := googleapi.ExtractActivationURL(body)
|
||||
if activationURL != "" {
|
||||
return nil, fmt.Errorf("gemini API not enabled for this project, please enable it by visiting: %s\n\nAfter enabling the API, wait a few minutes for the changes to propagate, then try again", activationURL)
|
||||
}
|
||||
return nil, fmt.Errorf("gemini API not enabled for this project, please enable it in the Google Cloud Console at: https://console.cloud.google.com/apis/library/cloudaicompanion.googleapis.com")
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("onboardUser failed: status %d, body: %s", resp.StatusCode, sanitizedBody)
|
||||
}
|
||||
fmt.Printf("[CodeAssist] OnboardUser success: status %d, response: %+v\n", resp.StatusCode, out)
|
||||
return &out, nil
|
||||
|
||||
@@ -60,6 +60,25 @@ func (r *proxyRepository) GetByID(ctx context.Context, id int64) (*service.Proxy
|
||||
return proxyEntityToService(m), nil
|
||||
}
|
||||
|
||||
func (r *proxyRepository) ListByIDs(ctx context.Context, ids []int64) ([]service.Proxy, error) {
|
||||
if len(ids) == 0 {
|
||||
return []service.Proxy{}, nil
|
||||
}
|
||||
|
||||
proxies, err := r.client.Proxy.Query().
|
||||
Where(proxy.IDIn(ids...)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
out := make([]service.Proxy, 0, len(proxies))
|
||||
for i := range proxies {
|
||||
out = append(out, *proxyEntityToService(proxies[i]))
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *proxyRepository) Update(ctx context.Context, proxyIn *service.Proxy) error {
|
||||
builder := r.client.Proxy.UpdateOneID(proxyIn.ID).
|
||||
SetName(proxyIn.Name).
|
||||
|
||||
158
backend/internal/repository/refresh_token_cache.go
Normal file
158
backend/internal/repository/refresh_token_cache.go
Normal file
@@ -0,0 +1,158 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const (
|
||||
refreshTokenKeyPrefix = "refresh_token:"
|
||||
userRefreshTokensPrefix = "user_refresh_tokens:"
|
||||
tokenFamilyPrefix = "token_family:"
|
||||
)
|
||||
|
||||
// refreshTokenKey generates the Redis key for a refresh token.
|
||||
func refreshTokenKey(tokenHash string) string {
|
||||
return refreshTokenKeyPrefix + tokenHash
|
||||
}
|
||||
|
||||
// userRefreshTokensKey generates the Redis key for user's token set.
|
||||
func userRefreshTokensKey(userID int64) string {
|
||||
return fmt.Sprintf("%s%d", userRefreshTokensPrefix, userID)
|
||||
}
|
||||
|
||||
// tokenFamilyKey generates the Redis key for token family set.
|
||||
func tokenFamilyKey(familyID string) string {
|
||||
return tokenFamilyPrefix + familyID
|
||||
}
|
||||
|
||||
type refreshTokenCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
// NewRefreshTokenCache creates a new RefreshTokenCache implementation.
|
||||
func NewRefreshTokenCache(rdb *redis.Client) service.RefreshTokenCache {
|
||||
return &refreshTokenCache{rdb: rdb}
|
||||
}
|
||||
|
||||
func (c *refreshTokenCache) StoreRefreshToken(ctx context.Context, tokenHash string, data *service.RefreshTokenData, ttl time.Duration) error {
|
||||
key := refreshTokenKey(tokenHash)
|
||||
val, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal refresh token data: %w", err)
|
||||
}
|
||||
return c.rdb.Set(ctx, key, val, ttl).Err()
|
||||
}
|
||||
|
||||
func (c *refreshTokenCache) GetRefreshToken(ctx context.Context, tokenHash string) (*service.RefreshTokenData, error) {
|
||||
key := refreshTokenKey(tokenHash)
|
||||
val, err := c.rdb.Get(ctx, key).Result()
|
||||
if err != nil {
|
||||
if err == redis.Nil {
|
||||
return nil, service.ErrRefreshTokenNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
var data service.RefreshTokenData
|
||||
if err := json.Unmarshal([]byte(val), &data); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal refresh token data: %w", err)
|
||||
}
|
||||
return &data, nil
|
||||
}
|
||||
|
||||
func (c *refreshTokenCache) DeleteRefreshToken(ctx context.Context, tokenHash string) error {
|
||||
key := refreshTokenKey(tokenHash)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
func (c *refreshTokenCache) DeleteUserRefreshTokens(ctx context.Context, userID int64) error {
|
||||
// Get all token hashes for this user
|
||||
tokenHashes, err := c.GetUserTokenHashes(ctx, userID)
|
||||
if err != nil && err != redis.Nil {
|
||||
return fmt.Errorf("get user token hashes: %w", err)
|
||||
}
|
||||
|
||||
if len(tokenHashes) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Build keys to delete
|
||||
keys := make([]string, 0, len(tokenHashes)+1)
|
||||
for _, hash := range tokenHashes {
|
||||
keys = append(keys, refreshTokenKey(hash))
|
||||
}
|
||||
keys = append(keys, userRefreshTokensKey(userID))
|
||||
|
||||
// Delete all keys in a pipeline
|
||||
pipe := c.rdb.Pipeline()
|
||||
for _, key := range keys {
|
||||
pipe.Del(ctx, key)
|
||||
}
|
||||
_, err = pipe.Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *refreshTokenCache) DeleteTokenFamily(ctx context.Context, familyID string) error {
|
||||
// Get all token hashes in this family
|
||||
tokenHashes, err := c.GetFamilyTokenHashes(ctx, familyID)
|
||||
if err != nil && err != redis.Nil {
|
||||
return fmt.Errorf("get family token hashes: %w", err)
|
||||
}
|
||||
|
||||
if len(tokenHashes) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Build keys to delete
|
||||
keys := make([]string, 0, len(tokenHashes)+1)
|
||||
for _, hash := range tokenHashes {
|
||||
keys = append(keys, refreshTokenKey(hash))
|
||||
}
|
||||
keys = append(keys, tokenFamilyKey(familyID))
|
||||
|
||||
// Delete all keys in a pipeline
|
||||
pipe := c.rdb.Pipeline()
|
||||
for _, key := range keys {
|
||||
pipe.Del(ctx, key)
|
||||
}
|
||||
_, err = pipe.Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *refreshTokenCache) AddToUserTokenSet(ctx context.Context, userID int64, tokenHash string, ttl time.Duration) error {
|
||||
key := userRefreshTokensKey(userID)
|
||||
pipe := c.rdb.Pipeline()
|
||||
pipe.SAdd(ctx, key, tokenHash)
|
||||
pipe.Expire(ctx, key, ttl)
|
||||
_, err := pipe.Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *refreshTokenCache) AddToFamilyTokenSet(ctx context.Context, familyID string, tokenHash string, ttl time.Duration) error {
|
||||
key := tokenFamilyKey(familyID)
|
||||
pipe := c.rdb.Pipeline()
|
||||
pipe.SAdd(ctx, key, tokenHash)
|
||||
pipe.Expire(ctx, key, ttl)
|
||||
_, err := pipe.Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *refreshTokenCache) GetUserTokenHashes(ctx context.Context, userID int64) ([]string, error) {
|
||||
key := userRefreshTokensKey(userID)
|
||||
return c.rdb.SMembers(ctx, key).Result()
|
||||
}
|
||||
|
||||
func (c *refreshTokenCache) GetFamilyTokenHashes(ctx context.Context, familyID string) ([]string, error) {
|
||||
key := tokenFamilyKey(familyID)
|
||||
return c.rdb.SMembers(ctx, key).Result()
|
||||
}
|
||||
|
||||
func (c *refreshTokenCache) IsTokenInFamily(ctx context.Context, familyID string, tokenHash string) (bool, error) {
|
||||
key := tokenFamilyKey(familyID)
|
||||
return c.rdb.SIsMember(ctx, key, tokenHash).Result()
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package repository
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
@@ -153,6 +154,21 @@ func NewSessionLimitCache(rdb *redis.Client, defaultIdleTimeoutMinutes int) serv
|
||||
if defaultIdleTimeoutMinutes <= 0 {
|
||||
defaultIdleTimeoutMinutes = 5 // 默认 5 分钟
|
||||
}
|
||||
|
||||
// 预加载 Lua 脚本到 Redis,避免 Pipeline 中出现 NOSCRIPT 错误
|
||||
ctx := context.Background()
|
||||
scripts := []*redis.Script{
|
||||
registerSessionScript,
|
||||
refreshSessionScript,
|
||||
getActiveSessionCountScript,
|
||||
isSessionActiveScript,
|
||||
}
|
||||
for _, script := range scripts {
|
||||
if err := script.Load(ctx, rdb).Err(); err != nil {
|
||||
log.Printf("[SessionLimitCache] Failed to preload Lua script: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return &sessionLimitCache{
|
||||
rdb: rdb,
|
||||
defaultIdleTimeout: time.Duration(defaultIdleTimeoutMinutes) * time.Minute,
|
||||
|
||||
@@ -1125,6 +1125,107 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
// getPerformanceStatsByAPIKey 获取指定 API Key 的 RPM 和 TPM(近5分钟平均值)
|
||||
func (r *usageLogRepository) getPerformanceStatsByAPIKey(ctx context.Context, apiKeyID int64) (rpm, tpm int64, err error) {
|
||||
fiveMinutesAgo := time.Now().Add(-5 * time.Minute)
|
||||
query := `
|
||||
SELECT
|
||||
COUNT(*) as request_count,
|
||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as token_count
|
||||
FROM usage_logs
|
||||
WHERE created_at >= $1 AND api_key_id = $2`
|
||||
args := []any{fiveMinutesAgo, apiKeyID}
|
||||
|
||||
var requestCount int64
|
||||
var tokenCount int64
|
||||
if err := scanSingleRow(ctx, r.sql, query, args, &requestCount, &tokenCount); err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
return requestCount / 5, tokenCount / 5, nil
|
||||
}
|
||||
|
||||
// GetAPIKeyDashboardStats 获取指定 API Key 的仪表盘统计(按 api_key_id 过滤)
|
||||
func (r *usageLogRepository) GetAPIKeyDashboardStats(ctx context.Context, apiKeyID int64) (*UserDashboardStats, error) {
|
||||
stats := &UserDashboardStats{}
|
||||
today := timezone.Today()
|
||||
|
||||
// API Key 维度不需要统计 key 数量,设为 1
|
||||
stats.TotalAPIKeys = 1
|
||||
stats.ActiveAPIKeys = 1
|
||||
|
||||
// 累计 Token 统计
|
||||
totalStatsQuery := `
|
||||
SELECT
|
||||
COUNT(*) as total_requests,
|
||||
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
|
||||
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
|
||||
COALESCE(SUM(cache_creation_tokens), 0) as total_cache_creation_tokens,
|
||||
COALESCE(SUM(cache_read_tokens), 0) as total_cache_read_tokens,
|
||||
COALESCE(SUM(total_cost), 0) as total_cost,
|
||||
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
|
||||
COALESCE(AVG(duration_ms), 0) as avg_duration_ms
|
||||
FROM usage_logs
|
||||
WHERE api_key_id = $1
|
||||
`
|
||||
if err := scanSingleRow(
|
||||
ctx,
|
||||
r.sql,
|
||||
totalStatsQuery,
|
||||
[]any{apiKeyID},
|
||||
&stats.TotalRequests,
|
||||
&stats.TotalInputTokens,
|
||||
&stats.TotalOutputTokens,
|
||||
&stats.TotalCacheCreationTokens,
|
||||
&stats.TotalCacheReadTokens,
|
||||
&stats.TotalCost,
|
||||
&stats.TotalActualCost,
|
||||
&stats.AverageDurationMs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheCreationTokens + stats.TotalCacheReadTokens
|
||||
|
||||
// 今日 Token 统计
|
||||
todayStatsQuery := `
|
||||
SELECT
|
||||
COUNT(*) as today_requests,
|
||||
COALESCE(SUM(input_tokens), 0) as today_input_tokens,
|
||||
COALESCE(SUM(output_tokens), 0) as today_output_tokens,
|
||||
COALESCE(SUM(cache_creation_tokens), 0) as today_cache_creation_tokens,
|
||||
COALESCE(SUM(cache_read_tokens), 0) as today_cache_read_tokens,
|
||||
COALESCE(SUM(total_cost), 0) as today_cost,
|
||||
COALESCE(SUM(actual_cost), 0) as today_actual_cost
|
||||
FROM usage_logs
|
||||
WHERE api_key_id = $1 AND created_at >= $2
|
||||
`
|
||||
if err := scanSingleRow(
|
||||
ctx,
|
||||
r.sql,
|
||||
todayStatsQuery,
|
||||
[]any{apiKeyID, today},
|
||||
&stats.TodayRequests,
|
||||
&stats.TodayInputTokens,
|
||||
&stats.TodayOutputTokens,
|
||||
&stats.TodayCacheCreationTokens,
|
||||
&stats.TodayCacheReadTokens,
|
||||
&stats.TodayCost,
|
||||
&stats.TodayActualCost,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stats.TodayTokens = stats.TodayInputTokens + stats.TodayOutputTokens + stats.TodayCacheCreationTokens + stats.TodayCacheReadTokens
|
||||
|
||||
// 性能指标:RPM 和 TPM(最近5分钟,按 API Key 过滤)
|
||||
rpm, tpm, err := r.getPerformanceStatsByAPIKey(ctx, apiKeyID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stats.Rpm = rpm
|
||||
stats.Tpm = tpm
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
// GetUserUsageTrendByUserID 获取指定用户的使用趋势
|
||||
func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) (results []TrendDataPoint, err error) {
|
||||
dateFormat := "YYYY-MM-DD"
|
||||
|
||||
113
backend/internal/repository/user_group_rate_repo.go
Normal file
113
backend/internal/repository/user_group_rate_repo.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
type userGroupRateRepository struct {
|
||||
sql sqlExecutor
|
||||
}
|
||||
|
||||
// NewUserGroupRateRepository 创建用户专属分组倍率仓储
|
||||
func NewUserGroupRateRepository(sqlDB *sql.DB) service.UserGroupRateRepository {
|
||||
return &userGroupRateRepository{sql: sqlDB}
|
||||
}
|
||||
|
||||
// GetByUserID 获取用户的所有专属分组倍率
|
||||
func (r *userGroupRateRepository) GetByUserID(ctx context.Context, userID int64) (map[int64]float64, error) {
|
||||
query := `SELECT group_id, rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1`
|
||||
rows, err := r.sql.QueryContext(ctx, query, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
result := make(map[int64]float64)
|
||||
for rows.Next() {
|
||||
var groupID int64
|
||||
var rate float64
|
||||
if err := rows.Scan(&groupID, &rate); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result[groupID] = rate
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetByUserAndGroup 获取用户在特定分组的专属倍率
|
||||
func (r *userGroupRateRepository) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) {
|
||||
query := `SELECT rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2`
|
||||
var rate float64
|
||||
err := scanSingleRow(ctx, r.sql, query, []any{userID, groupID}, &rate)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &rate, nil
|
||||
}
|
||||
|
||||
// SyncUserGroupRates 同步用户的分组专属倍率
|
||||
func (r *userGroupRateRepository) SyncUserGroupRates(ctx context.Context, userID int64, rates map[int64]*float64) error {
|
||||
if len(rates) == 0 {
|
||||
// 如果传入空 map,删除该用户的所有专属倍率
|
||||
_, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE user_id = $1`, userID)
|
||||
return err
|
||||
}
|
||||
|
||||
// 分离需要删除和需要 upsert 的记录
|
||||
var toDelete []int64
|
||||
toUpsert := make(map[int64]float64)
|
||||
for groupID, rate := range rates {
|
||||
if rate == nil {
|
||||
toDelete = append(toDelete, groupID)
|
||||
} else {
|
||||
toUpsert[groupID] = *rate
|
||||
}
|
||||
}
|
||||
|
||||
// 删除指定的记录
|
||||
for _, groupID := range toDelete {
|
||||
_, err := r.sql.ExecContext(ctx,
|
||||
`DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2`,
|
||||
userID, groupID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Upsert 记录
|
||||
now := time.Now()
|
||||
for groupID, rate := range toUpsert {
|
||||
_, err := r.sql.ExecContext(ctx, `
|
||||
INSERT INTO user_group_rate_multipliers (user_id, group_id, rate_multiplier, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $4)
|
||||
ON CONFLICT (user_id, group_id) DO UPDATE SET rate_multiplier = $3, updated_at = $4
|
||||
`, userID, groupID, rate, now)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteByGroupID 删除指定分组的所有用户专属倍率
|
||||
func (r *userGroupRateRepository) DeleteByGroupID(ctx context.Context, groupID int64) error {
|
||||
_, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE group_id = $1`, groupID)
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteByUserID 删除指定用户的所有专属倍率
|
||||
func (r *userGroupRateRepository) DeleteByUserID(ctx context.Context, userID int64) error {
|
||||
_, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE user_id = $1`, userID)
|
||||
return err
|
||||
}
|
||||
@@ -66,6 +66,8 @@ var ProviderSet = wire.NewSet(
|
||||
NewUserSubscriptionRepository,
|
||||
NewUserAttributeDefinitionRepository,
|
||||
NewUserAttributeValueRepository,
|
||||
NewUserGroupRateRepository,
|
||||
NewErrorPassthroughRepository,
|
||||
|
||||
// Cache implementations
|
||||
NewGatewayCache,
|
||||
@@ -85,6 +87,8 @@ var ProviderSet = wire.NewSet(
|
||||
NewSchedulerOutboxRepository,
|
||||
NewProxyLatencyCache,
|
||||
NewTotpCache,
|
||||
NewRefreshTokenCache,
|
||||
NewErrorPassthroughCache,
|
||||
|
||||
// Encryptors
|
||||
NewAESEncryptor,
|
||||
|
||||
@@ -593,7 +593,7 @@ func newContractDeps(t *testing.T) *contractDeps {
|
||||
}
|
||||
|
||||
userService := service.NewUserService(userRepo, nil)
|
||||
apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, apiKeyCache, cfg)
|
||||
apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, nil, apiKeyCache, cfg)
|
||||
|
||||
usageRepo := newStubUsageLogRepo()
|
||||
usageService := service.NewUsageService(usageRepo, userRepo, nil, nil)
|
||||
@@ -607,7 +607,7 @@ func newContractDeps(t *testing.T) *contractDeps {
|
||||
settingRepo := newStubSettingRepo()
|
||||
settingService := service.NewSettingService(settingRepo, cfg)
|
||||
|
||||
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil)
|
||||
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil)
|
||||
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil)
|
||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
||||
@@ -1059,6 +1059,10 @@ func (stubProxyRepo) GetByID(ctx context.Context, id int64) (*service.Proxy, err
|
||||
return nil, service.ErrProxyNotFound
|
||||
}
|
||||
|
||||
func (stubProxyRepo) ListByIDs(ctx context.Context, ids []int64) ([]service.Proxy, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (stubProxyRepo) Update(ctx context.Context, proxy *service.Proxy) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
@@ -1610,6 +1614,10 @@ func (r *stubUsageLogRepo) GetUserDashboardStats(ctx context.Context, userID int
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUsageLogRepo) GetAPIKeyDashboardStats(ctx context.Context, apiKeyID int64) (*usagestats.UserDashboardStats, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUsageLogRepo) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) ([]usagestats.TrendDataPoint, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
@@ -14,6 +14,8 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/wire"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/h2c"
|
||||
)
|
||||
|
||||
// ProviderSet 提供服务器层的依赖
|
||||
@@ -56,9 +58,39 @@ func ProvideRouter(
|
||||
|
||||
// ProvideHTTPServer 提供 HTTP 服务器
|
||||
func ProvideHTTPServer(cfg *config.Config, router *gin.Engine) *http.Server {
|
||||
httpHandler := http.Handler(router)
|
||||
|
||||
globalMaxSize := cfg.Server.MaxRequestBodySize
|
||||
if globalMaxSize <= 0 {
|
||||
globalMaxSize = cfg.Gateway.MaxBodySize
|
||||
}
|
||||
if globalMaxSize > 0 {
|
||||
httpHandler = http.MaxBytesHandler(httpHandler, globalMaxSize)
|
||||
log.Printf("Global max request body size: %d bytes (%.2f MB)", globalMaxSize, float64(globalMaxSize)/(1<<20))
|
||||
}
|
||||
|
||||
// 根据配置决定是否启用 H2C
|
||||
if cfg.Server.H2C.Enabled {
|
||||
h2cConfig := cfg.Server.H2C
|
||||
httpHandler = h2c.NewHandler(router, &http2.Server{
|
||||
MaxConcurrentStreams: h2cConfig.MaxConcurrentStreams,
|
||||
IdleTimeout: time.Duration(h2cConfig.IdleTimeout) * time.Second,
|
||||
MaxReadFrameSize: uint32(h2cConfig.MaxReadFrameSize),
|
||||
MaxUploadBufferPerConnection: int32(h2cConfig.MaxUploadBufferPerConnection),
|
||||
MaxUploadBufferPerStream: int32(h2cConfig.MaxUploadBufferPerStream),
|
||||
})
|
||||
log.Printf("HTTP/2 Cleartext (h2c) enabled: max_concurrent_streams=%d, idle_timeout=%ds, max_read_frame_size=%d, max_upload_buffer_per_connection=%d, max_upload_buffer_per_stream=%d",
|
||||
h2cConfig.MaxConcurrentStreams,
|
||||
h2cConfig.IdleTimeout,
|
||||
h2cConfig.MaxReadFrameSize,
|
||||
h2cConfig.MaxUploadBufferPerConnection,
|
||||
h2cConfig.MaxUploadBufferPerStream,
|
||||
)
|
||||
}
|
||||
|
||||
return &http.Server{
|
||||
Addr: cfg.Server.Address(),
|
||||
Handler: router,
|
||||
Handler: httpHandler,
|
||||
// ReadHeaderTimeout: 读取请求头的超时时间,防止慢速请求头攻击
|
||||
ReadHeaderTimeout: time.Duration(cfg.Server.ReadHeaderTimeout) * time.Second,
|
||||
// IdleTimeout: 空闲连接超时时间,释放不活跃的连接资源
|
||||
|
||||
@@ -93,6 +93,7 @@ func newTestAPIKeyService(repo service.APIKeyRepository) *service.APIKeyService
|
||||
nil, // userRepo (unused in GetByKey)
|
||||
nil, // groupRepo
|
||||
nil, // userSubRepo
|
||||
nil, // userGroupRateRepo
|
||||
nil, // cache
|
||||
&config.Config{},
|
||||
)
|
||||
@@ -187,6 +188,7 @@ func TestApiKeyAuthWithSubscriptionGoogleSetsGroupContext(t *testing.T) {
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
&config.Config{RunMode: config.RunModeSimple},
|
||||
)
|
||||
|
||||
|
||||
@@ -59,7 +59,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
|
||||
|
||||
t.Run("simple_mode_bypasses_quota_check", func(t *testing.T) {
|
||||
cfg := &config.Config{RunMode: config.RunModeSimple}
|
||||
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
|
||||
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
|
||||
subscriptionService := service.NewSubscriptionService(nil, &stubUserSubscriptionRepo{}, nil)
|
||||
router := newAuthTestRouter(apiKeyService, subscriptionService, cfg)
|
||||
|
||||
@@ -73,7 +73,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
|
||||
|
||||
t.Run("standard_mode_enforces_quota_check", func(t *testing.T) {
|
||||
cfg := &config.Config{RunMode: config.RunModeStandard}
|
||||
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
|
||||
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
|
||||
|
||||
now := time.Now()
|
||||
sub := &service.UserSubscription{
|
||||
@@ -150,7 +150,7 @@ func TestAPIKeyAuthSetsGroupContext(t *testing.T) {
|
||||
}
|
||||
|
||||
cfg := &config.Config{RunMode: config.RunModeSimple}
|
||||
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
|
||||
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
|
||||
router := gin.New()
|
||||
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, nil, cfg)))
|
||||
router.GET("/t", func(c *gin.Context) {
|
||||
@@ -208,7 +208,7 @@ func TestAPIKeyAuthOverwritesInvalidContextGroup(t *testing.T) {
|
||||
}
|
||||
|
||||
cfg := &config.Config{RunMode: config.RunModeSimple}
|
||||
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
|
||||
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
|
||||
router := gin.New()
|
||||
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, nil, cfg)))
|
||||
|
||||
|
||||
@@ -34,12 +34,16 @@ func Logger() gin.HandlerFunc {
|
||||
// 客户端IP
|
||||
clientIP := c.ClientIP()
|
||||
|
||||
// 日志格式: [时间] 状态码 | 延迟 | IP | 方法 路径
|
||||
log.Printf("[GIN] %v | %3d | %13v | %15s | %-7s %s",
|
||||
// 协议版本
|
||||
protocol := c.Request.Proto
|
||||
|
||||
// 日志格式: [时间] 状态码 | 延迟 | IP | 协议 | 方法 路径
|
||||
log.Printf("[GIN] %v | %3d | %13v | %15s | %-6s | %-7s %s",
|
||||
endTime.Format("2006/01/02 - 15:04:05"),
|
||||
statusCode,
|
||||
latency,
|
||||
clientIP,
|
||||
protocol,
|
||||
method,
|
||||
path,
|
||||
)
|
||||
|
||||
@@ -67,6 +67,9 @@ func RegisterAdminRoutes(
|
||||
|
||||
// 用户属性管理
|
||||
registerUserAttributeRoutes(admin, h)
|
||||
|
||||
// 错误透传规则管理
|
||||
registerErrorPassthroughRoutes(admin, h)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -219,6 +222,8 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
accounts.POST("/:id/schedulable", h.Admin.Account.SetSchedulable)
|
||||
accounts.GET("/:id/models", h.Admin.Account.GetAvailableModels)
|
||||
accounts.POST("/batch", h.Admin.Account.BatchCreate)
|
||||
accounts.GET("/data", h.Admin.Account.ExportData)
|
||||
accounts.POST("/data", h.Admin.Account.ImportData)
|
||||
accounts.POST("/batch-update-credentials", h.Admin.Account.BatchUpdateCredentials)
|
||||
accounts.POST("/batch-refresh-tier", h.Admin.Account.BatchRefreshTier)
|
||||
accounts.POST("/bulk-update", h.Admin.Account.BulkUpdate)
|
||||
@@ -278,6 +283,8 @@ func registerProxyRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
{
|
||||
proxies.GET("", h.Admin.Proxy.List)
|
||||
proxies.GET("/all", h.Admin.Proxy.GetAll)
|
||||
proxies.GET("/data", h.Admin.Proxy.ExportData)
|
||||
proxies.POST("/data", h.Admin.Proxy.ImportData)
|
||||
proxies.GET("/:id", h.Admin.Proxy.GetByID)
|
||||
proxies.POST("", h.Admin.Proxy.Create)
|
||||
proxies.PUT("/:id", h.Admin.Proxy.Update)
|
||||
@@ -387,3 +394,14 @@ func registerUserAttributeRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
attrs.DELETE("/:id", h.Admin.UserAttribute.DeleteDefinition)
|
||||
}
|
||||
}
|
||||
|
||||
func registerErrorPassthroughRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
rules := admin.Group("/error-passthrough-rules")
|
||||
{
|
||||
rules.GET("", h.Admin.ErrorPassthrough.List)
|
||||
rules.GET("/:id", h.Admin.ErrorPassthrough.GetByID)
|
||||
rules.POST("", h.Admin.ErrorPassthrough.Create)
|
||||
rules.PUT("/:id", h.Admin.ErrorPassthrough.Update)
|
||||
rules.DELETE("/:id", h.Admin.ErrorPassthrough.Delete)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -28,6 +28,12 @@ func RegisterAuthRoutes(
|
||||
auth.POST("/login", h.Auth.Login)
|
||||
auth.POST("/login/2fa", h.Auth.Login2FA)
|
||||
auth.POST("/send-verify-code", h.Auth.SendVerifyCode)
|
||||
// Token刷新接口添加速率限制:每分钟最多 30 次(Redis 故障时 fail-close)
|
||||
auth.POST("/refresh", rateLimiter.LimitWithOptions("refresh-token", 30, time.Minute, middleware.RateLimitOptions{
|
||||
FailureMode: middleware.RateLimitFailClose,
|
||||
}), h.Auth.RefreshToken)
|
||||
// 登出接口(公开,允许未认证用户调用以撤销Refresh Token)
|
||||
auth.POST("/logout", h.Auth.Logout)
|
||||
// 优惠码验证接口添加速率限制:每分钟最多 10 次(Redis 故障时 fail-close)
|
||||
auth.POST("/validate-promo-code", rateLimiter.LimitWithOptions("validate-promo", 10, time.Minute, middleware.RateLimitOptions{
|
||||
FailureMode: middleware.RateLimitFailClose,
|
||||
@@ -59,5 +65,7 @@ func RegisterAuthRoutes(
|
||||
authenticated.Use(gin.HandlerFunc(jwtAuth))
|
||||
{
|
||||
authenticated.GET("/auth/me", h.Auth.GetCurrentUser)
|
||||
// 撤销所有会话(需要认证)
|
||||
authenticated.POST("/auth/revoke-all-sessions", h.Auth.RevokeAllSessions)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -49,6 +49,7 @@ func RegisterUserRoutes(
|
||||
groups := authenticated.Group("/groups")
|
||||
{
|
||||
groups.GET("/available", h.APIKey.GetAvailableGroups)
|
||||
groups.GET("/rates", h.APIKey.GetUserGroupRates)
|
||||
}
|
||||
|
||||
// 使用记录
|
||||
|
||||
@@ -41,6 +41,7 @@ type UsageLogRepository interface {
|
||||
|
||||
// User dashboard stats
|
||||
GetUserDashboardStats(ctx context.Context, userID int64) (*usagestats.UserDashboardStats, error)
|
||||
GetAPIKeyDashboardStats(ctx context.Context, apiKeyID int64) (*usagestats.UserDashboardStats, error)
|
||||
GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) ([]usagestats.TrendDataPoint, error)
|
||||
GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) ([]usagestats.ModelStat, error)
|
||||
|
||||
|
||||
@@ -56,6 +56,7 @@ type AdminService interface {
|
||||
GetAllProxies(ctx context.Context) ([]Proxy, error)
|
||||
GetAllProxiesWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error)
|
||||
GetProxy(ctx context.Context, id int64) (*Proxy, error)
|
||||
GetProxiesByIDs(ctx context.Context, ids []int64) ([]Proxy, error)
|
||||
CreateProxy(ctx context.Context, input *CreateProxyInput) (*Proxy, error)
|
||||
UpdateProxy(ctx context.Context, id int64, input *UpdateProxyInput) (*Proxy, error)
|
||||
DeleteProxy(ctx context.Context, id int64) error
|
||||
@@ -93,6 +94,9 @@ type UpdateUserInput struct {
|
||||
Concurrency *int // 使用指针区分"未提供"和"设置为0"
|
||||
Status string
|
||||
AllowedGroups *[]int64 // 使用指针区分"未提供"和"设置为空数组"
|
||||
// GroupRates 用户专属分组倍率配置
|
||||
// map[groupID]*rate,nil 表示删除该分组的专属倍率
|
||||
GroupRates map[int64]*float64
|
||||
}
|
||||
|
||||
type CreateGroupInput struct {
|
||||
@@ -166,6 +170,8 @@ type CreateAccountInput struct {
|
||||
GroupIDs []int64
|
||||
ExpiresAt *int64
|
||||
AutoPauseOnExpired *bool
|
||||
// SkipDefaultGroupBind prevents auto-binding to platform default group when GroupIDs is empty.
|
||||
SkipDefaultGroupBind bool
|
||||
// SkipMixedChannelCheck skips the mixed channel risk check when binding groups.
|
||||
// This should only be set when the caller has explicitly confirmed the risk.
|
||||
SkipMixedChannelCheck bool
|
||||
@@ -293,6 +299,7 @@ type adminServiceImpl struct {
|
||||
proxyRepo ProxyRepository
|
||||
apiKeyRepo APIKeyRepository
|
||||
redeemCodeRepo RedeemCodeRepository
|
||||
userGroupRateRepo UserGroupRateRepository
|
||||
billingCacheService *BillingCacheService
|
||||
proxyProber ProxyExitInfoProber
|
||||
proxyLatencyCache ProxyLatencyCache
|
||||
@@ -307,6 +314,7 @@ func NewAdminService(
|
||||
proxyRepo ProxyRepository,
|
||||
apiKeyRepo APIKeyRepository,
|
||||
redeemCodeRepo RedeemCodeRepository,
|
||||
userGroupRateRepo UserGroupRateRepository,
|
||||
billingCacheService *BillingCacheService,
|
||||
proxyProber ProxyExitInfoProber,
|
||||
proxyLatencyCache ProxyLatencyCache,
|
||||
@@ -319,6 +327,7 @@ func NewAdminService(
|
||||
proxyRepo: proxyRepo,
|
||||
apiKeyRepo: apiKeyRepo,
|
||||
redeemCodeRepo: redeemCodeRepo,
|
||||
userGroupRateRepo: userGroupRateRepo,
|
||||
billingCacheService: billingCacheService,
|
||||
proxyProber: proxyProber,
|
||||
proxyLatencyCache: proxyLatencyCache,
|
||||
@@ -333,11 +342,35 @@ func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, fi
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
// 批量加载用户专属分组倍率
|
||||
if s.userGroupRateRepo != nil && len(users) > 0 {
|
||||
for i := range users {
|
||||
rates, err := s.userGroupRateRepo.GetByUserID(ctx, users[i].ID)
|
||||
if err != nil {
|
||||
log.Printf("failed to load user group rates: user_id=%d err=%v", users[i].ID, err)
|
||||
continue
|
||||
}
|
||||
users[i].GroupRates = rates
|
||||
}
|
||||
}
|
||||
return users, result.Total, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error) {
|
||||
return s.userRepo.GetByID(ctx, id)
|
||||
user, err := s.userRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 加载用户专属分组倍率
|
||||
if s.userGroupRateRepo != nil {
|
||||
rates, err := s.userGroupRateRepo.GetByUserID(ctx, id)
|
||||
if err != nil {
|
||||
log.Printf("failed to load user group rates: user_id=%d err=%v", id, err)
|
||||
} else {
|
||||
user.GroupRates = rates
|
||||
}
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInput) (*User, error) {
|
||||
@@ -406,6 +439,14 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
|
||||
if err := s.userRepo.Update(ctx, user); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 同步用户专属分组倍率
|
||||
if input.GroupRates != nil && s.userGroupRateRepo != nil {
|
||||
if err := s.userGroupRateRepo.SyncUserGroupRates(ctx, user.ID, input.GroupRates); err != nil {
|
||||
log.Printf("failed to sync user group rates: user_id=%d err=%v", user.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
if s.authCacheInvalidator != nil {
|
||||
if user.Concurrency != oldConcurrency || user.Status != oldStatus || user.Role != oldRole {
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, user.ID)
|
||||
@@ -941,6 +982,7 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// 注意:user_group_rate_multipliers 表通过外键 ON DELETE CASCADE 自动清理
|
||||
|
||||
// 事务成功后,异步失效受影响用户的订阅缓存
|
||||
if len(affectedUserIDs) > 0 && s.billingCacheService != nil {
|
||||
@@ -1004,7 +1046,7 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
|
||||
// 绑定分组
|
||||
groupIDs := input.GroupIDs
|
||||
// 如果没有指定分组,自动绑定对应平台的默认分组
|
||||
if len(groupIDs) == 0 {
|
||||
if len(groupIDs) == 0 && !input.SkipDefaultGroupBind {
|
||||
defaultGroupName := input.Platform + "-default"
|
||||
groups, err := s.groupRepo.ListActiveByPlatform(ctx, input.Platform)
|
||||
if err == nil {
|
||||
@@ -1344,6 +1386,10 @@ func (s *adminServiceImpl) GetProxy(ctx context.Context, id int64) (*Proxy, erro
|
||||
return s.proxyRepo.GetByID(ctx, id)
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) GetProxiesByIDs(ctx context.Context, ids []int64) ([]Proxy, error) {
|
||||
return s.proxyRepo.ListByIDs(ctx, ids)
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) CreateProxy(ctx context.Context, input *CreateProxyInput) (*Proxy, error) {
|
||||
proxy := &Proxy{
|
||||
Name: input.Name,
|
||||
|
||||
@@ -187,6 +187,10 @@ func (s *proxyRepoStub) GetByID(ctx context.Context, id int64) (*Proxy, error) {
|
||||
panic("unexpected GetByID call")
|
||||
}
|
||||
|
||||
func (s *proxyRepoStub) ListByIDs(ctx context.Context, ids []int64) ([]Proxy, error) {
|
||||
panic("unexpected ListByIDs call")
|
||||
}
|
||||
|
||||
func (s *proxyRepoStub) Update(ctx context.Context, proxy *Proxy) error {
|
||||
panic("unexpected Update call")
|
||||
}
|
||||
|
||||
@@ -1106,7 +1106,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
|
||||
}
|
||||
|
||||
return nil, s.writeMappedClaudeError(c, account, resp.StatusCode, resp.Header.Get("x-request-id"), respBody)
|
||||
@@ -1779,6 +1779,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
||||
// 处理错误响应
|
||||
if resp.StatusCode >= 400 {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
// 尽早关闭原始响应体,释放连接;后续逻辑仍可能需要读取 body,因此用内存副本重新包装。
|
||||
_ = resp.Body.Close()
|
||||
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||||
@@ -1849,10 +1850,8 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: unwrappedForOps}
|
||||
}
|
||||
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
if contentType == "" {
|
||||
contentType = "application/json"
|
||||
}
|
||||
|
||||
@@ -115,15 +115,16 @@ type UpdateAPIKeyRequest struct {
|
||||
|
||||
// APIKeyService API Key服务
|
||||
type APIKeyService struct {
|
||||
apiKeyRepo APIKeyRepository
|
||||
userRepo UserRepository
|
||||
groupRepo GroupRepository
|
||||
userSubRepo UserSubscriptionRepository
|
||||
cache APIKeyCache
|
||||
cfg *config.Config
|
||||
authCacheL1 *ristretto.Cache
|
||||
authCfg apiKeyAuthCacheConfig
|
||||
authGroup singleflight.Group
|
||||
apiKeyRepo APIKeyRepository
|
||||
userRepo UserRepository
|
||||
groupRepo GroupRepository
|
||||
userSubRepo UserSubscriptionRepository
|
||||
userGroupRateRepo UserGroupRateRepository
|
||||
cache APIKeyCache
|
||||
cfg *config.Config
|
||||
authCacheL1 *ristretto.Cache
|
||||
authCfg apiKeyAuthCacheConfig
|
||||
authGroup singleflight.Group
|
||||
}
|
||||
|
||||
// NewAPIKeyService 创建API Key服务实例
|
||||
@@ -132,16 +133,18 @@ func NewAPIKeyService(
|
||||
userRepo UserRepository,
|
||||
groupRepo GroupRepository,
|
||||
userSubRepo UserSubscriptionRepository,
|
||||
userGroupRateRepo UserGroupRateRepository,
|
||||
cache APIKeyCache,
|
||||
cfg *config.Config,
|
||||
) *APIKeyService {
|
||||
svc := &APIKeyService{
|
||||
apiKeyRepo: apiKeyRepo,
|
||||
userRepo: userRepo,
|
||||
groupRepo: groupRepo,
|
||||
userSubRepo: userSubRepo,
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
apiKeyRepo: apiKeyRepo,
|
||||
userRepo: userRepo,
|
||||
groupRepo: groupRepo,
|
||||
userSubRepo: userSubRepo,
|
||||
userGroupRateRepo: userGroupRateRepo,
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
}
|
||||
svc.initAuthCache(cfg)
|
||||
return svc
|
||||
@@ -627,6 +630,19 @@ func (s *APIKeyService) SearchAPIKeys(ctx context.Context, userID int64, keyword
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
// GetUserGroupRates 获取用户的专属分组倍率配置
|
||||
// 返回 map[groupID]rateMultiplier
|
||||
func (s *APIKeyService) GetUserGroupRates(ctx context.Context, userID int64) (map[int64]float64, error) {
|
||||
if s.userGroupRateRepo == nil {
|
||||
return nil, nil
|
||||
}
|
||||
rates, err := s.userGroupRateRepo.GetByUserID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get user group rates: %w", err)
|
||||
}
|
||||
return rates, nil
|
||||
}
|
||||
|
||||
// CheckAPIKeyQuotaAndExpiry checks if the API key is valid for use (not expired, quota not exhausted)
|
||||
// Returns nil if valid, error if invalid
|
||||
func (s *APIKeyService) CheckAPIKeyQuotaAndExpiry(apiKey *APIKey) error {
|
||||
|
||||
@@ -167,7 +167,7 @@ func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) {
|
||||
NegativeTTLSeconds: 30,
|
||||
},
|
||||
}
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
|
||||
|
||||
groupID := int64(9)
|
||||
cacheEntry := &APIKeyAuthCacheEntry{
|
||||
@@ -223,7 +223,7 @@ func TestAPIKeyService_GetByKey_NegativeCache(t *testing.T) {
|
||||
NegativeTTLSeconds: 30,
|
||||
},
|
||||
}
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
|
||||
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
|
||||
return &APIKeyAuthCacheEntry{NotFound: true}, nil
|
||||
}
|
||||
@@ -256,7 +256,7 @@ func TestAPIKeyService_GetByKey_CacheMissStoresL2(t *testing.T) {
|
||||
NegativeTTLSeconds: 30,
|
||||
},
|
||||
}
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
|
||||
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
|
||||
return nil, redis.Nil
|
||||
}
|
||||
@@ -293,7 +293,7 @@ func TestAPIKeyService_GetByKey_UsesL1Cache(t *testing.T) {
|
||||
L1TTLSeconds: 60,
|
||||
},
|
||||
}
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
|
||||
require.NotNil(t, svc.authCacheL1)
|
||||
|
||||
_, err := svc.GetByKey(context.Background(), "k-l1")
|
||||
@@ -320,7 +320,7 @@ func TestAPIKeyService_InvalidateAuthCacheByUserID(t *testing.T) {
|
||||
NegativeTTLSeconds: 30,
|
||||
},
|
||||
}
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
|
||||
|
||||
svc.InvalidateAuthCacheByUserID(context.Background(), 7)
|
||||
require.Len(t, cache.deleteAuthKeys, 2)
|
||||
@@ -338,7 +338,7 @@ func TestAPIKeyService_InvalidateAuthCacheByGroupID(t *testing.T) {
|
||||
L2TTLSeconds: 60,
|
||||
},
|
||||
}
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
|
||||
|
||||
svc.InvalidateAuthCacheByGroupID(context.Background(), 9)
|
||||
require.Len(t, cache.deleteAuthKeys, 2)
|
||||
@@ -356,7 +356,7 @@ func TestAPIKeyService_InvalidateAuthCacheByKey(t *testing.T) {
|
||||
L2TTLSeconds: 60,
|
||||
},
|
||||
}
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
|
||||
|
||||
svc.InvalidateAuthCacheByKey(context.Background(), "k1")
|
||||
require.Len(t, cache.deleteAuthKeys, 1)
|
||||
@@ -375,7 +375,7 @@ func TestAPIKeyService_GetByKey_CachesNegativeOnRepoMiss(t *testing.T) {
|
||||
NegativeTTLSeconds: 30,
|
||||
},
|
||||
}
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
|
||||
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
|
||||
return nil, redis.Nil
|
||||
}
|
||||
@@ -411,7 +411,7 @@ func TestAPIKeyService_GetByKey_SingleflightCollapses(t *testing.T) {
|
||||
Singleflight: true,
|
||||
},
|
||||
}
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
|
||||
|
||||
start := make(chan struct{})
|
||||
wg := sync.WaitGroup{}
|
||||
|
||||
@@ -3,6 +3,7 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -25,8 +26,12 @@ var (
|
||||
ErrEmailReserved = infraerrors.BadRequest("EMAIL_RESERVED", "email is reserved")
|
||||
ErrInvalidToken = infraerrors.Unauthorized("INVALID_TOKEN", "invalid token")
|
||||
ErrTokenExpired = infraerrors.Unauthorized("TOKEN_EXPIRED", "token has expired")
|
||||
ErrAccessTokenExpired = infraerrors.Unauthorized("ACCESS_TOKEN_EXPIRED", "access token has expired")
|
||||
ErrTokenTooLarge = infraerrors.BadRequest("TOKEN_TOO_LARGE", "token too large")
|
||||
ErrTokenRevoked = infraerrors.Unauthorized("TOKEN_REVOKED", "token has been revoked")
|
||||
ErrRefreshTokenInvalid = infraerrors.Unauthorized("REFRESH_TOKEN_INVALID", "invalid refresh token")
|
||||
ErrRefreshTokenExpired = infraerrors.Unauthorized("REFRESH_TOKEN_EXPIRED", "refresh token has expired")
|
||||
ErrRefreshTokenReused = infraerrors.Unauthorized("REFRESH_TOKEN_REUSED", "refresh token has been reused")
|
||||
ErrEmailVerifyRequired = infraerrors.BadRequest("EMAIL_VERIFY_REQUIRED", "email verification is required")
|
||||
ErrRegDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
|
||||
ErrServiceUnavailable = infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable")
|
||||
@@ -37,6 +42,9 @@ var (
|
||||
// maxTokenLength 限制 token 大小,避免超长 header 触发解析时的异常内存分配。
|
||||
const maxTokenLength = 8192
|
||||
|
||||
// refreshTokenPrefix is the prefix for refresh tokens to distinguish them from access tokens.
|
||||
const refreshTokenPrefix = "rt_"
|
||||
|
||||
// JWTClaims JWT载荷数据
|
||||
type JWTClaims struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
@@ -50,6 +58,7 @@ type JWTClaims struct {
|
||||
type AuthService struct {
|
||||
userRepo UserRepository
|
||||
redeemRepo RedeemCodeRepository
|
||||
refreshTokenCache RefreshTokenCache
|
||||
cfg *config.Config
|
||||
settingService *SettingService
|
||||
emailService *EmailService
|
||||
@@ -62,6 +71,7 @@ type AuthService struct {
|
||||
func NewAuthService(
|
||||
userRepo UserRepository,
|
||||
redeemRepo RedeemCodeRepository,
|
||||
refreshTokenCache RefreshTokenCache,
|
||||
cfg *config.Config,
|
||||
settingService *SettingService,
|
||||
emailService *EmailService,
|
||||
@@ -72,6 +82,7 @@ func NewAuthService(
|
||||
return &AuthService{
|
||||
userRepo: userRepo,
|
||||
redeemRepo: redeemRepo,
|
||||
refreshTokenCache: refreshTokenCache,
|
||||
cfg: cfg,
|
||||
settingService: settingService,
|
||||
emailService: emailService,
|
||||
@@ -481,6 +492,100 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
|
||||
return token, user, nil
|
||||
}
|
||||
|
||||
// LoginOrRegisterOAuthWithTokenPair 用于第三方 OAuth/SSO 登录,返回完整的 TokenPair
|
||||
// 与 LoginOrRegisterOAuth 功能相同,但返回 TokenPair 而非单个 token
|
||||
func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, email, username string) (*TokenPair, *User, error) {
|
||||
// 检查 refreshTokenCache 是否可用
|
||||
if s.refreshTokenCache == nil {
|
||||
return nil, nil, errors.New("refresh token cache not configured")
|
||||
}
|
||||
|
||||
email = strings.TrimSpace(email)
|
||||
if email == "" || len(email) > 255 {
|
||||
return nil, nil, infraerrors.BadRequest("INVALID_EMAIL", "invalid email")
|
||||
}
|
||||
if _, err := mail.ParseAddress(email); err != nil {
|
||||
return nil, nil, infraerrors.BadRequest("INVALID_EMAIL", "invalid email")
|
||||
}
|
||||
|
||||
username = strings.TrimSpace(username)
|
||||
if len([]rune(username)) > 100 {
|
||||
username = string([]rune(username)[:100])
|
||||
}
|
||||
|
||||
user, err := s.userRepo.GetByEmail(ctx, email)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrUserNotFound) {
|
||||
// OAuth 首次登录视为注册
|
||||
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
|
||||
return nil, nil, ErrRegDisabled
|
||||
}
|
||||
|
||||
randomPassword, err := randomHexString(32)
|
||||
if err != nil {
|
||||
log.Printf("[Auth] Failed to generate random password for oauth signup: %v", err)
|
||||
return nil, nil, ErrServiceUnavailable
|
||||
}
|
||||
hashedPassword, err := s.HashPassword(randomPassword)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("hash password: %w", err)
|
||||
}
|
||||
|
||||
defaultBalance := s.cfg.Default.UserBalance
|
||||
defaultConcurrency := s.cfg.Default.UserConcurrency
|
||||
if s.settingService != nil {
|
||||
defaultBalance = s.settingService.GetDefaultBalance(ctx)
|
||||
defaultConcurrency = s.settingService.GetDefaultConcurrency(ctx)
|
||||
}
|
||||
|
||||
newUser := &User{
|
||||
Email: email,
|
||||
Username: username,
|
||||
PasswordHash: hashedPassword,
|
||||
Role: RoleUser,
|
||||
Balance: defaultBalance,
|
||||
Concurrency: defaultConcurrency,
|
||||
Status: StatusActive,
|
||||
}
|
||||
|
||||
if err := s.userRepo.Create(ctx, newUser); err != nil {
|
||||
if errors.Is(err, ErrEmailExists) {
|
||||
user, err = s.userRepo.GetByEmail(ctx, email)
|
||||
if err != nil {
|
||||
log.Printf("[Auth] Database error getting user after conflict: %v", err)
|
||||
return nil, nil, ErrServiceUnavailable
|
||||
}
|
||||
} else {
|
||||
log.Printf("[Auth] Database error creating oauth user: %v", err)
|
||||
return nil, nil, ErrServiceUnavailable
|
||||
}
|
||||
} else {
|
||||
user = newUser
|
||||
}
|
||||
} else {
|
||||
log.Printf("[Auth] Database error during oauth login: %v", err)
|
||||
return nil, nil, ErrServiceUnavailable
|
||||
}
|
||||
}
|
||||
|
||||
if !user.IsActive() {
|
||||
return nil, nil, ErrUserNotActive
|
||||
}
|
||||
|
||||
if user.Username == "" && username != "" {
|
||||
user.Username = username
|
||||
if err := s.userRepo.Update(ctx, user); err != nil {
|
||||
log.Printf("[Auth] Failed to update username after oauth login: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
tokenPair, err := s.GenerateTokenPair(ctx, user, "")
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("generate token pair: %w", err)
|
||||
}
|
||||
return tokenPair, user, nil
|
||||
}
|
||||
|
||||
// ValidateToken 验证JWT token并返回用户声明
|
||||
func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) {
|
||||
// 先做长度校验,尽早拒绝异常超长 token,降低 DoS 风险。
|
||||
@@ -539,10 +644,17 @@ func isReservedEmail(email string) bool {
|
||||
return strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain)
|
||||
}
|
||||
|
||||
// GenerateToken 生成JWT token
|
||||
// GenerateToken 生成JWT access token
|
||||
// 使用新的access_token_expire_minutes配置项(如果配置了),否则回退到expire_hour
|
||||
func (s *AuthService) GenerateToken(user *User) (string, error) {
|
||||
now := time.Now()
|
||||
expiresAt := now.Add(time.Duration(s.cfg.JWT.ExpireHour) * time.Hour)
|
||||
var expiresAt time.Time
|
||||
if s.cfg.JWT.AccessTokenExpireMinutes > 0 {
|
||||
expiresAt = now.Add(time.Duration(s.cfg.JWT.AccessTokenExpireMinutes) * time.Minute)
|
||||
} else {
|
||||
// 向后兼容:使用旧的expire_hour配置
|
||||
expiresAt = now.Add(time.Duration(s.cfg.JWT.ExpireHour) * time.Hour)
|
||||
}
|
||||
|
||||
claims := &JWTClaims{
|
||||
UserID: user.ID,
|
||||
@@ -565,6 +677,15 @@ func (s *AuthService) GenerateToken(user *User) (string, error) {
|
||||
return tokenString, nil
|
||||
}
|
||||
|
||||
// GetAccessTokenExpiresIn 返回Access Token的有效期(秒)
|
||||
// 用于前端设置刷新定时器
|
||||
func (s *AuthService) GetAccessTokenExpiresIn() int {
|
||||
if s.cfg.JWT.AccessTokenExpireMinutes > 0 {
|
||||
return s.cfg.JWT.AccessTokenExpireMinutes * 60
|
||||
}
|
||||
return s.cfg.JWT.ExpireHour * 3600
|
||||
}
|
||||
|
||||
// HashPassword 使用bcrypt加密密码
|
||||
func (s *AuthService) HashPassword(password string) (string, error) {
|
||||
hashedBytes, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
@@ -755,6 +876,198 @@ func (s *AuthService) ResetPassword(ctx context.Context, email, token, newPasswo
|
||||
return ErrServiceUnavailable
|
||||
}
|
||||
|
||||
// Also revoke all refresh tokens for this user
|
||||
if err := s.RevokeAllUserSessions(ctx, user.ID); err != nil {
|
||||
log.Printf("[Auth] Failed to revoke refresh tokens for user %d: %v", user.ID, err)
|
||||
// Don't return error - password was already changed successfully
|
||||
}
|
||||
|
||||
log.Printf("[Auth] Password reset successful for user: %s", email)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ==================== Refresh Token Methods ====================
|
||||
|
||||
// TokenPair 包含Access Token和Refresh Token
|
||||
type TokenPair struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresIn int `json:"expires_in"` // Access Token有效期(秒)
|
||||
}
|
||||
|
||||
// GenerateTokenPair 生成Access Token和Refresh Token对
|
||||
// familyID: 可选的Token家族ID,用于Token轮转时保持家族关系
|
||||
func (s *AuthService) GenerateTokenPair(ctx context.Context, user *User, familyID string) (*TokenPair, error) {
|
||||
// 检查 refreshTokenCache 是否可用
|
||||
if s.refreshTokenCache == nil {
|
||||
return nil, errors.New("refresh token cache not configured")
|
||||
}
|
||||
|
||||
// 生成Access Token
|
||||
accessToken, err := s.GenerateToken(user)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate access token: %w", err)
|
||||
}
|
||||
|
||||
// 生成Refresh Token
|
||||
refreshToken, err := s.generateRefreshToken(ctx, user, familyID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate refresh token: %w", err)
|
||||
}
|
||||
|
||||
return &TokenPair{
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
ExpiresIn: s.GetAccessTokenExpiresIn(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// generateRefreshToken 生成并存储Refresh Token
|
||||
func (s *AuthService) generateRefreshToken(ctx context.Context, user *User, familyID string) (string, error) {
|
||||
// 生成随机Token
|
||||
tokenBytes := make([]byte, 32)
|
||||
if _, err := rand.Read(tokenBytes); err != nil {
|
||||
return "", fmt.Errorf("generate random bytes: %w", err)
|
||||
}
|
||||
rawToken := refreshTokenPrefix + hex.EncodeToString(tokenBytes)
|
||||
|
||||
// 计算Token哈希(存储哈希而非原始Token)
|
||||
tokenHash := hashToken(rawToken)
|
||||
|
||||
// 如果没有提供familyID,生成新的
|
||||
if familyID == "" {
|
||||
familyBytes := make([]byte, 16)
|
||||
if _, err := rand.Read(familyBytes); err != nil {
|
||||
return "", fmt.Errorf("generate family id: %w", err)
|
||||
}
|
||||
familyID = hex.EncodeToString(familyBytes)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
ttl := time.Duration(s.cfg.JWT.RefreshTokenExpireDays) * 24 * time.Hour
|
||||
|
||||
data := &RefreshTokenData{
|
||||
UserID: user.ID,
|
||||
TokenVersion: user.TokenVersion,
|
||||
FamilyID: familyID,
|
||||
CreatedAt: now,
|
||||
ExpiresAt: now.Add(ttl),
|
||||
}
|
||||
|
||||
// 存储Token数据
|
||||
if err := s.refreshTokenCache.StoreRefreshToken(ctx, tokenHash, data, ttl); err != nil {
|
||||
return "", fmt.Errorf("store refresh token: %w", err)
|
||||
}
|
||||
|
||||
// 添加到用户Token集合
|
||||
if err := s.refreshTokenCache.AddToUserTokenSet(ctx, user.ID, tokenHash, ttl); err != nil {
|
||||
log.Printf("[Auth] Failed to add token to user set: %v", err)
|
||||
// 不影响主流程
|
||||
}
|
||||
|
||||
// 添加到家族Token集合
|
||||
if err := s.refreshTokenCache.AddToFamilyTokenSet(ctx, familyID, tokenHash, ttl); err != nil {
|
||||
log.Printf("[Auth] Failed to add token to family set: %v", err)
|
||||
// 不影响主流程
|
||||
}
|
||||
|
||||
return rawToken, nil
|
||||
}
|
||||
|
||||
// RefreshTokenPair 使用Refresh Token刷新Token对
|
||||
// 实现Token轮转:每次刷新都会生成新的Refresh Token,旧Token立即失效
|
||||
func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string) (*TokenPair, error) {
|
||||
// 检查 refreshTokenCache 是否可用
|
||||
if s.refreshTokenCache == nil {
|
||||
return nil, ErrRefreshTokenInvalid
|
||||
}
|
||||
|
||||
// 验证Token格式
|
||||
if !strings.HasPrefix(refreshToken, refreshTokenPrefix) {
|
||||
return nil, ErrRefreshTokenInvalid
|
||||
}
|
||||
|
||||
tokenHash := hashToken(refreshToken)
|
||||
|
||||
// 获取Token数据
|
||||
data, err := s.refreshTokenCache.GetRefreshToken(ctx, tokenHash)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrRefreshTokenNotFound) {
|
||||
// Token不存在,可能是已被使用(Token轮转)或已过期
|
||||
log.Printf("[Auth] Refresh token not found, possible reuse attack")
|
||||
return nil, ErrRefreshTokenInvalid
|
||||
}
|
||||
log.Printf("[Auth] Error getting refresh token: %v", err)
|
||||
return nil, ErrServiceUnavailable
|
||||
}
|
||||
|
||||
// 检查Token是否过期
|
||||
if time.Now().After(data.ExpiresAt) {
|
||||
// 删除过期Token
|
||||
_ = s.refreshTokenCache.DeleteRefreshToken(ctx, tokenHash)
|
||||
return nil, ErrRefreshTokenExpired
|
||||
}
|
||||
|
||||
// 获取用户信息
|
||||
user, err := s.userRepo.GetByID(ctx, data.UserID)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrUserNotFound) {
|
||||
// 用户已删除,撤销整个Token家族
|
||||
_ = s.refreshTokenCache.DeleteTokenFamily(ctx, data.FamilyID)
|
||||
return nil, ErrRefreshTokenInvalid
|
||||
}
|
||||
log.Printf("[Auth] Database error getting user for token refresh: %v", err)
|
||||
return nil, ErrServiceUnavailable
|
||||
}
|
||||
|
||||
// 检查用户状态
|
||||
if !user.IsActive() {
|
||||
// 用户被禁用,撤销整个Token家族
|
||||
_ = s.refreshTokenCache.DeleteTokenFamily(ctx, data.FamilyID)
|
||||
return nil, ErrUserNotActive
|
||||
}
|
||||
|
||||
// 检查TokenVersion(密码更改后所有Token失效)
|
||||
if data.TokenVersion != user.TokenVersion {
|
||||
// TokenVersion不匹配,撤销整个Token家族
|
||||
_ = s.refreshTokenCache.DeleteTokenFamily(ctx, data.FamilyID)
|
||||
return nil, ErrTokenRevoked
|
||||
}
|
||||
|
||||
// Token轮转:立即使旧Token失效
|
||||
if err := s.refreshTokenCache.DeleteRefreshToken(ctx, tokenHash); err != nil {
|
||||
log.Printf("[Auth] Failed to delete old refresh token: %v", err)
|
||||
// 继续处理,不影响主流程
|
||||
}
|
||||
|
||||
// 生成新的Token对,保持同一个家族ID
|
||||
return s.GenerateTokenPair(ctx, user, data.FamilyID)
|
||||
}
|
||||
|
||||
// RevokeRefreshToken 撤销单个Refresh Token
|
||||
func (s *AuthService) RevokeRefreshToken(ctx context.Context, refreshToken string) error {
|
||||
if s.refreshTokenCache == nil {
|
||||
return nil // No-op if cache not configured
|
||||
}
|
||||
if !strings.HasPrefix(refreshToken, refreshTokenPrefix) {
|
||||
return ErrRefreshTokenInvalid
|
||||
}
|
||||
|
||||
tokenHash := hashToken(refreshToken)
|
||||
return s.refreshTokenCache.DeleteRefreshToken(ctx, tokenHash)
|
||||
}
|
||||
|
||||
// RevokeAllUserSessions 撤销用户的所有会话(所有Refresh Token)
|
||||
// 用于密码更改或用户主动登出所有设备
|
||||
func (s *AuthService) RevokeAllUserSessions(ctx context.Context, userID int64) error {
|
||||
if s.refreshTokenCache == nil {
|
||||
return nil // No-op if cache not configured
|
||||
}
|
||||
return s.refreshTokenCache.DeleteUserRefreshTokens(ctx, userID)
|
||||
}
|
||||
|
||||
// hashToken 计算Token的SHA256哈希
|
||||
func hashToken(token string) string {
|
||||
hash := sha256.Sum256([]byte(token))
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
@@ -116,6 +116,7 @@ func newAuthService(repo *userRepoStub, settings map[string]string, emailCache E
|
||||
return NewAuthService(
|
||||
repo,
|
||||
nil, // redeemRepo
|
||||
nil, // refreshTokenCache
|
||||
cfg,
|
||||
settingService,
|
||||
emailService,
|
||||
|
||||
300
backend/internal/service/error_passthrough_service.go
Normal file
300
backend/internal/service/error_passthrough_service.go
Normal file
@@ -0,0 +1,300 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
)
|
||||
|
||||
// ErrorPassthroughRepository 定义错误透传规则的数据访问接口
|
||||
type ErrorPassthroughRepository interface {
|
||||
// List 获取所有规则
|
||||
List(ctx context.Context) ([]*model.ErrorPassthroughRule, error)
|
||||
// GetByID 根据 ID 获取规则
|
||||
GetByID(ctx context.Context, id int64) (*model.ErrorPassthroughRule, error)
|
||||
// Create 创建规则
|
||||
Create(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error)
|
||||
// Update 更新规则
|
||||
Update(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error)
|
||||
// Delete 删除规则
|
||||
Delete(ctx context.Context, id int64) error
|
||||
}
|
||||
|
||||
// ErrorPassthroughCache 定义错误透传规则的缓存接口
|
||||
type ErrorPassthroughCache interface {
|
||||
// Get 从缓存获取规则列表
|
||||
Get(ctx context.Context) ([]*model.ErrorPassthroughRule, bool)
|
||||
// Set 设置缓存
|
||||
Set(ctx context.Context, rules []*model.ErrorPassthroughRule) error
|
||||
// Invalidate 使缓存失效
|
||||
Invalidate(ctx context.Context) error
|
||||
// NotifyUpdate 通知其他实例刷新缓存
|
||||
NotifyUpdate(ctx context.Context) error
|
||||
// SubscribeUpdates 订阅缓存更新通知
|
||||
SubscribeUpdates(ctx context.Context, handler func())
|
||||
}
|
||||
|
||||
// ErrorPassthroughService 错误透传规则服务
|
||||
type ErrorPassthroughService struct {
|
||||
repo ErrorPassthroughRepository
|
||||
cache ErrorPassthroughCache
|
||||
|
||||
// 本地内存缓存,用于快速匹配
|
||||
localCache []*model.ErrorPassthroughRule
|
||||
localCacheMu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewErrorPassthroughService 创建错误透传规则服务
|
||||
func NewErrorPassthroughService(
|
||||
repo ErrorPassthroughRepository,
|
||||
cache ErrorPassthroughCache,
|
||||
) *ErrorPassthroughService {
|
||||
svc := &ErrorPassthroughService{
|
||||
repo: repo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
// 启动时加载规则到本地缓存
|
||||
ctx := context.Background()
|
||||
if err := svc.refreshLocalCache(ctx); err != nil {
|
||||
log.Printf("[ErrorPassthroughService] Failed to load rules on startup: %v", err)
|
||||
}
|
||||
|
||||
// 订阅缓存更新通知
|
||||
if cache != nil {
|
||||
cache.SubscribeUpdates(ctx, func() {
|
||||
if err := svc.refreshLocalCache(context.Background()); err != nil {
|
||||
log.Printf("[ErrorPassthroughService] Failed to refresh cache on notification: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
return svc
|
||||
}
|
||||
|
||||
// List 获取所有规则
|
||||
func (s *ErrorPassthroughService) List(ctx context.Context) ([]*model.ErrorPassthroughRule, error) {
|
||||
return s.repo.List(ctx)
|
||||
}
|
||||
|
||||
// GetByID 根据 ID 获取规则
|
||||
func (s *ErrorPassthroughService) GetByID(ctx context.Context, id int64) (*model.ErrorPassthroughRule, error) {
|
||||
return s.repo.GetByID(ctx, id)
|
||||
}
|
||||
|
||||
// Create 创建规则
|
||||
func (s *ErrorPassthroughService) Create(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) {
|
||||
if err := rule.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
created, err := s.repo.Create(ctx, rule)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 刷新缓存
|
||||
s.invalidateAndNotify(ctx)
|
||||
|
||||
return created, nil
|
||||
}
|
||||
|
||||
// Update 更新规则
|
||||
func (s *ErrorPassthroughService) Update(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) {
|
||||
if err := rule.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
updated, err := s.repo.Update(ctx, rule)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 刷新缓存
|
||||
s.invalidateAndNotify(ctx)
|
||||
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
// Delete 删除规则
|
||||
func (s *ErrorPassthroughService) Delete(ctx context.Context, id int64) error {
|
||||
if err := s.repo.Delete(ctx, id); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 刷新缓存
|
||||
s.invalidateAndNotify(ctx)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MatchRule 匹配透传规则
|
||||
// 返回第一个匹配的规则,如果没有匹配则返回 nil
|
||||
func (s *ErrorPassthroughService) MatchRule(platform string, statusCode int, body []byte) *model.ErrorPassthroughRule {
|
||||
rules := s.getCachedRules()
|
||||
if len(rules) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
bodyStr := strings.ToLower(string(body))
|
||||
|
||||
for _, rule := range rules {
|
||||
if !rule.Enabled {
|
||||
continue
|
||||
}
|
||||
if !s.platformMatches(rule, platform) {
|
||||
continue
|
||||
}
|
||||
if s.ruleMatches(rule, statusCode, bodyStr) {
|
||||
return rule
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getCachedRules 获取缓存的规则列表(按优先级排序)
|
||||
func (s *ErrorPassthroughService) getCachedRules() []*model.ErrorPassthroughRule {
|
||||
s.localCacheMu.RLock()
|
||||
rules := s.localCache
|
||||
s.localCacheMu.RUnlock()
|
||||
|
||||
if rules != nil {
|
||||
return rules
|
||||
}
|
||||
|
||||
// 如果本地缓存为空,尝试刷新
|
||||
ctx := context.Background()
|
||||
if err := s.refreshLocalCache(ctx); err != nil {
|
||||
log.Printf("[ErrorPassthroughService] Failed to refresh cache: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
s.localCacheMu.RLock()
|
||||
defer s.localCacheMu.RUnlock()
|
||||
return s.localCache
|
||||
}
|
||||
|
||||
// refreshLocalCache 刷新本地缓存
|
||||
func (s *ErrorPassthroughService) refreshLocalCache(ctx context.Context) error {
|
||||
// 先尝试从 Redis 缓存获取
|
||||
if s.cache != nil {
|
||||
if rules, ok := s.cache.Get(ctx); ok {
|
||||
s.setLocalCache(rules)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// 从数据库加载(repo.List 已按 priority 排序)
|
||||
rules, err := s.repo.List(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 更新 Redis 缓存
|
||||
if s.cache != nil {
|
||||
if err := s.cache.Set(ctx, rules); err != nil {
|
||||
log.Printf("[ErrorPassthroughService] Failed to set cache: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 更新本地缓存(setLocalCache 内部会确保排序)
|
||||
s.setLocalCache(rules)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// setLocalCache 设置本地缓存
|
||||
func (s *ErrorPassthroughService) setLocalCache(rules []*model.ErrorPassthroughRule) {
|
||||
// 按优先级排序
|
||||
sorted := make([]*model.ErrorPassthroughRule, len(rules))
|
||||
copy(sorted, rules)
|
||||
sort.Slice(sorted, func(i, j int) bool {
|
||||
return sorted[i].Priority < sorted[j].Priority
|
||||
})
|
||||
|
||||
s.localCacheMu.Lock()
|
||||
s.localCache = sorted
|
||||
s.localCacheMu.Unlock()
|
||||
}
|
||||
|
||||
// invalidateAndNotify 使缓存失效并通知其他实例
|
||||
func (s *ErrorPassthroughService) invalidateAndNotify(ctx context.Context) {
|
||||
// 刷新本地缓存
|
||||
if err := s.refreshLocalCache(ctx); err != nil {
|
||||
log.Printf("[ErrorPassthroughService] Failed to refresh local cache: %v", err)
|
||||
}
|
||||
|
||||
// 通知其他实例
|
||||
if s.cache != nil {
|
||||
if err := s.cache.NotifyUpdate(ctx); err != nil {
|
||||
log.Printf("[ErrorPassthroughService] Failed to notify cache update: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// platformMatches 检查平台是否匹配
|
||||
func (s *ErrorPassthroughService) platformMatches(rule *model.ErrorPassthroughRule, platform string) bool {
|
||||
// 如果没有配置平台限制,则匹配所有平台
|
||||
if len(rule.Platforms) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
platform = strings.ToLower(platform)
|
||||
for _, p := range rule.Platforms {
|
||||
if strings.ToLower(p) == platform {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// ruleMatches 检查规则是否匹配
|
||||
func (s *ErrorPassthroughService) ruleMatches(rule *model.ErrorPassthroughRule, statusCode int, bodyLower string) bool {
|
||||
hasErrorCodes := len(rule.ErrorCodes) > 0
|
||||
hasKeywords := len(rule.Keywords) > 0
|
||||
|
||||
// 如果没有配置任何条件,不匹配
|
||||
if !hasErrorCodes && !hasKeywords {
|
||||
return false
|
||||
}
|
||||
|
||||
codeMatch := !hasErrorCodes || s.containsInt(rule.ErrorCodes, statusCode)
|
||||
keywordMatch := !hasKeywords || s.containsAnyKeyword(bodyLower, rule.Keywords)
|
||||
|
||||
if rule.MatchMode == model.MatchModeAll {
|
||||
// "all" 模式:所有配置的条件都必须满足
|
||||
return codeMatch && keywordMatch
|
||||
}
|
||||
|
||||
// "any" 模式:任一条件满足即可
|
||||
if hasErrorCodes && hasKeywords {
|
||||
return codeMatch || keywordMatch
|
||||
}
|
||||
return codeMatch && keywordMatch
|
||||
}
|
||||
|
||||
// containsInt 检查切片是否包含指定整数
|
||||
func (s *ErrorPassthroughService) containsInt(slice []int, val int) bool {
|
||||
for _, v := range slice {
|
||||
if v == val {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// containsAnyKeyword 检查字符串是否包含任一关键词(不区分大小写)
|
||||
func (s *ErrorPassthroughService) containsAnyKeyword(bodyLower string, keywords []string) bool {
|
||||
for _, kw := range keywords {
|
||||
if strings.Contains(bodyLower, strings.ToLower(kw)) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
755
backend/internal/service/error_passthrough_service_test.go
Normal file
755
backend/internal/service/error_passthrough_service_test.go
Normal file
@@ -0,0 +1,755 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// mockErrorPassthroughRepo 用于测试的 mock repository
|
||||
type mockErrorPassthroughRepo struct {
|
||||
rules []*model.ErrorPassthroughRule
|
||||
}
|
||||
|
||||
func (m *mockErrorPassthroughRepo) List(ctx context.Context) ([]*model.ErrorPassthroughRule, error) {
|
||||
return m.rules, nil
|
||||
}
|
||||
|
||||
func (m *mockErrorPassthroughRepo) GetByID(ctx context.Context, id int64) (*model.ErrorPassthroughRule, error) {
|
||||
for _, r := range m.rules {
|
||||
if r.ID == id {
|
||||
return r, nil
|
||||
}
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockErrorPassthroughRepo) Create(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) {
|
||||
rule.ID = int64(len(m.rules) + 1)
|
||||
m.rules = append(m.rules, rule)
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
func (m *mockErrorPassthroughRepo) Update(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) {
|
||||
for i, r := range m.rules {
|
||||
if r.ID == rule.ID {
|
||||
m.rules[i] = rule
|
||||
return rule, nil
|
||||
}
|
||||
}
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
func (m *mockErrorPassthroughRepo) Delete(ctx context.Context, id int64) error {
|
||||
for i, r := range m.rules {
|
||||
if r.ID == id {
|
||||
m.rules = append(m.rules[:i], m.rules[i+1:]...)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// newTestService 创建测试用的服务实例
|
||||
func newTestService(rules []*model.ErrorPassthroughRule) *ErrorPassthroughService {
|
||||
repo := &mockErrorPassthroughRepo{rules: rules}
|
||||
svc := &ErrorPassthroughService{
|
||||
repo: repo,
|
||||
cache: nil, // 不使用缓存
|
||||
}
|
||||
// 直接设置本地缓存,避免调用 refreshLocalCache
|
||||
svc.setLocalCache(rules)
|
||||
return svc
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// 测试 ruleMatches 核心匹配逻辑
|
||||
// =============================================================================
|
||||
|
||||
func TestRuleMatches_NoConditions(t *testing.T) {
|
||||
// 没有配置任何条件时,不应该匹配
|
||||
svc := newTestService(nil)
|
||||
rule := &model.ErrorPassthroughRule{
|
||||
Enabled: true,
|
||||
ErrorCodes: []int{},
|
||||
Keywords: []string{},
|
||||
MatchMode: model.MatchModeAny,
|
||||
}
|
||||
|
||||
assert.False(t, svc.ruleMatches(rule, 422, "some error message"),
|
||||
"没有配置条件时不应该匹配")
|
||||
}
|
||||
|
||||
func TestRuleMatches_OnlyErrorCodes_AnyMode(t *testing.T) {
|
||||
svc := newTestService(nil)
|
||||
rule := &model.ErrorPassthroughRule{
|
||||
Enabled: true,
|
||||
ErrorCodes: []int{422, 400},
|
||||
Keywords: []string{},
|
||||
MatchMode: model.MatchModeAny,
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
statusCode int
|
||||
body string
|
||||
expected bool
|
||||
}{
|
||||
{"状态码匹配 422", 422, "any message", true},
|
||||
{"状态码匹配 400", 400, "any message", true},
|
||||
{"状态码不匹配 500", 500, "any message", false},
|
||||
{"状态码不匹配 429", 429, "any message", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := svc.ruleMatches(rule, tt.statusCode, tt.body)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuleMatches_OnlyKeywords_AnyMode(t *testing.T) {
|
||||
svc := newTestService(nil)
|
||||
rule := &model.ErrorPassthroughRule{
|
||||
Enabled: true,
|
||||
ErrorCodes: []int{},
|
||||
Keywords: []string{"context limit", "model not supported"},
|
||||
MatchMode: model.MatchModeAny,
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
statusCode int
|
||||
body string
|
||||
expected bool
|
||||
}{
|
||||
{"关键词匹配 context limit", 500, "error: context limit reached", true},
|
||||
{"关键词匹配 model not supported", 400, "the model not supported here", true},
|
||||
{"关键词不匹配", 422, "some other error", false},
|
||||
// 注意:ruleMatches 接收的 body 参数应该是已经转换为小写的
|
||||
// 实际使用时,MatchRule 会先将 body 转换为小写再传给 ruleMatches
|
||||
{"关键词大小写 - 输入已小写", 500, "context limit exceeded", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 模拟 MatchRule 的行为:先转换为小写
|
||||
bodyLower := strings.ToLower(tt.body)
|
||||
result := svc.ruleMatches(rule, tt.statusCode, bodyLower)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuleMatches_BothConditions_AnyMode(t *testing.T) {
|
||||
// any 模式:错误码 OR 关键词
|
||||
svc := newTestService(nil)
|
||||
rule := &model.ErrorPassthroughRule{
|
||||
Enabled: true,
|
||||
ErrorCodes: []int{422, 400},
|
||||
Keywords: []string{"context limit"},
|
||||
MatchMode: model.MatchModeAny,
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
statusCode int
|
||||
body string
|
||||
expected bool
|
||||
reason string
|
||||
}{
|
||||
{
|
||||
name: "状态码和关键词都匹配",
|
||||
statusCode: 422,
|
||||
body: "context limit reached",
|
||||
expected: true,
|
||||
reason: "both match",
|
||||
},
|
||||
{
|
||||
name: "只有状态码匹配",
|
||||
statusCode: 422,
|
||||
body: "some other error",
|
||||
expected: true,
|
||||
reason: "code matches, keyword doesn't - OR mode should match",
|
||||
},
|
||||
{
|
||||
name: "只有关键词匹配",
|
||||
statusCode: 500,
|
||||
body: "context limit exceeded",
|
||||
expected: true,
|
||||
reason: "keyword matches, code doesn't - OR mode should match",
|
||||
},
|
||||
{
|
||||
name: "都不匹配",
|
||||
statusCode: 500,
|
||||
body: "some other error",
|
||||
expected: false,
|
||||
reason: "neither matches",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := svc.ruleMatches(rule, tt.statusCode, tt.body)
|
||||
assert.Equal(t, tt.expected, result, tt.reason)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuleMatches_BothConditions_AllMode(t *testing.T) {
|
||||
// all 模式:错误码 AND 关键词
|
||||
svc := newTestService(nil)
|
||||
rule := &model.ErrorPassthroughRule{
|
||||
Enabled: true,
|
||||
ErrorCodes: []int{422, 400},
|
||||
Keywords: []string{"context limit"},
|
||||
MatchMode: model.MatchModeAll,
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
statusCode int
|
||||
body string
|
||||
expected bool
|
||||
reason string
|
||||
}{
|
||||
{
|
||||
name: "状态码和关键词都匹配",
|
||||
statusCode: 422,
|
||||
body: "context limit reached",
|
||||
expected: true,
|
||||
reason: "both match - AND mode should match",
|
||||
},
|
||||
{
|
||||
name: "只有状态码匹配",
|
||||
statusCode: 422,
|
||||
body: "some other error",
|
||||
expected: false,
|
||||
reason: "code matches but keyword doesn't - AND mode should NOT match",
|
||||
},
|
||||
{
|
||||
name: "只有关键词匹配",
|
||||
statusCode: 500,
|
||||
body: "context limit exceeded",
|
||||
expected: false,
|
||||
reason: "keyword matches but code doesn't - AND mode should NOT match",
|
||||
},
|
||||
{
|
||||
name: "都不匹配",
|
||||
statusCode: 500,
|
||||
body: "some other error",
|
||||
expected: false,
|
||||
reason: "neither matches",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := svc.ruleMatches(rule, tt.statusCode, tt.body)
|
||||
assert.Equal(t, tt.expected, result, tt.reason)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// 测试 platformMatches 平台匹配逻辑
|
||||
// =============================================================================
|
||||
|
||||
func TestPlatformMatches(t *testing.T) {
|
||||
svc := newTestService(nil)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
rulePlatforms []string
|
||||
requestPlatform string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "空平台列表匹配所有",
|
||||
rulePlatforms: []string{},
|
||||
requestPlatform: "anthropic",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "nil平台列表匹配所有",
|
||||
rulePlatforms: nil,
|
||||
requestPlatform: "openai",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "精确匹配 anthropic",
|
||||
rulePlatforms: []string{"anthropic", "openai"},
|
||||
requestPlatform: "anthropic",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "精确匹配 openai",
|
||||
rulePlatforms: []string{"anthropic", "openai"},
|
||||
requestPlatform: "openai",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "不匹配 gemini",
|
||||
rulePlatforms: []string{"anthropic", "openai"},
|
||||
requestPlatform: "gemini",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "大小写不敏感",
|
||||
rulePlatforms: []string{"Anthropic", "OpenAI"},
|
||||
requestPlatform: "anthropic",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "匹配 antigravity",
|
||||
rulePlatforms: []string{"antigravity"},
|
||||
requestPlatform: "antigravity",
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rule := &model.ErrorPassthroughRule{
|
||||
Platforms: tt.rulePlatforms,
|
||||
}
|
||||
result := svc.platformMatches(rule, tt.requestPlatform)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// 测试 MatchRule 完整匹配流程
|
||||
// =============================================================================
|
||||
|
||||
func TestMatchRule_Priority(t *testing.T) {
|
||||
// 测试规则按优先级排序,优先级小的先匹配
|
||||
rules := []*model.ErrorPassthroughRule{
|
||||
{
|
||||
ID: 1,
|
||||
Name: "Low Priority",
|
||||
Enabled: true,
|
||||
Priority: 10,
|
||||
ErrorCodes: []int{422},
|
||||
MatchMode: model.MatchModeAny,
|
||||
},
|
||||
{
|
||||
ID: 2,
|
||||
Name: "High Priority",
|
||||
Enabled: true,
|
||||
Priority: 1,
|
||||
ErrorCodes: []int{422},
|
||||
MatchMode: model.MatchModeAny,
|
||||
},
|
||||
}
|
||||
|
||||
svc := newTestService(rules)
|
||||
matched := svc.MatchRule("anthropic", 422, []byte("error"))
|
||||
|
||||
require.NotNil(t, matched)
|
||||
assert.Equal(t, int64(2), matched.ID, "应该匹配优先级更高(数值更小)的规则")
|
||||
assert.Equal(t, "High Priority", matched.Name)
|
||||
}
|
||||
|
||||
func TestMatchRule_DisabledRule(t *testing.T) {
|
||||
rules := []*model.ErrorPassthroughRule{
|
||||
{
|
||||
ID: 1,
|
||||
Name: "Disabled Rule",
|
||||
Enabled: false,
|
||||
Priority: 1,
|
||||
ErrorCodes: []int{422},
|
||||
MatchMode: model.MatchModeAny,
|
||||
},
|
||||
{
|
||||
ID: 2,
|
||||
Name: "Enabled Rule",
|
||||
Enabled: true,
|
||||
Priority: 10,
|
||||
ErrorCodes: []int{422},
|
||||
MatchMode: model.MatchModeAny,
|
||||
},
|
||||
}
|
||||
|
||||
svc := newTestService(rules)
|
||||
matched := svc.MatchRule("anthropic", 422, []byte("error"))
|
||||
|
||||
require.NotNil(t, matched)
|
||||
assert.Equal(t, int64(2), matched.ID, "应该跳过禁用的规则")
|
||||
}
|
||||
|
||||
func TestMatchRule_PlatformFilter(t *testing.T) {
|
||||
rules := []*model.ErrorPassthroughRule{
|
||||
{
|
||||
ID: 1,
|
||||
Name: "Anthropic Only",
|
||||
Enabled: true,
|
||||
Priority: 1,
|
||||
ErrorCodes: []int{422},
|
||||
Platforms: []string{"anthropic"},
|
||||
MatchMode: model.MatchModeAny,
|
||||
},
|
||||
{
|
||||
ID: 2,
|
||||
Name: "OpenAI Only",
|
||||
Enabled: true,
|
||||
Priority: 2,
|
||||
ErrorCodes: []int{422},
|
||||
Platforms: []string{"openai"},
|
||||
MatchMode: model.MatchModeAny,
|
||||
},
|
||||
{
|
||||
ID: 3,
|
||||
Name: "All Platforms",
|
||||
Enabled: true,
|
||||
Priority: 3,
|
||||
ErrorCodes: []int{422},
|
||||
Platforms: []string{},
|
||||
MatchMode: model.MatchModeAny,
|
||||
},
|
||||
}
|
||||
|
||||
svc := newTestService(rules)
|
||||
|
||||
t.Run("Anthropic 请求匹配 Anthropic 规则", func(t *testing.T) {
|
||||
matched := svc.MatchRule("anthropic", 422, []byte("error"))
|
||||
require.NotNil(t, matched)
|
||||
assert.Equal(t, int64(1), matched.ID)
|
||||
})
|
||||
|
||||
t.Run("OpenAI 请求匹配 OpenAI 规则", func(t *testing.T) {
|
||||
matched := svc.MatchRule("openai", 422, []byte("error"))
|
||||
require.NotNil(t, matched)
|
||||
assert.Equal(t, int64(2), matched.ID)
|
||||
})
|
||||
|
||||
t.Run("Gemini 请求匹配全平台规则", func(t *testing.T) {
|
||||
matched := svc.MatchRule("gemini", 422, []byte("error"))
|
||||
require.NotNil(t, matched)
|
||||
assert.Equal(t, int64(3), matched.ID)
|
||||
})
|
||||
|
||||
t.Run("Antigravity 请求匹配全平台规则", func(t *testing.T) {
|
||||
matched := svc.MatchRule("antigravity", 422, []byte("error"))
|
||||
require.NotNil(t, matched)
|
||||
assert.Equal(t, int64(3), matched.ID)
|
||||
})
|
||||
}
|
||||
|
||||
func TestMatchRule_NoMatch(t *testing.T) {
|
||||
rules := []*model.ErrorPassthroughRule{
|
||||
{
|
||||
ID: 1,
|
||||
Name: "Rule for 422",
|
||||
Enabled: true,
|
||||
Priority: 1,
|
||||
ErrorCodes: []int{422},
|
||||
MatchMode: model.MatchModeAny,
|
||||
},
|
||||
}
|
||||
|
||||
svc := newTestService(rules)
|
||||
matched := svc.MatchRule("anthropic", 500, []byte("error"))
|
||||
|
||||
assert.Nil(t, matched, "不匹配任何规则时应返回 nil")
|
||||
}
|
||||
|
||||
func TestMatchRule_EmptyRules(t *testing.T) {
|
||||
svc := newTestService([]*model.ErrorPassthroughRule{})
|
||||
matched := svc.MatchRule("anthropic", 422, []byte("error"))
|
||||
|
||||
assert.Nil(t, matched, "没有规则时应返回 nil")
|
||||
}
|
||||
|
||||
func TestMatchRule_CaseInsensitiveKeyword(t *testing.T) {
|
||||
rules := []*model.ErrorPassthroughRule{
|
||||
{
|
||||
ID: 1,
|
||||
Name: "Context Limit",
|
||||
Enabled: true,
|
||||
Priority: 1,
|
||||
Keywords: []string{"Context Limit"},
|
||||
MatchMode: model.MatchModeAny,
|
||||
},
|
||||
}
|
||||
|
||||
svc := newTestService(rules)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
expected bool
|
||||
}{
|
||||
{"完全匹配", "Context Limit reached", true},
|
||||
{"小写匹配", "context limit reached", true},
|
||||
{"大写匹配", "CONTEXT LIMIT REACHED", true},
|
||||
{"混合大小写", "ConTeXt LiMiT error", true},
|
||||
{"不匹配", "some other error", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
matched := svc.MatchRule("anthropic", 500, []byte(tt.body))
|
||||
if tt.expected {
|
||||
assert.NotNil(t, matched)
|
||||
} else {
|
||||
assert.Nil(t, matched)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// 测试真实场景
|
||||
// =============================================================================
|
||||
|
||||
func TestMatchRule_RealWorldScenario_ContextLimitPassthrough(t *testing.T) {
|
||||
// 场景:上游返回 422 + "context limit has been reached",需要透传给客户端
|
||||
rules := []*model.ErrorPassthroughRule{
|
||||
{
|
||||
ID: 1,
|
||||
Name: "Context Limit Passthrough",
|
||||
Enabled: true,
|
||||
Priority: 1,
|
||||
ErrorCodes: []int{422},
|
||||
Keywords: []string{"context limit"},
|
||||
MatchMode: model.MatchModeAll, // 必须同时满足
|
||||
Platforms: []string{"anthropic", "antigravity"},
|
||||
PassthroughCode: true,
|
||||
PassthroughBody: true,
|
||||
},
|
||||
}
|
||||
|
||||
svc := newTestService(rules)
|
||||
|
||||
// 测试 Anthropic 平台
|
||||
t.Run("Anthropic 422 with context limit", func(t *testing.T) {
|
||||
body := []byte(`{"type":"error","error":{"type":"invalid_request","message":"The context limit has been reached"}}`)
|
||||
matched := svc.MatchRule("anthropic", 422, body)
|
||||
require.NotNil(t, matched)
|
||||
assert.True(t, matched.PassthroughCode)
|
||||
assert.True(t, matched.PassthroughBody)
|
||||
})
|
||||
|
||||
// 测试 Antigravity 平台
|
||||
t.Run("Antigravity 422 with context limit", func(t *testing.T) {
|
||||
body := []byte(`{"error":"context limit exceeded"}`)
|
||||
matched := svc.MatchRule("antigravity", 422, body)
|
||||
require.NotNil(t, matched)
|
||||
})
|
||||
|
||||
// 测试 OpenAI 平台(不在规则的平台列表中)
|
||||
t.Run("OpenAI should not match", func(t *testing.T) {
|
||||
body := []byte(`{"error":"context limit exceeded"}`)
|
||||
matched := svc.MatchRule("openai", 422, body)
|
||||
assert.Nil(t, matched, "OpenAI 不在规则的平台列表中")
|
||||
})
|
||||
|
||||
// 测试状态码不匹配
|
||||
t.Run("Wrong status code", func(t *testing.T) {
|
||||
body := []byte(`{"error":"context limit exceeded"}`)
|
||||
matched := svc.MatchRule("anthropic", 400, body)
|
||||
assert.Nil(t, matched, "状态码不匹配")
|
||||
})
|
||||
|
||||
// 测试关键词不匹配
|
||||
t.Run("Wrong keyword", func(t *testing.T) {
|
||||
body := []byte(`{"error":"rate limit exceeded"}`)
|
||||
matched := svc.MatchRule("anthropic", 422, body)
|
||||
assert.Nil(t, matched, "关键词不匹配")
|
||||
})
|
||||
}
|
||||
|
||||
func TestMatchRule_RealWorldScenario_CustomErrorMessage(t *testing.T) {
|
||||
// 场景:某些错误需要返回自定义消息,隐藏上游详细信息
|
||||
customMsg := "Service temporarily unavailable, please try again later"
|
||||
responseCode := 503
|
||||
rules := []*model.ErrorPassthroughRule{
|
||||
{
|
||||
ID: 1,
|
||||
Name: "Hide Internal Errors",
|
||||
Enabled: true,
|
||||
Priority: 1,
|
||||
ErrorCodes: []int{500, 502, 503},
|
||||
MatchMode: model.MatchModeAny,
|
||||
PassthroughCode: false,
|
||||
ResponseCode: &responseCode,
|
||||
PassthroughBody: false,
|
||||
CustomMessage: &customMsg,
|
||||
},
|
||||
}
|
||||
|
||||
svc := newTestService(rules)
|
||||
|
||||
matched := svc.MatchRule("anthropic", 500, []byte("internal server error"))
|
||||
require.NotNil(t, matched)
|
||||
assert.False(t, matched.PassthroughCode)
|
||||
assert.Equal(t, 503, *matched.ResponseCode)
|
||||
assert.False(t, matched.PassthroughBody)
|
||||
assert.Equal(t, customMsg, *matched.CustomMessage)
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// 测试 model.Validate
|
||||
// =============================================================================
|
||||
|
||||
func TestErrorPassthroughRule_Validate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
rule *model.ErrorPassthroughRule
|
||||
expectError bool
|
||||
errorField string
|
||||
}{
|
||||
{
|
||||
name: "有效规则 - 透传模式(含错误码)",
|
||||
rule: &model.ErrorPassthroughRule{
|
||||
Name: "Valid Rule",
|
||||
MatchMode: model.MatchModeAny,
|
||||
ErrorCodes: []int{422},
|
||||
PassthroughCode: true,
|
||||
PassthroughBody: true,
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "有效规则 - 透传模式(含关键词)",
|
||||
rule: &model.ErrorPassthroughRule{
|
||||
Name: "Valid Rule",
|
||||
MatchMode: model.MatchModeAny,
|
||||
Keywords: []string{"context limit"},
|
||||
PassthroughCode: true,
|
||||
PassthroughBody: true,
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "有效规则 - 自定义响应",
|
||||
rule: &model.ErrorPassthroughRule{
|
||||
Name: "Valid Rule",
|
||||
MatchMode: model.MatchModeAll,
|
||||
ErrorCodes: []int{500},
|
||||
Keywords: []string{"internal error"},
|
||||
PassthroughCode: false,
|
||||
ResponseCode: testIntPtr(503),
|
||||
PassthroughBody: false,
|
||||
CustomMessage: testStrPtr("Custom error"),
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "缺少名称",
|
||||
rule: &model.ErrorPassthroughRule{
|
||||
Name: "",
|
||||
MatchMode: model.MatchModeAny,
|
||||
ErrorCodes: []int{422},
|
||||
PassthroughCode: true,
|
||||
PassthroughBody: true,
|
||||
},
|
||||
expectError: true,
|
||||
errorField: "name",
|
||||
},
|
||||
{
|
||||
name: "无效的匹配模式",
|
||||
rule: &model.ErrorPassthroughRule{
|
||||
Name: "Invalid Mode",
|
||||
MatchMode: "invalid",
|
||||
ErrorCodes: []int{422},
|
||||
PassthroughCode: true,
|
||||
PassthroughBody: true,
|
||||
},
|
||||
expectError: true,
|
||||
errorField: "match_mode",
|
||||
},
|
||||
{
|
||||
name: "缺少匹配条件(错误码和关键词都为空)",
|
||||
rule: &model.ErrorPassthroughRule{
|
||||
Name: "No Conditions",
|
||||
MatchMode: model.MatchModeAny,
|
||||
ErrorCodes: []int{},
|
||||
Keywords: []string{},
|
||||
PassthroughCode: true,
|
||||
PassthroughBody: true,
|
||||
},
|
||||
expectError: true,
|
||||
errorField: "conditions",
|
||||
},
|
||||
{
|
||||
name: "缺少匹配条件(nil切片)",
|
||||
rule: &model.ErrorPassthroughRule{
|
||||
Name: "Nil Conditions",
|
||||
MatchMode: model.MatchModeAny,
|
||||
ErrorCodes: nil,
|
||||
Keywords: nil,
|
||||
PassthroughCode: true,
|
||||
PassthroughBody: true,
|
||||
},
|
||||
expectError: true,
|
||||
errorField: "conditions",
|
||||
},
|
||||
{
|
||||
name: "自定义状态码但未提供值",
|
||||
rule: &model.ErrorPassthroughRule{
|
||||
Name: "Missing Code",
|
||||
MatchMode: model.MatchModeAny,
|
||||
ErrorCodes: []int{422},
|
||||
PassthroughCode: false,
|
||||
ResponseCode: nil,
|
||||
PassthroughBody: true,
|
||||
},
|
||||
expectError: true,
|
||||
errorField: "response_code",
|
||||
},
|
||||
{
|
||||
name: "自定义消息但未提供值",
|
||||
rule: &model.ErrorPassthroughRule{
|
||||
Name: "Missing Message",
|
||||
MatchMode: model.MatchModeAny,
|
||||
ErrorCodes: []int{422},
|
||||
PassthroughCode: true,
|
||||
PassthroughBody: false,
|
||||
CustomMessage: nil,
|
||||
},
|
||||
expectError: true,
|
||||
errorField: "custom_message",
|
||||
},
|
||||
{
|
||||
name: "自定义消息为空字符串",
|
||||
rule: &model.ErrorPassthroughRule{
|
||||
Name: "Empty Message",
|
||||
MatchMode: model.MatchModeAny,
|
||||
ErrorCodes: []int{422},
|
||||
PassthroughCode: true,
|
||||
PassthroughBody: false,
|
||||
CustomMessage: testStrPtr(""),
|
||||
},
|
||||
expectError: true,
|
||||
errorField: "custom_message",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.rule.Validate()
|
||||
if tt.expectError {
|
||||
require.Error(t, err)
|
||||
validationErr, ok := err.(*model.ValidationError)
|
||||
require.True(t, ok, "应该返回 ValidationError")
|
||||
assert.Equal(t, tt.errorField, validationErr.Field)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
func testIntPtr(i int) *int { return &i }
|
||||
func testStrPtr(s string) *string { return &s }
|
||||
288
backend/internal/service/gateway_cached_tokens_test.go
Normal file
288
backend/internal/service/gateway_cached_tokens_test.go
Normal file
@@ -0,0 +1,288 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// ---------- reconcileCachedTokens 单元测试 ----------
|
||||
|
||||
func TestReconcileCachedTokens_NilUsage(t *testing.T) {
|
||||
assert.False(t, reconcileCachedTokens(nil))
|
||||
}
|
||||
|
||||
func TestReconcileCachedTokens_AlreadyHasCacheRead(t *testing.T) {
|
||||
// 已有标准字段,不应覆盖
|
||||
usage := map[string]any{
|
||||
"cache_read_input_tokens": float64(100),
|
||||
"cached_tokens": float64(50),
|
||||
}
|
||||
assert.False(t, reconcileCachedTokens(usage))
|
||||
assert.Equal(t, float64(100), usage["cache_read_input_tokens"])
|
||||
}
|
||||
|
||||
func TestReconcileCachedTokens_KimiStyle(t *testing.T) {
|
||||
// Kimi 风格:cache_read_input_tokens=0,cached_tokens>0
|
||||
usage := map[string]any{
|
||||
"input_tokens": float64(23),
|
||||
"cache_creation_input_tokens": float64(0),
|
||||
"cache_read_input_tokens": float64(0),
|
||||
"cached_tokens": float64(23),
|
||||
}
|
||||
assert.True(t, reconcileCachedTokens(usage))
|
||||
assert.Equal(t, float64(23), usage["cache_read_input_tokens"])
|
||||
}
|
||||
|
||||
func TestReconcileCachedTokens_NoCachedTokens(t *testing.T) {
|
||||
// 无 cached_tokens 字段(原生 Claude)
|
||||
usage := map[string]any{
|
||||
"input_tokens": float64(100),
|
||||
"cache_read_input_tokens": float64(0),
|
||||
"cache_creation_input_tokens": float64(0),
|
||||
}
|
||||
assert.False(t, reconcileCachedTokens(usage))
|
||||
assert.Equal(t, float64(0), usage["cache_read_input_tokens"])
|
||||
}
|
||||
|
||||
func TestReconcileCachedTokens_CachedTokensZero(t *testing.T) {
|
||||
// cached_tokens 为 0,不应覆盖
|
||||
usage := map[string]any{
|
||||
"cache_read_input_tokens": float64(0),
|
||||
"cached_tokens": float64(0),
|
||||
}
|
||||
assert.False(t, reconcileCachedTokens(usage))
|
||||
assert.Equal(t, float64(0), usage["cache_read_input_tokens"])
|
||||
}
|
||||
|
||||
func TestReconcileCachedTokens_MissingCacheReadField(t *testing.T) {
|
||||
// cache_read_input_tokens 字段完全不存在,cached_tokens > 0
|
||||
usage := map[string]any{
|
||||
"cached_tokens": float64(42),
|
||||
}
|
||||
assert.True(t, reconcileCachedTokens(usage))
|
||||
assert.Equal(t, float64(42), usage["cache_read_input_tokens"])
|
||||
}
|
||||
|
||||
// ---------- 流式 message_start 事件 reconcile 测试 ----------
|
||||
|
||||
func TestStreamingReconcile_MessageStart(t *testing.T) {
|
||||
// 模拟 Kimi 返回的 message_start SSE 事件
|
||||
eventJSON := `{
|
||||
"type": "message_start",
|
||||
"message": {
|
||||
"id": "msg_123",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": "kimi",
|
||||
"usage": {
|
||||
"input_tokens": 23,
|
||||
"cache_creation_input_tokens": 0,
|
||||
"cache_read_input_tokens": 0,
|
||||
"cached_tokens": 23
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
var event map[string]any
|
||||
require.NoError(t, json.Unmarshal([]byte(eventJSON), &event))
|
||||
|
||||
eventType, _ := event["type"].(string)
|
||||
require.Equal(t, "message_start", eventType)
|
||||
|
||||
// 模拟 processSSEEvent 中的 reconcile 逻辑
|
||||
if msg, ok := event["message"].(map[string]any); ok {
|
||||
if u, ok := msg["usage"].(map[string]any); ok {
|
||||
reconcileCachedTokens(u)
|
||||
}
|
||||
}
|
||||
|
||||
// 验证 cache_read_input_tokens 已被填充
|
||||
msg, ok := event["message"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
usage, ok := msg["usage"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, float64(23), usage["cache_read_input_tokens"])
|
||||
|
||||
// 验证重新序列化后 JSON 也包含正确值
|
||||
data, err := json.Marshal(event)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(23), gjson.GetBytes(data, "message.usage.cache_read_input_tokens").Int())
|
||||
}
|
||||
|
||||
func TestStreamingReconcile_MessageStart_NativeClaude(t *testing.T) {
|
||||
// 原生 Claude 不返回 cached_tokens,reconcile 不应改变任何值
|
||||
eventJSON := `{
|
||||
"type": "message_start",
|
||||
"message": {
|
||||
"usage": {
|
||||
"input_tokens": 100,
|
||||
"cache_creation_input_tokens": 50,
|
||||
"cache_read_input_tokens": 30
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
var event map[string]any
|
||||
require.NoError(t, json.Unmarshal([]byte(eventJSON), &event))
|
||||
|
||||
if msg, ok := event["message"].(map[string]any); ok {
|
||||
if u, ok := msg["usage"].(map[string]any); ok {
|
||||
reconcileCachedTokens(u)
|
||||
}
|
||||
}
|
||||
|
||||
msg, ok := event["message"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
usage, ok := msg["usage"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, float64(30), usage["cache_read_input_tokens"])
|
||||
}
|
||||
|
||||
// ---------- 流式 message_delta 事件 reconcile 测试 ----------
|
||||
|
||||
func TestStreamingReconcile_MessageDelta(t *testing.T) {
|
||||
// 模拟 Kimi 返回的 message_delta SSE 事件
|
||||
eventJSON := `{
|
||||
"type": "message_delta",
|
||||
"usage": {
|
||||
"output_tokens": 7,
|
||||
"cache_read_input_tokens": 0,
|
||||
"cached_tokens": 15
|
||||
}
|
||||
}`
|
||||
|
||||
var event map[string]any
|
||||
require.NoError(t, json.Unmarshal([]byte(eventJSON), &event))
|
||||
|
||||
eventType, _ := event["type"].(string)
|
||||
require.Equal(t, "message_delta", eventType)
|
||||
|
||||
// 模拟 processSSEEvent 中的 reconcile 逻辑
|
||||
usage, ok := event["usage"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
reconcileCachedTokens(usage)
|
||||
assert.Equal(t, float64(15), usage["cache_read_input_tokens"])
|
||||
}
|
||||
|
||||
func TestStreamingReconcile_MessageDelta_NativeClaude(t *testing.T) {
|
||||
// 原生 Claude 的 message_delta 通常没有 cached_tokens
|
||||
eventJSON := `{
|
||||
"type": "message_delta",
|
||||
"usage": {
|
||||
"output_tokens": 50
|
||||
}
|
||||
}`
|
||||
|
||||
var event map[string]any
|
||||
require.NoError(t, json.Unmarshal([]byte(eventJSON), &event))
|
||||
|
||||
usage, ok := event["usage"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
reconcileCachedTokens(usage)
|
||||
_, hasCacheRead := usage["cache_read_input_tokens"]
|
||||
assert.False(t, hasCacheRead, "不应为原生 Claude 响应注入 cache_read_input_tokens")
|
||||
}
|
||||
|
||||
// ---------- 非流式响应 reconcile 测试 ----------
|
||||
|
||||
func TestNonStreamingReconcile_KimiResponse(t *testing.T) {
|
||||
// 模拟 Kimi 非流式响应
|
||||
body := []byte(`{
|
||||
"id": "msg_123",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "hello"}],
|
||||
"model": "kimi",
|
||||
"usage": {
|
||||
"input_tokens": 23,
|
||||
"output_tokens": 7,
|
||||
"cache_creation_input_tokens": 0,
|
||||
"cache_read_input_tokens": 0,
|
||||
"cached_tokens": 23,
|
||||
"prompt_tokens": 23,
|
||||
"completion_tokens": 7
|
||||
}
|
||||
}`)
|
||||
|
||||
// 模拟 handleNonStreamingResponse 中的逻辑
|
||||
var response struct {
|
||||
Usage ClaudeUsage `json:"usage"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(body, &response))
|
||||
|
||||
// reconcile
|
||||
if response.Usage.CacheReadInputTokens == 0 {
|
||||
cachedTokens := gjson.GetBytes(body, "usage.cached_tokens").Int()
|
||||
if cachedTokens > 0 {
|
||||
response.Usage.CacheReadInputTokens = int(cachedTokens)
|
||||
if newBody, err := sjson.SetBytes(body, "usage.cache_read_input_tokens", cachedTokens); err == nil {
|
||||
body = newBody
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 验证内部 usage(计费用)
|
||||
assert.Equal(t, 23, response.Usage.CacheReadInputTokens)
|
||||
assert.Equal(t, 23, response.Usage.InputTokens)
|
||||
assert.Equal(t, 7, response.Usage.OutputTokens)
|
||||
|
||||
// 验证返回给客户端的 JSON body
|
||||
assert.Equal(t, int64(23), gjson.GetBytes(body, "usage.cache_read_input_tokens").Int())
|
||||
}
|
||||
|
||||
func TestNonStreamingReconcile_NativeClaude(t *testing.T) {
|
||||
// 原生 Claude 响应:cache_read_input_tokens 已有值
|
||||
body := []byte(`{
|
||||
"usage": {
|
||||
"input_tokens": 100,
|
||||
"output_tokens": 50,
|
||||
"cache_creation_input_tokens": 20,
|
||||
"cache_read_input_tokens": 30
|
||||
}
|
||||
}`)
|
||||
|
||||
var response struct {
|
||||
Usage ClaudeUsage `json:"usage"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(body, &response))
|
||||
|
||||
// CacheReadInputTokens == 30,条件不成立,整个 reconcile 分支不会执行
|
||||
assert.NotZero(t, response.Usage.CacheReadInputTokens)
|
||||
assert.Equal(t, 30, response.Usage.CacheReadInputTokens)
|
||||
}
|
||||
|
||||
func TestNonStreamingReconcile_NoCachedTokens(t *testing.T) {
|
||||
// 没有 cached_tokens 字段
|
||||
body := []byte(`{
|
||||
"usage": {
|
||||
"input_tokens": 100,
|
||||
"output_tokens": 50,
|
||||
"cache_creation_input_tokens": 0,
|
||||
"cache_read_input_tokens": 0
|
||||
}
|
||||
}`)
|
||||
|
||||
var response struct {
|
||||
Usage ClaudeUsage `json:"usage"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(body, &response))
|
||||
|
||||
if response.Usage.CacheReadInputTokens == 0 {
|
||||
cachedTokens := gjson.GetBytes(body, "usage.cached_tokens").Int()
|
||||
if cachedTokens > 0 {
|
||||
response.Usage.CacheReadInputTokens = int(cachedTokens)
|
||||
if newBody, err := sjson.SetBytes(body, "usage.cache_read_input_tokens", cachedTokens); err == nil {
|
||||
body = newBody
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// cache_read_input_tokens 应保持为 0
|
||||
assert.Equal(t, 0, response.Usage.CacheReadInputTokens)
|
||||
assert.Equal(t, int64(0), gjson.GetBytes(body, "usage.cache_read_input_tokens").Int())
|
||||
}
|
||||
@@ -12,10 +12,3 @@ func TestSanitizeOpenCodeText_RewritesCanonicalSentence(t *testing.T) {
|
||||
got := sanitizeSystemText(in)
|
||||
require.Equal(t, strings.TrimSpace(claudeCodeSystemPrompt), got)
|
||||
}
|
||||
|
||||
func TestSanitizeToolDescription_DoesNotRewriteKeywords(t *testing.T) {
|
||||
in := "OpenCode and opencode are mentioned."
|
||||
got := sanitizeToolDescription(in)
|
||||
// We no longer rewrite tool descriptions; only redact obvious path leaks.
|
||||
require.Equal(t, in, got)
|
||||
}
|
||||
|
||||
@@ -20,7 +20,6 @@ import (
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
"unicode"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
@@ -208,40 +207,6 @@ var (
|
||||
sseDataRe = regexp.MustCompile(`^data:\s*`)
|
||||
sessionIDRegex = regexp.MustCompile(`session_([a-f0-9-]{36})`)
|
||||
claudeCliUserAgentRe = regexp.MustCompile(`^claude-cli/\d+\.\d+\.\d+`)
|
||||
toolPrefixRe = regexp.MustCompile(`(?i)^(?:oc_|mcp_)`)
|
||||
toolNameBoundaryRe = regexp.MustCompile(`[^a-zA-Z0-9]+`)
|
||||
toolNameCamelRe = regexp.MustCompile(`([a-z0-9])([A-Z])`)
|
||||
toolNameFieldRe = regexp.MustCompile(`"name"\s*:\s*"([^"]+)"`)
|
||||
modelFieldRe = regexp.MustCompile(`"model"\s*:\s*"([^"]+)"`)
|
||||
toolDescAbsPathRe = regexp.MustCompile(`/\/?(?:home|Users|tmp|var|opt|usr|etc)\/[^\s,\)"'\]]+`)
|
||||
toolDescWinPathRe = regexp.MustCompile(`(?i)[A-Z]:\\[^\s,\)"'\]]+`)
|
||||
|
||||
claudeToolNameOverrides = map[string]string{
|
||||
"bash": "Bash",
|
||||
"read": "Read",
|
||||
"edit": "Edit",
|
||||
"write": "Write",
|
||||
"task": "Task",
|
||||
"glob": "Glob",
|
||||
"grep": "Grep",
|
||||
"webfetch": "WebFetch",
|
||||
"websearch": "WebSearch",
|
||||
"todowrite": "TodoWrite",
|
||||
"question": "AskUserQuestion",
|
||||
}
|
||||
openCodeToolOverrides = map[string]string{
|
||||
"Bash": "bash",
|
||||
"Read": "read",
|
||||
"Edit": "edit",
|
||||
"Write": "write",
|
||||
"Task": "task",
|
||||
"Glob": "glob",
|
||||
"Grep": "grep",
|
||||
"WebFetch": "webfetch",
|
||||
"WebSearch": "websearch",
|
||||
"TodoWrite": "todowrite",
|
||||
"AskUserQuestion": "question",
|
||||
}
|
||||
|
||||
// claudeCodePromptPrefixes 用于检测 Claude Code 系统提示词的前缀列表
|
||||
// 支持多种变体:标准版、Agent SDK 版、Explore Agent 版、Compact 版等
|
||||
@@ -371,7 +336,8 @@ type ForwardResult struct {
|
||||
|
||||
// UpstreamFailoverError indicates an upstream error that should trigger account failover.
|
||||
type UpstreamFailoverError struct {
|
||||
StatusCode int
|
||||
StatusCode int
|
||||
ResponseBody []byte // 上游响应体,用于错误透传规则匹配
|
||||
}
|
||||
|
||||
func (e *UpstreamFailoverError) Error() string {
|
||||
@@ -385,6 +351,7 @@ type GatewayService struct {
|
||||
usageLogRepo UsageLogRepository
|
||||
userRepo UserRepository
|
||||
userSubRepo UserSubscriptionRepository
|
||||
userGroupRateRepo UserGroupRateRepository
|
||||
cache GatewayCache
|
||||
cfg *config.Config
|
||||
schedulerSnapshot *SchedulerSnapshotService
|
||||
@@ -406,6 +373,7 @@ func NewGatewayService(
|
||||
usageLogRepo UsageLogRepository,
|
||||
userRepo UserRepository,
|
||||
userSubRepo UserSubscriptionRepository,
|
||||
userGroupRateRepo UserGroupRateRepository,
|
||||
cache GatewayCache,
|
||||
cfg *config.Config,
|
||||
schedulerSnapshot *SchedulerSnapshotService,
|
||||
@@ -425,6 +393,7 @@ func NewGatewayService(
|
||||
usageLogRepo: usageLogRepo,
|
||||
userRepo: userRepo,
|
||||
userSubRepo: userSubRepo,
|
||||
userGroupRateRepo: userGroupRateRepo,
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
schedulerSnapshot: schedulerSnapshot,
|
||||
@@ -613,98 +582,6 @@ type claudeOAuthNormalizeOptions struct {
|
||||
stripSystemCacheControl bool
|
||||
}
|
||||
|
||||
func stripToolPrefix(value string) string {
|
||||
if value == "" {
|
||||
return value
|
||||
}
|
||||
return toolPrefixRe.ReplaceAllString(value, "")
|
||||
}
|
||||
|
||||
func toPascalCase(value string) string {
|
||||
if value == "" {
|
||||
return value
|
||||
}
|
||||
normalized := toolNameBoundaryRe.ReplaceAllString(value, " ")
|
||||
tokens := make([]string, 0)
|
||||
for _, token := range strings.Fields(normalized) {
|
||||
expanded := toolNameCamelRe.ReplaceAllString(token, "$1 $2")
|
||||
parts := strings.Fields(expanded)
|
||||
if len(parts) > 0 {
|
||||
tokens = append(tokens, parts...)
|
||||
}
|
||||
}
|
||||
if len(tokens) == 0 {
|
||||
return value
|
||||
}
|
||||
var builder strings.Builder
|
||||
for _, token := range tokens {
|
||||
lower := strings.ToLower(token)
|
||||
if lower == "" {
|
||||
continue
|
||||
}
|
||||
runes := []rune(lower)
|
||||
runes[0] = unicode.ToUpper(runes[0])
|
||||
_, _ = builder.WriteString(string(runes))
|
||||
}
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
func toSnakeCase(value string) string {
|
||||
if value == "" {
|
||||
return value
|
||||
}
|
||||
output := toolNameCamelRe.ReplaceAllString(value, "$1_$2")
|
||||
output = toolNameBoundaryRe.ReplaceAllString(output, "_")
|
||||
output = strings.Trim(output, "_")
|
||||
return strings.ToLower(output)
|
||||
}
|
||||
|
||||
func normalizeToolNameForClaude(name string, cache map[string]string) string {
|
||||
if name == "" {
|
||||
return name
|
||||
}
|
||||
stripped := stripToolPrefix(name)
|
||||
mapped, ok := claudeToolNameOverrides[strings.ToLower(stripped)]
|
||||
if !ok {
|
||||
mapped = toPascalCase(stripped)
|
||||
}
|
||||
if mapped != "" && cache != nil && mapped != stripped {
|
||||
cache[mapped] = stripped
|
||||
}
|
||||
if mapped == "" {
|
||||
return stripped
|
||||
}
|
||||
return mapped
|
||||
}
|
||||
|
||||
func normalizeToolNameForOpenCode(name string, cache map[string]string) string {
|
||||
if name == "" {
|
||||
return name
|
||||
}
|
||||
stripped := stripToolPrefix(name)
|
||||
if cache != nil {
|
||||
if mapped, ok := cache[stripped]; ok {
|
||||
return mapped
|
||||
}
|
||||
}
|
||||
if mapped, ok := openCodeToolOverrides[stripped]; ok {
|
||||
return mapped
|
||||
}
|
||||
return toSnakeCase(stripped)
|
||||
}
|
||||
|
||||
func normalizeParamNameForOpenCode(name string, cache map[string]string) string {
|
||||
if name == "" {
|
||||
return name
|
||||
}
|
||||
if cache != nil {
|
||||
if mapped, ok := cache[name]; ok {
|
||||
return mapped
|
||||
}
|
||||
}
|
||||
return name
|
||||
}
|
||||
|
||||
// sanitizeSystemText rewrites only the fixed OpenCode identity sentence (if present).
|
||||
// We intentionally avoid broad keyword replacement in system prompts to prevent
|
||||
// accidentally changing user-provided instructions.
|
||||
@@ -723,55 +600,6 @@ func sanitizeSystemText(text string) string {
|
||||
return text
|
||||
}
|
||||
|
||||
func sanitizeToolDescription(description string) string {
|
||||
if description == "" {
|
||||
return description
|
||||
}
|
||||
description = toolDescAbsPathRe.ReplaceAllString(description, "[path]")
|
||||
description = toolDescWinPathRe.ReplaceAllString(description, "[path]")
|
||||
// Intentionally do NOT rewrite tool descriptions (OpenCode/Claude strings).
|
||||
// Tool names/skill names may rely on exact wording, and rewriting can be misleading.
|
||||
return description
|
||||
}
|
||||
|
||||
func normalizeToolInputSchema(inputSchema any, cache map[string]string) {
|
||||
schema, ok := inputSchema.(map[string]any)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
properties, ok := schema["properties"].(map[string]any)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
newProperties := make(map[string]any, len(properties))
|
||||
for key, value := range properties {
|
||||
snakeKey := toSnakeCase(key)
|
||||
newProperties[snakeKey] = value
|
||||
if snakeKey != key && cache != nil {
|
||||
cache[snakeKey] = key
|
||||
}
|
||||
}
|
||||
schema["properties"] = newProperties
|
||||
|
||||
if required, ok := schema["required"].([]any); ok {
|
||||
newRequired := make([]any, 0, len(required))
|
||||
for _, item := range required {
|
||||
name, ok := item.(string)
|
||||
if !ok {
|
||||
newRequired = append(newRequired, item)
|
||||
continue
|
||||
}
|
||||
snakeName := toSnakeCase(name)
|
||||
newRequired = append(newRequired, snakeName)
|
||||
if snakeName != name && cache != nil {
|
||||
cache[snakeName] = name
|
||||
}
|
||||
}
|
||||
schema["required"] = newRequired
|
||||
}
|
||||
}
|
||||
|
||||
func stripCacheControlFromSystemBlocks(system any) bool {
|
||||
blocks, ok := system.([]any)
|
||||
if !ok {
|
||||
@@ -792,24 +620,17 @@ func stripCacheControlFromSystemBlocks(system any) bool {
|
||||
return changed
|
||||
}
|
||||
|
||||
func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAuthNormalizeOptions) ([]byte, string, map[string]string) {
|
||||
func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAuthNormalizeOptions) ([]byte, string) {
|
||||
if len(body) == 0 {
|
||||
return body, modelID, nil
|
||||
return body, modelID
|
||||
}
|
||||
|
||||
// 使用 json.RawMessage 保留 messages 的原始字节,避免 thinking 块被修改
|
||||
var reqRaw map[string]json.RawMessage
|
||||
if err := json.Unmarshal(body, &reqRaw); err != nil {
|
||||
return body, modelID, nil
|
||||
}
|
||||
|
||||
// 同时解析为 map[string]any 用于修改非 messages 字段
|
||||
// 解析为 map[string]any 用于修改字段
|
||||
var req map[string]any
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
return body, modelID, nil
|
||||
return body, modelID
|
||||
}
|
||||
|
||||
toolNameMap := make(map[string]string)
|
||||
modified := false
|
||||
|
||||
if system, ok := req["system"]; ok {
|
||||
@@ -851,115 +672,12 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
|
||||
}
|
||||
}
|
||||
|
||||
if rawTools, exists := req["tools"]; exists {
|
||||
switch tools := rawTools.(type) {
|
||||
case []any:
|
||||
for idx, tool := range tools {
|
||||
toolMap, ok := tool.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if name, ok := toolMap["name"].(string); ok {
|
||||
normalized := normalizeToolNameForClaude(name, toolNameMap)
|
||||
if normalized != "" && normalized != name {
|
||||
toolMap["name"] = normalized
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
if desc, ok := toolMap["description"].(string); ok {
|
||||
sanitized := sanitizeToolDescription(desc)
|
||||
if sanitized != desc {
|
||||
toolMap["description"] = sanitized
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
if schema, ok := toolMap["input_schema"]; ok {
|
||||
normalizeToolInputSchema(schema, toolNameMap)
|
||||
modified = true
|
||||
}
|
||||
tools[idx] = toolMap
|
||||
}
|
||||
req["tools"] = tools
|
||||
case map[string]any:
|
||||
normalizedTools := make(map[string]any, len(tools))
|
||||
for name, value := range tools {
|
||||
normalized := normalizeToolNameForClaude(name, toolNameMap)
|
||||
if normalized == "" {
|
||||
normalized = name
|
||||
}
|
||||
if toolMap, ok := value.(map[string]any); ok {
|
||||
toolMap["name"] = normalized
|
||||
if desc, ok := toolMap["description"].(string); ok {
|
||||
sanitized := sanitizeToolDescription(desc)
|
||||
if sanitized != desc {
|
||||
toolMap["description"] = sanitized
|
||||
}
|
||||
}
|
||||
if schema, ok := toolMap["input_schema"]; ok {
|
||||
normalizeToolInputSchema(schema, toolNameMap)
|
||||
}
|
||||
normalizedTools[normalized] = toolMap
|
||||
continue
|
||||
}
|
||||
normalizedTools[normalized] = value
|
||||
}
|
||||
req["tools"] = normalizedTools
|
||||
modified = true
|
||||
}
|
||||
} else {
|
||||
// 确保 tools 字段存在(即使为空数组)
|
||||
if _, exists := req["tools"]; !exists {
|
||||
req["tools"] = []any{}
|
||||
modified = true
|
||||
}
|
||||
|
||||
// 处理 messages 中的 tool_use 块,但保留包含 thinking 块的消息的原始字节
|
||||
messagesModified := false
|
||||
if messages, ok := req["messages"].([]any); ok {
|
||||
for _, msg := range messages {
|
||||
msgMap, ok := msg.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
content, ok := msgMap["content"].([]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
// 检查此消息是否包含 thinking 块
|
||||
hasThinking := false
|
||||
for _, block := range content {
|
||||
blockMap, ok := block.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
blockType, _ := blockMap["type"].(string)
|
||||
if blockType == "thinking" || blockType == "redacted_thinking" {
|
||||
hasThinking = true
|
||||
break
|
||||
}
|
||||
}
|
||||
// 如果包含 thinking 块,跳过此消息的修改
|
||||
if hasThinking {
|
||||
continue
|
||||
}
|
||||
// 只修改不包含 thinking 块的消息中的 tool_use
|
||||
for _, block := range content {
|
||||
blockMap, ok := block.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if blockType, _ := blockMap["type"].(string); blockType != "tool_use" {
|
||||
continue
|
||||
}
|
||||
if name, ok := blockMap["name"].(string); ok {
|
||||
normalized := normalizeToolNameForClaude(name, toolNameMap)
|
||||
if normalized != "" && normalized != name {
|
||||
blockMap["name"] = normalized
|
||||
messagesModified = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if opts.stripSystemCacheControl {
|
||||
if system, ok := req["system"]; ok {
|
||||
_ = stripCacheControlFromSystemBlocks(system)
|
||||
@@ -988,38 +706,15 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
|
||||
modified = true
|
||||
}
|
||||
|
||||
if !modified && !messagesModified {
|
||||
return body, modelID, toolNameMap
|
||||
if !modified {
|
||||
return body, modelID
|
||||
}
|
||||
|
||||
// 如果 messages 没有被修改,保留原始 messages 字节
|
||||
if !messagesModified {
|
||||
// 序列化非 messages 字段
|
||||
newBody, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return body, modelID, toolNameMap
|
||||
}
|
||||
// 替换回原始的 messages
|
||||
var newReq map[string]json.RawMessage
|
||||
if err := json.Unmarshal(newBody, &newReq); err != nil {
|
||||
return newBody, modelID, toolNameMap
|
||||
}
|
||||
if origMessages, ok := reqRaw["messages"]; ok {
|
||||
newReq["messages"] = origMessages
|
||||
}
|
||||
finalBody, err := json.Marshal(newReq)
|
||||
if err != nil {
|
||||
return newBody, modelID, toolNameMap
|
||||
}
|
||||
return finalBody, modelID, toolNameMap
|
||||
}
|
||||
|
||||
// messages 被修改了,需要完整序列化
|
||||
newBody, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return body, modelID, toolNameMap
|
||||
return body, modelID
|
||||
}
|
||||
return newBody, modelID, toolNameMap
|
||||
return newBody, modelID
|
||||
}
|
||||
|
||||
func (s *GatewayService) buildOAuthMetadataUserID(parsed *ParsedRequest, account *Account, fp *Fingerprint) string {
|
||||
@@ -2984,7 +2679,6 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
reqModel := parsed.Model
|
||||
reqStream := parsed.Stream
|
||||
originalModel := reqModel
|
||||
var toolNameMap map[string]string
|
||||
|
||||
isClaudeCode := isClaudeCodeRequest(ctx, c, parsed)
|
||||
shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode
|
||||
@@ -3008,7 +2702,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
}
|
||||
}
|
||||
|
||||
body, reqModel, toolNameMap = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts)
|
||||
body, reqModel = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts)
|
||||
}
|
||||
|
||||
// 强制执行 cache_control 块数量限制(最多 4 个)
|
||||
@@ -3309,7 +3003,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
return ""
|
||||
}(),
|
||||
})
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
|
||||
}
|
||||
return s.handleRetryExhaustedError(ctx, resp, c, account)
|
||||
}
|
||||
@@ -3339,10 +3033,8 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
return ""
|
||||
}(),
|
||||
})
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
|
||||
}
|
||||
|
||||
// 处理错误响应(不可重试的错误)
|
||||
if resp.StatusCode >= 400 {
|
||||
// 可选:对部分 400 触发 failover(默认关闭以保持语义)
|
||||
if resp.StatusCode == 400 && s.cfg != nil && s.cfg.Gateway.FailoverOn400 {
|
||||
@@ -3386,7 +3078,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
log.Printf("Account %d: 400 error, attempting failover", account.ID)
|
||||
}
|
||||
s.handleFailoverSideEffects(ctx, resp, account)
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
|
||||
}
|
||||
}
|
||||
return s.handleErrorResponse(ctx, resp, c, account)
|
||||
@@ -3397,7 +3089,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
var firstTokenMs *int
|
||||
var clientDisconnect bool
|
||||
if reqStream {
|
||||
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, reqModel, toolNameMap, shouldMimicClaudeCode)
|
||||
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, reqModel, shouldMimicClaudeCode)
|
||||
if err != nil {
|
||||
if err.Error() == "have error in stream" {
|
||||
return nil, &UpstreamFailoverError{
|
||||
@@ -3410,7 +3102,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
firstTokenMs = streamResult.firstTokenMs
|
||||
clientDisconnect = streamResult.clientDisconnect
|
||||
} else {
|
||||
usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, reqModel, toolNameMap, shouldMimicClaudeCode)
|
||||
usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, reqModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -3783,6 +3475,12 @@ func (s *GatewayService) shouldFailoverOn400(respBody []byte) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// ExtractUpstreamErrorMessage 从上游响应体中提取错误消息
|
||||
// 支持 Claude 风格的错误格式:{"type":"error","error":{"type":"...","message":"..."}}
|
||||
func ExtractUpstreamErrorMessage(body []byte) string {
|
||||
return extractUpstreamErrorMessage(body)
|
||||
}
|
||||
|
||||
func extractUpstreamErrorMessage(body []byte) string {
|
||||
// Claude 风格:{"type":"error","error":{"type":"...","message":"..."}}
|
||||
if m := gjson.GetBytes(body, "error.message").String(); strings.TrimSpace(m) != "" {
|
||||
@@ -3850,7 +3548,7 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res
|
||||
shouldDisable = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
||||
}
|
||||
if shouldDisable {
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: body}
|
||||
}
|
||||
|
||||
// 记录上游错误响应体摘要便于排障(可选:由配置控制;不回显到客户端)
|
||||
@@ -4018,7 +3716,7 @@ type streamingResult struct {
|
||||
clientDisconnect bool // 客户端是否在流式传输过程中断开
|
||||
}
|
||||
|
||||
func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string, toolNameMap map[string]string, mimicClaudeCode bool) (*streamingResult, error) {
|
||||
func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string, mimicClaudeCode bool) (*streamingResult, error) {
|
||||
// 更新5h窗口状态
|
||||
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
|
||||
|
||||
@@ -4114,33 +3812,6 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
||||
clientDisconnected := false // 客户端断开标志,断开后继续读取上游以获取完整usage
|
||||
|
||||
pendingEventLines := make([]string, 0, 4)
|
||||
var toolInputBuffers map[int]string
|
||||
if mimicClaudeCode {
|
||||
toolInputBuffers = make(map[int]string)
|
||||
}
|
||||
|
||||
transformToolInputJSON := func(raw string) string {
|
||||
if !mimicClaudeCode {
|
||||
return raw
|
||||
}
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return raw
|
||||
}
|
||||
|
||||
var parsed any
|
||||
if err := json.Unmarshal([]byte(raw), &parsed); err != nil {
|
||||
return replaceToolNamesInText(raw, toolNameMap)
|
||||
}
|
||||
|
||||
rewritten, changed := rewriteParamKeysInValue(parsed, toolNameMap)
|
||||
if changed {
|
||||
if bytes, err := json.Marshal(rewritten); err == nil {
|
||||
return string(bytes)
|
||||
}
|
||||
}
|
||||
return raw
|
||||
}
|
||||
|
||||
processSSEEvent := func(lines []string) ([]string, string, error) {
|
||||
if len(lines) == 0 {
|
||||
@@ -4179,16 +3850,13 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
||||
|
||||
var event map[string]any
|
||||
if err := json.Unmarshal([]byte(dataLine), &event); err != nil {
|
||||
replaced := dataLine
|
||||
if mimicClaudeCode {
|
||||
replaced = replaceToolNamesInText(dataLine, toolNameMap)
|
||||
}
|
||||
// JSON 解析失败,直接透传原始数据
|
||||
block := ""
|
||||
if eventName != "" {
|
||||
block = "event: " + eventName + "\n"
|
||||
}
|
||||
block += "data: " + replaced + "\n\n"
|
||||
return []string{block}, replaced, nil
|
||||
block += "data: " + dataLine + "\n\n"
|
||||
return []string{block}, dataLine, nil
|
||||
}
|
||||
|
||||
eventType, _ := event["type"].(string)
|
||||
@@ -4196,6 +3864,20 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
||||
eventName = eventType
|
||||
}
|
||||
|
||||
// 兼容 Kimi cached_tokens → cache_read_input_tokens
|
||||
if eventType == "message_start" {
|
||||
if msg, ok := event["message"].(map[string]any); ok {
|
||||
if u, ok := msg["usage"].(map[string]any); ok {
|
||||
reconcileCachedTokens(u)
|
||||
}
|
||||
}
|
||||
}
|
||||
if eventType == "message_delta" {
|
||||
if u, ok := event["usage"].(map[string]any); ok {
|
||||
reconcileCachedTokens(u)
|
||||
}
|
||||
}
|
||||
|
||||
if needModelReplace {
|
||||
if msg, ok := event["message"].(map[string]any); ok {
|
||||
if model, ok := msg["model"].(string); ok && model == mappedModel {
|
||||
@@ -4204,70 +3886,15 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
||||
}
|
||||
}
|
||||
|
||||
if mimicClaudeCode && eventType == "content_block_delta" {
|
||||
if delta, ok := event["delta"].(map[string]any); ok {
|
||||
if deltaType, _ := delta["type"].(string); deltaType == "input_json_delta" {
|
||||
if indexVal, ok := event["index"].(float64); ok {
|
||||
index := int(indexVal)
|
||||
if partial, ok := delta["partial_json"].(string); ok {
|
||||
toolInputBuffers[index] += partial
|
||||
}
|
||||
}
|
||||
return nil, dataLine, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if mimicClaudeCode && eventType == "content_block_stop" {
|
||||
if indexVal, ok := event["index"].(float64); ok {
|
||||
index := int(indexVal)
|
||||
if buffered := toolInputBuffers[index]; buffered != "" {
|
||||
delete(toolInputBuffers, index)
|
||||
|
||||
transformed := transformToolInputJSON(buffered)
|
||||
synthetic := map[string]any{
|
||||
"type": "content_block_delta",
|
||||
"index": index,
|
||||
"delta": map[string]any{
|
||||
"type": "input_json_delta",
|
||||
"partial_json": transformed,
|
||||
},
|
||||
}
|
||||
|
||||
synthBytes, synthErr := json.Marshal(synthetic)
|
||||
if synthErr == nil {
|
||||
synthBlock := "event: content_block_delta\n" + "data: " + string(synthBytes) + "\n\n"
|
||||
|
||||
rewriteToolNamesInValue(event, toolNameMap)
|
||||
stopBytes, stopErr := json.Marshal(event)
|
||||
if stopErr == nil {
|
||||
stopBlock := ""
|
||||
if eventName != "" {
|
||||
stopBlock = "event: " + eventName + "\n"
|
||||
}
|
||||
stopBlock += "data: " + string(stopBytes) + "\n\n"
|
||||
return []string{synthBlock, stopBlock}, string(stopBytes), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if mimicClaudeCode {
|
||||
rewriteToolNamesInValue(event, toolNameMap)
|
||||
}
|
||||
newData, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
replaced := dataLine
|
||||
if mimicClaudeCode {
|
||||
replaced = replaceToolNamesInText(dataLine, toolNameMap)
|
||||
}
|
||||
// 序列化失败,直接透传原始数据
|
||||
block := ""
|
||||
if eventName != "" {
|
||||
block = "event: " + eventName + "\n"
|
||||
}
|
||||
block += "data: " + replaced + "\n\n"
|
||||
return []string{block}, replaced, nil
|
||||
block += "data: " + dataLine + "\n\n"
|
||||
return []string{block}, dataLine, nil
|
||||
}
|
||||
|
||||
block := ""
|
||||
@@ -4366,126 +3993,6 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
||||
|
||||
}
|
||||
|
||||
func rewriteParamKeysInValue(value any, cache map[string]string) (any, bool) {
|
||||
switch v := value.(type) {
|
||||
case map[string]any:
|
||||
changed := false
|
||||
rewritten := make(map[string]any, len(v))
|
||||
for key, item := range v {
|
||||
newKey := normalizeParamNameForOpenCode(key, cache)
|
||||
newItem, childChanged := rewriteParamKeysInValue(item, cache)
|
||||
if childChanged {
|
||||
changed = true
|
||||
}
|
||||
if newKey != key {
|
||||
changed = true
|
||||
}
|
||||
rewritten[newKey] = newItem
|
||||
}
|
||||
if !changed {
|
||||
return value, false
|
||||
}
|
||||
return rewritten, true
|
||||
case []any:
|
||||
changed := false
|
||||
rewritten := make([]any, len(v))
|
||||
for idx, item := range v {
|
||||
newItem, childChanged := rewriteParamKeysInValue(item, cache)
|
||||
if childChanged {
|
||||
changed = true
|
||||
}
|
||||
rewritten[idx] = newItem
|
||||
}
|
||||
if !changed {
|
||||
return value, false
|
||||
}
|
||||
return rewritten, true
|
||||
default:
|
||||
return value, false
|
||||
}
|
||||
}
|
||||
|
||||
func rewriteToolNamesInValue(value any, toolNameMap map[string]string) bool {
|
||||
switch v := value.(type) {
|
||||
case map[string]any:
|
||||
changed := false
|
||||
if blockType, _ := v["type"].(string); blockType == "tool_use" {
|
||||
if name, ok := v["name"].(string); ok {
|
||||
mapped := normalizeToolNameForOpenCode(name, toolNameMap)
|
||||
if mapped != name {
|
||||
v["name"] = mapped
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
if input, ok := v["input"].(map[string]any); ok {
|
||||
rewrittenInput, inputChanged := rewriteParamKeysInValue(input, toolNameMap)
|
||||
if inputChanged {
|
||||
if m, ok := rewrittenInput.(map[string]any); ok {
|
||||
v["input"] = m
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, item := range v {
|
||||
if rewriteToolNamesInValue(item, toolNameMap) {
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
return changed
|
||||
case []any:
|
||||
changed := false
|
||||
for _, item := range v {
|
||||
if rewriteToolNamesInValue(item, toolNameMap) {
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
return changed
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func replaceToolNamesInText(text string, toolNameMap map[string]string) string {
|
||||
if text == "" {
|
||||
return text
|
||||
}
|
||||
output := toolNameFieldRe.ReplaceAllStringFunc(text, func(match string) string {
|
||||
submatches := toolNameFieldRe.FindStringSubmatch(match)
|
||||
if len(submatches) < 2 {
|
||||
return match
|
||||
}
|
||||
name := submatches[1]
|
||||
mapped := normalizeToolNameForOpenCode(name, toolNameMap)
|
||||
if mapped == name {
|
||||
return match
|
||||
}
|
||||
return strings.Replace(match, name, mapped, 1)
|
||||
})
|
||||
output = modelFieldRe.ReplaceAllStringFunc(output, func(match string) string {
|
||||
submatches := modelFieldRe.FindStringSubmatch(match)
|
||||
if len(submatches) < 2 {
|
||||
return match
|
||||
}
|
||||
model := submatches[1]
|
||||
mapped := claude.DenormalizeModelID(model)
|
||||
if mapped == model {
|
||||
return match
|
||||
}
|
||||
return strings.Replace(match, model, mapped, 1)
|
||||
})
|
||||
|
||||
for mapped, original := range toolNameMap {
|
||||
if mapped == "" || original == "" || mapped == original {
|
||||
continue
|
||||
}
|
||||
output = strings.ReplaceAll(output, "\""+mapped+"\":", "\""+original+"\":")
|
||||
output = strings.ReplaceAll(output, "\\\""+mapped+"\\\":", "\\\""+original+"\\\":")
|
||||
}
|
||||
|
||||
return output
|
||||
}
|
||||
|
||||
func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
|
||||
// 解析message_start获取input tokens(标准Claude API格式)
|
||||
var msgStart struct {
|
||||
@@ -4529,7 +4036,7 @@ func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string, toolNameMap map[string]string, mimicClaudeCode bool) (*ClaudeUsage, error) {
|
||||
func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*ClaudeUsage, error) {
|
||||
// 更新5h窗口状态
|
||||
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
|
||||
|
||||
@@ -4546,13 +4053,21 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
|
||||
return nil, fmt.Errorf("parse response: %w", err)
|
||||
}
|
||||
|
||||
// 兼容 Kimi cached_tokens → cache_read_input_tokens
|
||||
if response.Usage.CacheReadInputTokens == 0 {
|
||||
cachedTokens := gjson.GetBytes(body, "usage.cached_tokens").Int()
|
||||
if cachedTokens > 0 {
|
||||
response.Usage.CacheReadInputTokens = int(cachedTokens)
|
||||
if newBody, err := sjson.SetBytes(body, "usage.cache_read_input_tokens", cachedTokens); err == nil {
|
||||
body = newBody
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 如果有模型映射,替换响应中的model字段
|
||||
if originalModel != mappedModel {
|
||||
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
|
||||
}
|
||||
if mimicClaudeCode {
|
||||
body = s.replaceToolNamesInResponseBody(body, toolNameMap)
|
||||
}
|
||||
|
||||
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
|
||||
|
||||
@@ -4590,28 +4105,6 @@ func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toMo
|
||||
return newBody
|
||||
}
|
||||
|
||||
func (s *GatewayService) replaceToolNamesInResponseBody(body []byte, toolNameMap map[string]string) []byte {
|
||||
if len(body) == 0 {
|
||||
return body
|
||||
}
|
||||
var resp map[string]any
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
replaced := replaceToolNamesInText(string(body), toolNameMap)
|
||||
if replaced == string(body) {
|
||||
return body
|
||||
}
|
||||
return []byte(replaced)
|
||||
}
|
||||
if !rewriteToolNamesInValue(resp, toolNameMap) {
|
||||
return body
|
||||
}
|
||||
newBody, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
return newBody
|
||||
}
|
||||
|
||||
// RecordUsageInput 记录使用量的输入参数
|
||||
type RecordUsageInput struct {
|
||||
Result *ForwardResult
|
||||
@@ -4637,10 +4130,17 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
account := input.Account
|
||||
subscription := input.Subscription
|
||||
|
||||
// 获取费率倍数
|
||||
// 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认)
|
||||
multiplier := s.cfg.Default.RateMultiplier
|
||||
if apiKey.GroupID != nil && apiKey.Group != nil {
|
||||
multiplier = apiKey.Group.RateMultiplier
|
||||
|
||||
// 检查用户专属倍率
|
||||
if s.userGroupRateRepo != nil {
|
||||
if userRate, err := s.userGroupRateRepo.GetByUserAndGroup(ctx, user.ID, *apiKey.GroupID); err == nil && userRate != nil {
|
||||
multiplier = *userRate
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var cost *CostBreakdown
|
||||
@@ -4801,10 +4301,17 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
|
||||
account := input.Account
|
||||
subscription := input.Subscription
|
||||
|
||||
// 获取费率倍数
|
||||
// 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认)
|
||||
multiplier := s.cfg.Default.RateMultiplier
|
||||
if apiKey.GroupID != nil && apiKey.Group != nil {
|
||||
multiplier = apiKey.Group.RateMultiplier
|
||||
|
||||
// 检查用户专属倍率
|
||||
if s.userGroupRateRepo != nil {
|
||||
if userRate, err := s.userGroupRateRepo.GetByUserAndGroup(ctx, user.ID, *apiKey.GroupID); err == nil && userRate != nil {
|
||||
multiplier = *userRate
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var cost *CostBreakdown
|
||||
@@ -4958,7 +4465,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
|
||||
if shouldMimicClaudeCode {
|
||||
normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: true}
|
||||
body, reqModel, _ = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts)
|
||||
body, reqModel = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts)
|
||||
}
|
||||
|
||||
// Antigravity 账户不支持 count_tokens 转发,直接返回空值
|
||||
@@ -5317,3 +4824,21 @@ func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64,
|
||||
|
||||
return models
|
||||
}
|
||||
|
||||
// reconcileCachedTokens 兼容 Kimi 等上游:
|
||||
// 将 OpenAI 风格的 cached_tokens 映射到 Claude 标准的 cache_read_input_tokens
|
||||
func reconcileCachedTokens(usage map[string]any) bool {
|
||||
if usage == nil {
|
||||
return false
|
||||
}
|
||||
cacheRead, _ := usage["cache_read_input_tokens"].(float64)
|
||||
if cacheRead > 0 {
|
||||
return false // 已有标准字段,无需处理
|
||||
}
|
||||
cached, _ := usage["cached_tokens"].(float64)
|
||||
if cached <= 0 {
|
||||
return false
|
||||
}
|
||||
usage["cache_read_input_tokens"] = cached
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -864,7 +864,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
|
||||
}
|
||||
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
|
||||
upstreamReqID := resp.Header.Get(requestIDHeader)
|
||||
@@ -891,7 +891,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
|
||||
}
|
||||
upstreamReqID := resp.Header.Get(requestIDHeader)
|
||||
if upstreamReqID == "" {
|
||||
@@ -1301,7 +1301,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
|
||||
}
|
||||
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
|
||||
evBody := unwrapIfNeeded(isOAuth, respBody)
|
||||
@@ -1325,7 +1325,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: evBody}
|
||||
}
|
||||
|
||||
respBody = unwrapIfNeeded(isOAuth, respBody)
|
||||
|
||||
@@ -944,6 +944,32 @@ func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, pr
|
||||
return strings.TrimSpace(loadResp.CloudAICompanionProject), tierID, nil
|
||||
}
|
||||
|
||||
// 关键逻辑:对齐 Gemini CLI 对“已注册用户”的处理方式。
|
||||
// 当 LoadCodeAssist 返回了 currentTier / paidTier(表示账号已注册)但没有返回 cloudaicompanionProject 时:
|
||||
// - 不要再调用 onboardUser(通常不会再分配 project_id,且可能触发 INVALID_ARGUMENT)
|
||||
// - 先尝试从 Cloud Resource Manager 获取可用项目;仍失败则提示用户手动填写 project_id
|
||||
if loadResp != nil {
|
||||
registeredTierID := strings.TrimSpace(loadResp.GetTier())
|
||||
if registeredTierID != "" {
|
||||
// 已注册但未返回 cloudaicompanionProject,这在 Google One 用户中较常见:需要用户自行提供 project_id。
|
||||
log.Printf("[GeminiOAuth] User has tier (%s) but no cloudaicompanionProject, trying Cloud Resource Manager...", registeredTierID)
|
||||
|
||||
// Try to get project from Cloud Resource Manager
|
||||
fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL)
|
||||
if fbErr == nil && strings.TrimSpace(fallback) != "" {
|
||||
log.Printf("[GeminiOAuth] Found project from Cloud Resource Manager: %s", fallback)
|
||||
return strings.TrimSpace(fallback), tierID, nil
|
||||
}
|
||||
|
||||
// No project found - user must provide project_id manually
|
||||
log.Printf("[GeminiOAuth] No project found from Cloud Resource Manager, user must provide project_id manually")
|
||||
return "", tierID, fmt.Errorf("user is registered (tier: %s) but no project_id available. Please provide Project ID manually in the authorization form, or create a project at https://console.cloud.google.com", registeredTierID)
|
||||
}
|
||||
}
|
||||
|
||||
// 未检测到 currentTier/paidTier,视为新用户,继续调用 onboardUser
|
||||
log.Printf("[GeminiOAuth] No currentTier/paidTier found, proceeding with onboardUser (tierID: %s)", tierID)
|
||||
|
||||
req := &geminicli.OnboardUserRequest{
|
||||
TierID: tierID,
|
||||
Metadata: geminicli.LoadCodeAssistMetadata{
|
||||
|
||||
@@ -21,6 +21,17 @@ const (
|
||||
var codexCLIInstructions string
|
||||
|
||||
var codexModelMap = map[string]string{
|
||||
"gpt-5.3": "gpt-5.3",
|
||||
"gpt-5.3-none": "gpt-5.3",
|
||||
"gpt-5.3-low": "gpt-5.3",
|
||||
"gpt-5.3-medium": "gpt-5.3",
|
||||
"gpt-5.3-high": "gpt-5.3",
|
||||
"gpt-5.3-xhigh": "gpt-5.3",
|
||||
"gpt-5.3-codex": "gpt-5.3-codex",
|
||||
"gpt-5.3-codex-low": "gpt-5.3-codex",
|
||||
"gpt-5.3-codex-medium": "gpt-5.3-codex",
|
||||
"gpt-5.3-codex-high": "gpt-5.3-codex",
|
||||
"gpt-5.3-codex-xhigh": "gpt-5.3-codex",
|
||||
"gpt-5.1-codex": "gpt-5.1-codex",
|
||||
"gpt-5.1-codex-low": "gpt-5.1-codex",
|
||||
"gpt-5.1-codex-medium": "gpt-5.1-codex",
|
||||
@@ -156,6 +167,12 @@ func normalizeCodexModel(model string) string {
|
||||
if strings.Contains(normalized, "gpt-5.2") || strings.Contains(normalized, "gpt 5.2") {
|
||||
return "gpt-5.2"
|
||||
}
|
||||
if strings.Contains(normalized, "gpt-5.3-codex") || strings.Contains(normalized, "gpt 5.3 codex") {
|
||||
return "gpt-5.3-codex"
|
||||
}
|
||||
if strings.Contains(normalized, "gpt-5.3") || strings.Contains(normalized, "gpt 5.3") {
|
||||
return "gpt-5.3"
|
||||
}
|
||||
if strings.Contains(normalized, "gpt-5.1-codex-max") || strings.Contains(normalized, "gpt 5.1 codex max") {
|
||||
return "gpt-5.1-codex-max"
|
||||
}
|
||||
|
||||
@@ -176,6 +176,19 @@ func TestApplyCodexOAuthTransform_EmptyInput(t *testing.T) {
|
||||
require.Len(t, input, 0)
|
||||
}
|
||||
|
||||
func TestNormalizeCodexModel_Gpt53(t *testing.T) {
|
||||
cases := map[string]string{
|
||||
"gpt-5.3": "gpt-5.3",
|
||||
"gpt-5.3-codex": "gpt-5.3-codex",
|
||||
"gpt-5.3-codex-xhigh": "gpt-5.3-codex",
|
||||
"gpt 5.3 codex": "gpt-5.3-codex",
|
||||
}
|
||||
|
||||
for input, expected := range cases {
|
||||
require.Equal(t, expected, normalizeCodexModel(input))
|
||||
}
|
||||
}
|
||||
|
||||
func setupCodexCache(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
|
||||
@@ -846,10 +846,12 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
}
|
||||
}
|
||||
|
||||
// Remove prompt_cache_retention (not supported by upstream OpenAI API)
|
||||
if _, has := reqBody["prompt_cache_retention"]; has {
|
||||
delete(reqBody, "prompt_cache_retention")
|
||||
bodyModified = true
|
||||
// Remove unsupported fields (not supported by upstream OpenAI API)
|
||||
for _, unsupportedField := range []string{"prompt_cache_retention", "safety_identifier", "previous_response_id"} {
|
||||
if _, has := reqBody[unsupportedField]; has {
|
||||
delete(reqBody, unsupportedField)
|
||||
bodyModified = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -938,7 +940,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
})
|
||||
|
||||
s.handleFailoverSideEffects(ctx, resp, account)
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
|
||||
}
|
||||
return s.handleErrorResponse(ctx, resp, c, account)
|
||||
}
|
||||
@@ -1129,7 +1131,7 @@ func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *ht
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
if shouldDisable {
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: body}
|
||||
}
|
||||
|
||||
// Return appropriate error response
|
||||
|
||||
@@ -424,6 +424,16 @@ func isSensitiveKey(key string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// Whitelist: known non-sensitive fields that contain sensitive substrings
|
||||
// (e.g., "max_tokens" contains "token" but is just an API parameter).
|
||||
switch k {
|
||||
case "max_tokens", "max_completion_tokens", "max_output_tokens",
|
||||
"completion_tokens", "prompt_tokens", "total_tokens",
|
||||
"input_tokens", "output_tokens",
|
||||
"cache_creation_input_tokens", "cache_read_input_tokens":
|
||||
return false
|
||||
}
|
||||
|
||||
// Exact matches (common credential fields).
|
||||
switch k {
|
||||
case "authorization",
|
||||
|
||||
@@ -579,6 +579,7 @@ func (s *PricingService) extractBaseName(model string) string {
|
||||
func (s *PricingService) matchByModelFamily(model string) *LiteLLMModelPricing {
|
||||
// Claude模型系列匹配规则
|
||||
familyPatterns := map[string][]string{
|
||||
"opus-4.6": {"claude-opus-4.6", "claude-opus-4-6"},
|
||||
"opus-4.5": {"claude-opus-4.5", "claude-opus-4-5"},
|
||||
"opus-4": {"claude-opus-4", "claude-3-opus"},
|
||||
"sonnet-4.5": {"claude-sonnet-4.5", "claude-sonnet-4-5"},
|
||||
@@ -651,7 +652,8 @@ func (s *PricingService) matchByModelFamily(model string) *LiteLLMModelPricing {
|
||||
// 回退顺序:
|
||||
// 1. gpt-5.2-codex -> gpt-5.2(去掉后缀如 -codex, -mini, -max 等)
|
||||
// 2. gpt-5.2-20251222 -> gpt-5.2(去掉日期版本号)
|
||||
// 3. 最终回退到 DefaultTestModel (gpt-5.1-codex)
|
||||
// 3. gpt-5.3-codex -> gpt-5.2-codex
|
||||
// 4. 最终回退到 DefaultTestModel (gpt-5.1-codex)
|
||||
func (s *PricingService) matchOpenAIModel(model string) *LiteLLMModelPricing {
|
||||
// 尝试的回退变体
|
||||
variants := s.generateOpenAIModelVariants(model, openAIModelDatePattern)
|
||||
@@ -663,6 +665,13 @@ func (s *PricingService) matchOpenAIModel(model string) *LiteLLMModelPricing {
|
||||
}
|
||||
}
|
||||
|
||||
if strings.HasPrefix(model, "gpt-5.3-codex") {
|
||||
if pricing, ok := s.pricingData["gpt-5.2-codex"]; ok {
|
||||
log.Printf("[Pricing] OpenAI fallback matched %s -> %s", model, "gpt-5.2-codex")
|
||||
return pricing
|
||||
}
|
||||
}
|
||||
|
||||
// 最终回退到 DefaultTestModel
|
||||
defaultModel := strings.ToLower(openai.DefaultTestModel)
|
||||
if pricing, ok := s.pricingData[defaultModel]; ok {
|
||||
|
||||
@@ -16,6 +16,7 @@ var (
|
||||
type ProxyRepository interface {
|
||||
Create(ctx context.Context, proxy *Proxy) error
|
||||
GetByID(ctx context.Context, id int64) (*Proxy, error)
|
||||
ListByIDs(ctx context.Context, ids []int64) ([]Proxy, error)
|
||||
Update(ctx context.Context, proxy *Proxy) error
|
||||
Delete(ctx context.Context, id int64) error
|
||||
|
||||
|
||||
73
backend/internal/service/refresh_token_cache.go
Normal file
73
backend/internal/service/refresh_token_cache.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ErrRefreshTokenNotFound is returned when a refresh token is not found in cache.
|
||||
// This is used to abstract away the underlying cache implementation (e.g., redis.Nil).
|
||||
var ErrRefreshTokenNotFound = errors.New("refresh token not found")
|
||||
|
||||
// RefreshTokenData 存储在Redis中的Refresh Token数据
|
||||
type RefreshTokenData struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
TokenVersion int64 `json:"token_version"` // 用于检测密码更改后的Token失效
|
||||
FamilyID string `json:"family_id"` // Token家族ID,用于防重放攻击
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
}
|
||||
|
||||
// RefreshTokenCache 管理Refresh Token的Redis缓存
|
||||
// 用于JWT Token刷新机制,支持Token轮转和防重放攻击
|
||||
//
|
||||
// Key 格式:
|
||||
// - refresh_token:{token_hash} -> RefreshTokenData (JSON)
|
||||
// - user_refresh_tokens:{user_id} -> Set<token_hash>
|
||||
// - token_family:{family_id} -> Set<token_hash>
|
||||
type RefreshTokenCache interface {
|
||||
// StoreRefreshToken 存储Refresh Token
|
||||
// tokenHash: Token的SHA256哈希值(不存储原始Token)
|
||||
// data: Token关联的数据
|
||||
// ttl: Token过期时间
|
||||
StoreRefreshToken(ctx context.Context, tokenHash string, data *RefreshTokenData, ttl time.Duration) error
|
||||
|
||||
// GetRefreshToken 获取Refresh Token数据
|
||||
// 返回 (data, nil) 如果Token存在
|
||||
// 返回 (nil, ErrRefreshTokenNotFound) 如果Token不存在
|
||||
// 返回 (nil, err) 如果发生其他错误
|
||||
GetRefreshToken(ctx context.Context, tokenHash string) (*RefreshTokenData, error)
|
||||
|
||||
// DeleteRefreshToken 删除单个Refresh Token
|
||||
// 用于Token轮转时使旧Token失效
|
||||
DeleteRefreshToken(ctx context.Context, tokenHash string) error
|
||||
|
||||
// DeleteUserRefreshTokens 删除用户的所有Refresh Token
|
||||
// 用于密码更改或用户主动登出所有设备
|
||||
DeleteUserRefreshTokens(ctx context.Context, userID int64) error
|
||||
|
||||
// DeleteTokenFamily 删除整个Token家族
|
||||
// 用于检测到Token重放攻击时,撤销整个会话链
|
||||
DeleteTokenFamily(ctx context.Context, familyID string) error
|
||||
|
||||
// AddToUserTokenSet 将Token添加到用户的Token集合
|
||||
// 用于跟踪用户的所有活跃Refresh Token
|
||||
AddToUserTokenSet(ctx context.Context, userID int64, tokenHash string, ttl time.Duration) error
|
||||
|
||||
// AddToFamilyTokenSet 将Token添加到家族Token集合
|
||||
// 用于跟踪同一登录会话的所有Token
|
||||
AddToFamilyTokenSet(ctx context.Context, familyID string, tokenHash string, ttl time.Duration) error
|
||||
|
||||
// GetUserTokenHashes 获取用户的所有Token哈希
|
||||
// 用于批量删除用户Token
|
||||
GetUserTokenHashes(ctx context.Context, userID int64) ([]string, error)
|
||||
|
||||
// GetFamilyTokenHashes 获取家族的所有Token哈希
|
||||
// 用于批量删除家族Token
|
||||
GetFamilyTokenHashes(ctx context.Context, familyID string) ([]string, error)
|
||||
|
||||
// IsTokenInFamily 检查Token是否属于指定家族
|
||||
// 用于验证Token家族关系
|
||||
IsTokenInFamily(ctx context.Context, familyID string, tokenHash string) (bool, error)
|
||||
}
|
||||
@@ -288,6 +288,15 @@ func (s *UsageService) GetUserDashboardStats(ctx context.Context, userID int64)
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
// GetAPIKeyDashboardStats returns dashboard summary stats filtered by API Key.
|
||||
func (s *UsageService) GetAPIKeyDashboardStats(ctx context.Context, apiKeyID int64) (*usagestats.UserDashboardStats, error) {
|
||||
stats, err := s.usageRepo.GetAPIKeyDashboardStats(ctx, apiKeyID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get api key dashboard stats: %w", err)
|
||||
}
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
// GetUserUsageTrendByUserID returns per-user usage trend.
|
||||
func (s *UsageService) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) ([]usagestats.TrendDataPoint, error) {
|
||||
trend, err := s.usageRepo.GetUserUsageTrendByUserID(ctx, userID, startTime, endTime, granularity)
|
||||
|
||||
@@ -21,6 +21,10 @@ type User struct {
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
|
||||
// GroupRates 用户专属分组倍率配置
|
||||
// map[groupID]rateMultiplier
|
||||
GroupRates map[int64]float64
|
||||
|
||||
// TOTP 双因素认证字段
|
||||
TotpSecretEncrypted *string // AES-256-GCM 加密的 TOTP 密钥
|
||||
TotpEnabled bool // 是否启用 TOTP
|
||||
@@ -40,18 +44,20 @@ func (u *User) IsActive() bool {
|
||||
|
||||
// CanBindGroup checks whether a user can bind to a given group.
|
||||
// For standard groups:
|
||||
// - If AllowedGroups is non-empty, only allow binding to IDs in that list.
|
||||
// - If AllowedGroups is empty (nil or length 0), allow binding to any non-exclusive group.
|
||||
// - Public groups (non-exclusive): all users can bind
|
||||
// - Exclusive groups: only users with the group in AllowedGroups can bind
|
||||
func (u *User) CanBindGroup(groupID int64, isExclusive bool) bool {
|
||||
if len(u.AllowedGroups) > 0 {
|
||||
for _, id := range u.AllowedGroups {
|
||||
if id == groupID {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
// 公开分组(非专属):所有用户都可以绑定
|
||||
if !isExclusive {
|
||||
return true
|
||||
}
|
||||
return !isExclusive
|
||||
// 专属分组:需要在 AllowedGroups 中
|
||||
for _, id := range u.AllowedGroups {
|
||||
if id == groupID {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (u *User) SetPassword(password string) error {
|
||||
|
||||
25
backend/internal/service/user_group_rate.go
Normal file
25
backend/internal/service/user_group_rate.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package service
|
||||
|
||||
import "context"
|
||||
|
||||
// UserGroupRateRepository 用户专属分组倍率仓储接口
|
||||
// 允许管理员为特定用户设置分组的专属计费倍率,覆盖分组默认倍率
|
||||
type UserGroupRateRepository interface {
|
||||
// GetByUserID 获取用户的所有专属分组倍率
|
||||
// 返回 map[groupID]rateMultiplier
|
||||
GetByUserID(ctx context.Context, userID int64) (map[int64]float64, error)
|
||||
|
||||
// GetByUserAndGroup 获取用户在特定分组的专属倍率
|
||||
// 如果未设置专属倍率,返回 nil
|
||||
GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error)
|
||||
|
||||
// SyncUserGroupRates 同步用户的分组专属倍率
|
||||
// rates: map[groupID]*rateMultiplier,nil 表示删除该分组的专属倍率
|
||||
SyncUserGroupRates(ctx context.Context, userID int64, rates map[int64]*float64) error
|
||||
|
||||
// DeleteByGroupID 删除指定分组的所有用户专属倍率(分组删除时调用)
|
||||
DeleteByGroupID(ctx context.Context, groupID int64) error
|
||||
|
||||
// DeleteByUserID 删除指定用户的所有专属倍率(用户删除时调用)
|
||||
DeleteByUserID(ctx context.Context, userID int64) error
|
||||
}
|
||||
@@ -274,4 +274,5 @@ var ProviderSet = wire.NewSet(
|
||||
NewUserAttributeService,
|
||||
NewUsageCache,
|
||||
NewTotpService,
|
||||
NewErrorPassthroughService,
|
||||
)
|
||||
|
||||
19
backend/migrations/047_add_user_group_rate_multipliers.sql
Normal file
19
backend/migrations/047_add_user_group_rate_multipliers.sql
Normal file
@@ -0,0 +1,19 @@
|
||||
-- 用户专属分组倍率表
|
||||
-- 允许管理员为特定用户设置分组的专属计费倍率,覆盖分组默认倍率
|
||||
CREATE TABLE IF NOT EXISTS user_group_rate_multipliers (
|
||||
user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
group_id BIGINT NOT NULL REFERENCES groups(id) ON DELETE CASCADE,
|
||||
rate_multiplier DECIMAL(10,4) NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
PRIMARY KEY (user_id, group_id)
|
||||
);
|
||||
|
||||
-- 按 group_id 查询索引(删除分组时清理关联记录)
|
||||
CREATE INDEX IF NOT EXISTS idx_user_group_rate_multipliers_group_id
|
||||
ON user_group_rate_multipliers(group_id);
|
||||
|
||||
COMMENT ON TABLE user_group_rate_multipliers IS '用户专属分组倍率配置';
|
||||
COMMENT ON COLUMN user_group_rate_multipliers.user_id IS '用户ID';
|
||||
COMMENT ON COLUMN user_group_rate_multipliers.group_id IS '分组ID';
|
||||
COMMENT ON COLUMN user_group_rate_multipliers.rate_multiplier IS '专属计费倍率(覆盖分组默认倍率)';
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user