This commit is contained in:
yangjianbo
2026-02-06 06:56:23 +08:00
94 changed files with 10861 additions and 858 deletions

View File

@@ -1,368 +0,0 @@
# Linux DO Connect
OAuthOpen Authorization是一个开放的网络授权标准目前最新版本为 OAuth 2.0。我们日常使用的第三方登录(如 Google 账号登录就采用了该标准。OAuth 允许用户授权第三方应用访问存储在其他服务提供商(如 Google上的信息无需在不同平台上重复填写注册信息。用户授权后平台可以直接访问用户的账户信息进行身份验证而用户无需向第三方应用提供密码。
目前系统已实现完整的 OAuth2 授权码code方式鉴权但界面等配套功能还在持续完善中。让我们一起打造一个更完善的共享方案。
## 基本介绍
这是一套标准的 OAuth2 鉴权系统,可以让开发者共享论坛的用户基本信息。
- 可获取字段:
| 参数 | 说明 |
| ----------------- | ------------------------------- |
| `id` | 用户唯一标识(不可变) |
| `username` | 论坛用户名 |
| `name` | 论坛用户昵称(可变) |
| `avatar_template` | 用户头像模板URL支持多种尺寸 |
| `active` | 账号活跃状态 |
| `trust_level` | 信任等级0-4 |
| `silenced` | 禁言状态 |
| `external_ids` | 外部ID关联信息 |
| `api_key` | API访问密钥 |
通过这些信息,公益网站/接口可以实现:
1. 基于 `id` 的服务频率限制
2. 基于 `trust_level` 的服务额度分配
3. 基于用户信息的滥用举报机制
## 相关端点
- Authorize 端点: `https://connect.linux.do/oauth2/authorize`
- Token 端点:`https://connect.linux.do/oauth2/token`
- 用户信息 端点:`https://connect.linux.do/api/user`
## 申请使用
- 访问 [Connect.Linux.Do](https://connect.linux.do/) 申请接入你的应用。
![linuxdoconnect_1](https://wiki.linux.do/_next/image?url=%2Flinuxdoconnect_1.png&w=1080&q=75)
- 点击 **`我的应用接入`** - **`申请新接入`**,填写相关信息。其中 **`回调地址`** 是你的应用接收用户信息的地址。
![linuxdoconnect_2](https://wiki.linux.do/_next/image?url=%2Flinuxdoconnect_2.png&w=1080&q=75)
- 申请成功后,你将获得 **`Client Id`** 和 **`Client Secret`**,这是你应用的唯一身份凭证。
![linuxdoconnect_3](https://wiki.linux.do/_next/image?url=%2Flinuxdoconnect_3.png&w=1080&q=75)
## 接入 Linux Do
JavaScript
```JavaScript
// 安装第三方请求库(或使用原生的 Fetch API本例中使用 axios
// npm install axios
// 通过 OAuth2 获取 Linux Do 用户信息的参考流程
const axios = require('axios');
const readline = require('readline');
// 配置信息(建议通过环境变量配置,避免使用硬编码)
const CLIENT_ID = '你的 Client ID';
const CLIENT_SECRET = '你的 Client Secret';
const REDIRECT_URI = '你的回调地址';
const AUTH_URL = 'https://connect.linux.do/oauth2/authorize';
const TOKEN_URL = 'https://connect.linux.do/oauth2/token';
const USER_INFO_URL = 'https://connect.linux.do/api/user';
// 第一步:生成授权 URL
function getAuthUrl() {
const params = new URLSearchParams({
client_id: CLIENT_ID,
redirect_uri: REDIRECT_URI,
response_type: 'code',
scope: 'user'
});
return `${AUTH_URL}?${params.toString()}`;
}
// 第二步:获取 code 参数
function getCode() {
return new Promise((resolve) => {
// 本例中使用终端输入来模拟流程,仅供本地测试
// 请在实际应用中替换为真实的处理逻辑
const rl = readline.createInterface({ input: process.stdin, output: process.stdout });
rl.question('从回调 URL 中提取出 code粘贴到此处并按回车', (answer) => {
rl.close();
resolve(answer.trim());
});
});
}
// 第三步:使用 code 参数获取访问令牌
async function getAccessToken(code) {
try {
const form = new URLSearchParams({
client_id: CLIENT_ID,
client_secret: CLIENT_SECRET,
code: code,
redirect_uri: REDIRECT_URI,
grant_type: 'authorization_code'
}).toString();
const response = await axios.post(TOKEN_URL, form, {
// 提醒:需正确配置请求头,否则无法正常获取访问令牌
headers: {
'Content-Type': 'application/x-www-form-urlencoded',
'Accept': 'application/json'
}
});
return response.data;
} catch (error) {
console.error(`获取访问令牌失败:${error.response ? JSON.stringify(error.response.data) : error.message}`);
throw error;
}
}
// 第四步:使用访问令牌获取用户信息
async function getUserInfo(accessToken) {
try {
const response = await axios.get(USER_INFO_URL, {
headers: {
Authorization: `Bearer ${accessToken}`
}
});
return response.data;
} catch (error) {
console.error(`获取用户信息失败:${error.response ? JSON.stringify(error.response.data) : error.message}`);
throw error;
}
}
// 主流程
async function main() {
// 1. 生成授权 URL前端引导用户访问授权页
const authUrl = getAuthUrl();
console.log(`请访问此 URL 授权:${authUrl}
`);
// 2. 用户授权后,从回调 URL 获取 code 参数
const code = await getCode();
try {
// 3. 使用 code 参数获取访问令牌
const tokenData = await getAccessToken(code);
const accessToken = tokenData.access_token;
// 4. 使用访问令牌获取用户信息
if (accessToken) {
const userInfo = await getUserInfo(accessToken);
console.log(`
获取用户信息成功:${JSON.stringify(userInfo, null, 2)}`);
} else {
console.log(`
获取访问令牌失败:${JSON.stringify(tokenData)}`);
}
} catch (error) {
console.error('发生错误:', error);
}
}
```
Python
```python
# 安装第三方请求库,本例中使用 requests
# pip install requests
# 通过 OAuth2 获取 Linux Do 用户信息的参考流程
import requests
import json
# 配置信息(建议通过环境变量配置,避免使用硬编码)
CLIENT_ID = '你的 Client ID'
CLIENT_SECRET = '你的 Client Secret'
REDIRECT_URI = '你的回调地址'
AUTH_URL = 'https://connect.linux.do/oauth2/authorize'
TOKEN_URL = 'https://connect.linux.do/oauth2/token'
USER_INFO_URL = 'https://connect.linux.do/api/user'
# 第一步:生成授权 URL
def get_auth_url():
params = {
'client_id': CLIENT_ID,
'redirect_uri': REDIRECT_URI,
'response_type': 'code',
'scope': 'user'
}
auth_url = f"{AUTH_URL}?{'&'.join(f'{k}={v}' for k, v in params.items())}"
return auth_url
# 第二步:获取 code 参数
def get_code():
# 本例中使用终端输入来模拟流程,仅供本地测试
# 请在实际应用中替换为真实的处理逻辑
return input('从回调 URL 中提取出 code粘贴到此处并按回车').strip()
# 第三步:使用 code 参数获取访问令牌
def get_access_token(code):
try:
data = {
'client_id': CLIENT_ID,
'client_secret': CLIENT_SECRET,
'code': code,
'redirect_uri': REDIRECT_URI,
'grant_type': 'authorization_code'
}
# 提醒:需正确配置请求头,否则无法正常获取访问令牌
headers = {
'Content-Type': 'application/x-www-form-urlencoded',
'Accept': 'application/json'
}
response = requests.post(TOKEN_URL, data=data, headers=headers)
response.raise_for_status()
return response.json()
except requests.exceptions.RequestException as e:
print(f"获取访问令牌失败:{e}")
return None
# 第四步:使用访问令牌获取用户信息
def get_user_info(access_token):
try:
headers = {
'Authorization': f'Bearer {access_token}'
}
response = requests.get(USER_INFO_URL, headers=headers)
response.raise_for_status()
return response.json()
except requests.exceptions.RequestException as e:
print(f"获取用户信息失败:{e}")
return None
# 主流程
if __name__ == '__main__':
# 1. 生成授权 URL前端引导用户访问授权页
auth_url = get_auth_url()
print(f'请访问此 URL 授权:{auth_url}
')
# 2. 用户授权后,从回调 URL 获取 code 参数
code = get_code()
# 3. 使用 code 参数获取访问令牌
token_data = get_access_token(code)
if token_data:
access_token = token_data.get('access_token')
# 4. 使用访问令牌获取用户信息
if access_token:
user_info = get_user_info(access_token)
if user_info:
print(f"
获取用户信息成功{json.dumps(user_info, indent=2)}")
else:
print("
获取用户信息失败")
else:
print(f"
获取访问令牌失败{json.dumps(token_data, indent=2)}")
else:
print("
获取访问令牌失败")
```
PHP
```php
// 通过 OAuth2 获取 Linux Do 用户信息的参考流程
// 配置信息
$CLIENT_ID = '你的 Client ID';
$CLIENT_SECRET = '你的 Client Secret';
$REDIRECT_URI = '你的回调地址';
$AUTH_URL = 'https://connect.linux.do/oauth2/authorize';
$TOKEN_URL = 'https://connect.linux.do/oauth2/token';
$USER_INFO_URL = 'https://connect.linux.do/api/user';
// 生成授权 URL
function getAuthUrl($clientId, $redirectUri) {
global $AUTH_URL;
return $AUTH_URL . '?' . http_build_query([
'client_id' => $clientId,
'redirect_uri' => $redirectUri,
'response_type' => 'code',
'scope' => 'user'
]);
}
// 使用 code 参数获取用户信息(合并获取令牌和获取用户信息的步骤)
function getUserInfoWithCode($code, $clientId, $clientSecret, $redirectUri) {
global $TOKEN_URL, $USER_INFO_URL;
// 1. 获取访问令牌
$ch = curl_init($TOKEN_URL);
curl_setopt($ch, CURLOPT_RETURNTRANSFER, true);
curl_setopt($ch, CURLOPT_POST, true);
curl_setopt($ch, CURLOPT_POSTFIELDS, http_build_query([
'client_id' => $clientId,
'client_secret' => $clientSecret,
'code' => $code,
'redirect_uri' => $redirectUri,
'grant_type' => 'authorization_code'
]));
curl_setopt($ch, CURLOPT_HTTPHEADER, [
'Content-Type: application/x-www-form-urlencoded',
'Accept: application/json'
]);
$tokenResponse = curl_exec($ch);
curl_close($ch);
$tokenData = json_decode($tokenResponse, true);
if (!isset($tokenData['access_token'])) {
return ['error' => '获取访问令牌失败', 'details' => $tokenData];
}
// 2. 获取用户信息
$ch = curl_init($USER_INFO_URL);
curl_setopt($ch, CURLOPT_RETURNTRANSFER, true);
curl_setopt($ch, CURLOPT_HTTPHEADER, [
'Authorization: Bearer ' . $tokenData['access_token']
]);
$userResponse = curl_exec($ch);
curl_close($ch);
return json_decode($userResponse, true);
}
// 主流程
// 1. 生成授权 URL
$authUrl = getAuthUrl($CLIENT_ID, $REDIRECT_URI);
echo "<a href='$authUrl'>使用 Linux Do 登录</a>";
// 2. 处理回调并获取用户信息
if (isset($_GET['code'])) {
$userInfo = getUserInfoWithCode(
$_GET['code'],
$CLIENT_ID,
$CLIENT_SECRET,
$REDIRECT_URI
);
if (isset($userInfo['error'])) {
echo '错误: ' . $userInfo['error'];
} else {
echo '欢迎, ' . $userInfo['name'] . '!';
// 处理用户登录逻辑...
}
}
```
## 使用说明
### 授权流程
1. 用户点击应用中的’使用 Linux Do 登录’按钮
2. 系统将用户重定向至 Linux Do 的授权页面
3. 用户完成授权后,系统自动重定向回应用并携带授权码
4. 应用使用授权码获取访问令牌
5. 使用访问令牌获取用户信息
### 安全建议
- 切勿在前端代码中暴露 Client Secret
- 对所有用户输入数据进行严格验证
- 确保使用 HTTPS 协议传输数据
- 定期更新并妥善保管 Client Secret

View File

@@ -1,164 +0,0 @@
## 概述
全面增强运维监控系统Ops的错误日志管理和告警静默功能优化前端 UI 组件代码质量和用户体验。本次更新重构了核心服务层和数据访问层,提升系统可维护性和运维效率。
## 主要改动
### 1. 错误日志查询优化
**功能特性:**
- 新增 GetErrorLogByID 接口,支持按 ID 精确查询错误详情
- 优化错误日志过滤逻辑,支持多维度筛选(平台、阶段、来源、所有者等)
- 改进查询参数处理,简化代码结构
- 增强错误分类和标准化处理
- 支持错误解决状态追踪resolved 字段)
**技术实现:**
- `ops_handler.go` - 新增单条错误日志查询接口
- `ops_repo.go` - 优化数据查询和过滤条件构建
- `ops_models.go` - 扩展错误日志数据模型
- 前端 API 接口同步更新
### 2. 告警静默功能
**功能特性:**
- 支持按规则、平台、分组、区域等维度静默告警
- 可设置静默时长和原因说明
- 静默记录可追溯,记录创建人和创建时间
- 自动过期机制,避免永久静默
**技术实现:**
- `037_ops_alert_silences.sql` - 新增告警静默表
- `ops_alerts.go` - 告警静默逻辑实现
- `ops_alerts_handler.go` - 告警静默 API 接口
- `OpsAlertEventsCard.vue` - 前端告警静默操作界面
**数据库结构:**
| 字段 | 类型 | 说明 |
|------|------|------|
| rule_id | BIGINT | 告警规则 ID |
| platform | VARCHAR(64) | 平台标识 |
| group_id | BIGINT | 分组 ID可选 |
| region | VARCHAR(64) | 区域(可选) |
| until | TIMESTAMPTZ | 静默截止时间 |
| reason | TEXT | 静默原因 |
| created_by | BIGINT | 创建人 ID |
### 3. 错误分类标准化
**功能特性:**
- 统一错误阶段分类request|auth|routing|upstream|network|internal
- 规范错误归属分类client|provider|platform
- 标准化错误来源分类client_request|upstream_http|gateway
- 自动迁移历史数据到新分类体系
**技术实现:**
- `038_ops_errors_resolution_retry_results_and_standardize_classification.sql` - 分类标准化迁移
- 自动映射历史遗留分类到新标准
- 自动解决已恢复的上游错误(客户端状态码 < 400
### 4. Gateway 服务集成
**功能特性:**
- 完善各 Gateway 服务的 Ops 集成
- 统一错误日志记录接口
- 增强上游错误追踪能力
**涉及服务:**
- `antigravity_gateway_service.go` - Antigravity 网关集成
- `gateway_service.go` - 通用网关集成
- `gemini_messages_compat_service.go` - Gemini 兼容层集成
- `openai_gateway_service.go` - OpenAI 网关集成
### 5. 前端 UI 优化
**代码重构:**
- 大幅简化错误详情模态框代码(从 828 行优化到 450 行)
- 优化错误日志表格组件,提升可读性
- 清理未使用的 i18n 翻译,减少冗余
- 统一组件代码风格和格式
- 优化骨架屏组件,更好匹配实际看板布局
**布局改进:**
- 修复模态框内容溢出和滚动问题
- 优化表格布局,使用 flex 布局确保正确显示
- 改进看板头部布局和交互
- 提升响应式体验
- 骨架屏支持全屏模式适配
**交互优化:**
- 优化告警事件卡片功能和展示
- 改进错误详情展示逻辑
- 增强请求详情模态框
- 完善运行时设置卡片
- 改进加载动画效果
### 6. 国际化完善
**文案补充:**
- 补充错误日志相关的英文翻译
- 添加告警静默功能的中英文文案
- 完善提示文本和错误信息
- 统一术语翻译标准
## 文件变更
**后端26 个文件):**
- `backend/internal/handler/admin/ops_alerts_handler.go` - 告警接口增强
- `backend/internal/handler/admin/ops_handler.go` - 错误日志接口优化
- `backend/internal/handler/ops_error_logger.go` - 错误记录器增强
- `backend/internal/repository/ops_repo.go` - 数据访问层重构
- `backend/internal/repository/ops_repo_alerts.go` - 告警数据访问增强
- `backend/internal/service/ops_*.go` - 核心服务层重构10 个文件)
- `backend/internal/service/*_gateway_service.go` - Gateway 集成4 个文件)
- `backend/internal/server/routes/admin.go` - 路由配置更新
- `backend/migrations/*.sql` - 数据库迁移2 个文件)
- 测试文件更新5 个文件)
**前端13 个文件):**
- `frontend/src/views/admin/ops/OpsDashboard.vue` - 看板主页优化
- `frontend/src/views/admin/ops/components/*.vue` - 组件重构10 个文件)
- `frontend/src/api/admin/ops.ts` - API 接口扩展
- `frontend/src/i18n/locales/*.ts` - 国际化文本2 个文件)
## 代码统计
- 44 个文件修改
- 3733 行新增
- 995 行删除
- 净增加 2738 行
## 核心改进
**可维护性提升:**
- 重构核心服务层,职责更清晰
- 简化前端组件代码,降低复杂度
- 统一代码风格和命名规范
- 清理冗余代码和未使用的翻译
- 标准化错误分类体系
**功能完善:**
- 告警静默功能,减少告警噪音
- 错误日志查询优化,提升运维效率
- Gateway 服务集成完善,统一监控能力
- 错误解决状态追踪,便于问题管理
**用户体验优化:**
- 修复多个 UI 布局问题
- 优化交互流程
- 完善国际化支持
- 提升响应式体验
- 改进加载状态展示
## 测试验证
- ✅ 错误日志查询和过滤功能
- ✅ 告警静默创建和自动过期
- ✅ 错误分类标准化迁移
- ✅ Gateway 服务错误日志记录
- ✅ 前端组件布局和交互
- ✅ 骨架屏全屏模式适配
- ✅ 国际化文本完整性
- ✅ API 接口功能正确性
- ✅ 数据库迁移执行成功

View File

@@ -33,7 +33,7 @@ func main() {
}() }()
userRepo := repository.NewUserRepository(client, sqlDB) userRepo := repository.NewUserRepository(client, sqlDB)
authService := service.NewAuthService(userRepo, nil, cfg, nil, nil, nil, nil, nil) authService := service.NewAuthService(userRepo, nil, nil, cfg, nil, nil, nil, nil, nil)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()

View File

@@ -44,9 +44,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
} }
userRepository := repository.NewUserRepository(client, db) userRepository := repository.NewUserRepository(client, db)
redeemCodeRepository := repository.NewRedeemCodeRepository(client) redeemCodeRepository := repository.NewRedeemCodeRepository(client)
redisClient := repository.ProvideRedis(configConfig)
refreshTokenCache := repository.NewRefreshTokenCache(redisClient)
settingRepository := repository.NewSettingRepository(client) settingRepository := repository.NewSettingRepository(client)
settingService := service.NewSettingService(settingRepository, configConfig) settingService := service.NewSettingService(settingRepository, configConfig)
redisClient := repository.ProvideRedis(configConfig)
emailCache := repository.NewEmailCache(redisClient) emailCache := repository.NewEmailCache(redisClient)
emailService := service.NewEmailService(settingRepository, emailCache) emailService := service.NewEmailService(settingRepository, emailCache)
turnstileVerifier := repository.NewTurnstileVerifier() turnstileVerifier := repository.NewTurnstileVerifier()
@@ -58,11 +59,12 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository, configConfig) billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository, configConfig)
apiKeyRepository := repository.NewAPIKeyRepository(client) apiKeyRepository := repository.NewAPIKeyRepository(client)
groupRepository := repository.NewGroupRepository(client, db) groupRepository := repository.NewGroupRepository(client, db)
userGroupRateRepository := repository.NewUserGroupRateRepository(db)
apiKeyCache := repository.NewAPIKeyCache(redisClient) apiKeyCache := repository.NewAPIKeyCache(redisClient)
apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, apiKeyCache, configConfig) apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, userGroupRateRepository, apiKeyCache, configConfig)
apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService) apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService)
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator) promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator)
authService := service.NewAuthService(userRepository, redeemCodeRepository, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService) authService := service.NewAuthService(userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService)
userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator) userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator)
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService) subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService)
redeemCache := repository.NewRedeemCache(redisClient) redeemCache := repository.NewRedeemCache(redisClient)
@@ -100,7 +102,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
proxyRepository := repository.NewProxyRepository(client, db) proxyRepository := repository.NewProxyRepository(client, db)
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig) proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient) proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator) adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator)
adminUserHandler := admin.NewUserHandler(adminService) adminUserHandler := admin.NewUserHandler(adminService)
groupHandler := admin.NewGroupHandler(adminService) groupHandler := admin.NewGroupHandler(adminService)
claudeOAuthClient := repository.NewClaudeOAuthClient() claudeOAuthClient := repository.NewClaudeOAuthClient()
@@ -153,7 +155,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
identityService := service.NewIdentityService(identityCache) identityService := service.NewIdentityService(identityCache)
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService) deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService) claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService)
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache) gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache)
openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService) openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider) openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig) geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
@@ -173,9 +175,13 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
userAttributeValueRepository := repository.NewUserAttributeValueRepository(client) userAttributeValueRepository := repository.NewUserAttributeValueRepository(client)
userAttributeService := service.NewUserAttributeService(userAttributeDefinitionRepository, userAttributeValueRepository) userAttributeService := service.NewUserAttributeService(userAttributeDefinitionRepository, userAttributeValueRepository)
userAttributeHandler := admin.NewUserAttributeHandler(userAttributeService) userAttributeHandler := admin.NewUserAttributeHandler(userAttributeService)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler) errorPassthroughRepository := repository.NewErrorPassthroughRepository(client)
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, configConfig) errorPassthroughCache := repository.NewErrorPassthroughCache(redisClient)
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, configConfig) errorPassthroughService := service.NewErrorPassthroughService(errorPassthroughRepository, errorPassthroughCache)
errorPassthroughHandler := admin.NewErrorPassthroughHandler(errorPassthroughService)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler)
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, errorPassthroughService, configConfig)
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, errorPassthroughService, configConfig)
soraDirectClient := service.NewSoraDirectClient(configConfig, httpUpstream, openAITokenProvider) soraDirectClient := service.NewSoraDirectClient(configConfig, httpUpstream, openAITokenProvider)
soraMediaStorage := service.ProvideSoraMediaStorage(configConfig) soraMediaStorage := service.ProvideSoraMediaStorage(configConfig)
soraGatewayService := service.NewSoraGatewayService(soraDirectClient, soraMediaStorage, rateLimitService, configConfig) soraGatewayService := service.NewSoraGatewayService(soraDirectClient, soraMediaStorage, rateLimitService, configConfig)

View File

@@ -20,6 +20,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/announcement" "github.com/Wei-Shaw/sub2api/ent/announcement"
"github.com/Wei-Shaw/sub2api/ent/announcementread" "github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey" "github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
"github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/promocode" "github.com/Wei-Shaw/sub2api/ent/promocode"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage" "github.com/Wei-Shaw/sub2api/ent/promocodeusage"
@@ -52,6 +53,8 @@ type Client struct {
Announcement *AnnouncementClient Announcement *AnnouncementClient
// AnnouncementRead is the client for interacting with the AnnouncementRead builders. // AnnouncementRead is the client for interacting with the AnnouncementRead builders.
AnnouncementRead *AnnouncementReadClient AnnouncementRead *AnnouncementReadClient
// ErrorPassthroughRule is the client for interacting with the ErrorPassthroughRule builders.
ErrorPassthroughRule *ErrorPassthroughRuleClient
// Group is the client for interacting with the Group builders. // Group is the client for interacting with the Group builders.
Group *GroupClient Group *GroupClient
// PromoCode is the client for interacting with the PromoCode builders. // PromoCode is the client for interacting with the PromoCode builders.
@@ -94,6 +97,7 @@ func (c *Client) init() {
c.AccountGroup = NewAccountGroupClient(c.config) c.AccountGroup = NewAccountGroupClient(c.config)
c.Announcement = NewAnnouncementClient(c.config) c.Announcement = NewAnnouncementClient(c.config)
c.AnnouncementRead = NewAnnouncementReadClient(c.config) c.AnnouncementRead = NewAnnouncementReadClient(c.config)
c.ErrorPassthroughRule = NewErrorPassthroughRuleClient(c.config)
c.Group = NewGroupClient(c.config) c.Group = NewGroupClient(c.config)
c.PromoCode = NewPromoCodeClient(c.config) c.PromoCode = NewPromoCodeClient(c.config)
c.PromoCodeUsage = NewPromoCodeUsageClient(c.config) c.PromoCodeUsage = NewPromoCodeUsageClient(c.config)
@@ -204,6 +208,7 @@ func (c *Client) Tx(ctx context.Context) (*Tx, error) {
AccountGroup: NewAccountGroupClient(cfg), AccountGroup: NewAccountGroupClient(cfg),
Announcement: NewAnnouncementClient(cfg), Announcement: NewAnnouncementClient(cfg),
AnnouncementRead: NewAnnouncementReadClient(cfg), AnnouncementRead: NewAnnouncementReadClient(cfg),
ErrorPassthroughRule: NewErrorPassthroughRuleClient(cfg),
Group: NewGroupClient(cfg), Group: NewGroupClient(cfg),
PromoCode: NewPromoCodeClient(cfg), PromoCode: NewPromoCodeClient(cfg),
PromoCodeUsage: NewPromoCodeUsageClient(cfg), PromoCodeUsage: NewPromoCodeUsageClient(cfg),
@@ -241,6 +246,7 @@ func (c *Client) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error)
AccountGroup: NewAccountGroupClient(cfg), AccountGroup: NewAccountGroupClient(cfg),
Announcement: NewAnnouncementClient(cfg), Announcement: NewAnnouncementClient(cfg),
AnnouncementRead: NewAnnouncementReadClient(cfg), AnnouncementRead: NewAnnouncementReadClient(cfg),
ErrorPassthroughRule: NewErrorPassthroughRuleClient(cfg),
Group: NewGroupClient(cfg), Group: NewGroupClient(cfg),
PromoCode: NewPromoCodeClient(cfg), PromoCode: NewPromoCodeClient(cfg),
PromoCodeUsage: NewPromoCodeUsageClient(cfg), PromoCodeUsage: NewPromoCodeUsageClient(cfg),
@@ -284,9 +290,10 @@ func (c *Client) Close() error {
func (c *Client) Use(hooks ...Hook) { func (c *Client) Use(hooks ...Hook) {
for _, n := range []interface{ Use(...Hook) }{ for _, n := range []interface{ Use(...Hook) }{
c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead, c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead,
c.Group, c.PromoCode, c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.Setting, c.ErrorPassthroughRule, c.Group, c.PromoCode, c.PromoCodeUsage, c.Proxy,
c.UsageCleanupTask, c.UsageLog, c.User, c.UserAllowedGroup, c.RedeemCode, c.Setting, c.UsageCleanupTask, c.UsageLog, c.User,
c.UserAttributeDefinition, c.UserAttributeValue, c.UserSubscription, c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue,
c.UserSubscription,
} { } {
n.Use(hooks...) n.Use(hooks...)
} }
@@ -297,9 +304,10 @@ func (c *Client) Use(hooks ...Hook) {
func (c *Client) Intercept(interceptors ...Interceptor) { func (c *Client) Intercept(interceptors ...Interceptor) {
for _, n := range []interface{ Intercept(...Interceptor) }{ for _, n := range []interface{ Intercept(...Interceptor) }{
c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead, c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead,
c.Group, c.PromoCode, c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.Setting, c.ErrorPassthroughRule, c.Group, c.PromoCode, c.PromoCodeUsage, c.Proxy,
c.UsageCleanupTask, c.UsageLog, c.User, c.UserAllowedGroup, c.RedeemCode, c.Setting, c.UsageCleanupTask, c.UsageLog, c.User,
c.UserAttributeDefinition, c.UserAttributeValue, c.UserSubscription, c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue,
c.UserSubscription,
} { } {
n.Intercept(interceptors...) n.Intercept(interceptors...)
} }
@@ -318,6 +326,8 @@ func (c *Client) Mutate(ctx context.Context, m Mutation) (Value, error) {
return c.Announcement.mutate(ctx, m) return c.Announcement.mutate(ctx, m)
case *AnnouncementReadMutation: case *AnnouncementReadMutation:
return c.AnnouncementRead.mutate(ctx, m) return c.AnnouncementRead.mutate(ctx, m)
case *ErrorPassthroughRuleMutation:
return c.ErrorPassthroughRule.mutate(ctx, m)
case *GroupMutation: case *GroupMutation:
return c.Group.mutate(ctx, m) return c.Group.mutate(ctx, m)
case *PromoCodeMutation: case *PromoCodeMutation:
@@ -1161,6 +1171,139 @@ func (c *AnnouncementReadClient) mutate(ctx context.Context, m *AnnouncementRead
} }
} }
// ErrorPassthroughRuleClient is a client for the ErrorPassthroughRule schema.
type ErrorPassthroughRuleClient struct {
config
}
// NewErrorPassthroughRuleClient returns a client for the ErrorPassthroughRule from the given config.
func NewErrorPassthroughRuleClient(c config) *ErrorPassthroughRuleClient {
return &ErrorPassthroughRuleClient{config: c}
}
// Use adds a list of mutation hooks to the hooks stack.
// A call to `Use(f, g, h)` equals to `errorpassthroughrule.Hooks(f(g(h())))`.
func (c *ErrorPassthroughRuleClient) Use(hooks ...Hook) {
c.hooks.ErrorPassthroughRule = append(c.hooks.ErrorPassthroughRule, hooks...)
}
// Intercept adds a list of query interceptors to the interceptors stack.
// A call to `Intercept(f, g, h)` equals to `errorpassthroughrule.Intercept(f(g(h())))`.
func (c *ErrorPassthroughRuleClient) Intercept(interceptors ...Interceptor) {
c.inters.ErrorPassthroughRule = append(c.inters.ErrorPassthroughRule, interceptors...)
}
// Create returns a builder for creating a ErrorPassthroughRule entity.
func (c *ErrorPassthroughRuleClient) Create() *ErrorPassthroughRuleCreate {
mutation := newErrorPassthroughRuleMutation(c.config, OpCreate)
return &ErrorPassthroughRuleCreate{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// CreateBulk returns a builder for creating a bulk of ErrorPassthroughRule entities.
func (c *ErrorPassthroughRuleClient) CreateBulk(builders ...*ErrorPassthroughRuleCreate) *ErrorPassthroughRuleCreateBulk {
return &ErrorPassthroughRuleCreateBulk{config: c.config, builders: builders}
}
// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
// a builder and applies setFunc on it.
func (c *ErrorPassthroughRuleClient) MapCreateBulk(slice any, setFunc func(*ErrorPassthroughRuleCreate, int)) *ErrorPassthroughRuleCreateBulk {
rv := reflect.ValueOf(slice)
if rv.Kind() != reflect.Slice {
return &ErrorPassthroughRuleCreateBulk{err: fmt.Errorf("calling to ErrorPassthroughRuleClient.MapCreateBulk with wrong type %T, need slice", slice)}
}
builders := make([]*ErrorPassthroughRuleCreate, rv.Len())
for i := 0; i < rv.Len(); i++ {
builders[i] = c.Create()
setFunc(builders[i], i)
}
return &ErrorPassthroughRuleCreateBulk{config: c.config, builders: builders}
}
// Update returns an update builder for ErrorPassthroughRule.
func (c *ErrorPassthroughRuleClient) Update() *ErrorPassthroughRuleUpdate {
mutation := newErrorPassthroughRuleMutation(c.config, OpUpdate)
return &ErrorPassthroughRuleUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// UpdateOne returns an update builder for the given entity.
func (c *ErrorPassthroughRuleClient) UpdateOne(_m *ErrorPassthroughRule) *ErrorPassthroughRuleUpdateOne {
mutation := newErrorPassthroughRuleMutation(c.config, OpUpdateOne, withErrorPassthroughRule(_m))
return &ErrorPassthroughRuleUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// UpdateOneID returns an update builder for the given id.
func (c *ErrorPassthroughRuleClient) UpdateOneID(id int64) *ErrorPassthroughRuleUpdateOne {
mutation := newErrorPassthroughRuleMutation(c.config, OpUpdateOne, withErrorPassthroughRuleID(id))
return &ErrorPassthroughRuleUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// Delete returns a delete builder for ErrorPassthroughRule.
func (c *ErrorPassthroughRuleClient) Delete() *ErrorPassthroughRuleDelete {
mutation := newErrorPassthroughRuleMutation(c.config, OpDelete)
return &ErrorPassthroughRuleDelete{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// DeleteOne returns a builder for deleting the given entity.
func (c *ErrorPassthroughRuleClient) DeleteOne(_m *ErrorPassthroughRule) *ErrorPassthroughRuleDeleteOne {
return c.DeleteOneID(_m.ID)
}
// DeleteOneID returns a builder for deleting the given entity by its id.
func (c *ErrorPassthroughRuleClient) DeleteOneID(id int64) *ErrorPassthroughRuleDeleteOne {
builder := c.Delete().Where(errorpassthroughrule.ID(id))
builder.mutation.id = &id
builder.mutation.op = OpDeleteOne
return &ErrorPassthroughRuleDeleteOne{builder}
}
// Query returns a query builder for ErrorPassthroughRule.
func (c *ErrorPassthroughRuleClient) Query() *ErrorPassthroughRuleQuery {
return &ErrorPassthroughRuleQuery{
config: c.config,
ctx: &QueryContext{Type: TypeErrorPassthroughRule},
inters: c.Interceptors(),
}
}
// Get returns a ErrorPassthroughRule entity by its id.
func (c *ErrorPassthroughRuleClient) Get(ctx context.Context, id int64) (*ErrorPassthroughRule, error) {
return c.Query().Where(errorpassthroughrule.ID(id)).Only(ctx)
}
// GetX is like Get, but panics if an error occurs.
func (c *ErrorPassthroughRuleClient) GetX(ctx context.Context, id int64) *ErrorPassthroughRule {
obj, err := c.Get(ctx, id)
if err != nil {
panic(err)
}
return obj
}
// Hooks returns the client hooks.
func (c *ErrorPassthroughRuleClient) Hooks() []Hook {
return c.hooks.ErrorPassthroughRule
}
// Interceptors returns the client interceptors.
func (c *ErrorPassthroughRuleClient) Interceptors() []Interceptor {
return c.inters.ErrorPassthroughRule
}
func (c *ErrorPassthroughRuleClient) mutate(ctx context.Context, m *ErrorPassthroughRuleMutation) (Value, error) {
switch m.Op() {
case OpCreate:
return (&ErrorPassthroughRuleCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpUpdate:
return (&ErrorPassthroughRuleUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpUpdateOne:
return (&ErrorPassthroughRuleUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpDelete, OpDeleteOne:
return (&ErrorPassthroughRuleDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
default:
return nil, fmt.Errorf("ent: unknown ErrorPassthroughRule mutation op: %q", m.Op())
}
}
// GroupClient is a client for the Group schema. // GroupClient is a client for the Group schema.
type GroupClient struct { type GroupClient struct {
config config
@@ -3462,16 +3605,16 @@ func (c *UserSubscriptionClient) mutate(ctx context.Context, m *UserSubscription
// hooks and interceptors per client, for fast access. // hooks and interceptors per client, for fast access.
type ( type (
hooks struct { hooks struct {
APIKey, Account, AccountGroup, Announcement, AnnouncementRead, Group, PromoCode, APIKey, Account, AccountGroup, Announcement, AnnouncementRead,
PromoCodeUsage, Proxy, RedeemCode, Setting, UsageCleanupTask, UsageLog, User, ErrorPassthroughRule, Group, PromoCode, PromoCodeUsage, Proxy, RedeemCode,
UserAllowedGroup, UserAttributeDefinition, UserAttributeValue, Setting, UsageCleanupTask, UsageLog, User, UserAllowedGroup,
UserSubscription []ent.Hook UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Hook
} }
inters struct { inters struct {
APIKey, Account, AccountGroup, Announcement, AnnouncementRead, Group, PromoCode, APIKey, Account, AccountGroup, Announcement, AnnouncementRead,
PromoCodeUsage, Proxy, RedeemCode, Setting, UsageCleanupTask, UsageLog, User, ErrorPassthroughRule, Group, PromoCode, PromoCodeUsage, Proxy, RedeemCode,
UserAllowedGroup, UserAttributeDefinition, UserAttributeValue, Setting, UsageCleanupTask, UsageLog, User, UserAllowedGroup,
UserSubscription []ent.Interceptor UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Interceptor
} }
) )

View File

@@ -17,6 +17,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/announcement" "github.com/Wei-Shaw/sub2api/ent/announcement"
"github.com/Wei-Shaw/sub2api/ent/announcementread" "github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey" "github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
"github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/promocode" "github.com/Wei-Shaw/sub2api/ent/promocode"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage" "github.com/Wei-Shaw/sub2api/ent/promocodeusage"
@@ -95,6 +96,7 @@ func checkColumn(t, c string) error {
accountgroup.Table: accountgroup.ValidColumn, accountgroup.Table: accountgroup.ValidColumn,
announcement.Table: announcement.ValidColumn, announcement.Table: announcement.ValidColumn,
announcementread.Table: announcementread.ValidColumn, announcementread.Table: announcementread.ValidColumn,
errorpassthroughrule.Table: errorpassthroughrule.ValidColumn,
group.Table: group.ValidColumn, group.Table: group.ValidColumn,
promocode.Table: promocode.ValidColumn, promocode.Table: promocode.ValidColumn,
promocodeusage.Table: promocodeusage.ValidColumn, promocodeusage.Table: promocodeusage.ValidColumn,

View File

@@ -0,0 +1,269 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"encoding/json"
"fmt"
"strings"
"time"
"entgo.io/ent"
"entgo.io/ent/dialect/sql"
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
)
// ErrorPassthroughRule is the model entity for the ErrorPassthroughRule schema.
type ErrorPassthroughRule struct {
config `json:"-"`
// ID of the ent.
ID int64 `json:"id,omitempty"`
// CreatedAt holds the value of the "created_at" field.
CreatedAt time.Time `json:"created_at,omitempty"`
// UpdatedAt holds the value of the "updated_at" field.
UpdatedAt time.Time `json:"updated_at,omitempty"`
// Name holds the value of the "name" field.
Name string `json:"name,omitempty"`
// Enabled holds the value of the "enabled" field.
Enabled bool `json:"enabled,omitempty"`
// Priority holds the value of the "priority" field.
Priority int `json:"priority,omitempty"`
// ErrorCodes holds the value of the "error_codes" field.
ErrorCodes []int `json:"error_codes,omitempty"`
// Keywords holds the value of the "keywords" field.
Keywords []string `json:"keywords,omitempty"`
// MatchMode holds the value of the "match_mode" field.
MatchMode string `json:"match_mode,omitempty"`
// Platforms holds the value of the "platforms" field.
Platforms []string `json:"platforms,omitempty"`
// PassthroughCode holds the value of the "passthrough_code" field.
PassthroughCode bool `json:"passthrough_code,omitempty"`
// ResponseCode holds the value of the "response_code" field.
ResponseCode *int `json:"response_code,omitempty"`
// PassthroughBody holds the value of the "passthrough_body" field.
PassthroughBody bool `json:"passthrough_body,omitempty"`
// CustomMessage holds the value of the "custom_message" field.
CustomMessage *string `json:"custom_message,omitempty"`
// Description holds the value of the "description" field.
Description *string `json:"description,omitempty"`
selectValues sql.SelectValues
}
// scanValues returns the types for scanning values from sql.Rows.
func (*ErrorPassthroughRule) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns))
for i := range columns {
switch columns[i] {
case errorpassthroughrule.FieldErrorCodes, errorpassthroughrule.FieldKeywords, errorpassthroughrule.FieldPlatforms:
values[i] = new([]byte)
case errorpassthroughrule.FieldEnabled, errorpassthroughrule.FieldPassthroughCode, errorpassthroughrule.FieldPassthroughBody:
values[i] = new(sql.NullBool)
case errorpassthroughrule.FieldID, errorpassthroughrule.FieldPriority, errorpassthroughrule.FieldResponseCode:
values[i] = new(sql.NullInt64)
case errorpassthroughrule.FieldName, errorpassthroughrule.FieldMatchMode, errorpassthroughrule.FieldCustomMessage, errorpassthroughrule.FieldDescription:
values[i] = new(sql.NullString)
case errorpassthroughrule.FieldCreatedAt, errorpassthroughrule.FieldUpdatedAt:
values[i] = new(sql.NullTime)
default:
values[i] = new(sql.UnknownType)
}
}
return values, nil
}
// assignValues assigns the values that were returned from sql.Rows (after scanning)
// to the ErrorPassthroughRule fields.
func (_m *ErrorPassthroughRule) assignValues(columns []string, values []any) error {
if m, n := len(values), len(columns); m < n {
return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
}
for i := range columns {
switch columns[i] {
case errorpassthroughrule.FieldID:
value, ok := values[i].(*sql.NullInt64)
if !ok {
return fmt.Errorf("unexpected type %T for field id", value)
}
_m.ID = int64(value.Int64)
case errorpassthroughrule.FieldCreatedAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field created_at", values[i])
} else if value.Valid {
_m.CreatedAt = value.Time
}
case errorpassthroughrule.FieldUpdatedAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field updated_at", values[i])
} else if value.Valid {
_m.UpdatedAt = value.Time
}
case errorpassthroughrule.FieldName:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field name", values[i])
} else if value.Valid {
_m.Name = value.String
}
case errorpassthroughrule.FieldEnabled:
if value, ok := values[i].(*sql.NullBool); !ok {
return fmt.Errorf("unexpected type %T for field enabled", values[i])
} else if value.Valid {
_m.Enabled = value.Bool
}
case errorpassthroughrule.FieldPriority:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field priority", values[i])
} else if value.Valid {
_m.Priority = int(value.Int64)
}
case errorpassthroughrule.FieldErrorCodes:
if value, ok := values[i].(*[]byte); !ok {
return fmt.Errorf("unexpected type %T for field error_codes", values[i])
} else if value != nil && len(*value) > 0 {
if err := json.Unmarshal(*value, &_m.ErrorCodes); err != nil {
return fmt.Errorf("unmarshal field error_codes: %w", err)
}
}
case errorpassthroughrule.FieldKeywords:
if value, ok := values[i].(*[]byte); !ok {
return fmt.Errorf("unexpected type %T for field keywords", values[i])
} else if value != nil && len(*value) > 0 {
if err := json.Unmarshal(*value, &_m.Keywords); err != nil {
return fmt.Errorf("unmarshal field keywords: %w", err)
}
}
case errorpassthroughrule.FieldMatchMode:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field match_mode", values[i])
} else if value.Valid {
_m.MatchMode = value.String
}
case errorpassthroughrule.FieldPlatforms:
if value, ok := values[i].(*[]byte); !ok {
return fmt.Errorf("unexpected type %T for field platforms", values[i])
} else if value != nil && len(*value) > 0 {
if err := json.Unmarshal(*value, &_m.Platforms); err != nil {
return fmt.Errorf("unmarshal field platforms: %w", err)
}
}
case errorpassthroughrule.FieldPassthroughCode:
if value, ok := values[i].(*sql.NullBool); !ok {
return fmt.Errorf("unexpected type %T for field passthrough_code", values[i])
} else if value.Valid {
_m.PassthroughCode = value.Bool
}
case errorpassthroughrule.FieldResponseCode:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field response_code", values[i])
} else if value.Valid {
_m.ResponseCode = new(int)
*_m.ResponseCode = int(value.Int64)
}
case errorpassthroughrule.FieldPassthroughBody:
if value, ok := values[i].(*sql.NullBool); !ok {
return fmt.Errorf("unexpected type %T for field passthrough_body", values[i])
} else if value.Valid {
_m.PassthroughBody = value.Bool
}
case errorpassthroughrule.FieldCustomMessage:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field custom_message", values[i])
} else if value.Valid {
_m.CustomMessage = new(string)
*_m.CustomMessage = value.String
}
case errorpassthroughrule.FieldDescription:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field description", values[i])
} else if value.Valid {
_m.Description = new(string)
*_m.Description = value.String
}
default:
_m.selectValues.Set(columns[i], values[i])
}
}
return nil
}
// Value returns the ent.Value that was dynamically selected and assigned to the ErrorPassthroughRule.
// This includes values selected through modifiers, order, etc.
func (_m *ErrorPassthroughRule) Value(name string) (ent.Value, error) {
return _m.selectValues.Get(name)
}
// Update returns a builder for updating this ErrorPassthroughRule.
// Note that you need to call ErrorPassthroughRule.Unwrap() before calling this method if this ErrorPassthroughRule
// was returned from a transaction, and the transaction was committed or rolled back.
func (_m *ErrorPassthroughRule) Update() *ErrorPassthroughRuleUpdateOne {
return NewErrorPassthroughRuleClient(_m.config).UpdateOne(_m)
}
// Unwrap unwraps the ErrorPassthroughRule entity that was returned from a transaction after it was closed,
// so that all future queries will be executed through the driver which created the transaction.
func (_m *ErrorPassthroughRule) Unwrap() *ErrorPassthroughRule {
_tx, ok := _m.config.driver.(*txDriver)
if !ok {
panic("ent: ErrorPassthroughRule is not a transactional entity")
}
_m.config.driver = _tx.drv
return _m
}
// String implements the fmt.Stringer.
func (_m *ErrorPassthroughRule) String() string {
var builder strings.Builder
builder.WriteString("ErrorPassthroughRule(")
builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID))
builder.WriteString("created_at=")
builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
builder.WriteString(", ")
builder.WriteString("updated_at=")
builder.WriteString(_m.UpdatedAt.Format(time.ANSIC))
builder.WriteString(", ")
builder.WriteString("name=")
builder.WriteString(_m.Name)
builder.WriteString(", ")
builder.WriteString("enabled=")
builder.WriteString(fmt.Sprintf("%v", _m.Enabled))
builder.WriteString(", ")
builder.WriteString("priority=")
builder.WriteString(fmt.Sprintf("%v", _m.Priority))
builder.WriteString(", ")
builder.WriteString("error_codes=")
builder.WriteString(fmt.Sprintf("%v", _m.ErrorCodes))
builder.WriteString(", ")
builder.WriteString("keywords=")
builder.WriteString(fmt.Sprintf("%v", _m.Keywords))
builder.WriteString(", ")
builder.WriteString("match_mode=")
builder.WriteString(_m.MatchMode)
builder.WriteString(", ")
builder.WriteString("platforms=")
builder.WriteString(fmt.Sprintf("%v", _m.Platforms))
builder.WriteString(", ")
builder.WriteString("passthrough_code=")
builder.WriteString(fmt.Sprintf("%v", _m.PassthroughCode))
builder.WriteString(", ")
if v := _m.ResponseCode; v != nil {
builder.WriteString("response_code=")
builder.WriteString(fmt.Sprintf("%v", *v))
}
builder.WriteString(", ")
builder.WriteString("passthrough_body=")
builder.WriteString(fmt.Sprintf("%v", _m.PassthroughBody))
builder.WriteString(", ")
if v := _m.CustomMessage; v != nil {
builder.WriteString("custom_message=")
builder.WriteString(*v)
}
builder.WriteString(", ")
if v := _m.Description; v != nil {
builder.WriteString("description=")
builder.WriteString(*v)
}
builder.WriteByte(')')
return builder.String()
}
// ErrorPassthroughRules is a parsable slice of ErrorPassthroughRule.
type ErrorPassthroughRules []*ErrorPassthroughRule

View File

@@ -0,0 +1,161 @@
// Code generated by ent, DO NOT EDIT.
package errorpassthroughrule
import (
"time"
"entgo.io/ent/dialect/sql"
)
const (
// Label holds the string label denoting the errorpassthroughrule type in the database.
Label = "error_passthrough_rule"
// FieldID holds the string denoting the id field in the database.
FieldID = "id"
// FieldCreatedAt holds the string denoting the created_at field in the database.
FieldCreatedAt = "created_at"
// FieldUpdatedAt holds the string denoting the updated_at field in the database.
FieldUpdatedAt = "updated_at"
// FieldName holds the string denoting the name field in the database.
FieldName = "name"
// FieldEnabled holds the string denoting the enabled field in the database.
FieldEnabled = "enabled"
// FieldPriority holds the string denoting the priority field in the database.
FieldPriority = "priority"
// FieldErrorCodes holds the string denoting the error_codes field in the database.
FieldErrorCodes = "error_codes"
// FieldKeywords holds the string denoting the keywords field in the database.
FieldKeywords = "keywords"
// FieldMatchMode holds the string denoting the match_mode field in the database.
FieldMatchMode = "match_mode"
// FieldPlatforms holds the string denoting the platforms field in the database.
FieldPlatforms = "platforms"
// FieldPassthroughCode holds the string denoting the passthrough_code field in the database.
FieldPassthroughCode = "passthrough_code"
// FieldResponseCode holds the string denoting the response_code field in the database.
FieldResponseCode = "response_code"
// FieldPassthroughBody holds the string denoting the passthrough_body field in the database.
FieldPassthroughBody = "passthrough_body"
// FieldCustomMessage holds the string denoting the custom_message field in the database.
FieldCustomMessage = "custom_message"
// FieldDescription holds the string denoting the description field in the database.
FieldDescription = "description"
// Table holds the table name of the errorpassthroughrule in the database.
Table = "error_passthrough_rules"
)
// Columns holds all SQL columns for errorpassthroughrule fields.
var Columns = []string{
FieldID,
FieldCreatedAt,
FieldUpdatedAt,
FieldName,
FieldEnabled,
FieldPriority,
FieldErrorCodes,
FieldKeywords,
FieldMatchMode,
FieldPlatforms,
FieldPassthroughCode,
FieldResponseCode,
FieldPassthroughBody,
FieldCustomMessage,
FieldDescription,
}
// ValidColumn reports if the column name is valid (part of the table columns).
func ValidColumn(column string) bool {
for i := range Columns {
if column == Columns[i] {
return true
}
}
return false
}
var (
// DefaultCreatedAt holds the default value on creation for the "created_at" field.
DefaultCreatedAt func() time.Time
// DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
DefaultUpdatedAt func() time.Time
// UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field.
UpdateDefaultUpdatedAt func() time.Time
// NameValidator is a validator for the "name" field. It is called by the builders before save.
NameValidator func(string) error
// DefaultEnabled holds the default value on creation for the "enabled" field.
DefaultEnabled bool
// DefaultPriority holds the default value on creation for the "priority" field.
DefaultPriority int
// DefaultMatchMode holds the default value on creation for the "match_mode" field.
DefaultMatchMode string
// MatchModeValidator is a validator for the "match_mode" field. It is called by the builders before save.
MatchModeValidator func(string) error
// DefaultPassthroughCode holds the default value on creation for the "passthrough_code" field.
DefaultPassthroughCode bool
// DefaultPassthroughBody holds the default value on creation for the "passthrough_body" field.
DefaultPassthroughBody bool
)
// OrderOption defines the ordering options for the ErrorPassthroughRule queries.
type OrderOption func(*sql.Selector)
// ByID orders the results by the id field.
func ByID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldID, opts...).ToFunc()
}
// ByCreatedAt orders the results by the created_at field.
func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
}
// ByUpdatedAt orders the results by the updated_at field.
func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc()
}
// ByName orders the results by the name field.
func ByName(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldName, opts...).ToFunc()
}
// ByEnabled orders the results by the enabled field.
func ByEnabled(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldEnabled, opts...).ToFunc()
}
// ByPriority orders the results by the priority field.
func ByPriority(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldPriority, opts...).ToFunc()
}
// ByMatchMode orders the results by the match_mode field.
func ByMatchMode(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldMatchMode, opts...).ToFunc()
}
// ByPassthroughCode orders the results by the passthrough_code field.
func ByPassthroughCode(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldPassthroughCode, opts...).ToFunc()
}
// ByResponseCode orders the results by the response_code field.
func ByResponseCode(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldResponseCode, opts...).ToFunc()
}
// ByPassthroughBody orders the results by the passthrough_body field.
func ByPassthroughBody(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldPassthroughBody, opts...).ToFunc()
}
// ByCustomMessage orders the results by the custom_message field.
func ByCustomMessage(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldCustomMessage, opts...).ToFunc()
}
// ByDescription orders the results by the description field.
func ByDescription(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldDescription, opts...).ToFunc()
}

View File

@@ -0,0 +1,635 @@
// Code generated by ent, DO NOT EDIT.
package errorpassthroughrule
import (
"time"
"entgo.io/ent/dialect/sql"
"github.com/Wei-Shaw/sub2api/ent/predicate"
)
// ID filters vertices based on their ID field.
func ID(id int64) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldID, id))
}
// IDEQ applies the EQ predicate on the ID field.
func IDEQ(id int64) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldID, id))
}
// IDNEQ applies the NEQ predicate on the ID field.
func IDNEQ(id int64) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldNEQ(FieldID, id))
}
// IDIn applies the In predicate on the ID field.
func IDIn(ids ...int64) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldIn(FieldID, ids...))
}
// IDNotIn applies the NotIn predicate on the ID field.
func IDNotIn(ids ...int64) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldNotIn(FieldID, ids...))
}
// IDGT applies the GT predicate on the ID field.
func IDGT(id int64) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldGT(FieldID, id))
}
// IDGTE applies the GTE predicate on the ID field.
func IDGTE(id int64) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldGTE(FieldID, id))
}
// IDLT applies the LT predicate on the ID field.
func IDLT(id int64) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldLT(FieldID, id))
}
// IDLTE applies the LTE predicate on the ID field.
func IDLTE(id int64) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldLTE(FieldID, id))
}
// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
func CreatedAt(v time.Time) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldCreatedAt, v))
}
// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ.
func UpdatedAt(v time.Time) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldUpdatedAt, v))
}
// Name applies equality check predicate on the "name" field. It's identical to NameEQ.
func Name(v string) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldName, v))
}
// Enabled applies equality check predicate on the "enabled" field. It's identical to EnabledEQ.
func Enabled(v bool) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldEnabled, v))
}
// Priority applies equality check predicate on the "priority" field. It's identical to PriorityEQ.
func Priority(v int) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldPriority, v))
}
// MatchMode applies equality check predicate on the "match_mode" field. It's identical to MatchModeEQ.
func MatchMode(v string) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldMatchMode, v))
}
// PassthroughCode applies equality check predicate on the "passthrough_code" field. It's identical to PassthroughCodeEQ.
func PassthroughCode(v bool) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldPassthroughCode, v))
}
// ResponseCode applies equality check predicate on the "response_code" field. It's identical to ResponseCodeEQ.
func ResponseCode(v int) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldResponseCode, v))
}
// PassthroughBody applies equality check predicate on the "passthrough_body" field. It's identical to PassthroughBodyEQ.
func PassthroughBody(v bool) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldPassthroughBody, v))
}
// CustomMessage applies equality check predicate on the "custom_message" field. It's identical to CustomMessageEQ.
func CustomMessage(v string) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldCustomMessage, v))
}
// Description applies equality check predicate on the "description" field. It's identical to DescriptionEQ.
func Description(v string) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldDescription, v))
}
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
func CreatedAtEQ(v time.Time) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldCreatedAt, v))
}
// CreatedAtNEQ applies the NEQ predicate on the "created_at" field.
func CreatedAtNEQ(v time.Time) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldNEQ(FieldCreatedAt, v))
}
// CreatedAtIn applies the In predicate on the "created_at" field.
func CreatedAtIn(vs ...time.Time) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldIn(FieldCreatedAt, vs...))
}
// CreatedAtNotIn applies the NotIn predicate on the "created_at" field.
func CreatedAtNotIn(vs ...time.Time) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldNotIn(FieldCreatedAt, vs...))
}
// CreatedAtGT applies the GT predicate on the "created_at" field.
func CreatedAtGT(v time.Time) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldGT(FieldCreatedAt, v))
}
// CreatedAtGTE applies the GTE predicate on the "created_at" field.
func CreatedAtGTE(v time.Time) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldGTE(FieldCreatedAt, v))
}
// CreatedAtLT applies the LT predicate on the "created_at" field.
func CreatedAtLT(v time.Time) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldLT(FieldCreatedAt, v))
}
// CreatedAtLTE applies the LTE predicate on the "created_at" field.
func CreatedAtLTE(v time.Time) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldLTE(FieldCreatedAt, v))
}
// UpdatedAtEQ applies the EQ predicate on the "updated_at" field.
func UpdatedAtEQ(v time.Time) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldUpdatedAt, v))
}
// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field.
func UpdatedAtNEQ(v time.Time) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldNEQ(FieldUpdatedAt, v))
}
// UpdatedAtIn applies the In predicate on the "updated_at" field.
func UpdatedAtIn(vs ...time.Time) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldIn(FieldUpdatedAt, vs...))
}
// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field.
func UpdatedAtNotIn(vs ...time.Time) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldNotIn(FieldUpdatedAt, vs...))
}
// UpdatedAtGT applies the GT predicate on the "updated_at" field.
func UpdatedAtGT(v time.Time) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldGT(FieldUpdatedAt, v))
}
// UpdatedAtGTE applies the GTE predicate on the "updated_at" field.
func UpdatedAtGTE(v time.Time) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldGTE(FieldUpdatedAt, v))
}
// UpdatedAtLT applies the LT predicate on the "updated_at" field.
func UpdatedAtLT(v time.Time) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldLT(FieldUpdatedAt, v))
}
// UpdatedAtLTE applies the LTE predicate on the "updated_at" field.
func UpdatedAtLTE(v time.Time) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldLTE(FieldUpdatedAt, v))
}
// NameEQ applies the EQ predicate on the "name" field.
func NameEQ(v string) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldName, v))
}
// NameNEQ applies the NEQ predicate on the "name" field.
func NameNEQ(v string) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldNEQ(FieldName, v))
}
// NameIn applies the In predicate on the "name" field.
func NameIn(vs ...string) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldIn(FieldName, vs...))
}
// NameNotIn applies the NotIn predicate on the "name" field.
func NameNotIn(vs ...string) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldNotIn(FieldName, vs...))
}
// NameGT applies the GT predicate on the "name" field.
func NameGT(v string) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldGT(FieldName, v))
}
// NameGTE applies the GTE predicate on the "name" field.
func NameGTE(v string) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldGTE(FieldName, v))
}
// NameLT applies the LT predicate on the "name" field.
func NameLT(v string) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldLT(FieldName, v))
}
// NameLTE applies the LTE predicate on the "name" field.
func NameLTE(v string) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldLTE(FieldName, v))
}
// NameContains applies the Contains predicate on the "name" field.
func NameContains(v string) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldContains(FieldName, v))
}
// NameHasPrefix applies the HasPrefix predicate on the "name" field.
func NameHasPrefix(v string) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldHasPrefix(FieldName, v))
}
// NameHasSuffix applies the HasSuffix predicate on the "name" field.
func NameHasSuffix(v string) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldHasSuffix(FieldName, v))
}
// NameEqualFold applies the EqualFold predicate on the "name" field.
func NameEqualFold(v string) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldEqualFold(FieldName, v))
}
// NameContainsFold applies the ContainsFold predicate on the "name" field.
func NameContainsFold(v string) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldContainsFold(FieldName, v))
}
// EnabledEQ applies the EQ predicate on the "enabled" field.
func EnabledEQ(v bool) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldEnabled, v))
}
// EnabledNEQ applies the NEQ predicate on the "enabled" field.
func EnabledNEQ(v bool) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldNEQ(FieldEnabled, v))
}
// PriorityEQ applies the EQ predicate on the "priority" field.
func PriorityEQ(v int) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldPriority, v))
}
// PriorityNEQ applies the NEQ predicate on the "priority" field.
func PriorityNEQ(v int) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldNEQ(FieldPriority, v))
}
// PriorityIn applies the In predicate on the "priority" field.
func PriorityIn(vs ...int) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldIn(FieldPriority, vs...))
}
// PriorityNotIn applies the NotIn predicate on the "priority" field.
func PriorityNotIn(vs ...int) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldNotIn(FieldPriority, vs...))
}
// PriorityGT applies the GT predicate on the "priority" field.
func PriorityGT(v int) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldGT(FieldPriority, v))
}
// PriorityGTE applies the GTE predicate on the "priority" field.
func PriorityGTE(v int) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldGTE(FieldPriority, v))
}
// PriorityLT applies the LT predicate on the "priority" field.
func PriorityLT(v int) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldLT(FieldPriority, v))
}
// PriorityLTE applies the LTE predicate on the "priority" field.
func PriorityLTE(v int) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldLTE(FieldPriority, v))
}
// ErrorCodesIsNil applies the IsNil predicate on the "error_codes" field.
func ErrorCodesIsNil() predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldIsNull(FieldErrorCodes))
}
// ErrorCodesNotNil applies the NotNil predicate on the "error_codes" field.
func ErrorCodesNotNil() predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldNotNull(FieldErrorCodes))
}
// KeywordsIsNil applies the IsNil predicate on the "keywords" field.
func KeywordsIsNil() predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldIsNull(FieldKeywords))
}
// KeywordsNotNil applies the NotNil predicate on the "keywords" field.
func KeywordsNotNil() predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldNotNull(FieldKeywords))
}
// MatchModeEQ applies the EQ predicate on the "match_mode" field.
func MatchModeEQ(v string) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldMatchMode, v))
}
// MatchModeNEQ applies the NEQ predicate on the "match_mode" field.
func MatchModeNEQ(v string) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldNEQ(FieldMatchMode, v))
}
// MatchModeIn applies the In predicate on the "match_mode" field.
func MatchModeIn(vs ...string) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldIn(FieldMatchMode, vs...))
}
// MatchModeNotIn applies the NotIn predicate on the "match_mode" field.
func MatchModeNotIn(vs ...string) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldNotIn(FieldMatchMode, vs...))
}
// MatchModeGT applies the GT predicate on the "match_mode" field.
func MatchModeGT(v string) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldGT(FieldMatchMode, v))
}
// MatchModeGTE applies the GTE predicate on the "match_mode" field.
func MatchModeGTE(v string) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldGTE(FieldMatchMode, v))
}
// MatchModeLT applies the LT predicate on the "match_mode" field.
func MatchModeLT(v string) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldLT(FieldMatchMode, v))
}
// MatchModeLTE applies the LTE predicate on the "match_mode" field.
func MatchModeLTE(v string) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldLTE(FieldMatchMode, v))
}
// MatchModeContains applies the Contains predicate on the "match_mode" field.
func MatchModeContains(v string) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldContains(FieldMatchMode, v))
}
// MatchModeHasPrefix applies the HasPrefix predicate on the "match_mode" field.
func MatchModeHasPrefix(v string) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldHasPrefix(FieldMatchMode, v))
}
// MatchModeHasSuffix applies the HasSuffix predicate on the "match_mode" field.
func MatchModeHasSuffix(v string) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldHasSuffix(FieldMatchMode, v))
}
// MatchModeEqualFold applies the EqualFold predicate on the "match_mode" field.
func MatchModeEqualFold(v string) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldEqualFold(FieldMatchMode, v))
}
// MatchModeContainsFold applies the ContainsFold predicate on the "match_mode" field.
func MatchModeContainsFold(v string) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldContainsFold(FieldMatchMode, v))
}
// PlatformsIsNil applies the IsNil predicate on the "platforms" field.
func PlatformsIsNil() predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldIsNull(FieldPlatforms))
}
// PlatformsNotNil applies the NotNil predicate on the "platforms" field.
func PlatformsNotNil() predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldNotNull(FieldPlatforms))
}
// PassthroughCodeEQ applies the EQ predicate on the "passthrough_code" field.
func PassthroughCodeEQ(v bool) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldPassthroughCode, v))
}
// PassthroughCodeNEQ applies the NEQ predicate on the "passthrough_code" field.
func PassthroughCodeNEQ(v bool) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldNEQ(FieldPassthroughCode, v))
}
// ResponseCodeEQ applies the EQ predicate on the "response_code" field.
func ResponseCodeEQ(v int) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldResponseCode, v))
}
// ResponseCodeNEQ applies the NEQ predicate on the "response_code" field.
func ResponseCodeNEQ(v int) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldNEQ(FieldResponseCode, v))
}
// ResponseCodeIn applies the In predicate on the "response_code" field.
func ResponseCodeIn(vs ...int) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldIn(FieldResponseCode, vs...))
}
// ResponseCodeNotIn applies the NotIn predicate on the "response_code" field.
func ResponseCodeNotIn(vs ...int) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldNotIn(FieldResponseCode, vs...))
}
// ResponseCodeGT applies the GT predicate on the "response_code" field.
func ResponseCodeGT(v int) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldGT(FieldResponseCode, v))
}
// ResponseCodeGTE applies the GTE predicate on the "response_code" field.
func ResponseCodeGTE(v int) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldGTE(FieldResponseCode, v))
}
// ResponseCodeLT applies the LT predicate on the "response_code" field.
func ResponseCodeLT(v int) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldLT(FieldResponseCode, v))
}
// ResponseCodeLTE applies the LTE predicate on the "response_code" field.
func ResponseCodeLTE(v int) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldLTE(FieldResponseCode, v))
}
// ResponseCodeIsNil applies the IsNil predicate on the "response_code" field.
func ResponseCodeIsNil() predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldIsNull(FieldResponseCode))
}
// ResponseCodeNotNil applies the NotNil predicate on the "response_code" field.
func ResponseCodeNotNil() predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldNotNull(FieldResponseCode))
}
// PassthroughBodyEQ applies the EQ predicate on the "passthrough_body" field.
func PassthroughBodyEQ(v bool) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldPassthroughBody, v))
}
// PassthroughBodyNEQ applies the NEQ predicate on the "passthrough_body" field.
func PassthroughBodyNEQ(v bool) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldNEQ(FieldPassthroughBody, v))
}
// CustomMessageEQ applies the EQ predicate on the "custom_message" field.
func CustomMessageEQ(v string) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldCustomMessage, v))
}
// CustomMessageNEQ applies the NEQ predicate on the "custom_message" field.
func CustomMessageNEQ(v string) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldNEQ(FieldCustomMessage, v))
}
// CustomMessageIn applies the In predicate on the "custom_message" field.
func CustomMessageIn(vs ...string) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldIn(FieldCustomMessage, vs...))
}
// CustomMessageNotIn applies the NotIn predicate on the "custom_message" field.
func CustomMessageNotIn(vs ...string) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldNotIn(FieldCustomMessage, vs...))
}
// CustomMessageGT applies the GT predicate on the "custom_message" field.
func CustomMessageGT(v string) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldGT(FieldCustomMessage, v))
}
// CustomMessageGTE applies the GTE predicate on the "custom_message" field.
func CustomMessageGTE(v string) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldGTE(FieldCustomMessage, v))
}
// CustomMessageLT applies the LT predicate on the "custom_message" field.
func CustomMessageLT(v string) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldLT(FieldCustomMessage, v))
}
// CustomMessageLTE applies the LTE predicate on the "custom_message" field.
func CustomMessageLTE(v string) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldLTE(FieldCustomMessage, v))
}
// CustomMessageContains applies the Contains predicate on the "custom_message" field.
func CustomMessageContains(v string) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldContains(FieldCustomMessage, v))
}
// CustomMessageHasPrefix applies the HasPrefix predicate on the "custom_message" field.
func CustomMessageHasPrefix(v string) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldHasPrefix(FieldCustomMessage, v))
}
// CustomMessageHasSuffix applies the HasSuffix predicate on the "custom_message" field.
func CustomMessageHasSuffix(v string) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldHasSuffix(FieldCustomMessage, v))
}
// CustomMessageIsNil applies the IsNil predicate on the "custom_message" field.
func CustomMessageIsNil() predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldIsNull(FieldCustomMessage))
}
// CustomMessageNotNil applies the NotNil predicate on the "custom_message" field.
func CustomMessageNotNil() predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldNotNull(FieldCustomMessage))
}
// CustomMessageEqualFold applies the EqualFold predicate on the "custom_message" field.
func CustomMessageEqualFold(v string) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldEqualFold(FieldCustomMessage, v))
}
// CustomMessageContainsFold applies the ContainsFold predicate on the "custom_message" field.
func CustomMessageContainsFold(v string) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldContainsFold(FieldCustomMessage, v))
}
// DescriptionEQ applies the EQ predicate on the "description" field.
func DescriptionEQ(v string) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldDescription, v))
}
// DescriptionNEQ applies the NEQ predicate on the "description" field.
func DescriptionNEQ(v string) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldNEQ(FieldDescription, v))
}
// DescriptionIn applies the In predicate on the "description" field.
func DescriptionIn(vs ...string) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldIn(FieldDescription, vs...))
}
// DescriptionNotIn applies the NotIn predicate on the "description" field.
func DescriptionNotIn(vs ...string) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldNotIn(FieldDescription, vs...))
}
// DescriptionGT applies the GT predicate on the "description" field.
func DescriptionGT(v string) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldGT(FieldDescription, v))
}
// DescriptionGTE applies the GTE predicate on the "description" field.
func DescriptionGTE(v string) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldGTE(FieldDescription, v))
}
// DescriptionLT applies the LT predicate on the "description" field.
func DescriptionLT(v string) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldLT(FieldDescription, v))
}
// DescriptionLTE applies the LTE predicate on the "description" field.
func DescriptionLTE(v string) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldLTE(FieldDescription, v))
}
// DescriptionContains applies the Contains predicate on the "description" field.
func DescriptionContains(v string) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldContains(FieldDescription, v))
}
// DescriptionHasPrefix applies the HasPrefix predicate on the "description" field.
func DescriptionHasPrefix(v string) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldHasPrefix(FieldDescription, v))
}
// DescriptionHasSuffix applies the HasSuffix predicate on the "description" field.
func DescriptionHasSuffix(v string) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldHasSuffix(FieldDescription, v))
}
// DescriptionIsNil applies the IsNil predicate on the "description" field.
func DescriptionIsNil() predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldIsNull(FieldDescription))
}
// DescriptionNotNil applies the NotNil predicate on the "description" field.
func DescriptionNotNil() predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldNotNull(FieldDescription))
}
// DescriptionEqualFold applies the EqualFold predicate on the "description" field.
func DescriptionEqualFold(v string) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldEqualFold(FieldDescription, v))
}
// DescriptionContainsFold applies the ContainsFold predicate on the "description" field.
func DescriptionContainsFold(v string) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldContainsFold(FieldDescription, v))
}
// And groups predicates with the AND operator between them.
func And(predicates ...predicate.ErrorPassthroughRule) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.AndPredicates(predicates...))
}
// Or groups predicates with the OR operator between them.
func Or(predicates ...predicate.ErrorPassthroughRule) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.OrPredicates(predicates...))
}
// Not applies the not operator on the given predicate.
func Not(p predicate.ErrorPassthroughRule) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.NotPredicates(p))
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,88 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
"github.com/Wei-Shaw/sub2api/ent/predicate"
)
// ErrorPassthroughRuleDelete is the builder for deleting a ErrorPassthroughRule entity.
type ErrorPassthroughRuleDelete struct {
config
hooks []Hook
mutation *ErrorPassthroughRuleMutation
}
// Where appends a list predicates to the ErrorPassthroughRuleDelete builder.
func (_d *ErrorPassthroughRuleDelete) Where(ps ...predicate.ErrorPassthroughRule) *ErrorPassthroughRuleDelete {
_d.mutation.Where(ps...)
return _d
}
// Exec executes the deletion query and returns how many vertices were deleted.
func (_d *ErrorPassthroughRuleDelete) Exec(ctx context.Context) (int, error) {
return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks)
}
// ExecX is like Exec, but panics if an error occurs.
func (_d *ErrorPassthroughRuleDelete) ExecX(ctx context.Context) int {
n, err := _d.Exec(ctx)
if err != nil {
panic(err)
}
return n
}
func (_d *ErrorPassthroughRuleDelete) sqlExec(ctx context.Context) (int, error) {
_spec := sqlgraph.NewDeleteSpec(errorpassthroughrule.Table, sqlgraph.NewFieldSpec(errorpassthroughrule.FieldID, field.TypeInt64))
if ps := _d.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec)
if err != nil && sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
_d.mutation.done = true
return affected, err
}
// ErrorPassthroughRuleDeleteOne is the builder for deleting a single ErrorPassthroughRule entity.
type ErrorPassthroughRuleDeleteOne struct {
_d *ErrorPassthroughRuleDelete
}
// Where appends a list predicates to the ErrorPassthroughRuleDelete builder.
func (_d *ErrorPassthroughRuleDeleteOne) Where(ps ...predicate.ErrorPassthroughRule) *ErrorPassthroughRuleDeleteOne {
_d._d.mutation.Where(ps...)
return _d
}
// Exec executes the deletion query.
func (_d *ErrorPassthroughRuleDeleteOne) Exec(ctx context.Context) error {
n, err := _d._d.Exec(ctx)
switch {
case err != nil:
return err
case n == 0:
return &NotFoundError{errorpassthroughrule.Label}
default:
return nil
}
}
// ExecX is like Exec, but panics if an error occurs.
func (_d *ErrorPassthroughRuleDeleteOne) ExecX(ctx context.Context) {
if err := _d.Exec(ctx); err != nil {
panic(err)
}
}

View File

@@ -0,0 +1,564 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"fmt"
"math"
"entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
"github.com/Wei-Shaw/sub2api/ent/predicate"
)
// ErrorPassthroughRuleQuery is the builder for querying ErrorPassthroughRule entities.
type ErrorPassthroughRuleQuery struct {
config
ctx *QueryContext
order []errorpassthroughrule.OrderOption
inters []Interceptor
predicates []predicate.ErrorPassthroughRule
modifiers []func(*sql.Selector)
// intermediate query (i.e. traversal path).
sql *sql.Selector
path func(context.Context) (*sql.Selector, error)
}
// Where adds a new predicate for the ErrorPassthroughRuleQuery builder.
func (_q *ErrorPassthroughRuleQuery) Where(ps ...predicate.ErrorPassthroughRule) *ErrorPassthroughRuleQuery {
_q.predicates = append(_q.predicates, ps...)
return _q
}
// Limit the number of records to be returned by this query.
func (_q *ErrorPassthroughRuleQuery) Limit(limit int) *ErrorPassthroughRuleQuery {
_q.ctx.Limit = &limit
return _q
}
// Offset to start from.
func (_q *ErrorPassthroughRuleQuery) Offset(offset int) *ErrorPassthroughRuleQuery {
_q.ctx.Offset = &offset
return _q
}
// Unique configures the query builder to filter duplicate records on query.
// By default, unique is set to true, and can be disabled using this method.
func (_q *ErrorPassthroughRuleQuery) Unique(unique bool) *ErrorPassthroughRuleQuery {
_q.ctx.Unique = &unique
return _q
}
// Order specifies how the records should be ordered.
func (_q *ErrorPassthroughRuleQuery) Order(o ...errorpassthroughrule.OrderOption) *ErrorPassthroughRuleQuery {
_q.order = append(_q.order, o...)
return _q
}
// First returns the first ErrorPassthroughRule entity from the query.
// Returns a *NotFoundError when no ErrorPassthroughRule was found.
func (_q *ErrorPassthroughRuleQuery) First(ctx context.Context) (*ErrorPassthroughRule, error) {
nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst))
if err != nil {
return nil, err
}
if len(nodes) == 0 {
return nil, &NotFoundError{errorpassthroughrule.Label}
}
return nodes[0], nil
}
// FirstX is like First, but panics if an error occurs.
func (_q *ErrorPassthroughRuleQuery) FirstX(ctx context.Context) *ErrorPassthroughRule {
node, err := _q.First(ctx)
if err != nil && !IsNotFound(err) {
panic(err)
}
return node
}
// FirstID returns the first ErrorPassthroughRule ID from the query.
// Returns a *NotFoundError when no ErrorPassthroughRule ID was found.
func (_q *ErrorPassthroughRuleQuery) FirstID(ctx context.Context) (id int64, err error) {
var ids []int64
if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil {
return
}
if len(ids) == 0 {
err = &NotFoundError{errorpassthroughrule.Label}
return
}
return ids[0], nil
}
// FirstIDX is like FirstID, but panics if an error occurs.
func (_q *ErrorPassthroughRuleQuery) FirstIDX(ctx context.Context) int64 {
id, err := _q.FirstID(ctx)
if err != nil && !IsNotFound(err) {
panic(err)
}
return id
}
// Only returns a single ErrorPassthroughRule entity found by the query, ensuring it only returns one.
// Returns a *NotSingularError when more than one ErrorPassthroughRule entity is found.
// Returns a *NotFoundError when no ErrorPassthroughRule entities are found.
func (_q *ErrorPassthroughRuleQuery) Only(ctx context.Context) (*ErrorPassthroughRule, error) {
nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly))
if err != nil {
return nil, err
}
switch len(nodes) {
case 1:
return nodes[0], nil
case 0:
return nil, &NotFoundError{errorpassthroughrule.Label}
default:
return nil, &NotSingularError{errorpassthroughrule.Label}
}
}
// OnlyX is like Only, but panics if an error occurs.
func (_q *ErrorPassthroughRuleQuery) OnlyX(ctx context.Context) *ErrorPassthroughRule {
node, err := _q.Only(ctx)
if err != nil {
panic(err)
}
return node
}
// OnlyID is like Only, but returns the only ErrorPassthroughRule ID in the query.
// Returns a *NotSingularError when more than one ErrorPassthroughRule ID is found.
// Returns a *NotFoundError when no entities are found.
func (_q *ErrorPassthroughRuleQuery) OnlyID(ctx context.Context) (id int64, err error) {
var ids []int64
if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil {
return
}
switch len(ids) {
case 1:
id = ids[0]
case 0:
err = &NotFoundError{errorpassthroughrule.Label}
default:
err = &NotSingularError{errorpassthroughrule.Label}
}
return
}
// OnlyIDX is like OnlyID, but panics if an error occurs.
func (_q *ErrorPassthroughRuleQuery) OnlyIDX(ctx context.Context) int64 {
id, err := _q.OnlyID(ctx)
if err != nil {
panic(err)
}
return id
}
// All executes the query and returns a list of ErrorPassthroughRules.
func (_q *ErrorPassthroughRuleQuery) All(ctx context.Context) ([]*ErrorPassthroughRule, error) {
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll)
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
qr := querierAll[[]*ErrorPassthroughRule, *ErrorPassthroughRuleQuery]()
return withInterceptors[[]*ErrorPassthroughRule](ctx, _q, qr, _q.inters)
}
// AllX is like All, but panics if an error occurs.
func (_q *ErrorPassthroughRuleQuery) AllX(ctx context.Context) []*ErrorPassthroughRule {
nodes, err := _q.All(ctx)
if err != nil {
panic(err)
}
return nodes
}
// IDs executes the query and returns a list of ErrorPassthroughRule IDs.
func (_q *ErrorPassthroughRuleQuery) IDs(ctx context.Context) (ids []int64, err error) {
if _q.ctx.Unique == nil && _q.path != nil {
_q.Unique(true)
}
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs)
if err = _q.Select(errorpassthroughrule.FieldID).Scan(ctx, &ids); err != nil {
return nil, err
}
return ids, nil
}
// IDsX is like IDs, but panics if an error occurs.
func (_q *ErrorPassthroughRuleQuery) IDsX(ctx context.Context) []int64 {
ids, err := _q.IDs(ctx)
if err != nil {
panic(err)
}
return ids
}
// Count returns the count of the given query.
func (_q *ErrorPassthroughRuleQuery) Count(ctx context.Context) (int, error) {
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount)
if err := _q.prepareQuery(ctx); err != nil {
return 0, err
}
return withInterceptors[int](ctx, _q, querierCount[*ErrorPassthroughRuleQuery](), _q.inters)
}
// CountX is like Count, but panics if an error occurs.
func (_q *ErrorPassthroughRuleQuery) CountX(ctx context.Context) int {
count, err := _q.Count(ctx)
if err != nil {
panic(err)
}
return count
}
// Exist returns true if the query has elements in the graph.
func (_q *ErrorPassthroughRuleQuery) Exist(ctx context.Context) (bool, error) {
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist)
switch _, err := _q.FirstID(ctx); {
case IsNotFound(err):
return false, nil
case err != nil:
return false, fmt.Errorf("ent: check existence: %w", err)
default:
return true, nil
}
}
// ExistX is like Exist, but panics if an error occurs.
func (_q *ErrorPassthroughRuleQuery) ExistX(ctx context.Context) bool {
exist, err := _q.Exist(ctx)
if err != nil {
panic(err)
}
return exist
}
// Clone returns a duplicate of the ErrorPassthroughRuleQuery builder, including all associated steps. It can be
// used to prepare common query builders and use them differently after the clone is made.
func (_q *ErrorPassthroughRuleQuery) Clone() *ErrorPassthroughRuleQuery {
if _q == nil {
return nil
}
return &ErrorPassthroughRuleQuery{
config: _q.config,
ctx: _q.ctx.Clone(),
order: append([]errorpassthroughrule.OrderOption{}, _q.order...),
inters: append([]Interceptor{}, _q.inters...),
predicates: append([]predicate.ErrorPassthroughRule{}, _q.predicates...),
// clone intermediate query.
sql: _q.sql.Clone(),
path: _q.path,
}
}
// GroupBy is used to group vertices by one or more fields/columns.
// It is often used with aggregate functions, like: count, max, mean, min, sum.
//
// Example:
//
// var v []struct {
// CreatedAt time.Time `json:"created_at,omitempty"`
// Count int `json:"count,omitempty"`
// }
//
// client.ErrorPassthroughRule.Query().
// GroupBy(errorpassthroughrule.FieldCreatedAt).
// Aggregate(ent.Count()).
// Scan(ctx, &v)
func (_q *ErrorPassthroughRuleQuery) GroupBy(field string, fields ...string) *ErrorPassthroughRuleGroupBy {
_q.ctx.Fields = append([]string{field}, fields...)
grbuild := &ErrorPassthroughRuleGroupBy{build: _q}
grbuild.flds = &_q.ctx.Fields
grbuild.label = errorpassthroughrule.Label
grbuild.scan = grbuild.Scan
return grbuild
}
// Select allows the selection one or more fields/columns for the given query,
// instead of selecting all fields in the entity.
//
// Example:
//
// var v []struct {
// CreatedAt time.Time `json:"created_at,omitempty"`
// }
//
// client.ErrorPassthroughRule.Query().
// Select(errorpassthroughrule.FieldCreatedAt).
// Scan(ctx, &v)
func (_q *ErrorPassthroughRuleQuery) Select(fields ...string) *ErrorPassthroughRuleSelect {
_q.ctx.Fields = append(_q.ctx.Fields, fields...)
sbuild := &ErrorPassthroughRuleSelect{ErrorPassthroughRuleQuery: _q}
sbuild.label = errorpassthroughrule.Label
sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan
return sbuild
}
// Aggregate returns a ErrorPassthroughRuleSelect configured with the given aggregations.
func (_q *ErrorPassthroughRuleQuery) Aggregate(fns ...AggregateFunc) *ErrorPassthroughRuleSelect {
return _q.Select().Aggregate(fns...)
}
func (_q *ErrorPassthroughRuleQuery) prepareQuery(ctx context.Context) error {
for _, inter := range _q.inters {
if inter == nil {
return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
}
if trv, ok := inter.(Traverser); ok {
if err := trv.Traverse(ctx, _q); err != nil {
return err
}
}
}
for _, f := range _q.ctx.Fields {
if !errorpassthroughrule.ValidColumn(f) {
return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
}
}
if _q.path != nil {
prev, err := _q.path(ctx)
if err != nil {
return err
}
_q.sql = prev
}
return nil
}
func (_q *ErrorPassthroughRuleQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*ErrorPassthroughRule, error) {
var (
nodes = []*ErrorPassthroughRule{}
_spec = _q.querySpec()
)
_spec.ScanValues = func(columns []string) ([]any, error) {
return (*ErrorPassthroughRule).scanValues(nil, columns)
}
_spec.Assign = func(columns []string, values []any) error {
node := &ErrorPassthroughRule{config: _q.config}
nodes = append(nodes, node)
return node.assignValues(columns, values)
}
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
for i := range hooks {
hooks[i](ctx, _spec)
}
if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil {
return nil, err
}
if len(nodes) == 0 {
return nodes, nil
}
return nodes, nil
}
func (_q *ErrorPassthroughRuleQuery) sqlCount(ctx context.Context) (int, error) {
_spec := _q.querySpec()
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
_spec.Node.Columns = _q.ctx.Fields
if len(_q.ctx.Fields) > 0 {
_spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
}
return sqlgraph.CountNodes(ctx, _q.driver, _spec)
}
func (_q *ErrorPassthroughRuleQuery) querySpec() *sqlgraph.QuerySpec {
_spec := sqlgraph.NewQuerySpec(errorpassthroughrule.Table, errorpassthroughrule.Columns, sqlgraph.NewFieldSpec(errorpassthroughrule.FieldID, field.TypeInt64))
_spec.From = _q.sql
if unique := _q.ctx.Unique; unique != nil {
_spec.Unique = *unique
} else if _q.path != nil {
_spec.Unique = true
}
if fields := _q.ctx.Fields; len(fields) > 0 {
_spec.Node.Columns = make([]string, 0, len(fields))
_spec.Node.Columns = append(_spec.Node.Columns, errorpassthroughrule.FieldID)
for i := range fields {
if fields[i] != errorpassthroughrule.FieldID {
_spec.Node.Columns = append(_spec.Node.Columns, fields[i])
}
}
}
if ps := _q.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if limit := _q.ctx.Limit; limit != nil {
_spec.Limit = *limit
}
if offset := _q.ctx.Offset; offset != nil {
_spec.Offset = *offset
}
if ps := _q.order; len(ps) > 0 {
_spec.Order = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
return _spec
}
func (_q *ErrorPassthroughRuleQuery) sqlQuery(ctx context.Context) *sql.Selector {
builder := sql.Dialect(_q.driver.Dialect())
t1 := builder.Table(errorpassthroughrule.Table)
columns := _q.ctx.Fields
if len(columns) == 0 {
columns = errorpassthroughrule.Columns
}
selector := builder.Select(t1.Columns(columns...)...).From(t1)
if _q.sql != nil {
selector = _q.sql
selector.Select(selector.Columns(columns...)...)
}
if _q.ctx.Unique != nil && *_q.ctx.Unique {
selector.Distinct()
}
for _, m := range _q.modifiers {
m(selector)
}
for _, p := range _q.predicates {
p(selector)
}
for _, p := range _q.order {
p(selector)
}
if offset := _q.ctx.Offset; offset != nil {
// limit is mandatory for offset clause. We start
// with default value, and override it below if needed.
selector.Offset(*offset).Limit(math.MaxInt32)
}
if limit := _q.ctx.Limit; limit != nil {
selector.Limit(*limit)
}
return selector
}
// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
// updated, deleted or "selected ... for update" by other sessions, until the transaction is
// either committed or rolled-back.
func (_q *ErrorPassthroughRuleQuery) ForUpdate(opts ...sql.LockOption) *ErrorPassthroughRuleQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForUpdate(opts...)
})
return _q
}
// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
// on any rows that are read. Other sessions can read the rows, but cannot modify them
// until your transaction commits.
func (_q *ErrorPassthroughRuleQuery) ForShare(opts ...sql.LockOption) *ErrorPassthroughRuleQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForShare(opts...)
})
return _q
}
// ErrorPassthroughRuleGroupBy is the group-by builder for ErrorPassthroughRule entities.
type ErrorPassthroughRuleGroupBy struct {
selector
build *ErrorPassthroughRuleQuery
}
// Aggregate adds the given aggregation functions to the group-by query.
func (_g *ErrorPassthroughRuleGroupBy) Aggregate(fns ...AggregateFunc) *ErrorPassthroughRuleGroupBy {
_g.fns = append(_g.fns, fns...)
return _g
}
// Scan applies the selector query and scans the result into the given value.
func (_g *ErrorPassthroughRuleGroupBy) Scan(ctx context.Context, v any) error {
ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy)
if err := _g.build.prepareQuery(ctx); err != nil {
return err
}
return scanWithInterceptors[*ErrorPassthroughRuleQuery, *ErrorPassthroughRuleGroupBy](ctx, _g.build, _g, _g.build.inters, v)
}
func (_g *ErrorPassthroughRuleGroupBy) sqlScan(ctx context.Context, root *ErrorPassthroughRuleQuery, v any) error {
selector := root.sqlQuery(ctx).Select()
aggregation := make([]string, 0, len(_g.fns))
for _, fn := range _g.fns {
aggregation = append(aggregation, fn(selector))
}
if len(selector.SelectedColumns()) == 0 {
columns := make([]string, 0, len(*_g.flds)+len(_g.fns))
for _, f := range *_g.flds {
columns = append(columns, selector.C(f))
}
columns = append(columns, aggregation...)
selector.Select(columns...)
}
selector.GroupBy(selector.Columns(*_g.flds...)...)
if err := selector.Err(); err != nil {
return err
}
rows := &sql.Rows{}
query, args := selector.Query()
if err := _g.build.driver.Query(ctx, query, args, rows); err != nil {
return err
}
defer rows.Close()
return sql.ScanSlice(rows, v)
}
// ErrorPassthroughRuleSelect is the builder for selecting fields of ErrorPassthroughRule entities.
type ErrorPassthroughRuleSelect struct {
*ErrorPassthroughRuleQuery
selector
}
// Aggregate adds the given aggregation functions to the selector query.
func (_s *ErrorPassthroughRuleSelect) Aggregate(fns ...AggregateFunc) *ErrorPassthroughRuleSelect {
_s.fns = append(_s.fns, fns...)
return _s
}
// Scan applies the selector query and scans the result into the given value.
func (_s *ErrorPassthroughRuleSelect) Scan(ctx context.Context, v any) error {
ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect)
if err := _s.prepareQuery(ctx); err != nil {
return err
}
return scanWithInterceptors[*ErrorPassthroughRuleQuery, *ErrorPassthroughRuleSelect](ctx, _s.ErrorPassthroughRuleQuery, _s, _s.inters, v)
}
func (_s *ErrorPassthroughRuleSelect) sqlScan(ctx context.Context, root *ErrorPassthroughRuleQuery, v any) error {
selector := root.sqlQuery(ctx)
aggregation := make([]string, 0, len(_s.fns))
for _, fn := range _s.fns {
aggregation = append(aggregation, fn(selector))
}
switch n := len(*_s.selector.flds); {
case n == 0 && len(aggregation) > 0:
selector.Select(aggregation...)
case n != 0 && len(aggregation) > 0:
selector.AppendSelect(aggregation...)
}
rows := &sql.Rows{}
query, args := selector.Query()
if err := _s.driver.Query(ctx, query, args, rows); err != nil {
return err
}
defer rows.Close()
return sql.ScanSlice(rows, v)
}

View File

@@ -0,0 +1,823 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"errors"
"fmt"
"time"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/dialect/sql/sqljson"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
"github.com/Wei-Shaw/sub2api/ent/predicate"
)
// ErrorPassthroughRuleUpdate is the builder for updating ErrorPassthroughRule entities.
type ErrorPassthroughRuleUpdate struct {
config
hooks []Hook
mutation *ErrorPassthroughRuleMutation
}
// Where appends a list predicates to the ErrorPassthroughRuleUpdate builder.
func (_u *ErrorPassthroughRuleUpdate) Where(ps ...predicate.ErrorPassthroughRule) *ErrorPassthroughRuleUpdate {
_u.mutation.Where(ps...)
return _u
}
// SetUpdatedAt sets the "updated_at" field.
func (_u *ErrorPassthroughRuleUpdate) SetUpdatedAt(v time.Time) *ErrorPassthroughRuleUpdate {
_u.mutation.SetUpdatedAt(v)
return _u
}
// SetName sets the "name" field.
func (_u *ErrorPassthroughRuleUpdate) SetName(v string) *ErrorPassthroughRuleUpdate {
_u.mutation.SetName(v)
return _u
}
// SetNillableName sets the "name" field if the given value is not nil.
func (_u *ErrorPassthroughRuleUpdate) SetNillableName(v *string) *ErrorPassthroughRuleUpdate {
if v != nil {
_u.SetName(*v)
}
return _u
}
// SetEnabled sets the "enabled" field.
func (_u *ErrorPassthroughRuleUpdate) SetEnabled(v bool) *ErrorPassthroughRuleUpdate {
_u.mutation.SetEnabled(v)
return _u
}
// SetNillableEnabled sets the "enabled" field if the given value is not nil.
func (_u *ErrorPassthroughRuleUpdate) SetNillableEnabled(v *bool) *ErrorPassthroughRuleUpdate {
if v != nil {
_u.SetEnabled(*v)
}
return _u
}
// SetPriority sets the "priority" field.
func (_u *ErrorPassthroughRuleUpdate) SetPriority(v int) *ErrorPassthroughRuleUpdate {
_u.mutation.ResetPriority()
_u.mutation.SetPriority(v)
return _u
}
// SetNillablePriority sets the "priority" field if the given value is not nil.
func (_u *ErrorPassthroughRuleUpdate) SetNillablePriority(v *int) *ErrorPassthroughRuleUpdate {
if v != nil {
_u.SetPriority(*v)
}
return _u
}
// AddPriority adds value to the "priority" field.
func (_u *ErrorPassthroughRuleUpdate) AddPriority(v int) *ErrorPassthroughRuleUpdate {
_u.mutation.AddPriority(v)
return _u
}
// SetErrorCodes sets the "error_codes" field.
func (_u *ErrorPassthroughRuleUpdate) SetErrorCodes(v []int) *ErrorPassthroughRuleUpdate {
_u.mutation.SetErrorCodes(v)
return _u
}
// AppendErrorCodes appends value to the "error_codes" field.
func (_u *ErrorPassthroughRuleUpdate) AppendErrorCodes(v []int) *ErrorPassthroughRuleUpdate {
_u.mutation.AppendErrorCodes(v)
return _u
}
// ClearErrorCodes clears the value of the "error_codes" field.
func (_u *ErrorPassthroughRuleUpdate) ClearErrorCodes() *ErrorPassthroughRuleUpdate {
_u.mutation.ClearErrorCodes()
return _u
}
// SetKeywords sets the "keywords" field.
func (_u *ErrorPassthroughRuleUpdate) SetKeywords(v []string) *ErrorPassthroughRuleUpdate {
_u.mutation.SetKeywords(v)
return _u
}
// AppendKeywords appends value to the "keywords" field.
func (_u *ErrorPassthroughRuleUpdate) AppendKeywords(v []string) *ErrorPassthroughRuleUpdate {
_u.mutation.AppendKeywords(v)
return _u
}
// ClearKeywords clears the value of the "keywords" field.
func (_u *ErrorPassthroughRuleUpdate) ClearKeywords() *ErrorPassthroughRuleUpdate {
_u.mutation.ClearKeywords()
return _u
}
// SetMatchMode sets the "match_mode" field.
func (_u *ErrorPassthroughRuleUpdate) SetMatchMode(v string) *ErrorPassthroughRuleUpdate {
_u.mutation.SetMatchMode(v)
return _u
}
// SetNillableMatchMode sets the "match_mode" field if the given value is not nil.
func (_u *ErrorPassthroughRuleUpdate) SetNillableMatchMode(v *string) *ErrorPassthroughRuleUpdate {
if v != nil {
_u.SetMatchMode(*v)
}
return _u
}
// SetPlatforms sets the "platforms" field.
func (_u *ErrorPassthroughRuleUpdate) SetPlatforms(v []string) *ErrorPassthroughRuleUpdate {
_u.mutation.SetPlatforms(v)
return _u
}
// AppendPlatforms appends value to the "platforms" field.
func (_u *ErrorPassthroughRuleUpdate) AppendPlatforms(v []string) *ErrorPassthroughRuleUpdate {
_u.mutation.AppendPlatforms(v)
return _u
}
// ClearPlatforms clears the value of the "platforms" field.
func (_u *ErrorPassthroughRuleUpdate) ClearPlatforms() *ErrorPassthroughRuleUpdate {
_u.mutation.ClearPlatforms()
return _u
}
// SetPassthroughCode sets the "passthrough_code" field.
func (_u *ErrorPassthroughRuleUpdate) SetPassthroughCode(v bool) *ErrorPassthroughRuleUpdate {
_u.mutation.SetPassthroughCode(v)
return _u
}
// SetNillablePassthroughCode sets the "passthrough_code" field if the given value is not nil.
func (_u *ErrorPassthroughRuleUpdate) SetNillablePassthroughCode(v *bool) *ErrorPassthroughRuleUpdate {
if v != nil {
_u.SetPassthroughCode(*v)
}
return _u
}
// SetResponseCode sets the "response_code" field.
func (_u *ErrorPassthroughRuleUpdate) SetResponseCode(v int) *ErrorPassthroughRuleUpdate {
_u.mutation.ResetResponseCode()
_u.mutation.SetResponseCode(v)
return _u
}
// SetNillableResponseCode sets the "response_code" field if the given value is not nil.
func (_u *ErrorPassthroughRuleUpdate) SetNillableResponseCode(v *int) *ErrorPassthroughRuleUpdate {
if v != nil {
_u.SetResponseCode(*v)
}
return _u
}
// AddResponseCode adds value to the "response_code" field.
func (_u *ErrorPassthroughRuleUpdate) AddResponseCode(v int) *ErrorPassthroughRuleUpdate {
_u.mutation.AddResponseCode(v)
return _u
}
// ClearResponseCode clears the value of the "response_code" field.
func (_u *ErrorPassthroughRuleUpdate) ClearResponseCode() *ErrorPassthroughRuleUpdate {
_u.mutation.ClearResponseCode()
return _u
}
// SetPassthroughBody sets the "passthrough_body" field.
func (_u *ErrorPassthroughRuleUpdate) SetPassthroughBody(v bool) *ErrorPassthroughRuleUpdate {
_u.mutation.SetPassthroughBody(v)
return _u
}
// SetNillablePassthroughBody sets the "passthrough_body" field if the given value is not nil.
func (_u *ErrorPassthroughRuleUpdate) SetNillablePassthroughBody(v *bool) *ErrorPassthroughRuleUpdate {
if v != nil {
_u.SetPassthroughBody(*v)
}
return _u
}
// SetCustomMessage sets the "custom_message" field.
func (_u *ErrorPassthroughRuleUpdate) SetCustomMessage(v string) *ErrorPassthroughRuleUpdate {
_u.mutation.SetCustomMessage(v)
return _u
}
// SetNillableCustomMessage sets the "custom_message" field if the given value is not nil.
func (_u *ErrorPassthroughRuleUpdate) SetNillableCustomMessage(v *string) *ErrorPassthroughRuleUpdate {
if v != nil {
_u.SetCustomMessage(*v)
}
return _u
}
// ClearCustomMessage clears the value of the "custom_message" field.
func (_u *ErrorPassthroughRuleUpdate) ClearCustomMessage() *ErrorPassthroughRuleUpdate {
_u.mutation.ClearCustomMessage()
return _u
}
// SetDescription sets the "description" field.
func (_u *ErrorPassthroughRuleUpdate) SetDescription(v string) *ErrorPassthroughRuleUpdate {
_u.mutation.SetDescription(v)
return _u
}
// SetNillableDescription sets the "description" field if the given value is not nil.
func (_u *ErrorPassthroughRuleUpdate) SetNillableDescription(v *string) *ErrorPassthroughRuleUpdate {
if v != nil {
_u.SetDescription(*v)
}
return _u
}
// ClearDescription clears the value of the "description" field.
func (_u *ErrorPassthroughRuleUpdate) ClearDescription() *ErrorPassthroughRuleUpdate {
_u.mutation.ClearDescription()
return _u
}
// Mutation returns the ErrorPassthroughRuleMutation object of the builder.
func (_u *ErrorPassthroughRuleUpdate) Mutation() *ErrorPassthroughRuleMutation {
return _u.mutation
}
// Save executes the query and returns the number of nodes affected by the update operation.
func (_u *ErrorPassthroughRuleUpdate) Save(ctx context.Context) (int, error) {
_u.defaults()
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
}
// SaveX is like Save, but panics if an error occurs.
func (_u *ErrorPassthroughRuleUpdate) SaveX(ctx context.Context) int {
affected, err := _u.Save(ctx)
if err != nil {
panic(err)
}
return affected
}
// Exec executes the query.
func (_u *ErrorPassthroughRuleUpdate) Exec(ctx context.Context) error {
_, err := _u.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (_u *ErrorPassthroughRuleUpdate) ExecX(ctx context.Context) {
if err := _u.Exec(ctx); err != nil {
panic(err)
}
}
// defaults sets the default values of the builder before save.
func (_u *ErrorPassthroughRuleUpdate) defaults() {
if _, ok := _u.mutation.UpdatedAt(); !ok {
v := errorpassthroughrule.UpdateDefaultUpdatedAt()
_u.mutation.SetUpdatedAt(v)
}
}
// check runs all checks and user-defined validators on the builder.
func (_u *ErrorPassthroughRuleUpdate) check() error {
if v, ok := _u.mutation.Name(); ok {
if err := errorpassthroughrule.NameValidator(v); err != nil {
return &ValidationError{Name: "name", err: fmt.Errorf(`ent: validator failed for field "ErrorPassthroughRule.name": %w`, err)}
}
}
if v, ok := _u.mutation.MatchMode(); ok {
if err := errorpassthroughrule.MatchModeValidator(v); err != nil {
return &ValidationError{Name: "match_mode", err: fmt.Errorf(`ent: validator failed for field "ErrorPassthroughRule.match_mode": %w`, err)}
}
}
return nil
}
func (_u *ErrorPassthroughRuleUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if err := _u.check(); err != nil {
return _node, err
}
_spec := sqlgraph.NewUpdateSpec(errorpassthroughrule.Table, errorpassthroughrule.Columns, sqlgraph.NewFieldSpec(errorpassthroughrule.FieldID, field.TypeInt64))
if ps := _u.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if value, ok := _u.mutation.UpdatedAt(); ok {
_spec.SetField(errorpassthroughrule.FieldUpdatedAt, field.TypeTime, value)
}
if value, ok := _u.mutation.Name(); ok {
_spec.SetField(errorpassthroughrule.FieldName, field.TypeString, value)
}
if value, ok := _u.mutation.Enabled(); ok {
_spec.SetField(errorpassthroughrule.FieldEnabled, field.TypeBool, value)
}
if value, ok := _u.mutation.Priority(); ok {
_spec.SetField(errorpassthroughrule.FieldPriority, field.TypeInt, value)
}
if value, ok := _u.mutation.AddedPriority(); ok {
_spec.AddField(errorpassthroughrule.FieldPriority, field.TypeInt, value)
}
if value, ok := _u.mutation.ErrorCodes(); ok {
_spec.SetField(errorpassthroughrule.FieldErrorCodes, field.TypeJSON, value)
}
if value, ok := _u.mutation.AppendedErrorCodes(); ok {
_spec.AddModifier(func(u *sql.UpdateBuilder) {
sqljson.Append(u, errorpassthroughrule.FieldErrorCodes, value)
})
}
if _u.mutation.ErrorCodesCleared() {
_spec.ClearField(errorpassthroughrule.FieldErrorCodes, field.TypeJSON)
}
if value, ok := _u.mutation.Keywords(); ok {
_spec.SetField(errorpassthroughrule.FieldKeywords, field.TypeJSON, value)
}
if value, ok := _u.mutation.AppendedKeywords(); ok {
_spec.AddModifier(func(u *sql.UpdateBuilder) {
sqljson.Append(u, errorpassthroughrule.FieldKeywords, value)
})
}
if _u.mutation.KeywordsCleared() {
_spec.ClearField(errorpassthroughrule.FieldKeywords, field.TypeJSON)
}
if value, ok := _u.mutation.MatchMode(); ok {
_spec.SetField(errorpassthroughrule.FieldMatchMode, field.TypeString, value)
}
if value, ok := _u.mutation.Platforms(); ok {
_spec.SetField(errorpassthroughrule.FieldPlatforms, field.TypeJSON, value)
}
if value, ok := _u.mutation.AppendedPlatforms(); ok {
_spec.AddModifier(func(u *sql.UpdateBuilder) {
sqljson.Append(u, errorpassthroughrule.FieldPlatforms, value)
})
}
if _u.mutation.PlatformsCleared() {
_spec.ClearField(errorpassthroughrule.FieldPlatforms, field.TypeJSON)
}
if value, ok := _u.mutation.PassthroughCode(); ok {
_spec.SetField(errorpassthroughrule.FieldPassthroughCode, field.TypeBool, value)
}
if value, ok := _u.mutation.ResponseCode(); ok {
_spec.SetField(errorpassthroughrule.FieldResponseCode, field.TypeInt, value)
}
if value, ok := _u.mutation.AddedResponseCode(); ok {
_spec.AddField(errorpassthroughrule.FieldResponseCode, field.TypeInt, value)
}
if _u.mutation.ResponseCodeCleared() {
_spec.ClearField(errorpassthroughrule.FieldResponseCode, field.TypeInt)
}
if value, ok := _u.mutation.PassthroughBody(); ok {
_spec.SetField(errorpassthroughrule.FieldPassthroughBody, field.TypeBool, value)
}
if value, ok := _u.mutation.CustomMessage(); ok {
_spec.SetField(errorpassthroughrule.FieldCustomMessage, field.TypeString, value)
}
if _u.mutation.CustomMessageCleared() {
_spec.ClearField(errorpassthroughrule.FieldCustomMessage, field.TypeString)
}
if value, ok := _u.mutation.Description(); ok {
_spec.SetField(errorpassthroughrule.FieldDescription, field.TypeString, value)
}
if _u.mutation.DescriptionCleared() {
_spec.ClearField(errorpassthroughrule.FieldDescription, field.TypeString)
}
if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{errorpassthroughrule.Label}
} else if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
return 0, err
}
_u.mutation.done = true
return _node, nil
}
// ErrorPassthroughRuleUpdateOne is the builder for updating a single ErrorPassthroughRule entity.
type ErrorPassthroughRuleUpdateOne struct {
config
fields []string
hooks []Hook
mutation *ErrorPassthroughRuleMutation
}
// SetUpdatedAt sets the "updated_at" field.
func (_u *ErrorPassthroughRuleUpdateOne) SetUpdatedAt(v time.Time) *ErrorPassthroughRuleUpdateOne {
_u.mutation.SetUpdatedAt(v)
return _u
}
// SetName sets the "name" field.
func (_u *ErrorPassthroughRuleUpdateOne) SetName(v string) *ErrorPassthroughRuleUpdateOne {
_u.mutation.SetName(v)
return _u
}
// SetNillableName sets the "name" field if the given value is not nil.
func (_u *ErrorPassthroughRuleUpdateOne) SetNillableName(v *string) *ErrorPassthroughRuleUpdateOne {
if v != nil {
_u.SetName(*v)
}
return _u
}
// SetEnabled sets the "enabled" field.
func (_u *ErrorPassthroughRuleUpdateOne) SetEnabled(v bool) *ErrorPassthroughRuleUpdateOne {
_u.mutation.SetEnabled(v)
return _u
}
// SetNillableEnabled sets the "enabled" field if the given value is not nil.
func (_u *ErrorPassthroughRuleUpdateOne) SetNillableEnabled(v *bool) *ErrorPassthroughRuleUpdateOne {
if v != nil {
_u.SetEnabled(*v)
}
return _u
}
// SetPriority sets the "priority" field.
func (_u *ErrorPassthroughRuleUpdateOne) SetPriority(v int) *ErrorPassthroughRuleUpdateOne {
_u.mutation.ResetPriority()
_u.mutation.SetPriority(v)
return _u
}
// SetNillablePriority sets the "priority" field if the given value is not nil.
func (_u *ErrorPassthroughRuleUpdateOne) SetNillablePriority(v *int) *ErrorPassthroughRuleUpdateOne {
if v != nil {
_u.SetPriority(*v)
}
return _u
}
// AddPriority adds value to the "priority" field.
func (_u *ErrorPassthroughRuleUpdateOne) AddPriority(v int) *ErrorPassthroughRuleUpdateOne {
_u.mutation.AddPriority(v)
return _u
}
// SetErrorCodes sets the "error_codes" field.
func (_u *ErrorPassthroughRuleUpdateOne) SetErrorCodes(v []int) *ErrorPassthroughRuleUpdateOne {
_u.mutation.SetErrorCodes(v)
return _u
}
// AppendErrorCodes appends value to the "error_codes" field.
func (_u *ErrorPassthroughRuleUpdateOne) AppendErrorCodes(v []int) *ErrorPassthroughRuleUpdateOne {
_u.mutation.AppendErrorCodes(v)
return _u
}
// ClearErrorCodes clears the value of the "error_codes" field.
func (_u *ErrorPassthroughRuleUpdateOne) ClearErrorCodes() *ErrorPassthroughRuleUpdateOne {
_u.mutation.ClearErrorCodes()
return _u
}
// SetKeywords sets the "keywords" field.
func (_u *ErrorPassthroughRuleUpdateOne) SetKeywords(v []string) *ErrorPassthroughRuleUpdateOne {
_u.mutation.SetKeywords(v)
return _u
}
// AppendKeywords appends value to the "keywords" field.
func (_u *ErrorPassthroughRuleUpdateOne) AppendKeywords(v []string) *ErrorPassthroughRuleUpdateOne {
_u.mutation.AppendKeywords(v)
return _u
}
// ClearKeywords clears the value of the "keywords" field.
func (_u *ErrorPassthroughRuleUpdateOne) ClearKeywords() *ErrorPassthroughRuleUpdateOne {
_u.mutation.ClearKeywords()
return _u
}
// SetMatchMode sets the "match_mode" field.
func (_u *ErrorPassthroughRuleUpdateOne) SetMatchMode(v string) *ErrorPassthroughRuleUpdateOne {
_u.mutation.SetMatchMode(v)
return _u
}
// SetNillableMatchMode sets the "match_mode" field if the given value is not nil.
func (_u *ErrorPassthroughRuleUpdateOne) SetNillableMatchMode(v *string) *ErrorPassthroughRuleUpdateOne {
if v != nil {
_u.SetMatchMode(*v)
}
return _u
}
// SetPlatforms sets the "platforms" field.
func (_u *ErrorPassthroughRuleUpdateOne) SetPlatforms(v []string) *ErrorPassthroughRuleUpdateOne {
_u.mutation.SetPlatforms(v)
return _u
}
// AppendPlatforms appends value to the "platforms" field.
func (_u *ErrorPassthroughRuleUpdateOne) AppendPlatforms(v []string) *ErrorPassthroughRuleUpdateOne {
_u.mutation.AppendPlatforms(v)
return _u
}
// ClearPlatforms clears the value of the "platforms" field.
func (_u *ErrorPassthroughRuleUpdateOne) ClearPlatforms() *ErrorPassthroughRuleUpdateOne {
_u.mutation.ClearPlatforms()
return _u
}
// SetPassthroughCode sets the "passthrough_code" field.
func (_u *ErrorPassthroughRuleUpdateOne) SetPassthroughCode(v bool) *ErrorPassthroughRuleUpdateOne {
_u.mutation.SetPassthroughCode(v)
return _u
}
// SetNillablePassthroughCode sets the "passthrough_code" field if the given value is not nil.
func (_u *ErrorPassthroughRuleUpdateOne) SetNillablePassthroughCode(v *bool) *ErrorPassthroughRuleUpdateOne {
if v != nil {
_u.SetPassthroughCode(*v)
}
return _u
}
// SetResponseCode sets the "response_code" field.
func (_u *ErrorPassthroughRuleUpdateOne) SetResponseCode(v int) *ErrorPassthroughRuleUpdateOne {
_u.mutation.ResetResponseCode()
_u.mutation.SetResponseCode(v)
return _u
}
// SetNillableResponseCode sets the "response_code" field if the given value is not nil.
func (_u *ErrorPassthroughRuleUpdateOne) SetNillableResponseCode(v *int) *ErrorPassthroughRuleUpdateOne {
if v != nil {
_u.SetResponseCode(*v)
}
return _u
}
// AddResponseCode adds value to the "response_code" field.
func (_u *ErrorPassthroughRuleUpdateOne) AddResponseCode(v int) *ErrorPassthroughRuleUpdateOne {
_u.mutation.AddResponseCode(v)
return _u
}
// ClearResponseCode clears the value of the "response_code" field.
func (_u *ErrorPassthroughRuleUpdateOne) ClearResponseCode() *ErrorPassthroughRuleUpdateOne {
_u.mutation.ClearResponseCode()
return _u
}
// SetPassthroughBody sets the "passthrough_body" field.
func (_u *ErrorPassthroughRuleUpdateOne) SetPassthroughBody(v bool) *ErrorPassthroughRuleUpdateOne {
_u.mutation.SetPassthroughBody(v)
return _u
}
// SetNillablePassthroughBody sets the "passthrough_body" field if the given value is not nil.
func (_u *ErrorPassthroughRuleUpdateOne) SetNillablePassthroughBody(v *bool) *ErrorPassthroughRuleUpdateOne {
if v != nil {
_u.SetPassthroughBody(*v)
}
return _u
}
// SetCustomMessage sets the "custom_message" field.
func (_u *ErrorPassthroughRuleUpdateOne) SetCustomMessage(v string) *ErrorPassthroughRuleUpdateOne {
_u.mutation.SetCustomMessage(v)
return _u
}
// SetNillableCustomMessage sets the "custom_message" field if the given value is not nil.
func (_u *ErrorPassthroughRuleUpdateOne) SetNillableCustomMessage(v *string) *ErrorPassthroughRuleUpdateOne {
if v != nil {
_u.SetCustomMessage(*v)
}
return _u
}
// ClearCustomMessage clears the value of the "custom_message" field.
func (_u *ErrorPassthroughRuleUpdateOne) ClearCustomMessage() *ErrorPassthroughRuleUpdateOne {
_u.mutation.ClearCustomMessage()
return _u
}
// SetDescription sets the "description" field.
func (_u *ErrorPassthroughRuleUpdateOne) SetDescription(v string) *ErrorPassthroughRuleUpdateOne {
_u.mutation.SetDescription(v)
return _u
}
// SetNillableDescription sets the "description" field if the given value is not nil.
func (_u *ErrorPassthroughRuleUpdateOne) SetNillableDescription(v *string) *ErrorPassthroughRuleUpdateOne {
if v != nil {
_u.SetDescription(*v)
}
return _u
}
// ClearDescription clears the value of the "description" field.
func (_u *ErrorPassthroughRuleUpdateOne) ClearDescription() *ErrorPassthroughRuleUpdateOne {
_u.mutation.ClearDescription()
return _u
}
// Mutation returns the ErrorPassthroughRuleMutation object of the builder.
func (_u *ErrorPassthroughRuleUpdateOne) Mutation() *ErrorPassthroughRuleMutation {
return _u.mutation
}
// Where appends a list predicates to the ErrorPassthroughRuleUpdate builder.
func (_u *ErrorPassthroughRuleUpdateOne) Where(ps ...predicate.ErrorPassthroughRule) *ErrorPassthroughRuleUpdateOne {
_u.mutation.Where(ps...)
return _u
}
// Select allows selecting one or more fields (columns) of the returned entity.
// The default is selecting all fields defined in the entity schema.
func (_u *ErrorPassthroughRuleUpdateOne) Select(field string, fields ...string) *ErrorPassthroughRuleUpdateOne {
_u.fields = append([]string{field}, fields...)
return _u
}
// Save executes the query and returns the updated ErrorPassthroughRule entity.
func (_u *ErrorPassthroughRuleUpdateOne) Save(ctx context.Context) (*ErrorPassthroughRule, error) {
_u.defaults()
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
}
// SaveX is like Save, but panics if an error occurs.
func (_u *ErrorPassthroughRuleUpdateOne) SaveX(ctx context.Context) *ErrorPassthroughRule {
node, err := _u.Save(ctx)
if err != nil {
panic(err)
}
return node
}
// Exec executes the query on the entity.
func (_u *ErrorPassthroughRuleUpdateOne) Exec(ctx context.Context) error {
_, err := _u.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (_u *ErrorPassthroughRuleUpdateOne) ExecX(ctx context.Context) {
if err := _u.Exec(ctx); err != nil {
panic(err)
}
}
// defaults sets the default values of the builder before save.
func (_u *ErrorPassthroughRuleUpdateOne) defaults() {
if _, ok := _u.mutation.UpdatedAt(); !ok {
v := errorpassthroughrule.UpdateDefaultUpdatedAt()
_u.mutation.SetUpdatedAt(v)
}
}
// check runs all checks and user-defined validators on the builder.
func (_u *ErrorPassthroughRuleUpdateOne) check() error {
if v, ok := _u.mutation.Name(); ok {
if err := errorpassthroughrule.NameValidator(v); err != nil {
return &ValidationError{Name: "name", err: fmt.Errorf(`ent: validator failed for field "ErrorPassthroughRule.name": %w`, err)}
}
}
if v, ok := _u.mutation.MatchMode(); ok {
if err := errorpassthroughrule.MatchModeValidator(v); err != nil {
return &ValidationError{Name: "match_mode", err: fmt.Errorf(`ent: validator failed for field "ErrorPassthroughRule.match_mode": %w`, err)}
}
}
return nil
}
func (_u *ErrorPassthroughRuleUpdateOne) sqlSave(ctx context.Context) (_node *ErrorPassthroughRule, err error) {
if err := _u.check(); err != nil {
return _node, err
}
_spec := sqlgraph.NewUpdateSpec(errorpassthroughrule.Table, errorpassthroughrule.Columns, sqlgraph.NewFieldSpec(errorpassthroughrule.FieldID, field.TypeInt64))
id, ok := _u.mutation.ID()
if !ok {
return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "ErrorPassthroughRule.id" for update`)}
}
_spec.Node.ID.Value = id
if fields := _u.fields; len(fields) > 0 {
_spec.Node.Columns = make([]string, 0, len(fields))
_spec.Node.Columns = append(_spec.Node.Columns, errorpassthroughrule.FieldID)
for _, f := range fields {
if !errorpassthroughrule.ValidColumn(f) {
return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
}
if f != errorpassthroughrule.FieldID {
_spec.Node.Columns = append(_spec.Node.Columns, f)
}
}
}
if ps := _u.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if value, ok := _u.mutation.UpdatedAt(); ok {
_spec.SetField(errorpassthroughrule.FieldUpdatedAt, field.TypeTime, value)
}
if value, ok := _u.mutation.Name(); ok {
_spec.SetField(errorpassthroughrule.FieldName, field.TypeString, value)
}
if value, ok := _u.mutation.Enabled(); ok {
_spec.SetField(errorpassthroughrule.FieldEnabled, field.TypeBool, value)
}
if value, ok := _u.mutation.Priority(); ok {
_spec.SetField(errorpassthroughrule.FieldPriority, field.TypeInt, value)
}
if value, ok := _u.mutation.AddedPriority(); ok {
_spec.AddField(errorpassthroughrule.FieldPriority, field.TypeInt, value)
}
if value, ok := _u.mutation.ErrorCodes(); ok {
_spec.SetField(errorpassthroughrule.FieldErrorCodes, field.TypeJSON, value)
}
if value, ok := _u.mutation.AppendedErrorCodes(); ok {
_spec.AddModifier(func(u *sql.UpdateBuilder) {
sqljson.Append(u, errorpassthroughrule.FieldErrorCodes, value)
})
}
if _u.mutation.ErrorCodesCleared() {
_spec.ClearField(errorpassthroughrule.FieldErrorCodes, field.TypeJSON)
}
if value, ok := _u.mutation.Keywords(); ok {
_spec.SetField(errorpassthroughrule.FieldKeywords, field.TypeJSON, value)
}
if value, ok := _u.mutation.AppendedKeywords(); ok {
_spec.AddModifier(func(u *sql.UpdateBuilder) {
sqljson.Append(u, errorpassthroughrule.FieldKeywords, value)
})
}
if _u.mutation.KeywordsCleared() {
_spec.ClearField(errorpassthroughrule.FieldKeywords, field.TypeJSON)
}
if value, ok := _u.mutation.MatchMode(); ok {
_spec.SetField(errorpassthroughrule.FieldMatchMode, field.TypeString, value)
}
if value, ok := _u.mutation.Platforms(); ok {
_spec.SetField(errorpassthroughrule.FieldPlatforms, field.TypeJSON, value)
}
if value, ok := _u.mutation.AppendedPlatforms(); ok {
_spec.AddModifier(func(u *sql.UpdateBuilder) {
sqljson.Append(u, errorpassthroughrule.FieldPlatforms, value)
})
}
if _u.mutation.PlatformsCleared() {
_spec.ClearField(errorpassthroughrule.FieldPlatforms, field.TypeJSON)
}
if value, ok := _u.mutation.PassthroughCode(); ok {
_spec.SetField(errorpassthroughrule.FieldPassthroughCode, field.TypeBool, value)
}
if value, ok := _u.mutation.ResponseCode(); ok {
_spec.SetField(errorpassthroughrule.FieldResponseCode, field.TypeInt, value)
}
if value, ok := _u.mutation.AddedResponseCode(); ok {
_spec.AddField(errorpassthroughrule.FieldResponseCode, field.TypeInt, value)
}
if _u.mutation.ResponseCodeCleared() {
_spec.ClearField(errorpassthroughrule.FieldResponseCode, field.TypeInt)
}
if value, ok := _u.mutation.PassthroughBody(); ok {
_spec.SetField(errorpassthroughrule.FieldPassthroughBody, field.TypeBool, value)
}
if value, ok := _u.mutation.CustomMessage(); ok {
_spec.SetField(errorpassthroughrule.FieldCustomMessage, field.TypeString, value)
}
if _u.mutation.CustomMessageCleared() {
_spec.ClearField(errorpassthroughrule.FieldCustomMessage, field.TypeString)
}
if value, ok := _u.mutation.Description(); ok {
_spec.SetField(errorpassthroughrule.FieldDescription, field.TypeString, value)
}
if _u.mutation.DescriptionCleared() {
_spec.ClearField(errorpassthroughrule.FieldDescription, field.TypeString)
}
_node = &ErrorPassthroughRule{config: _u.config}
_spec.Assign = _node.assignValues
_spec.ScanValues = _node.scanValues
if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{errorpassthroughrule.Label}
} else if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
return nil, err
}
_u.mutation.done = true
return _node, nil
}

View File

@@ -69,6 +69,18 @@ func (f AnnouncementReadFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.V
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.AnnouncementReadMutation", m) return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.AnnouncementReadMutation", m)
} }
// The ErrorPassthroughRuleFunc type is an adapter to allow the use of ordinary
// function as ErrorPassthroughRule mutator.
type ErrorPassthroughRuleFunc func(context.Context, *ent.ErrorPassthroughRuleMutation) (ent.Value, error)
// Mutate calls f(ctx, m).
func (f ErrorPassthroughRuleFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
if mv, ok := m.(*ent.ErrorPassthroughRuleMutation); ok {
return f(ctx, mv)
}
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.ErrorPassthroughRuleMutation", m)
}
// The GroupFunc type is an adapter to allow the use of ordinary // The GroupFunc type is an adapter to allow the use of ordinary
// function as Group mutator. // function as Group mutator.
type GroupFunc func(context.Context, *ent.GroupMutation) (ent.Value, error) type GroupFunc func(context.Context, *ent.GroupMutation) (ent.Value, error)

View File

@@ -13,6 +13,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/announcement" "github.com/Wei-Shaw/sub2api/ent/announcement"
"github.com/Wei-Shaw/sub2api/ent/announcementread" "github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey" "github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
"github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/predicate" "github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/promocode" "github.com/Wei-Shaw/sub2api/ent/promocode"
@@ -220,6 +221,33 @@ func (f TraverseAnnouncementRead) Traverse(ctx context.Context, q ent.Query) err
return fmt.Errorf("unexpected query type %T. expect *ent.AnnouncementReadQuery", q) return fmt.Errorf("unexpected query type %T. expect *ent.AnnouncementReadQuery", q)
} }
// The ErrorPassthroughRuleFunc type is an adapter to allow the use of ordinary function as a Querier.
type ErrorPassthroughRuleFunc func(context.Context, *ent.ErrorPassthroughRuleQuery) (ent.Value, error)
// Query calls f(ctx, q).
func (f ErrorPassthroughRuleFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
if q, ok := q.(*ent.ErrorPassthroughRuleQuery); ok {
return f(ctx, q)
}
return nil, fmt.Errorf("unexpected query type %T. expect *ent.ErrorPassthroughRuleQuery", q)
}
// The TraverseErrorPassthroughRule type is an adapter to allow the use of ordinary function as Traverser.
type TraverseErrorPassthroughRule func(context.Context, *ent.ErrorPassthroughRuleQuery) error
// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline.
func (f TraverseErrorPassthroughRule) Intercept(next ent.Querier) ent.Querier {
return next
}
// Traverse calls f(ctx, q).
func (f TraverseErrorPassthroughRule) Traverse(ctx context.Context, q ent.Query) error {
if q, ok := q.(*ent.ErrorPassthroughRuleQuery); ok {
return f(ctx, q)
}
return fmt.Errorf("unexpected query type %T. expect *ent.ErrorPassthroughRuleQuery", q)
}
// The GroupFunc type is an adapter to allow the use of ordinary function as a Querier. // The GroupFunc type is an adapter to allow the use of ordinary function as a Querier.
type GroupFunc func(context.Context, *ent.GroupQuery) (ent.Value, error) type GroupFunc func(context.Context, *ent.GroupQuery) (ent.Value, error)
@@ -584,6 +612,8 @@ func NewQuery(q ent.Query) (Query, error) {
return &query[*ent.AnnouncementQuery, predicate.Announcement, announcement.OrderOption]{typ: ent.TypeAnnouncement, tq: q}, nil return &query[*ent.AnnouncementQuery, predicate.Announcement, announcement.OrderOption]{typ: ent.TypeAnnouncement, tq: q}, nil
case *ent.AnnouncementReadQuery: case *ent.AnnouncementReadQuery:
return &query[*ent.AnnouncementReadQuery, predicate.AnnouncementRead, announcementread.OrderOption]{typ: ent.TypeAnnouncementRead, tq: q}, nil return &query[*ent.AnnouncementReadQuery, predicate.AnnouncementRead, announcementread.OrderOption]{typ: ent.TypeAnnouncementRead, tq: q}, nil
case *ent.ErrorPassthroughRuleQuery:
return &query[*ent.ErrorPassthroughRuleQuery, predicate.ErrorPassthroughRule, errorpassthroughrule.OrderOption]{typ: ent.TypeErrorPassthroughRule, tq: q}, nil
case *ent.GroupQuery: case *ent.GroupQuery:
return &query[*ent.GroupQuery, predicate.Group, group.OrderOption]{typ: ent.TypeGroup, tq: q}, nil return &query[*ent.GroupQuery, predicate.Group, group.OrderOption]{typ: ent.TypeGroup, tq: q}, nil
case *ent.PromoCodeQuery: case *ent.PromoCodeQuery:

View File

@@ -309,6 +309,42 @@ var (
}, },
}, },
} }
// ErrorPassthroughRulesColumns holds the columns for the "error_passthrough_rules" table.
ErrorPassthroughRulesColumns = []*schema.Column{
{Name: "id", Type: field.TypeInt64, Increment: true},
{Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "name", Type: field.TypeString, Size: 100},
{Name: "enabled", Type: field.TypeBool, Default: true},
{Name: "priority", Type: field.TypeInt, Default: 0},
{Name: "error_codes", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}},
{Name: "keywords", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}},
{Name: "match_mode", Type: field.TypeString, Size: 10, Default: "any"},
{Name: "platforms", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}},
{Name: "passthrough_code", Type: field.TypeBool, Default: true},
{Name: "response_code", Type: field.TypeInt, Nullable: true},
{Name: "passthrough_body", Type: field.TypeBool, Default: true},
{Name: "custom_message", Type: field.TypeString, Nullable: true, Size: 2147483647},
{Name: "description", Type: field.TypeString, Nullable: true, Size: 2147483647},
}
// ErrorPassthroughRulesTable holds the schema information for the "error_passthrough_rules" table.
ErrorPassthroughRulesTable = &schema.Table{
Name: "error_passthrough_rules",
Columns: ErrorPassthroughRulesColumns,
PrimaryKey: []*schema.Column{ErrorPassthroughRulesColumns[0]},
Indexes: []*schema.Index{
{
Name: "errorpassthroughrule_enabled",
Unique: false,
Columns: []*schema.Column{ErrorPassthroughRulesColumns[4]},
},
{
Name: "errorpassthroughrule_priority",
Unique: false,
Columns: []*schema.Column{ErrorPassthroughRulesColumns[5]},
},
},
}
// GroupsColumns holds the columns for the "groups" table. // GroupsColumns holds the columns for the "groups" table.
GroupsColumns = []*schema.Column{ GroupsColumns = []*schema.Column{
{Name: "id", Type: field.TypeInt64, Increment: true}, {Name: "id", Type: field.TypeInt64, Increment: true},
@@ -955,6 +991,7 @@ var (
AccountGroupsTable, AccountGroupsTable,
AnnouncementsTable, AnnouncementsTable,
AnnouncementReadsTable, AnnouncementReadsTable,
ErrorPassthroughRulesTable,
GroupsTable, GroupsTable,
PromoCodesTable, PromoCodesTable,
PromoCodeUsagesTable, PromoCodeUsagesTable,
@@ -994,6 +1031,9 @@ func init() {
AnnouncementReadsTable.Annotation = &entsql.Annotation{ AnnouncementReadsTable.Annotation = &entsql.Annotation{
Table: "announcement_reads", Table: "announcement_reads",
} }
ErrorPassthroughRulesTable.Annotation = &entsql.Annotation{
Table: "error_passthrough_rules",
}
GroupsTable.Annotation = &entsql.Annotation{ GroupsTable.Annotation = &entsql.Annotation{
Table: "groups", Table: "groups",
} }

File diff suppressed because it is too large Load Diff

View File

@@ -21,6 +21,9 @@ type Announcement func(*sql.Selector)
// AnnouncementRead is the predicate function for announcementread builders. // AnnouncementRead is the predicate function for announcementread builders.
type AnnouncementRead func(*sql.Selector) type AnnouncementRead func(*sql.Selector)
// ErrorPassthroughRule is the predicate function for errorpassthroughrule builders.
type ErrorPassthroughRule func(*sql.Selector)
// Group is the predicate function for group builders. // Group is the predicate function for group builders.
type Group func(*sql.Selector) type Group func(*sql.Selector)

View File

@@ -10,6 +10,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/announcement" "github.com/Wei-Shaw/sub2api/ent/announcement"
"github.com/Wei-Shaw/sub2api/ent/announcementread" "github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey" "github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
"github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/promocode" "github.com/Wei-Shaw/sub2api/ent/promocode"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage" "github.com/Wei-Shaw/sub2api/ent/promocodeusage"
@@ -270,6 +271,61 @@ func init() {
announcementreadDescCreatedAt := announcementreadFields[3].Descriptor() announcementreadDescCreatedAt := announcementreadFields[3].Descriptor()
// announcementread.DefaultCreatedAt holds the default value on creation for the created_at field. // announcementread.DefaultCreatedAt holds the default value on creation for the created_at field.
announcementread.DefaultCreatedAt = announcementreadDescCreatedAt.Default.(func() time.Time) announcementread.DefaultCreatedAt = announcementreadDescCreatedAt.Default.(func() time.Time)
errorpassthroughruleMixin := schema.ErrorPassthroughRule{}.Mixin()
errorpassthroughruleMixinFields0 := errorpassthroughruleMixin[0].Fields()
_ = errorpassthroughruleMixinFields0
errorpassthroughruleFields := schema.ErrorPassthroughRule{}.Fields()
_ = errorpassthroughruleFields
// errorpassthroughruleDescCreatedAt is the schema descriptor for created_at field.
errorpassthroughruleDescCreatedAt := errorpassthroughruleMixinFields0[0].Descriptor()
// errorpassthroughrule.DefaultCreatedAt holds the default value on creation for the created_at field.
errorpassthroughrule.DefaultCreatedAt = errorpassthroughruleDescCreatedAt.Default.(func() time.Time)
// errorpassthroughruleDescUpdatedAt is the schema descriptor for updated_at field.
errorpassthroughruleDescUpdatedAt := errorpassthroughruleMixinFields0[1].Descriptor()
// errorpassthroughrule.DefaultUpdatedAt holds the default value on creation for the updated_at field.
errorpassthroughrule.DefaultUpdatedAt = errorpassthroughruleDescUpdatedAt.Default.(func() time.Time)
// errorpassthroughrule.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
errorpassthroughrule.UpdateDefaultUpdatedAt = errorpassthroughruleDescUpdatedAt.UpdateDefault.(func() time.Time)
// errorpassthroughruleDescName is the schema descriptor for name field.
errorpassthroughruleDescName := errorpassthroughruleFields[0].Descriptor()
// errorpassthroughrule.NameValidator is a validator for the "name" field. It is called by the builders before save.
errorpassthroughrule.NameValidator = func() func(string) error {
validators := errorpassthroughruleDescName.Validators
fns := [...]func(string) error{
validators[0].(func(string) error),
validators[1].(func(string) error),
}
return func(name string) error {
for _, fn := range fns {
if err := fn(name); err != nil {
return err
}
}
return nil
}
}()
// errorpassthroughruleDescEnabled is the schema descriptor for enabled field.
errorpassthroughruleDescEnabled := errorpassthroughruleFields[1].Descriptor()
// errorpassthroughrule.DefaultEnabled holds the default value on creation for the enabled field.
errorpassthroughrule.DefaultEnabled = errorpassthroughruleDescEnabled.Default.(bool)
// errorpassthroughruleDescPriority is the schema descriptor for priority field.
errorpassthroughruleDescPriority := errorpassthroughruleFields[2].Descriptor()
// errorpassthroughrule.DefaultPriority holds the default value on creation for the priority field.
errorpassthroughrule.DefaultPriority = errorpassthroughruleDescPriority.Default.(int)
// errorpassthroughruleDescMatchMode is the schema descriptor for match_mode field.
errorpassthroughruleDescMatchMode := errorpassthroughruleFields[5].Descriptor()
// errorpassthroughrule.DefaultMatchMode holds the default value on creation for the match_mode field.
errorpassthroughrule.DefaultMatchMode = errorpassthroughruleDescMatchMode.Default.(string)
// errorpassthroughrule.MatchModeValidator is a validator for the "match_mode" field. It is called by the builders before save.
errorpassthroughrule.MatchModeValidator = errorpassthroughruleDescMatchMode.Validators[0].(func(string) error)
// errorpassthroughruleDescPassthroughCode is the schema descriptor for passthrough_code field.
errorpassthroughruleDescPassthroughCode := errorpassthroughruleFields[7].Descriptor()
// errorpassthroughrule.DefaultPassthroughCode holds the default value on creation for the passthrough_code field.
errorpassthroughrule.DefaultPassthroughCode = errorpassthroughruleDescPassthroughCode.Default.(bool)
// errorpassthroughruleDescPassthroughBody is the schema descriptor for passthrough_body field.
errorpassthroughruleDescPassthroughBody := errorpassthroughruleFields[9].Descriptor()
// errorpassthroughrule.DefaultPassthroughBody holds the default value on creation for the passthrough_body field.
errorpassthroughrule.DefaultPassthroughBody = errorpassthroughruleDescPassthroughBody.Default.(bool)
groupMixin := schema.Group{}.Mixin() groupMixin := schema.Group{}.Mixin()
groupMixinHooks1 := groupMixin[1].Hooks() groupMixinHooks1 := groupMixin[1].Hooks()
group.Hooks[0] = groupMixinHooks1[0] group.Hooks[0] = groupMixinHooks1[0]

View File

@@ -0,0 +1,121 @@
// Package schema 定义 Ent ORM 的数据库 schema。
package schema
import (
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
"entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/entsql"
"entgo.io/ent/schema"
"entgo.io/ent/schema/field"
"entgo.io/ent/schema/index"
)
// ErrorPassthroughRule 定义全局错误透传规则的 schema。
//
// 错误透传规则用于控制上游错误如何返回给客户端:
// - 匹配条件:错误码 + 关键词组合
// - 响应行为:透传原始信息 或 自定义错误信息
// - 响应状态码:可指定返回给客户端的状态码
// - 平台范围规则适用的平台Anthropic、OpenAI、Gemini、Antigravity
type ErrorPassthroughRule struct {
ent.Schema
}
// Annotations 返回 schema 的注解配置。
func (ErrorPassthroughRule) Annotations() []schema.Annotation {
return []schema.Annotation{
entsql.Annotation{Table: "error_passthrough_rules"},
}
}
// Mixin 返回该 schema 使用的混入组件。
func (ErrorPassthroughRule) Mixin() []ent.Mixin {
return []ent.Mixin{
mixins.TimeMixin{},
}
}
// Fields 定义错误透传规则实体的所有字段。
func (ErrorPassthroughRule) Fields() []ent.Field {
return []ent.Field{
// name: 规则名称,用于在界面中标识规则
field.String("name").
MaxLen(100).
NotEmpty(),
// enabled: 是否启用该规则
field.Bool("enabled").
Default(true),
// priority: 规则优先级,数值越小优先级越高
// 匹配时按优先级顺序检查,命中第一个匹配的规则
field.Int("priority").
Default(0),
// error_codes: 匹配的错误码列表OR关系
// 例如:[422, 400] 表示匹配 422 或 400 错误码
field.JSON("error_codes", []int{}).
Optional().
SchemaType(map[string]string{dialect.Postgres: "jsonb"}),
// keywords: 匹配的关键词列表OR关系
// 例如:["context limit", "model not supported"]
// 关键词匹配不区分大小写
field.JSON("keywords", []string{}).
Optional().
SchemaType(map[string]string{dialect.Postgres: "jsonb"}),
// match_mode: 匹配模式
// - "any": 错误码匹配 OR 关键词匹配(任一条件满足即可)
// - "all": 错误码匹配 AND 关键词匹配(所有条件都必须满足)
field.String("match_mode").
MaxLen(10).
Default("any"),
// platforms: 适用平台列表
// 例如:["anthropic", "openai", "gemini", "antigravity"]
// 空列表表示适用于所有平台
field.JSON("platforms", []string{}).
Optional().
SchemaType(map[string]string{dialect.Postgres: "jsonb"}),
// passthrough_code: 是否透传上游原始状态码
// true: 使用上游返回的状态码
// false: 使用 response_code 指定的状态码
field.Bool("passthrough_code").
Default(true),
// response_code: 自定义响应状态码
// 当 passthrough_code=false 时使用此状态码
field.Int("response_code").
Optional().
Nillable(),
// passthrough_body: 是否透传上游原始错误信息
// true: 使用上游返回的错误信息
// false: 使用 custom_message 指定的错误信息
field.Bool("passthrough_body").
Default(true),
// custom_message: 自定义错误信息
// 当 passthrough_body=false 时使用此错误信息
field.Text("custom_message").
Optional().
Nillable(),
// description: 规则描述,用于说明规则的用途
field.Text("description").
Optional().
Nillable(),
}
}
// Indexes 定义数据库索引,优化查询性能。
func (ErrorPassthroughRule) Indexes() []ent.Index {
return []ent.Index{
index.Fields("enabled"), // 筛选启用的规则
index.Fields("priority"), // 按优先级排序
}
}

View File

@@ -24,6 +24,8 @@ type Tx struct {
Announcement *AnnouncementClient Announcement *AnnouncementClient
// AnnouncementRead is the client for interacting with the AnnouncementRead builders. // AnnouncementRead is the client for interacting with the AnnouncementRead builders.
AnnouncementRead *AnnouncementReadClient AnnouncementRead *AnnouncementReadClient
// ErrorPassthroughRule is the client for interacting with the ErrorPassthroughRule builders.
ErrorPassthroughRule *ErrorPassthroughRuleClient
// Group is the client for interacting with the Group builders. // Group is the client for interacting with the Group builders.
Group *GroupClient Group *GroupClient
// PromoCode is the client for interacting with the PromoCode builders. // PromoCode is the client for interacting with the PromoCode builders.
@@ -186,6 +188,7 @@ func (tx *Tx) init() {
tx.AccountGroup = NewAccountGroupClient(tx.config) tx.AccountGroup = NewAccountGroupClient(tx.config)
tx.Announcement = NewAnnouncementClient(tx.config) tx.Announcement = NewAnnouncementClient(tx.config)
tx.AnnouncementRead = NewAnnouncementReadClient(tx.config) tx.AnnouncementRead = NewAnnouncementReadClient(tx.config)
tx.ErrorPassthroughRule = NewErrorPassthroughRuleClient(tx.config)
tx.Group = NewGroupClient(tx.config) tx.Group = NewGroupClient(tx.config)
tx.PromoCode = NewPromoCodeClient(tx.config) tx.PromoCode = NewPromoCodeClient(tx.config)
tx.PromoCodeUsage = NewPromoCodeUsageClient(tx.config) tx.PromoCodeUsage = NewPromoCodeUsageClient(tx.config)

View File

@@ -25,10 +25,10 @@ require (
github.com/tidwall/gjson v1.18.0 github.com/tidwall/gjson v1.18.0
github.com/tidwall/sjson v1.2.5 github.com/tidwall/sjson v1.2.5
github.com/zeromicro/go-zero v1.9.4 github.com/zeromicro/go-zero v1.9.4
golang.org/x/crypto v0.46.0 golang.org/x/crypto v0.47.0
golang.org/x/net v0.48.0 golang.org/x/net v0.49.0
golang.org/x/sync v0.19.0 golang.org/x/sync v0.19.0
golang.org/x/term v0.38.0 golang.org/x/term v0.39.0
gopkg.in/yaml.v3 v3.0.1 gopkg.in/yaml.v3 v3.0.1
modernc.org/sqlite v1.44.3 modernc.org/sqlite v1.44.3
) )
@@ -75,12 +75,10 @@ require (
github.com/goccy/go-json v0.10.2 // indirect github.com/goccy/go-json v0.10.2 // indirect
github.com/google/go-cmp v0.7.0 // indirect github.com/google/go-cmp v0.7.0 // indirect
github.com/google/go-querystring v1.1.0 // indirect github.com/google/go-querystring v1.1.0 // indirect
github.com/google/subcommands v1.2.0 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect github.com/hashicorp/hcl v1.0.0 // indirect
github.com/hashicorp/hcl/v2 v2.18.1 // indirect github.com/hashicorp/hcl/v2 v2.18.1 // indirect
github.com/icholy/digest v1.1.0 // indirect github.com/icholy/digest v1.1.0 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/compress v1.18.2 // indirect github.com/klauspost/compress v1.18.2 // indirect
github.com/klauspost/cpuid/v2 v2.2.4 // indirect github.com/klauspost/cpuid/v2 v2.2.4 // indirect
@@ -89,7 +87,6 @@ require (
github.com/magiconair/properties v1.8.10 // indirect github.com/magiconair/properties v1.8.10 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-runewidth v0.0.15 // indirect
github.com/mdelapenya/tlscert v0.2.0 // indirect github.com/mdelapenya/tlscert v0.2.0 // indirect
github.com/mitchellh/go-wordwrap v1.0.1 // indirect github.com/mitchellh/go-wordwrap v1.0.1 // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect
@@ -104,7 +101,6 @@ require (
github.com/modern-go/reflect2 v1.0.2 // indirect github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/morikuni/aec v1.0.0 // indirect github.com/morikuni/aec v1.0.0 // indirect
github.com/ncruces/go-strftime v1.0.0 // indirect github.com/ncruces/go-strftime v1.0.0 // indirect
github.com/olekukonko/tablewriter v0.0.5 // indirect
github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/opencontainers/image-spec v1.1.1 // indirect github.com/opencontainers/image-spec v1.1.1 // indirect
github.com/pelletier/go-toml/v2 v2.2.2 // indirect github.com/pelletier/go-toml/v2 v2.2.2 // indirect
@@ -114,7 +110,6 @@ require (
github.com/quic-go/qpack v0.6.0 // indirect github.com/quic-go/qpack v0.6.0 // indirect
github.com/quic-go/quic-go v0.57.1 // indirect github.com/quic-go/quic-go v0.57.1 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
github.com/rivo/uniseg v0.2.0 // indirect
github.com/sagikazarmark/locafero v0.4.0 // indirect github.com/sagikazarmark/locafero v0.4.0 // indirect
github.com/sagikazarmark/slog-shim v0.1.0 // indirect github.com/sagikazarmark/slog-shim v0.1.0 // indirect
github.com/sirupsen/logrus v1.9.3 // indirect github.com/sirupsen/logrus v1.9.3 // indirect
@@ -122,7 +117,6 @@ require (
github.com/spaolacci/murmur3 v1.1.0 // indirect github.com/spaolacci/murmur3 v1.1.0 // indirect
github.com/spf13/afero v1.11.0 // indirect github.com/spf13/afero v1.11.0 // indirect
github.com/spf13/cast v1.6.0 // indirect github.com/spf13/cast v1.6.0 // indirect
github.com/spf13/cobra v1.7.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect github.com/spf13/pflag v1.0.5 // indirect
github.com/subosito/gotenv v1.6.0 // indirect github.com/subosito/gotenv v1.6.0 // indirect
github.com/testcontainers/testcontainers-go v0.40.0 // indirect github.com/testcontainers/testcontainers-go v0.40.0 // indirect
@@ -146,10 +140,9 @@ require (
go.uber.org/multierr v1.9.0 // indirect go.uber.org/multierr v1.9.0 // indirect
golang.org/x/arch v0.3.0 // indirect golang.org/x/arch v0.3.0 // indirect
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
golang.org/x/mod v0.30.0 // indirect golang.org/x/mod v0.31.0 // indirect
golang.org/x/sys v0.39.0 // indirect golang.org/x/sys v0.40.0 // indirect
golang.org/x/text v0.32.0 // indirect golang.org/x/text v0.33.0 // indirect
golang.org/x/tools v0.39.0 // indirect
google.golang.org/grpc v1.75.1 // indirect google.golang.org/grpc v1.75.1 // indirect
google.golang.org/protobuf v1.36.10 // indirect google.golang.org/protobuf v1.36.10 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect gopkg.in/ini.v1 v1.67.0 // indirect

View File

@@ -46,7 +46,6 @@ github.com/containerd/platforms v0.2.1 h1:zvwtM3rz2YHPQsF2CHYM8+KtB5dvhISiXh5ZpS
github.com/containerd/platforms v0.2.1/go.mod h1:XHCb+2/hzowdiut9rkudds9bE5yJ7npe7dG/wG+uFPw= github.com/containerd/platforms v0.2.1/go.mod h1:XHCb+2/hzowdiut9rkudds9bE5yJ7npe7dG/wG+uFPw=
github.com/cpuguy83/dockercfg v0.3.2 h1:DlJTyZGBDlXqUZ2Dk2Q3xHs/FtnooJJVaad2S9GKorA= github.com/cpuguy83/dockercfg v0.3.2 h1:DlJTyZGBDlXqUZ2Dk2Q3xHs/FtnooJJVaad2S9GKorA=
github.com/cpuguy83/dockercfg v0.3.2/go.mod h1:sugsbF4//dDlL/i+S+rtpIWp+5h0BHJHfjj5/jFyUJc= github.com/cpuguy83/dockercfg v0.3.2/go.mod h1:sugsbF4//dDlL/i+S+rtpIWp+5h0BHJHfjj5/jFyUJc=
github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY= github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY=
github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4= github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
@@ -117,8 +116,6 @@ github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs= github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA= github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE=
github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4= github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4=
@@ -138,8 +135,6 @@ github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4=
github.com/icholy/digest v1.1.0/go.mod h1:QNrsSGQ5v7v9cReDI0+eyjsXGUoRSUZQHeQ5C4XLa0Y= github.com/icholy/digest v1.1.0/go.mod h1:QNrsSGQ5v7v9cReDI0+eyjsXGUoRSUZQHeQ5C4XLa0Y=
github.com/imroc/req/v3 v3.57.0 h1:LMTUjNRUybUkTPn8oJDq8Kg3JRBOBTcnDhKu7mzupKI= github.com/imroc/req/v3 v3.57.0 h1:LMTUjNRUybUkTPn8oJDq8Kg3JRBOBTcnDhKu7mzupKI=
github.com/imroc/req/v3 v3.57.0/go.mod h1:JL62ey1nvSLq81HORNcosvlf7SxZStONNqOprg0Pz00= github.com/imroc/req/v3 v3.57.0/go.mod h1:JL62ey1nvSLq81HORNcosvlf7SxZStONNqOprg0Pz00=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
@@ -175,7 +170,6 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U= github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U=
github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM= github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
@@ -246,7 +240,6 @@ github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs=
github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro= github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/sagikazarmark/locafero v0.4.0 h1:HApY1R9zGo4DBgr7dqsTH/JJxLTTsOt7u6keLGt6kNQ= github.com/sagikazarmark/locafero v0.4.0 h1:HApY1R9zGo4DBgr7dqsTH/JJxLTTsOt7u6keLGt6kNQ=
github.com/sagikazarmark/locafero v0.4.0/go.mod h1:Pe1W6UlPYUk/+wc/6KFhbORCfqzgYEpgQ3O5fPuL3H4= github.com/sagikazarmark/locafero v0.4.0/go.mod h1:Pe1W6UlPYUk/+wc/6KFhbORCfqzgYEpgQ3O5fPuL3H4=
github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6gto+ugjYE= github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6gto+ugjYE=
@@ -350,14 +343,14 @@ go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTV
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k= golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU= golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8=
golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0= golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A=
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY= golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY=
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70= golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk= golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI=
golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc= golang.org/x/mod v0.31.0/go.mod h1:43JraMp9cGx1Rx3AqioxrbrhNsLl2l/iNAvuBkrezpg=
golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU= golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o=
golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY= golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8=
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -369,20 +362,16 @@ golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk= golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ=
golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/term v0.38.0 h1:PQ5pkm/rLO6HnxFR7N2lJHOZX6Kez5Y1gDSJla6jo7Q= golang.org/x/term v0.39.0 h1:RclSuaJf32jOqZz74CkPA9qFuVTX7vhLlpfj/IGWlqY=
golang.org/x/term v0.38.0/go.mod h1:bSEAKrOT1W+VSu9TSCMtoGEOUcKxOKgl3LE5QEF/xVg= golang.org/x/term v0.39.0/go.mod h1:yxzUCTP/U+FzoxfdKmLaA0RV1WgE0VY7hXBwKtY/4ww=
golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU= golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE=
golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY= golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8=
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
golang.org/x/tools v0.39.0 h1:ik4ho21kwuQln40uelmciQPp9SipgNDdrafrYA4TmQQ= golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA=
golang.org/x/tools v0.39.0/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ= golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc=
golang.org/x/tools/go/expect v0.1.1-deprecated h1:jpBZDwmgPhXsKZC6WhL20P4b/wmnpsEAGHaNy0n/rJM=
golang.org/x/tools/go/expect v0.1.1-deprecated/go.mod h1:eihoPOH+FgIqa3FpoTwguz/bVUSGBlGQU67vpBeOrBY=
golang.org/x/tools/go/packages/packagestest v0.1.1-deprecated h1:1h2MnaIAIXISqTFKdENegdpAgUXz6NrPEsbIeWaBRvM=
golang.org/x/tools/go/packages/packagestest v0.1.1-deprecated/go.mod h1:RVAQXBGNv1ib0J382/DPCRS/BPnsGebyM1Gj5VSDpG8=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/genproto v0.0.0-20231106174013-bbf56f31fb17 h1:wpZ8pe2x1Q3f2KyT5f8oP/fa9rHAKgFPr/HZdNuS+PQ= google.golang.org/genproto v0.0.0-20231106174013-bbf56f31fb17 h1:wpZ8pe2x1Q3f2KyT5f8oP/fa9rHAKgFPr/HZdNuS+PQ=
google.golang.org/genproto/googleapis/api v0.0.0-20250929231259-57b25ae835d4 h1:8XJ4pajGwOlasW+L13MnEGA8W4115jJySQtVfS2/IBU= google.golang.org/genproto/googleapis/api v0.0.0-20250929231259-57b25ae835d4 h1:8XJ4pajGwOlasW+L13MnEGA8W4115jJySQtVfS2/IBU=

View File

@@ -145,12 +145,24 @@ type PricingConfig struct {
} }
type ServerConfig struct { type ServerConfig struct {
Host string `mapstructure:"host"` Host string `mapstructure:"host"`
Port int `mapstructure:"port"` Port int `mapstructure:"port"`
Mode string `mapstructure:"mode"` // debug/release Mode string `mapstructure:"mode"` // debug/release
ReadHeaderTimeout int `mapstructure:"read_header_timeout"` // 读取请求头超时(秒) ReadHeaderTimeout int `mapstructure:"read_header_timeout"` // 读取请求头超时(秒)
IdleTimeout int `mapstructure:"idle_timeout"` // 空闲连接超时(秒) IdleTimeout int `mapstructure:"idle_timeout"` // 空闲连接超时(秒)
TrustedProxies []string `mapstructure:"trusted_proxies"` // 可信代理列表CIDR/IP TrustedProxies []string `mapstructure:"trusted_proxies"` // 可信代理列表CIDR/IP
MaxRequestBodySize int64 `mapstructure:"max_request_body_size"` // 全局最大请求体限制
H2C H2CConfig `mapstructure:"h2c"` // HTTP/2 Cleartext 配置
}
// H2CConfig HTTP/2 Cleartext 配置
type H2CConfig struct {
Enabled bool `mapstructure:"enabled"` // 是否启用 H2C
MaxConcurrentStreams uint32 `mapstructure:"max_concurrent_streams"` // 最大并发流数量
IdleTimeout int `mapstructure:"idle_timeout"` // 空闲超时(秒)
MaxReadFrameSize int `mapstructure:"max_read_frame_size"` // 最大帧大小(字节)
MaxUploadBufferPerConnection int `mapstructure:"max_upload_buffer_per_connection"` // 每个连接的上传缓冲区(字节)
MaxUploadBufferPerStream int `mapstructure:"max_upload_buffer_per_stream"` // 每个流的上传缓冲区(字节)
} }
type CORSConfig struct { type CORSConfig struct {
@@ -532,6 +544,13 @@ type OpsMetricsCollectorCacheConfig struct {
type JWTConfig struct { type JWTConfig struct {
Secret string `mapstructure:"secret"` Secret string `mapstructure:"secret"`
ExpireHour int `mapstructure:"expire_hour"` ExpireHour int `mapstructure:"expire_hour"`
// AccessTokenExpireMinutes: Access Token有效期分钟默认15分钟
// 短有效期减少被盗用风险配合Refresh Token实现无感续期
AccessTokenExpireMinutes int `mapstructure:"access_token_expire_minutes"`
// RefreshTokenExpireDays: Refresh Token有效期默认30天
RefreshTokenExpireDays int `mapstructure:"refresh_token_expire_days"`
// RefreshWindowMinutes: 刷新窗口分钟在Access Token过期前多久开始允许刷新
RefreshWindowMinutes int `mapstructure:"refresh_window_minutes"`
} }
// TotpConfig TOTP 双因素认证配置 // TotpConfig TOTP 双因素认证配置
@@ -752,6 +771,14 @@ func setDefaults() {
viper.SetDefault("server.read_header_timeout", 30) // 30秒读取请求头 viper.SetDefault("server.read_header_timeout", 30) // 30秒读取请求头
viper.SetDefault("server.idle_timeout", 120) // 120秒空闲超时 viper.SetDefault("server.idle_timeout", 120) // 120秒空闲超时
viper.SetDefault("server.trusted_proxies", []string{}) viper.SetDefault("server.trusted_proxies", []string{})
viper.SetDefault("server.max_request_body_size", int64(100*1024*1024))
// H2C 默认配置
viper.SetDefault("server.h2c.enabled", false)
viper.SetDefault("server.h2c.max_concurrent_streams", uint32(50)) // 50 个并发流
viper.SetDefault("server.h2c.idle_timeout", 75) // 75 秒
viper.SetDefault("server.h2c.max_read_frame_size", 1<<20) // 1MB够用
viper.SetDefault("server.h2c.max_upload_buffer_per_connection", 2<<20) // 2MB
viper.SetDefault("server.h2c.max_upload_buffer_per_stream", 512<<10) // 512KB
// CORS // CORS
viper.SetDefault("cors.allowed_origins", []string{}) viper.SetDefault("cors.allowed_origins", []string{})
@@ -848,6 +875,9 @@ func setDefaults() {
// JWT // JWT
viper.SetDefault("jwt.secret", "") viper.SetDefault("jwt.secret", "")
viper.SetDefault("jwt.expire_hour", 24) viper.SetDefault("jwt.expire_hour", 24)
viper.SetDefault("jwt.access_token_expire_minutes", 360) // 6小时Access Token有效期
viper.SetDefault("jwt.refresh_token_expire_days", 30) // 30天Refresh Token有效期
viper.SetDefault("jwt.refresh_window_minutes", 2) // 过期前2分钟开始允许刷新
// TOTP // TOTP
viper.SetDefault("totp.encryption_key", "") viper.SetDefault("totp.encryption_key", "")
@@ -1009,6 +1039,22 @@ func (c *Config) Validate() error {
if c.JWT.ExpireHour > 24 { if c.JWT.ExpireHour > 24 {
log.Printf("Warning: jwt.expire_hour is %d hours (> 24). Consider shorter expiration for security.", c.JWT.ExpireHour) log.Printf("Warning: jwt.expire_hour is %d hours (> 24). Consider shorter expiration for security.", c.JWT.ExpireHour)
} }
// JWT Refresh Token配置验证
if c.JWT.AccessTokenExpireMinutes <= 0 {
return fmt.Errorf("jwt.access_token_expire_minutes must be positive")
}
if c.JWT.AccessTokenExpireMinutes > 720 {
log.Printf("Warning: jwt.access_token_expire_minutes is %d (> 720). Consider shorter expiration for security.", c.JWT.AccessTokenExpireMinutes)
}
if c.JWT.RefreshTokenExpireDays <= 0 {
return fmt.Errorf("jwt.refresh_token_expire_days must be positive")
}
if c.JWT.RefreshTokenExpireDays > 90 {
log.Printf("Warning: jwt.refresh_token_expire_days is %d (> 90). Consider shorter expiration for security.", c.JWT.RefreshTokenExpireDays)
}
if c.JWT.RefreshWindowMinutes < 0 {
return fmt.Errorf("jwt.refresh_window_minutes must be non-negative")
}
if c.Security.CSP.Enabled && strings.TrimSpace(c.Security.CSP.Policy) == "" { if c.Security.CSP.Enabled && strings.TrimSpace(c.Security.CSP.Policy) == "" {
return fmt.Errorf("security.csp.policy is required when CSP is enabled") return fmt.Errorf("security.csp.policy is required when CSP is enabled")
} }

View File

@@ -0,0 +1,273 @@
package admin
import (
"strconv"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// ErrorPassthroughHandler 处理错误透传规则的 HTTP 请求
type ErrorPassthroughHandler struct {
service *service.ErrorPassthroughService
}
// NewErrorPassthroughHandler 创建错误透传规则处理器
func NewErrorPassthroughHandler(service *service.ErrorPassthroughService) *ErrorPassthroughHandler {
return &ErrorPassthroughHandler{service: service}
}
// CreateErrorPassthroughRuleRequest 创建规则请求
type CreateErrorPassthroughRuleRequest struct {
Name string `json:"name" binding:"required"`
Enabled *bool `json:"enabled"`
Priority int `json:"priority"`
ErrorCodes []int `json:"error_codes"`
Keywords []string `json:"keywords"`
MatchMode string `json:"match_mode"`
Platforms []string `json:"platforms"`
PassthroughCode *bool `json:"passthrough_code"`
ResponseCode *int `json:"response_code"`
PassthroughBody *bool `json:"passthrough_body"`
CustomMessage *string `json:"custom_message"`
Description *string `json:"description"`
}
// UpdateErrorPassthroughRuleRequest 更新规则请求(部分更新,所有字段可选)
type UpdateErrorPassthroughRuleRequest struct {
Name *string `json:"name"`
Enabled *bool `json:"enabled"`
Priority *int `json:"priority"`
ErrorCodes []int `json:"error_codes"`
Keywords []string `json:"keywords"`
MatchMode *string `json:"match_mode"`
Platforms []string `json:"platforms"`
PassthroughCode *bool `json:"passthrough_code"`
ResponseCode *int `json:"response_code"`
PassthroughBody *bool `json:"passthrough_body"`
CustomMessage *string `json:"custom_message"`
Description *string `json:"description"`
}
// List 获取所有规则
// GET /api/v1/admin/error-passthrough-rules
func (h *ErrorPassthroughHandler) List(c *gin.Context) {
rules, err := h.service.List(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, rules)
}
// GetByID 根据 ID 获取规则
// GET /api/v1/admin/error-passthrough-rules/:id
func (h *ErrorPassthroughHandler) GetByID(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid rule ID")
return
}
rule, err := h.service.GetByID(c.Request.Context(), id)
if err != nil {
response.ErrorFrom(c, err)
return
}
if rule == nil {
response.NotFound(c, "Rule not found")
return
}
response.Success(c, rule)
}
// Create 创建规则
// POST /api/v1/admin/error-passthrough-rules
func (h *ErrorPassthroughHandler) Create(c *gin.Context) {
var req CreateErrorPassthroughRuleRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
rule := &model.ErrorPassthroughRule{
Name: req.Name,
Priority: req.Priority,
ErrorCodes: req.ErrorCodes,
Keywords: req.Keywords,
Platforms: req.Platforms,
}
// 设置默认值
if req.Enabled != nil {
rule.Enabled = *req.Enabled
} else {
rule.Enabled = true
}
if req.MatchMode != "" {
rule.MatchMode = req.MatchMode
} else {
rule.MatchMode = model.MatchModeAny
}
if req.PassthroughCode != nil {
rule.PassthroughCode = *req.PassthroughCode
} else {
rule.PassthroughCode = true
}
if req.PassthroughBody != nil {
rule.PassthroughBody = *req.PassthroughBody
} else {
rule.PassthroughBody = true
}
rule.ResponseCode = req.ResponseCode
rule.CustomMessage = req.CustomMessage
rule.Description = req.Description
// 确保切片不为 nil
if rule.ErrorCodes == nil {
rule.ErrorCodes = []int{}
}
if rule.Keywords == nil {
rule.Keywords = []string{}
}
if rule.Platforms == nil {
rule.Platforms = []string{}
}
created, err := h.service.Create(c.Request.Context(), rule)
if err != nil {
if _, ok := err.(*model.ValidationError); ok {
response.BadRequest(c, err.Error())
return
}
response.ErrorFrom(c, err)
return
}
response.Success(c, created)
}
// Update 更新规则(支持部分更新)
// PUT /api/v1/admin/error-passthrough-rules/:id
func (h *ErrorPassthroughHandler) Update(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid rule ID")
return
}
var req UpdateErrorPassthroughRuleRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
// 先获取现有规则
existing, err := h.service.GetByID(c.Request.Context(), id)
if err != nil {
response.ErrorFrom(c, err)
return
}
if existing == nil {
response.NotFound(c, "Rule not found")
return
}
// 部分更新:只更新请求中提供的字段
rule := &model.ErrorPassthroughRule{
ID: id,
Name: existing.Name,
Enabled: existing.Enabled,
Priority: existing.Priority,
ErrorCodes: existing.ErrorCodes,
Keywords: existing.Keywords,
MatchMode: existing.MatchMode,
Platforms: existing.Platforms,
PassthroughCode: existing.PassthroughCode,
ResponseCode: existing.ResponseCode,
PassthroughBody: existing.PassthroughBody,
CustomMessage: existing.CustomMessage,
Description: existing.Description,
}
// 应用请求中提供的更新
if req.Name != nil {
rule.Name = *req.Name
}
if req.Enabled != nil {
rule.Enabled = *req.Enabled
}
if req.Priority != nil {
rule.Priority = *req.Priority
}
if req.ErrorCodes != nil {
rule.ErrorCodes = req.ErrorCodes
}
if req.Keywords != nil {
rule.Keywords = req.Keywords
}
if req.MatchMode != nil {
rule.MatchMode = *req.MatchMode
}
if req.Platforms != nil {
rule.Platforms = req.Platforms
}
if req.PassthroughCode != nil {
rule.PassthroughCode = *req.PassthroughCode
}
if req.ResponseCode != nil {
rule.ResponseCode = req.ResponseCode
}
if req.PassthroughBody != nil {
rule.PassthroughBody = *req.PassthroughBody
}
if req.CustomMessage != nil {
rule.CustomMessage = req.CustomMessage
}
if req.Description != nil {
rule.Description = req.Description
}
// 确保切片不为 nil
if rule.ErrorCodes == nil {
rule.ErrorCodes = []int{}
}
if rule.Keywords == nil {
rule.Keywords = []string{}
}
if rule.Platforms == nil {
rule.Platforms = []string{}
}
updated, err := h.service.Update(c.Request.Context(), rule)
if err != nil {
if _, ok := err.(*model.ValidationError); ok {
response.BadRequest(c, err.Error())
return
}
response.ErrorFrom(c, err)
return
}
response.Success(c, updated)
}
// Delete 删除规则
// DELETE /api/v1/admin/error-passthrough-rules/:id
func (h *ErrorPassthroughHandler) Delete(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid rule ID")
return
}
if err := h.service.Delete(c.Request.Context(), id); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"message": "Rule deleted successfully"})
}

View File

@@ -45,6 +45,9 @@ type UpdateUserRequest struct {
Concurrency *int `json:"concurrency"` Concurrency *int `json:"concurrency"`
Status string `json:"status" binding:"omitempty,oneof=active disabled"` Status string `json:"status" binding:"omitempty,oneof=active disabled"`
AllowedGroups *[]int64 `json:"allowed_groups"` AllowedGroups *[]int64 `json:"allowed_groups"`
// GroupRates 用户专属分组倍率配置
// map[groupID]*ratenil 表示删除该分组的专属倍率
GroupRates map[int64]*float64 `json:"group_rates"`
} }
// UpdateBalanceRequest represents balance update request // UpdateBalanceRequest represents balance update request
@@ -183,6 +186,7 @@ func (h *UserHandler) Update(c *gin.Context) {
Concurrency: req.Concurrency, Concurrency: req.Concurrency,
Status: req.Status, Status: req.Status,
AllowedGroups: req.AllowedGroups, AllowedGroups: req.AllowedGroups,
GroupRates: req.GroupRates,
}) })
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)

View File

@@ -243,3 +243,21 @@ func (h *APIKeyHandler) GetAvailableGroups(c *gin.Context) {
} }
response.Success(c, out) response.Success(c, out)
} }
// GetUserGroupRates 获取当前用户的专属分组倍率配置
// GET /api/v1/groups/rates
func (h *APIKeyHandler) GetUserGroupRates(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
rates, err := h.apiKeyService.GetUserGroupRates(c.Request.Context(), subject.UserID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, rates)
}

View File

@@ -68,9 +68,39 @@ type LoginRequest struct {
// AuthResponse 认证响应格式(匹配前端期望) // AuthResponse 认证响应格式(匹配前端期望)
type AuthResponse struct { type AuthResponse struct {
AccessToken string `json:"access_token"` AccessToken string `json:"access_token"`
TokenType string `json:"token_type"` RefreshToken string `json:"refresh_token,omitempty"` // 新增Refresh Token
User *dto.User `json:"user"` ExpiresIn int `json:"expires_in,omitempty"` // 新增Access Token有效期
TokenType string `json:"token_type"`
User *dto.User `json:"user"`
}
// respondWithTokenPair 生成 Token 对并返回认证响应
// 如果 Token 对生成失败,回退到只返回 Access Token向后兼容
func (h *AuthHandler) respondWithTokenPair(c *gin.Context, user *service.User) {
tokenPair, err := h.authService.GenerateTokenPair(c.Request.Context(), user, "")
if err != nil {
slog.Error("failed to generate token pair", "error", err, "user_id", user.ID)
// 回退到只返回Access Token
token, tokenErr := h.authService.GenerateToken(user)
if tokenErr != nil {
response.InternalError(c, "Failed to generate token")
return
}
response.Success(c, AuthResponse{
AccessToken: token,
TokenType: "Bearer",
User: dto.UserFromService(user),
})
return
}
response.Success(c, AuthResponse{
AccessToken: tokenPair.AccessToken,
RefreshToken: tokenPair.RefreshToken,
ExpiresIn: tokenPair.ExpiresIn,
TokenType: "Bearer",
User: dto.UserFromService(user),
})
} }
// Register handles user registration // Register handles user registration
@@ -90,17 +120,13 @@ func (h *AuthHandler) Register(c *gin.Context) {
} }
} }
token, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode, req.PromoCode, req.InvitationCode) _, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode, req.PromoCode, req.InvitationCode)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
response.Success(c, AuthResponse{ h.respondWithTokenPair(c, user)
AccessToken: token,
TokenType: "Bearer",
User: dto.UserFromService(user),
})
} }
// SendVerifyCode 发送邮箱验证码 // SendVerifyCode 发送邮箱验证码
@@ -150,6 +176,7 @@ func (h *AuthHandler) Login(c *gin.Context) {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
_ = token // token 由 authService.Login 返回但此处由 respondWithTokenPair 重新生成
// Check if TOTP 2FA is enabled for this user // Check if TOTP 2FA is enabled for this user
if h.totpService != nil && h.settingSvc.IsTotpEnabled(c.Request.Context()) && user.TotpEnabled { if h.totpService != nil && h.settingSvc.IsTotpEnabled(c.Request.Context()) && user.TotpEnabled {
@@ -168,11 +195,7 @@ func (h *AuthHandler) Login(c *gin.Context) {
return return
} }
response.Success(c, AuthResponse{ h.respondWithTokenPair(c, user)
AccessToken: token,
TokenType: "Bearer",
User: dto.UserFromService(user),
})
} }
// TotpLoginResponse represents the response when 2FA is required // TotpLoginResponse represents the response when 2FA is required
@@ -238,18 +261,7 @@ func (h *AuthHandler) Login2FA(c *gin.Context) {
return return
} }
// Generate the JWT token h.respondWithTokenPair(c, user)
token, err := h.authService.GenerateToken(user)
if err != nil {
response.InternalError(c, "Failed to generate token")
return
}
response.Success(c, AuthResponse{
AccessToken: token,
TokenType: "Bearer",
User: dto.UserFromService(user),
})
} }
// GetCurrentUser handles getting current authenticated user // GetCurrentUser handles getting current authenticated user
@@ -491,3 +503,96 @@ func (h *AuthHandler) ResetPassword(c *gin.Context) {
Message: "Your password has been reset successfully. You can now log in with your new password.", Message: "Your password has been reset successfully. You can now log in with your new password.",
}) })
} }
// ==================== Token Refresh Endpoints ====================
// RefreshTokenRequest 刷新Token请求
type RefreshTokenRequest struct {
RefreshToken string `json:"refresh_token" binding:"required"`
}
// RefreshTokenResponse 刷新Token响应
type RefreshTokenResponse struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int `json:"expires_in"` // Access Token有效期
TokenType string `json:"token_type"`
}
// RefreshToken 刷新Token
// POST /api/v1/auth/refresh
func (h *AuthHandler) RefreshToken(c *gin.Context) {
var req RefreshTokenRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
tokenPair, err := h.authService.RefreshTokenPair(c.Request.Context(), req.RefreshToken)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, RefreshTokenResponse{
AccessToken: tokenPair.AccessToken,
RefreshToken: tokenPair.RefreshToken,
ExpiresIn: tokenPair.ExpiresIn,
TokenType: "Bearer",
})
}
// LogoutRequest 登出请求
type LogoutRequest struct {
RefreshToken string `json:"refresh_token,omitempty"` // 可选撤销指定的Refresh Token
}
// LogoutResponse 登出响应
type LogoutResponse struct {
Message string `json:"message"`
}
// Logout 用户登出
// POST /api/v1/auth/logout
func (h *AuthHandler) Logout(c *gin.Context) {
var req LogoutRequest
// 允许空请求体(向后兼容)
_ = c.ShouldBindJSON(&req)
// 如果提供了Refresh Token撤销它
if req.RefreshToken != "" {
if err := h.authService.RevokeRefreshToken(c.Request.Context(), req.RefreshToken); err != nil {
slog.Debug("failed to revoke refresh token", "error", err)
// 不影响登出流程
}
}
response.Success(c, LogoutResponse{
Message: "Logged out successfully",
})
}
// RevokeAllSessionsResponse 撤销所有会话响应
type RevokeAllSessionsResponse struct {
Message string `json:"message"`
}
// RevokeAllSessions 撤销当前用户的所有会话
// POST /api/v1/auth/revoke-all-sessions
func (h *AuthHandler) RevokeAllSessions(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
if err := h.authService.RevokeAllUserSessions(c.Request.Context(), subject.UserID); err != nil {
slog.Error("failed to revoke all sessions", "user_id", subject.UserID, "error", err)
response.InternalError(c, "Failed to revoke sessions")
return
}
response.Success(c, RevokeAllSessionsResponse{
Message: "All sessions have been revoked. Please log in again.",
})
}

View File

@@ -211,7 +211,7 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
email = linuxDoSyntheticEmail(subject) email = linuxDoSyntheticEmail(subject)
} }
jwtToken, _, err := h.authService.LoginOrRegisterOAuth(c.Request.Context(), email, username) tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username)
if err != nil { if err != nil {
// 避免把内部细节泄露给客户端;给前端保留结构化原因与提示信息即可。 // 避免把内部细节泄露给客户端;给前端保留结构化原因与提示信息即可。
redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err)) redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err))
@@ -219,7 +219,9 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
} }
fragment := url.Values{} fragment := url.Values{}
fragment.Set("access_token", jwtToken) fragment.Set("access_token", tokenPair.AccessToken)
fragment.Set("refresh_token", tokenPair.RefreshToken)
fragment.Set("expires_in", fmt.Sprintf("%d", tokenPair.ExpiresIn))
fragment.Set("token_type", "Bearer") fragment.Set("token_type", "Bearer")
fragment.Set("redirect", redirectTo) fragment.Set("redirect", redirectTo)
redirectWithFragment(c, frontendCallback, fragment) redirectWithFragment(c, frontendCallback, fragment)

View File

@@ -58,8 +58,9 @@ func UserFromServiceAdmin(u *service.User) *AdminUser {
return nil return nil
} }
return &AdminUser{ return &AdminUser{
User: *base, User: *base,
Notes: u.Notes, Notes: u.Notes,
GroupRates: u.GroupRates,
} }
} }

View File

@@ -29,6 +29,9 @@ type AdminUser struct {
User User
Notes string `json:"notes"` Notes string `json:"notes"`
// GroupRates 用户专属分组倍率配置
// map[groupID]rateMultiplier
GroupRates map[int64]float64 `json:"group_rates,omitempty"`
} }
type APIKey struct { type APIKey struct {

View File

@@ -33,6 +33,7 @@ type GatewayHandler struct {
billingCacheService *service.BillingCacheService billingCacheService *service.BillingCacheService
usageService *service.UsageService usageService *service.UsageService
apiKeyService *service.APIKeyService apiKeyService *service.APIKeyService
errorPassthroughService *service.ErrorPassthroughService
concurrencyHelper *ConcurrencyHelper concurrencyHelper *ConcurrencyHelper
maxAccountSwitches int maxAccountSwitches int
maxAccountSwitchesGemini int maxAccountSwitchesGemini int
@@ -49,6 +50,7 @@ func NewGatewayHandler(
billingCacheService *service.BillingCacheService, billingCacheService *service.BillingCacheService,
usageService *service.UsageService, usageService *service.UsageService,
apiKeyService *service.APIKeyService, apiKeyService *service.APIKeyService,
errorPassthroughService *service.ErrorPassthroughService,
cfg *config.Config, cfg *config.Config,
) *GatewayHandler { ) *GatewayHandler {
pingInterval := time.Duration(0) pingInterval := time.Duration(0)
@@ -71,6 +73,7 @@ func NewGatewayHandler(
billingCacheService: billingCacheService, billingCacheService: billingCacheService,
usageService: usageService, usageService: usageService,
apiKeyService: apiKeyService, apiKeyService: apiKeyService,
errorPassthroughService: errorPassthroughService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval), concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval),
maxAccountSwitches: maxAccountSwitches, maxAccountSwitches: maxAccountSwitches,
maxAccountSwitchesGemini: maxAccountSwitchesGemini, maxAccountSwitchesGemini: maxAccountSwitchesGemini,
@@ -203,7 +206,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
maxAccountSwitches := h.maxAccountSwitchesGemini maxAccountSwitches := h.maxAccountSwitchesGemini
switchCount := 0 switchCount := 0
failedAccountIDs := make(map[int64]struct{}) failedAccountIDs := make(map[int64]struct{})
lastFailoverStatus := 0 var lastFailoverErr *service.UpstreamFailoverError
for { for {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, "") // Gemini 不使用会话限制 selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, "") // Gemini 不使用会话限制
@@ -212,7 +215,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
return return
} }
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) if lastFailoverErr != nil {
h.handleFailoverExhausted(c, lastFailoverErr, service.PlatformGemini, streamStarted)
} else {
h.handleFailoverExhaustedSimple(c, 502, streamStarted)
}
return return
} }
account := selection.Account account := selection.Account
@@ -303,9 +310,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
var failoverErr *service.UpstreamFailoverError var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) { if errors.As(err, &failoverErr) {
failedAccountIDs[account.ID] = struct{}{} failedAccountIDs[account.ID] = struct{}{}
lastFailoverStatus = failoverErr.StatusCode lastFailoverErr = failoverErr
if switchCount >= maxAccountSwitches { if switchCount >= maxAccountSwitches {
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) h.handleFailoverExhausted(c, failoverErr, service.PlatformGemini, streamStarted)
return return
} }
switchCount++ switchCount++
@@ -354,7 +361,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
maxAccountSwitches := h.maxAccountSwitches maxAccountSwitches := h.maxAccountSwitches
switchCount := 0 switchCount := 0
failedAccountIDs := make(map[int64]struct{}) failedAccountIDs := make(map[int64]struct{})
lastFailoverStatus := 0 var lastFailoverErr *service.UpstreamFailoverError
retryWithFallback := false retryWithFallback := false
for { for {
@@ -365,7 +372,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
return return
} }
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) if lastFailoverErr != nil {
h.handleFailoverExhausted(c, lastFailoverErr, platform, streamStarted)
} else {
h.handleFailoverExhaustedSimple(c, 502, streamStarted)
}
return return
} }
account := selection.Account account := selection.Account
@@ -489,9 +500,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
var failoverErr *service.UpstreamFailoverError var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) { if errors.As(err, &failoverErr) {
failedAccountIDs[account.ID] = struct{}{} failedAccountIDs[account.ID] = struct{}{}
lastFailoverStatus = failoverErr.StatusCode lastFailoverErr = failoverErr
if switchCount >= maxAccountSwitches { if switchCount >= maxAccountSwitches {
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) h.handleFailoverExhausted(c, failoverErr, account.Platform, streamStarted)
return return
} }
switchCount++ switchCount++
@@ -629,10 +640,10 @@ func (h *GatewayHandler) Usage(c *gin.Context) {
return return
} }
// Best-effort: 获取用量统计,失败不影响基础响应 // Best-effort: 获取用量统计(按当前 API Key 过滤),失败不影响基础响应
var usageData gin.H var usageData gin.H
if h.usageService != nil { if h.usageService != nil {
dashStats, err := h.usageService.GetUserDashboardStats(c.Request.Context(), subject.UserID) dashStats, err := h.usageService.GetAPIKeyDashboardStats(c.Request.Context(), apiKey.ID)
if err == nil && dashStats != nil { if err == nil && dashStats != nil {
usageData = gin.H{ usageData = gin.H{
"today": gin.H{ "today": gin.H{
@@ -768,7 +779,37 @@ func (h *GatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotT
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted) fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
} }
func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, streamStarted bool) { func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, platform string, streamStarted bool) {
statusCode := failoverErr.StatusCode
responseBody := failoverErr.ResponseBody
// 先检查透传规则
if h.errorPassthroughService != nil && len(responseBody) > 0 {
if rule := h.errorPassthroughService.MatchRule(platform, statusCode, responseBody); rule != nil {
// 确定响应状态码
respCode := statusCode
if !rule.PassthroughCode && rule.ResponseCode != nil {
respCode = *rule.ResponseCode
}
// 确定响应消息
msg := service.ExtractUpstreamErrorMessage(responseBody)
if !rule.PassthroughBody && rule.CustomMessage != nil {
msg = *rule.CustomMessage
}
h.handleStreamingAwareError(c, respCode, "upstream_error", msg, streamStarted)
return
}
}
// 使用默认的错误映射
status, errType, errMsg := h.mapUpstreamError(statusCode)
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
}
// handleFailoverExhaustedSimple 简化版本,用于没有响应体的情况
func (h *GatewayHandler) handleFailoverExhaustedSimple(c *gin.Context, statusCode int, streamStarted bool) {
status, errType, errMsg := h.mapUpstreamError(statusCode) status, errType, errMsg := h.mapUpstreamError(statusCode)
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted) h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
} }

View File

@@ -253,7 +253,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
maxAccountSwitches := h.maxAccountSwitchesGemini maxAccountSwitches := h.maxAccountSwitchesGemini
switchCount := 0 switchCount := 0
failedAccountIDs := make(map[int64]struct{}) failedAccountIDs := make(map[int64]struct{})
lastFailoverStatus := 0 var lastFailoverErr *service.UpstreamFailoverError
for { for {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, failedAccountIDs, "") // Gemini 不使用会话限制 selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, failedAccountIDs, "") // Gemini 不使用会话限制
@@ -262,7 +262,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error()) googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
return return
} }
handleGeminiFailoverExhausted(c, lastFailoverStatus) h.handleGeminiFailoverExhausted(c, lastFailoverErr)
return return
} }
account := selection.Account account := selection.Account
@@ -353,11 +353,11 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
if errors.As(err, &failoverErr) { if errors.As(err, &failoverErr) {
failedAccountIDs[account.ID] = struct{}{} failedAccountIDs[account.ID] = struct{}{}
if switchCount >= maxAccountSwitches { if switchCount >= maxAccountSwitches {
lastFailoverStatus = failoverErr.StatusCode lastFailoverErr = failoverErr
handleGeminiFailoverExhausted(c, lastFailoverStatus) h.handleGeminiFailoverExhausted(c, lastFailoverErr)
return return
} }
lastFailoverStatus = failoverErr.StatusCode lastFailoverErr = failoverErr
switchCount++ switchCount++
log.Printf("Gemini account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches) log.Printf("Gemini account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
continue continue
@@ -414,7 +414,36 @@ func parseGeminiModelAction(rest string) (model string, action string, err error
return "", "", &pathParseError{"invalid model action path"} return "", "", &pathParseError{"invalid model action path"}
} }
func handleGeminiFailoverExhausted(c *gin.Context, statusCode int) { func (h *GatewayHandler) handleGeminiFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError) {
if failoverErr == nil {
googleError(c, http.StatusBadGateway, "Upstream request failed")
return
}
statusCode := failoverErr.StatusCode
responseBody := failoverErr.ResponseBody
// 先检查透传规则
if h.errorPassthroughService != nil && len(responseBody) > 0 {
if rule := h.errorPassthroughService.MatchRule(service.PlatformGemini, statusCode, responseBody); rule != nil {
// 确定响应状态码
respCode := statusCode
if !rule.PassthroughCode && rule.ResponseCode != nil {
respCode = *rule.ResponseCode
}
// 确定响应消息
msg := service.ExtractUpstreamErrorMessage(responseBody)
if !rule.PassthroughBody && rule.CustomMessage != nil {
msg = *rule.CustomMessage
}
googleError(c, respCode, msg)
return
}
}
// 使用默认的错误映射
status, message := mapGeminiUpstreamError(statusCode) status, message := mapGeminiUpstreamError(statusCode)
googleError(c, status, message) googleError(c, status, message)
} }

View File

@@ -24,6 +24,7 @@ type AdminHandlers struct {
Subscription *admin.SubscriptionHandler Subscription *admin.SubscriptionHandler
Usage *admin.UsageHandler Usage *admin.UsageHandler
UserAttribute *admin.UserAttributeHandler UserAttribute *admin.UserAttributeHandler
ErrorPassthrough *admin.ErrorPassthroughHandler
} }
// Handlers contains all HTTP handlers // Handlers contains all HTTP handlers

View File

@@ -22,11 +22,12 @@ import (
// OpenAIGatewayHandler handles OpenAI API gateway requests // OpenAIGatewayHandler handles OpenAI API gateway requests
type OpenAIGatewayHandler struct { type OpenAIGatewayHandler struct {
gatewayService *service.OpenAIGatewayService gatewayService *service.OpenAIGatewayService
billingCacheService *service.BillingCacheService billingCacheService *service.BillingCacheService
apiKeyService *service.APIKeyService apiKeyService *service.APIKeyService
concurrencyHelper *ConcurrencyHelper errorPassthroughService *service.ErrorPassthroughService
maxAccountSwitches int concurrencyHelper *ConcurrencyHelper
maxAccountSwitches int
} }
// NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler // NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler
@@ -35,6 +36,7 @@ func NewOpenAIGatewayHandler(
concurrencyService *service.ConcurrencyService, concurrencyService *service.ConcurrencyService,
billingCacheService *service.BillingCacheService, billingCacheService *service.BillingCacheService,
apiKeyService *service.APIKeyService, apiKeyService *service.APIKeyService,
errorPassthroughService *service.ErrorPassthroughService,
cfg *config.Config, cfg *config.Config,
) *OpenAIGatewayHandler { ) *OpenAIGatewayHandler {
pingInterval := time.Duration(0) pingInterval := time.Duration(0)
@@ -46,11 +48,12 @@ func NewOpenAIGatewayHandler(
} }
} }
return &OpenAIGatewayHandler{ return &OpenAIGatewayHandler{
gatewayService: gatewayService, gatewayService: gatewayService,
billingCacheService: billingCacheService, billingCacheService: billingCacheService,
apiKeyService: apiKeyService, apiKeyService: apiKeyService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval), errorPassthroughService: errorPassthroughService,
maxAccountSwitches: maxAccountSwitches, concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
maxAccountSwitches: maxAccountSwitches,
} }
} }
@@ -201,7 +204,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
maxAccountSwitches := h.maxAccountSwitches maxAccountSwitches := h.maxAccountSwitches
switchCount := 0 switchCount := 0
failedAccountIDs := make(map[int64]struct{}) failedAccountIDs := make(map[int64]struct{})
lastFailoverStatus := 0 var lastFailoverErr *service.UpstreamFailoverError
for { for {
// Select account supporting the requested model // Select account supporting the requested model
@@ -213,7 +216,11 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
return return
} }
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) if lastFailoverErr != nil {
h.handleFailoverExhausted(c, lastFailoverErr, streamStarted)
} else {
h.handleFailoverExhaustedSimple(c, 502, streamStarted)
}
return return
} }
account := selection.Account account := selection.Account
@@ -278,12 +285,11 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
var failoverErr *service.UpstreamFailoverError var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) { if errors.As(err, &failoverErr) {
failedAccountIDs[account.ID] = struct{}{} failedAccountIDs[account.ID] = struct{}{}
lastFailoverErr = failoverErr
if switchCount >= maxAccountSwitches { if switchCount >= maxAccountSwitches {
lastFailoverStatus = failoverErr.StatusCode h.handleFailoverExhausted(c, failoverErr, streamStarted)
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
return return
} }
lastFailoverStatus = failoverErr.StatusCode
switchCount++ switchCount++
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches) log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
continue continue
@@ -324,7 +330,37 @@ func (h *OpenAIGatewayHandler) handleConcurrencyError(c *gin.Context, err error,
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted) fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
} }
func (h *OpenAIGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, streamStarted bool) { func (h *OpenAIGatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, streamStarted bool) {
statusCode := failoverErr.StatusCode
responseBody := failoverErr.ResponseBody
// 先检查透传规则
if h.errorPassthroughService != nil && len(responseBody) > 0 {
if rule := h.errorPassthroughService.MatchRule("openai", statusCode, responseBody); rule != nil {
// 确定响应状态码
respCode := statusCode
if !rule.PassthroughCode && rule.ResponseCode != nil {
respCode = *rule.ResponseCode
}
// 确定响应消息
msg := service.ExtractUpstreamErrorMessage(responseBody)
if !rule.PassthroughBody && rule.CustomMessage != nil {
msg = *rule.CustomMessage
}
h.handleStreamingAwareError(c, respCode, "upstream_error", msg, streamStarted)
return
}
}
// 使用默认的错误映射
status, errType, errMsg := h.mapUpstreamError(statusCode)
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
}
// handleFailoverExhaustedSimple 简化版本,用于没有响应体的情况
func (h *OpenAIGatewayHandler) handleFailoverExhaustedSimple(c *gin.Context, statusCode int, streamStarted bool) {
status, errType, errMsg := h.mapUpstreamError(statusCode) status, errType, errMsg := h.mapUpstreamError(statusCode)
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted) h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
} }

View File

@@ -27,6 +27,7 @@ func ProvideAdminHandlers(
subscriptionHandler *admin.SubscriptionHandler, subscriptionHandler *admin.SubscriptionHandler,
usageHandler *admin.UsageHandler, usageHandler *admin.UsageHandler,
userAttributeHandler *admin.UserAttributeHandler, userAttributeHandler *admin.UserAttributeHandler,
errorPassthroughHandler *admin.ErrorPassthroughHandler,
) *AdminHandlers { ) *AdminHandlers {
return &AdminHandlers{ return &AdminHandlers{
Dashboard: dashboardHandler, Dashboard: dashboardHandler,
@@ -47,6 +48,7 @@ func ProvideAdminHandlers(
Subscription: subscriptionHandler, Subscription: subscriptionHandler,
Usage: usageHandler, Usage: usageHandler,
UserAttribute: userAttributeHandler, UserAttribute: userAttributeHandler,
ErrorPassthrough: errorPassthroughHandler,
} }
} }
@@ -128,6 +130,7 @@ var ProviderSet = wire.NewSet(
admin.NewSubscriptionHandler, admin.NewSubscriptionHandler,
admin.NewUsageHandler, admin.NewUsageHandler,
admin.NewUserAttributeHandler, admin.NewUserAttributeHandler,
admin.NewErrorPassthroughHandler,
// AdminHandlers and Handlers constructors // AdminHandlers and Handlers constructors
ProvideAdminHandlers, ProvideAdminHandlers,

View File

@@ -0,0 +1,74 @@
// Package model 定义服务层使用的数据模型。
package model
import "time"
// ErrorPassthroughRule 全局错误透传规则
// 用于控制上游错误如何返回给客户端
type ErrorPassthroughRule struct {
ID int64 `json:"id"`
Name string `json:"name"` // 规则名称
Enabled bool `json:"enabled"` // 是否启用
Priority int `json:"priority"` // 优先级(数字越小优先级越高)
ErrorCodes []int `json:"error_codes"` // 匹配的错误码列表OR关系
Keywords []string `json:"keywords"` // 匹配的关键词列表OR关系
MatchMode string `json:"match_mode"` // "any"(任一条件) 或 "all"(所有条件)
Platforms []string `json:"platforms"` // 适用平台列表
PassthroughCode bool `json:"passthrough_code"` // 是否透传原始状态码
ResponseCode *int `json:"response_code"` // 自定义状态码passthrough_code=false 时使用)
PassthroughBody bool `json:"passthrough_body"` // 是否透传原始错误信息
CustomMessage *string `json:"custom_message"` // 自定义错误信息passthrough_body=false 时使用)
Description *string `json:"description"` // 规则描述
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// MatchModeAny 表示任一条件匹配即可
const MatchModeAny = "any"
// MatchModeAll 表示所有条件都必须匹配
const MatchModeAll = "all"
// 支持的平台常量
const (
PlatformAnthropic = "anthropic"
PlatformOpenAI = "openai"
PlatformGemini = "gemini"
PlatformAntigravity = "antigravity"
)
// AllPlatforms 返回所有支持的平台列表
func AllPlatforms() []string {
return []string{PlatformAnthropic, PlatformOpenAI, PlatformGemini, PlatformAntigravity}
}
// Validate 验证规则配置的有效性
func (r *ErrorPassthroughRule) Validate() error {
if r.Name == "" {
return &ValidationError{Field: "name", Message: "name is required"}
}
if r.MatchMode != MatchModeAny && r.MatchMode != MatchModeAll {
return &ValidationError{Field: "match_mode", Message: "match_mode must be 'any' or 'all'"}
}
// 至少需要配置一个匹配条件(错误码或关键词)
if len(r.ErrorCodes) == 0 && len(r.Keywords) == 0 {
return &ValidationError{Field: "conditions", Message: "at least one error_code or keyword is required"}
}
if !r.PassthroughCode && (r.ResponseCode == nil || *r.ResponseCode <= 0) {
return &ValidationError{Field: "response_code", Message: "response_code is required when passthrough_code is false"}
}
if !r.PassthroughBody && (r.CustomMessage == nil || *r.CustomMessage == "") {
return &ValidationError{Field: "custom_message", Message: "custom_message is required when passthrough_body is false"}
}
return nil
}
// ValidationError 表示验证错误
type ValidationError struct {
Field string
Message string
}
func (e *ValidationError) Error() string {
return e.Field + ": " + e.Message
}

View File

@@ -0,0 +1,109 @@
// Package googleapi provides helpers for Google-style API responses.
package googleapi
import (
"encoding/json"
"fmt"
"strings"
)
// ErrorResponse represents a Google API error response
type ErrorResponse struct {
Error ErrorDetail `json:"error"`
}
// ErrorDetail contains the error details from Google API
type ErrorDetail struct {
Code int `json:"code"`
Message string `json:"message"`
Status string `json:"status"`
Details []json.RawMessage `json:"details,omitempty"`
}
// ErrorDetailInfo contains additional error information
type ErrorDetailInfo struct {
Type string `json:"@type"`
Reason string `json:"reason,omitempty"`
Domain string `json:"domain,omitempty"`
Metadata map[string]string `json:"metadata,omitempty"`
}
// ErrorHelp contains help links
type ErrorHelp struct {
Type string `json:"@type"`
Links []HelpLink `json:"links,omitempty"`
}
// HelpLink represents a help link
type HelpLink struct {
Description string `json:"description"`
URL string `json:"url"`
}
// ParseError parses a Google API error response and extracts key information
func ParseError(body string) (*ErrorResponse, error) {
var errResp ErrorResponse
if err := json.Unmarshal([]byte(body), &errResp); err != nil {
return nil, fmt.Errorf("failed to parse error response: %w", err)
}
return &errResp, nil
}
// ExtractActivationURL extracts the API activation URL from error details
func ExtractActivationURL(body string) string {
var errResp ErrorResponse
if err := json.Unmarshal([]byte(body), &errResp); err != nil {
return ""
}
// Check error details for activation URL
for _, detailRaw := range errResp.Error.Details {
// Parse as ErrorDetailInfo
var info ErrorDetailInfo
if err := json.Unmarshal(detailRaw, &info); err == nil {
if info.Metadata != nil {
if activationURL, ok := info.Metadata["activationUrl"]; ok && activationURL != "" {
return activationURL
}
}
}
// Parse as ErrorHelp
var help ErrorHelp
if err := json.Unmarshal(detailRaw, &help); err == nil {
for _, link := range help.Links {
if strings.Contains(link.Description, "activation") ||
strings.Contains(link.Description, "API activation") ||
strings.Contains(link.URL, "/apis/api/") {
return link.URL
}
}
}
}
return ""
}
// IsServiceDisabledError checks if the error is a SERVICE_DISABLED error
func IsServiceDisabledError(body string) bool {
var errResp ErrorResponse
if err := json.Unmarshal([]byte(body), &errResp); err != nil {
return false
}
// Check if it's a 403 PERMISSION_DENIED with SERVICE_DISABLED reason
if errResp.Error.Code != 403 || errResp.Error.Status != "PERMISSION_DENIED" {
return false
}
for _, detailRaw := range errResp.Error.Details {
var info ErrorDetailInfo
if err := json.Unmarshal(detailRaw, &info); err == nil {
if info.Reason == "SERVICE_DISABLED" {
return true
}
}
}
return false
}

View File

@@ -0,0 +1,143 @@
package googleapi
import (
"testing"
)
func TestExtractActivationURL(t *testing.T) {
// Test case from the user's error message
errorBody := `{
"error": {
"code": 403,
"message": "Gemini for Google Cloud API has not been used in project project-6eca5881-ab73-4736-843 before or it is disabled. Enable it by visiting https://console.developers.google.com/apis/api/cloudaicompanion.googleapis.com/overview?project=project-6eca5881-ab73-4736-843 then retry. If you enabled this API recently, wait a few minutes for the action to propagate to our systems and retry.",
"status": "PERMISSION_DENIED",
"details": [
{
"@type": "type.googleapis.com/google.rpc.ErrorInfo",
"reason": "SERVICE_DISABLED",
"domain": "googleapis.com",
"metadata": {
"service": "cloudaicompanion.googleapis.com",
"activationUrl": "https://console.developers.google.com/apis/api/cloudaicompanion.googleapis.com/overview?project=project-6eca5881-ab73-4736-843",
"consumer": "projects/project-6eca5881-ab73-4736-843",
"serviceTitle": "Gemini for Google Cloud API",
"containerInfo": "project-6eca5881-ab73-4736-843"
}
},
{
"@type": "type.googleapis.com/google.rpc.LocalizedMessage",
"locale": "en-US",
"message": "Gemini for Google Cloud API has not been used in project project-6eca5881-ab73-4736-843 before or it is disabled. Enable it by visiting https://console.developers.google.com/apis/api/cloudaicompanion.googleapis.com/overview?project=project-6eca5881-ab73-4736-843 then retry. If you enabled this API recently, wait a few minutes for the action to propagate to our systems and retry."
},
{
"@type": "type.googleapis.com/google.rpc.Help",
"links": [
{
"description": "Google developers console API activation",
"url": "https://console.developers.google.com/apis/api/cloudaicompanion.googleapis.com/overview?project=project-6eca5881-ab73-4736-843"
}
]
}
]
}
}`
activationURL := ExtractActivationURL(errorBody)
expectedURL := "https://console.developers.google.com/apis/api/cloudaicompanion.googleapis.com/overview?project=project-6eca5881-ab73-4736-843"
if activationURL != expectedURL {
t.Errorf("Expected activation URL %s, got %s", expectedURL, activationURL)
}
}
func TestIsServiceDisabledError(t *testing.T) {
tests := []struct {
name string
body string
expected bool
}{
{
name: "SERVICE_DISABLED error",
body: `{
"error": {
"code": 403,
"status": "PERMISSION_DENIED",
"details": [
{
"@type": "type.googleapis.com/google.rpc.ErrorInfo",
"reason": "SERVICE_DISABLED"
}
]
}
}`,
expected: true,
},
{
name: "Other 403 error",
body: `{
"error": {
"code": 403,
"status": "PERMISSION_DENIED",
"details": [
{
"@type": "type.googleapis.com/google.rpc.ErrorInfo",
"reason": "OTHER_REASON"
}
]
}
}`,
expected: false,
},
{
name: "404 error",
body: `{
"error": {
"code": 404,
"status": "NOT_FOUND"
}
}`,
expected: false,
},
{
name: "Invalid JSON",
body: `invalid json`,
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := IsServiceDisabledError(tt.body)
if result != tt.expected {
t.Errorf("Expected %v, got %v", tt.expected, result)
}
})
}
}
func TestParseError(t *testing.T) {
errorBody := `{
"error": {
"code": 403,
"message": "API not enabled",
"status": "PERMISSION_DENIED"
}
}`
errResp, err := ParseError(errorBody)
if err != nil {
t.Fatalf("Failed to parse error: %v", err)
}
if errResp.Error.Code != 403 {
t.Errorf("Expected code 403, got %d", errResp.Error.Code)
}
if errResp.Error.Status != "PERMISSION_DENIED" {
t.Errorf("Expected status PERMISSION_DENIED, got %s", errResp.Error.Status)
}
if errResp.Error.Message != "API not enabled" {
t.Errorf("Expected message 'API not enabled', got %s", errResp.Error.Message)
}
}

View File

@@ -0,0 +1,128 @@
package repository
import (
"context"
"encoding/json"
"log"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
const (
errorPassthroughCacheKey = "error_passthrough_rules"
errorPassthroughPubSubKey = "error_passthrough_rules_updated"
errorPassthroughCacheTTL = 24 * time.Hour
)
type errorPassthroughCache struct {
rdb *redis.Client
localCache []*model.ErrorPassthroughRule
localMu sync.RWMutex
}
// NewErrorPassthroughCache 创建错误透传规则缓存
func NewErrorPassthroughCache(rdb *redis.Client) service.ErrorPassthroughCache {
return &errorPassthroughCache{
rdb: rdb,
}
}
// Get 从缓存获取规则列表
func (c *errorPassthroughCache) Get(ctx context.Context) ([]*model.ErrorPassthroughRule, bool) {
// 先检查本地缓存
c.localMu.RLock()
if c.localCache != nil {
rules := c.localCache
c.localMu.RUnlock()
return rules, true
}
c.localMu.RUnlock()
// 从 Redis 获取
data, err := c.rdb.Get(ctx, errorPassthroughCacheKey).Bytes()
if err != nil {
if err != redis.Nil {
log.Printf("[ErrorPassthroughCache] Failed to get from Redis: %v", err)
}
return nil, false
}
var rules []*model.ErrorPassthroughRule
if err := json.Unmarshal(data, &rules); err != nil {
log.Printf("[ErrorPassthroughCache] Failed to unmarshal rules: %v", err)
return nil, false
}
// 更新本地缓存
c.localMu.Lock()
c.localCache = rules
c.localMu.Unlock()
return rules, true
}
// Set 设置缓存
func (c *errorPassthroughCache) Set(ctx context.Context, rules []*model.ErrorPassthroughRule) error {
data, err := json.Marshal(rules)
if err != nil {
return err
}
if err := c.rdb.Set(ctx, errorPassthroughCacheKey, data, errorPassthroughCacheTTL).Err(); err != nil {
return err
}
// 更新本地缓存
c.localMu.Lock()
c.localCache = rules
c.localMu.Unlock()
return nil
}
// Invalidate 使缓存失效
func (c *errorPassthroughCache) Invalidate(ctx context.Context) error {
// 清除本地缓存
c.localMu.Lock()
c.localCache = nil
c.localMu.Unlock()
// 清除 Redis 缓存
return c.rdb.Del(ctx, errorPassthroughCacheKey).Err()
}
// NotifyUpdate 通知其他实例刷新缓存
func (c *errorPassthroughCache) NotifyUpdate(ctx context.Context) error {
return c.rdb.Publish(ctx, errorPassthroughPubSubKey, "refresh").Err()
}
// SubscribeUpdates 订阅缓存更新通知
func (c *errorPassthroughCache) SubscribeUpdates(ctx context.Context, handler func()) {
go func() {
sub := c.rdb.Subscribe(ctx, errorPassthroughPubSubKey)
defer func() { _ = sub.Close() }()
ch := sub.Channel()
for {
select {
case <-ctx.Done():
return
case msg := <-ch:
if msg == nil {
return
}
// 清除本地缓存,下次访问时会从 Redis 或数据库重新加载
c.localMu.Lock()
c.localCache = nil
c.localMu.Unlock()
// 调用处理函数
handler()
}
}
}()
}

View File

@@ -0,0 +1,178 @@
package repository
import (
"context"
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/service"
)
type errorPassthroughRepository struct {
client *ent.Client
}
// NewErrorPassthroughRepository 创建错误透传规则仓库
func NewErrorPassthroughRepository(client *ent.Client) service.ErrorPassthroughRepository {
return &errorPassthroughRepository{client: client}
}
// List 获取所有规则
func (r *errorPassthroughRepository) List(ctx context.Context) ([]*model.ErrorPassthroughRule, error) {
rules, err := r.client.ErrorPassthroughRule.Query().
Order(ent.Asc(errorpassthroughrule.FieldPriority)).
All(ctx)
if err != nil {
return nil, err
}
result := make([]*model.ErrorPassthroughRule, len(rules))
for i, rule := range rules {
result[i] = r.toModel(rule)
}
return result, nil
}
// GetByID 根据 ID 获取规则
func (r *errorPassthroughRepository) GetByID(ctx context.Context, id int64) (*model.ErrorPassthroughRule, error) {
rule, err := r.client.ErrorPassthroughRule.Get(ctx, id)
if err != nil {
if ent.IsNotFound(err) {
return nil, nil
}
return nil, err
}
return r.toModel(rule), nil
}
// Create 创建规则
func (r *errorPassthroughRepository) Create(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) {
builder := r.client.ErrorPassthroughRule.Create().
SetName(rule.Name).
SetEnabled(rule.Enabled).
SetPriority(rule.Priority).
SetMatchMode(rule.MatchMode).
SetPassthroughCode(rule.PassthroughCode).
SetPassthroughBody(rule.PassthroughBody)
if len(rule.ErrorCodes) > 0 {
builder.SetErrorCodes(rule.ErrorCodes)
}
if len(rule.Keywords) > 0 {
builder.SetKeywords(rule.Keywords)
}
if len(rule.Platforms) > 0 {
builder.SetPlatforms(rule.Platforms)
}
if rule.ResponseCode != nil {
builder.SetResponseCode(*rule.ResponseCode)
}
if rule.CustomMessage != nil {
builder.SetCustomMessage(*rule.CustomMessage)
}
if rule.Description != nil {
builder.SetDescription(*rule.Description)
}
created, err := builder.Save(ctx)
if err != nil {
return nil, err
}
return r.toModel(created), nil
}
// Update 更新规则
func (r *errorPassthroughRepository) Update(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) {
builder := r.client.ErrorPassthroughRule.UpdateOneID(rule.ID).
SetName(rule.Name).
SetEnabled(rule.Enabled).
SetPriority(rule.Priority).
SetMatchMode(rule.MatchMode).
SetPassthroughCode(rule.PassthroughCode).
SetPassthroughBody(rule.PassthroughBody)
// 处理可选字段
if len(rule.ErrorCodes) > 0 {
builder.SetErrorCodes(rule.ErrorCodes)
} else {
builder.ClearErrorCodes()
}
if len(rule.Keywords) > 0 {
builder.SetKeywords(rule.Keywords)
} else {
builder.ClearKeywords()
}
if len(rule.Platforms) > 0 {
builder.SetPlatforms(rule.Platforms)
} else {
builder.ClearPlatforms()
}
if rule.ResponseCode != nil {
builder.SetResponseCode(*rule.ResponseCode)
} else {
builder.ClearResponseCode()
}
if rule.CustomMessage != nil {
builder.SetCustomMessage(*rule.CustomMessage)
} else {
builder.ClearCustomMessage()
}
if rule.Description != nil {
builder.SetDescription(*rule.Description)
} else {
builder.ClearDescription()
}
updated, err := builder.Save(ctx)
if err != nil {
return nil, err
}
return r.toModel(updated), nil
}
// Delete 删除规则
func (r *errorPassthroughRepository) Delete(ctx context.Context, id int64) error {
return r.client.ErrorPassthroughRule.DeleteOneID(id).Exec(ctx)
}
// toModel 将 Ent 实体转换为服务模型
func (r *errorPassthroughRepository) toModel(e *ent.ErrorPassthroughRule) *model.ErrorPassthroughRule {
rule := &model.ErrorPassthroughRule{
ID: int64(e.ID),
Name: e.Name,
Enabled: e.Enabled,
Priority: e.Priority,
ErrorCodes: e.ErrorCodes,
Keywords: e.Keywords,
MatchMode: e.MatchMode,
Platforms: e.Platforms,
PassthroughCode: e.PassthroughCode,
PassthroughBody: e.PassthroughBody,
CreatedAt: e.CreatedAt,
UpdatedAt: e.UpdatedAt,
}
if e.ResponseCode != nil {
rule.ResponseCode = e.ResponseCode
}
if e.CustomMessage != nil {
rule.CustomMessage = e.CustomMessage
}
if e.Description != nil {
rule.Description = e.Description
}
// 确保切片不为 nil
if rule.ErrorCodes == nil {
rule.ErrorCodes = []int{}
}
if rule.Keywords == nil {
rule.Keywords = []string{}
}
if rule.Platforms == nil {
rule.Platforms = []string{}
}
return rule
}

View File

@@ -6,6 +6,7 @@ import (
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/imroc/req/v3" "github.com/imroc/req/v3"
@@ -38,9 +39,20 @@ func (c *geminiCliCodeAssistClient) LoadCodeAssist(ctx context.Context, accessTo
return nil, fmt.Errorf("request failed: %w", err) return nil, fmt.Errorf("request failed: %w", err)
} }
if !resp.IsSuccessState() { if !resp.IsSuccessState() {
body := geminicli.SanitizeBodyForLogs(resp.String()) body := resp.String()
fmt.Printf("[CodeAssist] LoadCodeAssist failed: status %d, body: %s\n", resp.StatusCode, body) sanitizedBody := geminicli.SanitizeBodyForLogs(body)
return nil, fmt.Errorf("loadCodeAssist failed: status %d, body: %s", resp.StatusCode, body) fmt.Printf("[CodeAssist] LoadCodeAssist failed: status %d, body: %s\n", resp.StatusCode, sanitizedBody)
// Check if this is a SERVICE_DISABLED error and extract activation URL
if googleapi.IsServiceDisabledError(body) {
activationURL := googleapi.ExtractActivationURL(body)
if activationURL != "" {
return nil, fmt.Errorf("gemini API not enabled for this project, please enable it by visiting: %s\n\nAfter enabling the API, wait a few minutes for the changes to propagate, then try again", activationURL)
}
return nil, fmt.Errorf("gemini API not enabled for this project, please enable it in the Google Cloud Console at: https://console.cloud.google.com/apis/library/cloudaicompanion.googleapis.com")
}
return nil, fmt.Errorf("loadCodeAssist failed: status %d, body: %s", resp.StatusCode, sanitizedBody)
} }
fmt.Printf("[CodeAssist] LoadCodeAssist success: status %d, response: %+v\n", resp.StatusCode, out) fmt.Printf("[CodeAssist] LoadCodeAssist success: status %d, response: %+v\n", resp.StatusCode, out)
return &out, nil return &out, nil
@@ -67,9 +79,20 @@ func (c *geminiCliCodeAssistClient) OnboardUser(ctx context.Context, accessToken
return nil, fmt.Errorf("request failed: %w", err) return nil, fmt.Errorf("request failed: %w", err)
} }
if !resp.IsSuccessState() { if !resp.IsSuccessState() {
body := geminicli.SanitizeBodyForLogs(resp.String()) body := resp.String()
fmt.Printf("[CodeAssist] OnboardUser failed: status %d, body: %s\n", resp.StatusCode, body) sanitizedBody := geminicli.SanitizeBodyForLogs(body)
return nil, fmt.Errorf("onboardUser failed: status %d, body: %s", resp.StatusCode, body) fmt.Printf("[CodeAssist] OnboardUser failed: status %d, body: %s\n", resp.StatusCode, sanitizedBody)
// Check if this is a SERVICE_DISABLED error and extract activation URL
if googleapi.IsServiceDisabledError(body) {
activationURL := googleapi.ExtractActivationURL(body)
if activationURL != "" {
return nil, fmt.Errorf("gemini API not enabled for this project, please enable it by visiting: %s\n\nAfter enabling the API, wait a few minutes for the changes to propagate, then try again", activationURL)
}
return nil, fmt.Errorf("gemini API not enabled for this project, please enable it in the Google Cloud Console at: https://console.cloud.google.com/apis/library/cloudaicompanion.googleapis.com")
}
return nil, fmt.Errorf("onboardUser failed: status %d, body: %s", resp.StatusCode, sanitizedBody)
} }
fmt.Printf("[CodeAssist] OnboardUser success: status %d, response: %+v\n", resp.StatusCode, out) fmt.Printf("[CodeAssist] OnboardUser success: status %d, response: %+v\n", resp.StatusCode, out)
return &out, nil return &out, nil

View File

@@ -0,0 +1,158 @@
package repository
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
const (
refreshTokenKeyPrefix = "refresh_token:"
userRefreshTokensPrefix = "user_refresh_tokens:"
tokenFamilyPrefix = "token_family:"
)
// refreshTokenKey generates the Redis key for a refresh token.
func refreshTokenKey(tokenHash string) string {
return refreshTokenKeyPrefix + tokenHash
}
// userRefreshTokensKey generates the Redis key for user's token set.
func userRefreshTokensKey(userID int64) string {
return fmt.Sprintf("%s%d", userRefreshTokensPrefix, userID)
}
// tokenFamilyKey generates the Redis key for token family set.
func tokenFamilyKey(familyID string) string {
return tokenFamilyPrefix + familyID
}
type refreshTokenCache struct {
rdb *redis.Client
}
// NewRefreshTokenCache creates a new RefreshTokenCache implementation.
func NewRefreshTokenCache(rdb *redis.Client) service.RefreshTokenCache {
return &refreshTokenCache{rdb: rdb}
}
func (c *refreshTokenCache) StoreRefreshToken(ctx context.Context, tokenHash string, data *service.RefreshTokenData, ttl time.Duration) error {
key := refreshTokenKey(tokenHash)
val, err := json.Marshal(data)
if err != nil {
return fmt.Errorf("marshal refresh token data: %w", err)
}
return c.rdb.Set(ctx, key, val, ttl).Err()
}
func (c *refreshTokenCache) GetRefreshToken(ctx context.Context, tokenHash string) (*service.RefreshTokenData, error) {
key := refreshTokenKey(tokenHash)
val, err := c.rdb.Get(ctx, key).Result()
if err != nil {
if err == redis.Nil {
return nil, service.ErrRefreshTokenNotFound
}
return nil, err
}
var data service.RefreshTokenData
if err := json.Unmarshal([]byte(val), &data); err != nil {
return nil, fmt.Errorf("unmarshal refresh token data: %w", err)
}
return &data, nil
}
func (c *refreshTokenCache) DeleteRefreshToken(ctx context.Context, tokenHash string) error {
key := refreshTokenKey(tokenHash)
return c.rdb.Del(ctx, key).Err()
}
func (c *refreshTokenCache) DeleteUserRefreshTokens(ctx context.Context, userID int64) error {
// Get all token hashes for this user
tokenHashes, err := c.GetUserTokenHashes(ctx, userID)
if err != nil && err != redis.Nil {
return fmt.Errorf("get user token hashes: %w", err)
}
if len(tokenHashes) == 0 {
return nil
}
// Build keys to delete
keys := make([]string, 0, len(tokenHashes)+1)
for _, hash := range tokenHashes {
keys = append(keys, refreshTokenKey(hash))
}
keys = append(keys, userRefreshTokensKey(userID))
// Delete all keys in a pipeline
pipe := c.rdb.Pipeline()
for _, key := range keys {
pipe.Del(ctx, key)
}
_, err = pipe.Exec(ctx)
return err
}
func (c *refreshTokenCache) DeleteTokenFamily(ctx context.Context, familyID string) error {
// Get all token hashes in this family
tokenHashes, err := c.GetFamilyTokenHashes(ctx, familyID)
if err != nil && err != redis.Nil {
return fmt.Errorf("get family token hashes: %w", err)
}
if len(tokenHashes) == 0 {
return nil
}
// Build keys to delete
keys := make([]string, 0, len(tokenHashes)+1)
for _, hash := range tokenHashes {
keys = append(keys, refreshTokenKey(hash))
}
keys = append(keys, tokenFamilyKey(familyID))
// Delete all keys in a pipeline
pipe := c.rdb.Pipeline()
for _, key := range keys {
pipe.Del(ctx, key)
}
_, err = pipe.Exec(ctx)
return err
}
func (c *refreshTokenCache) AddToUserTokenSet(ctx context.Context, userID int64, tokenHash string, ttl time.Duration) error {
key := userRefreshTokensKey(userID)
pipe := c.rdb.Pipeline()
pipe.SAdd(ctx, key, tokenHash)
pipe.Expire(ctx, key, ttl)
_, err := pipe.Exec(ctx)
return err
}
func (c *refreshTokenCache) AddToFamilyTokenSet(ctx context.Context, familyID string, tokenHash string, ttl time.Duration) error {
key := tokenFamilyKey(familyID)
pipe := c.rdb.Pipeline()
pipe.SAdd(ctx, key, tokenHash)
pipe.Expire(ctx, key, ttl)
_, err := pipe.Exec(ctx)
return err
}
func (c *refreshTokenCache) GetUserTokenHashes(ctx context.Context, userID int64) ([]string, error) {
key := userRefreshTokensKey(userID)
return c.rdb.SMembers(ctx, key).Result()
}
func (c *refreshTokenCache) GetFamilyTokenHashes(ctx context.Context, familyID string) ([]string, error) {
key := tokenFamilyKey(familyID)
return c.rdb.SMembers(ctx, key).Result()
}
func (c *refreshTokenCache) IsTokenInFamily(ctx context.Context, familyID string, tokenHash string) (bool, error) {
key := tokenFamilyKey(familyID)
return c.rdb.SIsMember(ctx, key, tokenHash).Result()
}

View File

@@ -3,6 +3,7 @@ package repository
import ( import (
"context" "context"
"fmt" "fmt"
"log"
"strconv" "strconv"
"time" "time"
@@ -153,6 +154,21 @@ func NewSessionLimitCache(rdb *redis.Client, defaultIdleTimeoutMinutes int) serv
if defaultIdleTimeoutMinutes <= 0 { if defaultIdleTimeoutMinutes <= 0 {
defaultIdleTimeoutMinutes = 5 // 默认 5 分钟 defaultIdleTimeoutMinutes = 5 // 默认 5 分钟
} }
// 预加载 Lua 脚本到 Redis避免 Pipeline 中出现 NOSCRIPT 错误
ctx := context.Background()
scripts := []*redis.Script{
registerSessionScript,
refreshSessionScript,
getActiveSessionCountScript,
isSessionActiveScript,
}
for _, script := range scripts {
if err := script.Load(ctx, rdb).Err(); err != nil {
log.Printf("[SessionLimitCache] Failed to preload Lua script: %v", err)
}
}
return &sessionLimitCache{ return &sessionLimitCache{
rdb: rdb, rdb: rdb,
defaultIdleTimeout: time.Duration(defaultIdleTimeoutMinutes) * time.Minute, defaultIdleTimeout: time.Duration(defaultIdleTimeoutMinutes) * time.Minute,

View File

@@ -1128,6 +1128,107 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
return stats, nil return stats, nil
} }
// getPerformanceStatsByAPIKey 获取指定 API Key 的 RPM 和 TPM近5分钟平均值
func (r *usageLogRepository) getPerformanceStatsByAPIKey(ctx context.Context, apiKeyID int64) (rpm, tpm int64, err error) {
fiveMinutesAgo := time.Now().Add(-5 * time.Minute)
query := `
SELECT
COUNT(*) as request_count,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as token_count
FROM usage_logs
WHERE created_at >= $1 AND api_key_id = $2`
args := []any{fiveMinutesAgo, apiKeyID}
var requestCount int64
var tokenCount int64
if err := scanSingleRow(ctx, r.sql, query, args, &requestCount, &tokenCount); err != nil {
return 0, 0, err
}
return requestCount / 5, tokenCount / 5, nil
}
// GetAPIKeyDashboardStats 获取指定 API Key 的仪表盘统计(按 api_key_id 过滤)
func (r *usageLogRepository) GetAPIKeyDashboardStats(ctx context.Context, apiKeyID int64) (*UserDashboardStats, error) {
stats := &UserDashboardStats{}
today := timezone.Today()
// API Key 维度不需要统计 key 数量,设为 1
stats.TotalAPIKeys = 1
stats.ActiveAPIKeys = 1
// 累计 Token 统计
totalStatsQuery := `
SELECT
COUNT(*) as total_requests,
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
COALESCE(SUM(cache_creation_tokens), 0) as total_cache_creation_tokens,
COALESCE(SUM(cache_read_tokens), 0) as total_cache_read_tokens,
COALESCE(SUM(total_cost), 0) as total_cost,
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
COALESCE(AVG(duration_ms), 0) as avg_duration_ms
FROM usage_logs
WHERE api_key_id = $1
`
if err := scanSingleRow(
ctx,
r.sql,
totalStatsQuery,
[]any{apiKeyID},
&stats.TotalRequests,
&stats.TotalInputTokens,
&stats.TotalOutputTokens,
&stats.TotalCacheCreationTokens,
&stats.TotalCacheReadTokens,
&stats.TotalCost,
&stats.TotalActualCost,
&stats.AverageDurationMs,
); err != nil {
return nil, err
}
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheCreationTokens + stats.TotalCacheReadTokens
// 今日 Token 统计
todayStatsQuery := `
SELECT
COUNT(*) as today_requests,
COALESCE(SUM(input_tokens), 0) as today_input_tokens,
COALESCE(SUM(output_tokens), 0) as today_output_tokens,
COALESCE(SUM(cache_creation_tokens), 0) as today_cache_creation_tokens,
COALESCE(SUM(cache_read_tokens), 0) as today_cache_read_tokens,
COALESCE(SUM(total_cost), 0) as today_cost,
COALESCE(SUM(actual_cost), 0) as today_actual_cost
FROM usage_logs
WHERE api_key_id = $1 AND created_at >= $2
`
if err := scanSingleRow(
ctx,
r.sql,
todayStatsQuery,
[]any{apiKeyID, today},
&stats.TodayRequests,
&stats.TodayInputTokens,
&stats.TodayOutputTokens,
&stats.TodayCacheCreationTokens,
&stats.TodayCacheReadTokens,
&stats.TodayCost,
&stats.TodayActualCost,
); err != nil {
return nil, err
}
stats.TodayTokens = stats.TodayInputTokens + stats.TodayOutputTokens + stats.TodayCacheCreationTokens + stats.TodayCacheReadTokens
// 性能指标RPM 和 TPM最近5分钟按 API Key 过滤)
rpm, tpm, err := r.getPerformanceStatsByAPIKey(ctx, apiKeyID)
if err != nil {
return nil, err
}
stats.Rpm = rpm
stats.Tpm = tpm
return stats, nil
}
// GetUserUsageTrendByUserID 获取指定用户的使用趋势 // GetUserUsageTrendByUserID 获取指定用户的使用趋势
func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) (results []TrendDataPoint, err error) { func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) (results []TrendDataPoint, err error) {
dateFormat := "YYYY-MM-DD" dateFormat := "YYYY-MM-DD"

View File

@@ -0,0 +1,113 @@
package repository
import (
"context"
"database/sql"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
)
type userGroupRateRepository struct {
sql sqlExecutor
}
// NewUserGroupRateRepository 创建用户专属分组倍率仓储
func NewUserGroupRateRepository(sqlDB *sql.DB) service.UserGroupRateRepository {
return &userGroupRateRepository{sql: sqlDB}
}
// GetByUserID 获取用户的所有专属分组倍率
func (r *userGroupRateRepository) GetByUserID(ctx context.Context, userID int64) (map[int64]float64, error) {
query := `SELECT group_id, rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1`
rows, err := r.sql.QueryContext(ctx, query, userID)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
result := make(map[int64]float64)
for rows.Next() {
var groupID int64
var rate float64
if err := rows.Scan(&groupID, &rate); err != nil {
return nil, err
}
result[groupID] = rate
}
if err := rows.Err(); err != nil {
return nil, err
}
return result, nil
}
// GetByUserAndGroup 获取用户在特定分组的专属倍率
func (r *userGroupRateRepository) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) {
query := `SELECT rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2`
var rate float64
err := scanSingleRow(ctx, r.sql, query, []any{userID, groupID}, &rate)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
return &rate, nil
}
// SyncUserGroupRates 同步用户的分组专属倍率
func (r *userGroupRateRepository) SyncUserGroupRates(ctx context.Context, userID int64, rates map[int64]*float64) error {
if len(rates) == 0 {
// 如果传入空 map删除该用户的所有专属倍率
_, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE user_id = $1`, userID)
return err
}
// 分离需要删除和需要 upsert 的记录
var toDelete []int64
toUpsert := make(map[int64]float64)
for groupID, rate := range rates {
if rate == nil {
toDelete = append(toDelete, groupID)
} else {
toUpsert[groupID] = *rate
}
}
// 删除指定的记录
for _, groupID := range toDelete {
_, err := r.sql.ExecContext(ctx,
`DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2`,
userID, groupID)
if err != nil {
return err
}
}
// Upsert 记录
now := time.Now()
for groupID, rate := range toUpsert {
_, err := r.sql.ExecContext(ctx, `
INSERT INTO user_group_rate_multipliers (user_id, group_id, rate_multiplier, created_at, updated_at)
VALUES ($1, $2, $3, $4, $4)
ON CONFLICT (user_id, group_id) DO UPDATE SET rate_multiplier = $3, updated_at = $4
`, userID, groupID, rate, now)
if err != nil {
return err
}
}
return nil
}
// DeleteByGroupID 删除指定分组的所有用户专属倍率
func (r *userGroupRateRepository) DeleteByGroupID(ctx context.Context, groupID int64) error {
_, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE group_id = $1`, groupID)
return err
}
// DeleteByUserID 删除指定用户的所有专属倍率
func (r *userGroupRateRepository) DeleteByUserID(ctx context.Context, userID int64) error {
_, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE user_id = $1`, userID)
return err
}

View File

@@ -67,6 +67,8 @@ var ProviderSet = wire.NewSet(
NewUserSubscriptionRepository, NewUserSubscriptionRepository,
NewUserAttributeDefinitionRepository, NewUserAttributeDefinitionRepository,
NewUserAttributeValueRepository, NewUserAttributeValueRepository,
NewUserGroupRateRepository,
NewErrorPassthroughRepository,
// Cache implementations // Cache implementations
NewGatewayCache, NewGatewayCache,
@@ -86,6 +88,8 @@ var ProviderSet = wire.NewSet(
NewSchedulerOutboxRepository, NewSchedulerOutboxRepository,
NewProxyLatencyCache, NewProxyLatencyCache,
NewTotpCache, NewTotpCache,
NewRefreshTokenCache,
NewErrorPassthroughCache,
// Encryptors // Encryptors
NewAESEncryptor, NewAESEncryptor,

View File

@@ -598,7 +598,7 @@ func newContractDeps(t *testing.T) *contractDeps {
} }
userService := service.NewUserService(userRepo, nil) userService := service.NewUserService(userRepo, nil)
apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, apiKeyCache, cfg) apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, nil, apiKeyCache, cfg)
usageRepo := newStubUsageLogRepo() usageRepo := newStubUsageLogRepo()
usageService := service.NewUsageService(usageRepo, userRepo, nil, nil) usageService := service.NewUsageService(usageRepo, userRepo, nil, nil)
@@ -612,7 +612,7 @@ func newContractDeps(t *testing.T) *contractDeps {
settingRepo := newStubSettingRepo() settingRepo := newStubSettingRepo()
settingService := service.NewSettingService(settingRepo, cfg) settingService := service.NewSettingService(settingRepo, cfg)
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil) adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil)
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil) authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService) usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
@@ -1619,6 +1619,10 @@ func (r *stubUsageLogRepo) GetUserDashboardStats(ctx context.Context, userID int
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
func (r *stubUsageLogRepo) GetAPIKeyDashboardStats(ctx context.Context, apiKeyID int64) (*usagestats.UserDashboardStats, error) {
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) ([]usagestats.TrendDataPoint, error) { func (r *stubUsageLogRepo) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) ([]usagestats.TrendDataPoint, error) {
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }

View File

@@ -58,10 +58,39 @@ func ProvideRouter(
// ProvideHTTPServer 提供 HTTP 服务器 // ProvideHTTPServer 提供 HTTP 服务器
func ProvideHTTPServer(cfg *config.Config, router *gin.Engine) *http.Server { func ProvideHTTPServer(cfg *config.Config, router *gin.Engine) *http.Server {
handler := h2c.NewHandler(router, &http2.Server{}) httpHandler := http.Handler(router)
globalMaxSize := cfg.Server.MaxRequestBodySize
if globalMaxSize <= 0 {
globalMaxSize = cfg.Gateway.MaxBodySize
}
if globalMaxSize > 0 {
httpHandler = http.MaxBytesHandler(httpHandler, globalMaxSize)
log.Printf("Global max request body size: %d bytes (%.2f MB)", globalMaxSize, float64(globalMaxSize)/(1<<20))
}
// 根据配置决定是否启用 H2C
if cfg.Server.H2C.Enabled {
h2cConfig := cfg.Server.H2C
httpHandler = h2c.NewHandler(router, &http2.Server{
MaxConcurrentStreams: h2cConfig.MaxConcurrentStreams,
IdleTimeout: time.Duration(h2cConfig.IdleTimeout) * time.Second,
MaxReadFrameSize: uint32(h2cConfig.MaxReadFrameSize),
MaxUploadBufferPerConnection: int32(h2cConfig.MaxUploadBufferPerConnection),
MaxUploadBufferPerStream: int32(h2cConfig.MaxUploadBufferPerStream),
})
log.Printf("HTTP/2 Cleartext (h2c) enabled: max_concurrent_streams=%d, idle_timeout=%ds, max_read_frame_size=%d, max_upload_buffer_per_connection=%d, max_upload_buffer_per_stream=%d",
h2cConfig.MaxConcurrentStreams,
h2cConfig.IdleTimeout,
h2cConfig.MaxReadFrameSize,
h2cConfig.MaxUploadBufferPerConnection,
h2cConfig.MaxUploadBufferPerStream,
)
}
return &http.Server{ return &http.Server{
Addr: cfg.Server.Address(), Addr: cfg.Server.Address(),
Handler: handler, Handler: httpHandler,
// ReadHeaderTimeout: 读取请求头的超时时间,防止慢速请求头攻击 // ReadHeaderTimeout: 读取请求头的超时时间,防止慢速请求头攻击
ReadHeaderTimeout: time.Duration(cfg.Server.ReadHeaderTimeout) * time.Second, ReadHeaderTimeout: time.Duration(cfg.Server.ReadHeaderTimeout) * time.Second,
// IdleTimeout: 空闲连接超时时间,释放不活跃的连接资源 // IdleTimeout: 空闲连接超时时间,释放不活跃的连接资源

View File

@@ -93,6 +93,7 @@ func newTestAPIKeyService(repo service.APIKeyRepository) *service.APIKeyService
nil, // userRepo (unused in GetByKey) nil, // userRepo (unused in GetByKey)
nil, // groupRepo nil, // groupRepo
nil, // userSubRepo nil, // userSubRepo
nil, // userGroupRateRepo
nil, // cache nil, // cache
&config.Config{}, &config.Config{},
) )
@@ -187,6 +188,7 @@ func TestApiKeyAuthWithSubscriptionGoogleSetsGroupContext(t *testing.T) {
nil, nil,
nil, nil,
nil, nil,
nil,
&config.Config{RunMode: config.RunModeSimple}, &config.Config{RunMode: config.RunModeSimple},
) )

View File

@@ -59,7 +59,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
t.Run("simple_mode_bypasses_quota_check", func(t *testing.T) { t.Run("simple_mode_bypasses_quota_check", func(t *testing.T) {
cfg := &config.Config{RunMode: config.RunModeSimple} cfg := &config.Config{RunMode: config.RunModeSimple}
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg) apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
subscriptionService := service.NewSubscriptionService(nil, &stubUserSubscriptionRepo{}, nil) subscriptionService := service.NewSubscriptionService(nil, &stubUserSubscriptionRepo{}, nil)
router := newAuthTestRouter(apiKeyService, subscriptionService, cfg) router := newAuthTestRouter(apiKeyService, subscriptionService, cfg)
@@ -73,7 +73,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
t.Run("standard_mode_enforces_quota_check", func(t *testing.T) { t.Run("standard_mode_enforces_quota_check", func(t *testing.T) {
cfg := &config.Config{RunMode: config.RunModeStandard} cfg := &config.Config{RunMode: config.RunModeStandard}
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg) apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
now := time.Now() now := time.Now()
sub := &service.UserSubscription{ sub := &service.UserSubscription{
@@ -150,7 +150,7 @@ func TestAPIKeyAuthSetsGroupContext(t *testing.T) {
} }
cfg := &config.Config{RunMode: config.RunModeSimple} cfg := &config.Config{RunMode: config.RunModeSimple}
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg) apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
router := gin.New() router := gin.New()
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, nil, cfg))) router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, nil, cfg)))
router.GET("/t", func(c *gin.Context) { router.GET("/t", func(c *gin.Context) {
@@ -208,7 +208,7 @@ func TestAPIKeyAuthOverwritesInvalidContextGroup(t *testing.T) {
} }
cfg := &config.Config{RunMode: config.RunModeSimple} cfg := &config.Config{RunMode: config.RunModeSimple}
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg) apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
router := gin.New() router := gin.New()
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, nil, cfg))) router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, nil, cfg)))

View File

@@ -34,12 +34,16 @@ func Logger() gin.HandlerFunc {
// 客户端IP // 客户端IP
clientIP := c.ClientIP() clientIP := c.ClientIP()
// 日志格式: [时间] 状态码 | 延迟 | IP | 方法 路径 // 协议版本
log.Printf("[GIN] %v | %3d | %13v | %15s | %-7s %s", protocol := c.Request.Proto
// 日志格式: [时间] 状态码 | 延迟 | IP | 协议 | 方法 路径
log.Printf("[GIN] %v | %3d | %13v | %15s | %-6s | %-7s %s",
endTime.Format("2006/01/02 - 15:04:05"), endTime.Format("2006/01/02 - 15:04:05"),
statusCode, statusCode,
latency, latency,
clientIP, clientIP,
protocol,
method, method,
path, path,
) )

View File

@@ -67,6 +67,9 @@ func RegisterAdminRoutes(
// 用户属性管理 // 用户属性管理
registerUserAttributeRoutes(admin, h) registerUserAttributeRoutes(admin, h)
// 错误透传规则管理
registerErrorPassthroughRoutes(admin, h)
} }
} }
@@ -387,3 +390,14 @@ func registerUserAttributeRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
attrs.DELETE("/:id", h.Admin.UserAttribute.DeleteDefinition) attrs.DELETE("/:id", h.Admin.UserAttribute.DeleteDefinition)
} }
} }
func registerErrorPassthroughRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
rules := admin.Group("/error-passthrough-rules")
{
rules.GET("", h.Admin.ErrorPassthrough.List)
rules.GET("/:id", h.Admin.ErrorPassthrough.GetByID)
rules.POST("", h.Admin.ErrorPassthrough.Create)
rules.PUT("/:id", h.Admin.ErrorPassthrough.Update)
rules.DELETE("/:id", h.Admin.ErrorPassthrough.Delete)
}
}

View File

@@ -28,6 +28,12 @@ func RegisterAuthRoutes(
auth.POST("/login", h.Auth.Login) auth.POST("/login", h.Auth.Login)
auth.POST("/login/2fa", h.Auth.Login2FA) auth.POST("/login/2fa", h.Auth.Login2FA)
auth.POST("/send-verify-code", h.Auth.SendVerifyCode) auth.POST("/send-verify-code", h.Auth.SendVerifyCode)
// Token刷新接口添加速率限制每分钟最多 30 次Redis 故障时 fail-close
auth.POST("/refresh", rateLimiter.LimitWithOptions("refresh-token", 30, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}), h.Auth.RefreshToken)
// 登出接口公开允许未认证用户调用以撤销Refresh Token
auth.POST("/logout", h.Auth.Logout)
// 优惠码验证接口添加速率限制:每分钟最多 10 次Redis 故障时 fail-close // 优惠码验证接口添加速率限制:每分钟最多 10 次Redis 故障时 fail-close
auth.POST("/validate-promo-code", rateLimiter.LimitWithOptions("validate-promo", 10, time.Minute, middleware.RateLimitOptions{ auth.POST("/validate-promo-code", rateLimiter.LimitWithOptions("validate-promo", 10, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose, FailureMode: middleware.RateLimitFailClose,
@@ -59,5 +65,7 @@ func RegisterAuthRoutes(
authenticated.Use(gin.HandlerFunc(jwtAuth)) authenticated.Use(gin.HandlerFunc(jwtAuth))
{ {
authenticated.GET("/auth/me", h.Auth.GetCurrentUser) authenticated.GET("/auth/me", h.Auth.GetCurrentUser)
// 撤销所有会话(需要认证)
authenticated.POST("/auth/revoke-all-sessions", h.Auth.RevokeAllSessions)
} }
} }

View File

@@ -49,6 +49,7 @@ func RegisterUserRoutes(
groups := authenticated.Group("/groups") groups := authenticated.Group("/groups")
{ {
groups.GET("/available", h.APIKey.GetAvailableGroups) groups.GET("/available", h.APIKey.GetAvailableGroups)
groups.GET("/rates", h.APIKey.GetUserGroupRates)
} }
// 使用记录 // 使用记录

View File

@@ -41,6 +41,7 @@ type UsageLogRepository interface {
// User dashboard stats // User dashboard stats
GetUserDashboardStats(ctx context.Context, userID int64) (*usagestats.UserDashboardStats, error) GetUserDashboardStats(ctx context.Context, userID int64) (*usagestats.UserDashboardStats, error)
GetAPIKeyDashboardStats(ctx context.Context, apiKeyID int64) (*usagestats.UserDashboardStats, error)
GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) ([]usagestats.TrendDataPoint, error) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) ([]usagestats.TrendDataPoint, error)
GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) ([]usagestats.ModelStat, error) GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) ([]usagestats.ModelStat, error)

View File

@@ -93,6 +93,9 @@ type UpdateUserInput struct {
Concurrency *int // 使用指针区分"未提供"和"设置为0" Concurrency *int // 使用指针区分"未提供"和"设置为0"
Status string Status string
AllowedGroups *[]int64 // 使用指针区分"未提供"和"设置为空数组" AllowedGroups *[]int64 // 使用指针区分"未提供"和"设置为空数组"
// GroupRates 用户专属分组倍率配置
// map[groupID]*ratenil 表示删除该分组的专属倍率
GroupRates map[int64]*float64
} }
type CreateGroupInput struct { type CreateGroupInput struct {
@@ -304,6 +307,7 @@ type adminServiceImpl struct {
proxyRepo ProxyRepository proxyRepo ProxyRepository
apiKeyRepo APIKeyRepository apiKeyRepo APIKeyRepository
redeemCodeRepo RedeemCodeRepository redeemCodeRepo RedeemCodeRepository
userGroupRateRepo UserGroupRateRepository
billingCacheService *BillingCacheService billingCacheService *BillingCacheService
proxyProber ProxyExitInfoProber proxyProber ProxyExitInfoProber
proxyLatencyCache ProxyLatencyCache proxyLatencyCache ProxyLatencyCache
@@ -319,6 +323,7 @@ func NewAdminService(
proxyRepo ProxyRepository, proxyRepo ProxyRepository,
apiKeyRepo APIKeyRepository, apiKeyRepo APIKeyRepository,
redeemCodeRepo RedeemCodeRepository, redeemCodeRepo RedeemCodeRepository,
userGroupRateRepo UserGroupRateRepository,
billingCacheService *BillingCacheService, billingCacheService *BillingCacheService,
proxyProber ProxyExitInfoProber, proxyProber ProxyExitInfoProber,
proxyLatencyCache ProxyLatencyCache, proxyLatencyCache ProxyLatencyCache,
@@ -332,6 +337,7 @@ func NewAdminService(
proxyRepo: proxyRepo, proxyRepo: proxyRepo,
apiKeyRepo: apiKeyRepo, apiKeyRepo: apiKeyRepo,
redeemCodeRepo: redeemCodeRepo, redeemCodeRepo: redeemCodeRepo,
userGroupRateRepo: userGroupRateRepo,
billingCacheService: billingCacheService, billingCacheService: billingCacheService,
proxyProber: proxyProber, proxyProber: proxyProber,
proxyLatencyCache: proxyLatencyCache, proxyLatencyCache: proxyLatencyCache,
@@ -346,11 +352,35 @@ func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, fi
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
} }
// 批量加载用户专属分组倍率
if s.userGroupRateRepo != nil && len(users) > 0 {
for i := range users {
rates, err := s.userGroupRateRepo.GetByUserID(ctx, users[i].ID)
if err != nil {
log.Printf("failed to load user group rates: user_id=%d err=%v", users[i].ID, err)
continue
}
users[i].GroupRates = rates
}
}
return users, result.Total, nil return users, result.Total, nil
} }
func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error) { func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error) {
return s.userRepo.GetByID(ctx, id) user, err := s.userRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
// 加载用户专属分组倍率
if s.userGroupRateRepo != nil {
rates, err := s.userGroupRateRepo.GetByUserID(ctx, id)
if err != nil {
log.Printf("failed to load user group rates: user_id=%d err=%v", id, err)
} else {
user.GroupRates = rates
}
}
return user, nil
} }
func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInput) (*User, error) { func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInput) (*User, error) {
@@ -419,6 +449,14 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
if err := s.userRepo.Update(ctx, user); err != nil { if err := s.userRepo.Update(ctx, user); err != nil {
return nil, err return nil, err
} }
// 同步用户专属分组倍率
if input.GroupRates != nil && s.userGroupRateRepo != nil {
if err := s.userGroupRateRepo.SyncUserGroupRates(ctx, user.ID, input.GroupRates); err != nil {
log.Printf("failed to sync user group rates: user_id=%d err=%v", user.ID, err)
}
}
if s.authCacheInvalidator != nil { if s.authCacheInvalidator != nil {
if user.Concurrency != oldConcurrency || user.Status != oldStatus || user.Role != oldRole { if user.Concurrency != oldConcurrency || user.Status != oldStatus || user.Role != oldRole {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, user.ID) s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, user.ID)
@@ -974,6 +1012,7 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
if err != nil { if err != nil {
return err return err
} }
// 注意user_group_rate_multipliers 表通过外键 ON DELETE CASCADE 自动清理
// 事务成功后,异步失效受影响用户的订阅缓存 // 事务成功后,异步失效受影响用户的订阅缓存
if len(affectedUserIDs) > 0 && s.billingCacheService != nil { if len(affectedUserIDs) > 0 && s.billingCacheService != nil {

View File

@@ -1106,7 +1106,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
Message: upstreamMsg, Message: upstreamMsg,
Detail: upstreamDetail, Detail: upstreamDetail,
}) })
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
} }
return nil, s.writeMappedClaudeError(c, account, resp.StatusCode, resp.Header.Get("x-request-id"), respBody) return nil, s.writeMappedClaudeError(c, account, resp.StatusCode, resp.Header.Get("x-request-id"), respBody)
@@ -1779,6 +1779,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
// 处理错误响应 // 处理错误响应
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
contentType := resp.Header.Get("Content-Type")
// 尽早关闭原始响应体,释放连接;后续逻辑仍可能需要读取 body因此用内存副本重新包装。 // 尽早关闭原始响应体,释放连接;后续逻辑仍可能需要读取 body因此用内存副本重新包装。
_ = resp.Body.Close() _ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody)) resp.Body = io.NopCloser(bytes.NewReader(respBody))
@@ -1849,10 +1850,8 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
Message: upstreamMsg, Message: upstreamMsg,
Detail: upstreamDetail, Detail: upstreamDetail,
}) })
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: unwrappedForOps}
} }
contentType := resp.Header.Get("Content-Type")
if contentType == "" { if contentType == "" {
contentType = "application/json" contentType = "application/json"
} }

View File

@@ -115,15 +115,16 @@ type UpdateAPIKeyRequest struct {
// APIKeyService API Key服务 // APIKeyService API Key服务
type APIKeyService struct { type APIKeyService struct {
apiKeyRepo APIKeyRepository apiKeyRepo APIKeyRepository
userRepo UserRepository userRepo UserRepository
groupRepo GroupRepository groupRepo GroupRepository
userSubRepo UserSubscriptionRepository userSubRepo UserSubscriptionRepository
cache APIKeyCache userGroupRateRepo UserGroupRateRepository
cfg *config.Config cache APIKeyCache
authCacheL1 *ristretto.Cache cfg *config.Config
authCfg apiKeyAuthCacheConfig authCacheL1 *ristretto.Cache
authGroup singleflight.Group authCfg apiKeyAuthCacheConfig
authGroup singleflight.Group
} }
// NewAPIKeyService 创建API Key服务实例 // NewAPIKeyService 创建API Key服务实例
@@ -132,16 +133,18 @@ func NewAPIKeyService(
userRepo UserRepository, userRepo UserRepository,
groupRepo GroupRepository, groupRepo GroupRepository,
userSubRepo UserSubscriptionRepository, userSubRepo UserSubscriptionRepository,
userGroupRateRepo UserGroupRateRepository,
cache APIKeyCache, cache APIKeyCache,
cfg *config.Config, cfg *config.Config,
) *APIKeyService { ) *APIKeyService {
svc := &APIKeyService{ svc := &APIKeyService{
apiKeyRepo: apiKeyRepo, apiKeyRepo: apiKeyRepo,
userRepo: userRepo, userRepo: userRepo,
groupRepo: groupRepo, groupRepo: groupRepo,
userSubRepo: userSubRepo, userSubRepo: userSubRepo,
cache: cache, userGroupRateRepo: userGroupRateRepo,
cfg: cfg, cache: cache,
cfg: cfg,
} }
svc.initAuthCache(cfg) svc.initAuthCache(cfg)
return svc return svc
@@ -627,6 +630,19 @@ func (s *APIKeyService) SearchAPIKeys(ctx context.Context, userID int64, keyword
return keys, nil return keys, nil
} }
// GetUserGroupRates 获取用户的专属分组倍率配置
// 返回 map[groupID]rateMultiplier
func (s *APIKeyService) GetUserGroupRates(ctx context.Context, userID int64) (map[int64]float64, error) {
if s.userGroupRateRepo == nil {
return nil, nil
}
rates, err := s.userGroupRateRepo.GetByUserID(ctx, userID)
if err != nil {
return nil, fmt.Errorf("get user group rates: %w", err)
}
return rates, nil
}
// CheckAPIKeyQuotaAndExpiry checks if the API key is valid for use (not expired, quota not exhausted) // CheckAPIKeyQuotaAndExpiry checks if the API key is valid for use (not expired, quota not exhausted)
// Returns nil if valid, error if invalid // Returns nil if valid, error if invalid
func (s *APIKeyService) CheckAPIKeyQuotaAndExpiry(apiKey *APIKey) error { func (s *APIKeyService) CheckAPIKeyQuotaAndExpiry(apiKey *APIKey) error {

View File

@@ -167,7 +167,7 @@ func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) {
NegativeTTLSeconds: 30, NegativeTTLSeconds: 30,
}, },
} }
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
groupID := int64(9) groupID := int64(9)
cacheEntry := &APIKeyAuthCacheEntry{ cacheEntry := &APIKeyAuthCacheEntry{
@@ -223,7 +223,7 @@ func TestAPIKeyService_GetByKey_NegativeCache(t *testing.T) {
NegativeTTLSeconds: 30, NegativeTTLSeconds: 30,
}, },
} }
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) { cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
return &APIKeyAuthCacheEntry{NotFound: true}, nil return &APIKeyAuthCacheEntry{NotFound: true}, nil
} }
@@ -256,7 +256,7 @@ func TestAPIKeyService_GetByKey_CacheMissStoresL2(t *testing.T) {
NegativeTTLSeconds: 30, NegativeTTLSeconds: 30,
}, },
} }
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) { cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
return nil, redis.Nil return nil, redis.Nil
} }
@@ -293,7 +293,7 @@ func TestAPIKeyService_GetByKey_UsesL1Cache(t *testing.T) {
L1TTLSeconds: 60, L1TTLSeconds: 60,
}, },
} }
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
require.NotNil(t, svc.authCacheL1) require.NotNil(t, svc.authCacheL1)
_, err := svc.GetByKey(context.Background(), "k-l1") _, err := svc.GetByKey(context.Background(), "k-l1")
@@ -320,7 +320,7 @@ func TestAPIKeyService_InvalidateAuthCacheByUserID(t *testing.T) {
NegativeTTLSeconds: 30, NegativeTTLSeconds: 30,
}, },
} }
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
svc.InvalidateAuthCacheByUserID(context.Background(), 7) svc.InvalidateAuthCacheByUserID(context.Background(), 7)
require.Len(t, cache.deleteAuthKeys, 2) require.Len(t, cache.deleteAuthKeys, 2)
@@ -338,7 +338,7 @@ func TestAPIKeyService_InvalidateAuthCacheByGroupID(t *testing.T) {
L2TTLSeconds: 60, L2TTLSeconds: 60,
}, },
} }
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
svc.InvalidateAuthCacheByGroupID(context.Background(), 9) svc.InvalidateAuthCacheByGroupID(context.Background(), 9)
require.Len(t, cache.deleteAuthKeys, 2) require.Len(t, cache.deleteAuthKeys, 2)
@@ -356,7 +356,7 @@ func TestAPIKeyService_InvalidateAuthCacheByKey(t *testing.T) {
L2TTLSeconds: 60, L2TTLSeconds: 60,
}, },
} }
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
svc.InvalidateAuthCacheByKey(context.Background(), "k1") svc.InvalidateAuthCacheByKey(context.Background(), "k1")
require.Len(t, cache.deleteAuthKeys, 1) require.Len(t, cache.deleteAuthKeys, 1)
@@ -375,7 +375,7 @@ func TestAPIKeyService_GetByKey_CachesNegativeOnRepoMiss(t *testing.T) {
NegativeTTLSeconds: 30, NegativeTTLSeconds: 30,
}, },
} }
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) { cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
return nil, redis.Nil return nil, redis.Nil
} }
@@ -411,7 +411,7 @@ func TestAPIKeyService_GetByKey_SingleflightCollapses(t *testing.T) {
Singleflight: true, Singleflight: true,
}, },
} }
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
start := make(chan struct{}) start := make(chan struct{})
wg := sync.WaitGroup{} wg := sync.WaitGroup{}

View File

@@ -3,6 +3,7 @@ package service
import ( import (
"context" "context"
"crypto/rand" "crypto/rand"
"crypto/sha256"
"encoding/hex" "encoding/hex"
"errors" "errors"
"fmt" "fmt"
@@ -25,8 +26,12 @@ var (
ErrEmailReserved = infraerrors.BadRequest("EMAIL_RESERVED", "email is reserved") ErrEmailReserved = infraerrors.BadRequest("EMAIL_RESERVED", "email is reserved")
ErrInvalidToken = infraerrors.Unauthorized("INVALID_TOKEN", "invalid token") ErrInvalidToken = infraerrors.Unauthorized("INVALID_TOKEN", "invalid token")
ErrTokenExpired = infraerrors.Unauthorized("TOKEN_EXPIRED", "token has expired") ErrTokenExpired = infraerrors.Unauthorized("TOKEN_EXPIRED", "token has expired")
ErrAccessTokenExpired = infraerrors.Unauthorized("ACCESS_TOKEN_EXPIRED", "access token has expired")
ErrTokenTooLarge = infraerrors.BadRequest("TOKEN_TOO_LARGE", "token too large") ErrTokenTooLarge = infraerrors.BadRequest("TOKEN_TOO_LARGE", "token too large")
ErrTokenRevoked = infraerrors.Unauthorized("TOKEN_REVOKED", "token has been revoked") ErrTokenRevoked = infraerrors.Unauthorized("TOKEN_REVOKED", "token has been revoked")
ErrRefreshTokenInvalid = infraerrors.Unauthorized("REFRESH_TOKEN_INVALID", "invalid refresh token")
ErrRefreshTokenExpired = infraerrors.Unauthorized("REFRESH_TOKEN_EXPIRED", "refresh token has expired")
ErrRefreshTokenReused = infraerrors.Unauthorized("REFRESH_TOKEN_REUSED", "refresh token has been reused")
ErrEmailVerifyRequired = infraerrors.BadRequest("EMAIL_VERIFY_REQUIRED", "email verification is required") ErrEmailVerifyRequired = infraerrors.BadRequest("EMAIL_VERIFY_REQUIRED", "email verification is required")
ErrRegDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled") ErrRegDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
ErrServiceUnavailable = infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable") ErrServiceUnavailable = infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable")
@@ -37,6 +42,9 @@ var (
// maxTokenLength 限制 token 大小,避免超长 header 触发解析时的异常内存分配。 // maxTokenLength 限制 token 大小,避免超长 header 触发解析时的异常内存分配。
const maxTokenLength = 8192 const maxTokenLength = 8192
// refreshTokenPrefix is the prefix for refresh tokens to distinguish them from access tokens.
const refreshTokenPrefix = "rt_"
// JWTClaims JWT载荷数据 // JWTClaims JWT载荷数据
type JWTClaims struct { type JWTClaims struct {
UserID int64 `json:"user_id"` UserID int64 `json:"user_id"`
@@ -50,6 +58,7 @@ type JWTClaims struct {
type AuthService struct { type AuthService struct {
userRepo UserRepository userRepo UserRepository
redeemRepo RedeemCodeRepository redeemRepo RedeemCodeRepository
refreshTokenCache RefreshTokenCache
cfg *config.Config cfg *config.Config
settingService *SettingService settingService *SettingService
emailService *EmailService emailService *EmailService
@@ -62,6 +71,7 @@ type AuthService struct {
func NewAuthService( func NewAuthService(
userRepo UserRepository, userRepo UserRepository,
redeemRepo RedeemCodeRepository, redeemRepo RedeemCodeRepository,
refreshTokenCache RefreshTokenCache,
cfg *config.Config, cfg *config.Config,
settingService *SettingService, settingService *SettingService,
emailService *EmailService, emailService *EmailService,
@@ -72,6 +82,7 @@ func NewAuthService(
return &AuthService{ return &AuthService{
userRepo: userRepo, userRepo: userRepo,
redeemRepo: redeemRepo, redeemRepo: redeemRepo,
refreshTokenCache: refreshTokenCache,
cfg: cfg, cfg: cfg,
settingService: settingService, settingService: settingService,
emailService: emailService, emailService: emailService,
@@ -481,6 +492,100 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
return token, user, nil return token, user, nil
} }
// LoginOrRegisterOAuthWithTokenPair 用于第三方 OAuth/SSO 登录,返回完整的 TokenPair
// 与 LoginOrRegisterOAuth 功能相同,但返回 TokenPair 而非单个 token
func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, email, username string) (*TokenPair, *User, error) {
// 检查 refreshTokenCache 是否可用
if s.refreshTokenCache == nil {
return nil, nil, errors.New("refresh token cache not configured")
}
email = strings.TrimSpace(email)
if email == "" || len(email) > 255 {
return nil, nil, infraerrors.BadRequest("INVALID_EMAIL", "invalid email")
}
if _, err := mail.ParseAddress(email); err != nil {
return nil, nil, infraerrors.BadRequest("INVALID_EMAIL", "invalid email")
}
username = strings.TrimSpace(username)
if len([]rune(username)) > 100 {
username = string([]rune(username)[:100])
}
user, err := s.userRepo.GetByEmail(ctx, email)
if err != nil {
if errors.Is(err, ErrUserNotFound) {
// OAuth 首次登录视为注册
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
return nil, nil, ErrRegDisabled
}
randomPassword, err := randomHexString(32)
if err != nil {
log.Printf("[Auth] Failed to generate random password for oauth signup: %v", err)
return nil, nil, ErrServiceUnavailable
}
hashedPassword, err := s.HashPassword(randomPassword)
if err != nil {
return nil, nil, fmt.Errorf("hash password: %w", err)
}
defaultBalance := s.cfg.Default.UserBalance
defaultConcurrency := s.cfg.Default.UserConcurrency
if s.settingService != nil {
defaultBalance = s.settingService.GetDefaultBalance(ctx)
defaultConcurrency = s.settingService.GetDefaultConcurrency(ctx)
}
newUser := &User{
Email: email,
Username: username,
PasswordHash: hashedPassword,
Role: RoleUser,
Balance: defaultBalance,
Concurrency: defaultConcurrency,
Status: StatusActive,
}
if err := s.userRepo.Create(ctx, newUser); err != nil {
if errors.Is(err, ErrEmailExists) {
user, err = s.userRepo.GetByEmail(ctx, email)
if err != nil {
log.Printf("[Auth] Database error getting user after conflict: %v", err)
return nil, nil, ErrServiceUnavailable
}
} else {
log.Printf("[Auth] Database error creating oauth user: %v", err)
return nil, nil, ErrServiceUnavailable
}
} else {
user = newUser
}
} else {
log.Printf("[Auth] Database error during oauth login: %v", err)
return nil, nil, ErrServiceUnavailable
}
}
if !user.IsActive() {
return nil, nil, ErrUserNotActive
}
if user.Username == "" && username != "" {
user.Username = username
if err := s.userRepo.Update(ctx, user); err != nil {
log.Printf("[Auth] Failed to update username after oauth login: %v", err)
}
}
tokenPair, err := s.GenerateTokenPair(ctx, user, "")
if err != nil {
return nil, nil, fmt.Errorf("generate token pair: %w", err)
}
return tokenPair, user, nil
}
// ValidateToken 验证JWT token并返回用户声明 // ValidateToken 验证JWT token并返回用户声明
func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) { func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) {
// 先做长度校验,尽早拒绝异常超长 token降低 DoS 风险。 // 先做长度校验,尽早拒绝异常超长 token降低 DoS 风险。
@@ -539,10 +644,17 @@ func isReservedEmail(email string) bool {
return strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain) return strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain)
} }
// GenerateToken 生成JWT token // GenerateToken 生成JWT access token
// 使用新的access_token_expire_minutes配置项如果配置了否则回退到expire_hour
func (s *AuthService) GenerateToken(user *User) (string, error) { func (s *AuthService) GenerateToken(user *User) (string, error) {
now := time.Now() now := time.Now()
expiresAt := now.Add(time.Duration(s.cfg.JWT.ExpireHour) * time.Hour) var expiresAt time.Time
if s.cfg.JWT.AccessTokenExpireMinutes > 0 {
expiresAt = now.Add(time.Duration(s.cfg.JWT.AccessTokenExpireMinutes) * time.Minute)
} else {
// 向后兼容使用旧的expire_hour配置
expiresAt = now.Add(time.Duration(s.cfg.JWT.ExpireHour) * time.Hour)
}
claims := &JWTClaims{ claims := &JWTClaims{
UserID: user.ID, UserID: user.ID,
@@ -565,6 +677,15 @@ func (s *AuthService) GenerateToken(user *User) (string, error) {
return tokenString, nil return tokenString, nil
} }
// GetAccessTokenExpiresIn 返回Access Token的有效期
// 用于前端设置刷新定时器
func (s *AuthService) GetAccessTokenExpiresIn() int {
if s.cfg.JWT.AccessTokenExpireMinutes > 0 {
return s.cfg.JWT.AccessTokenExpireMinutes * 60
}
return s.cfg.JWT.ExpireHour * 3600
}
// HashPassword 使用bcrypt加密密码 // HashPassword 使用bcrypt加密密码
func (s *AuthService) HashPassword(password string) (string, error) { func (s *AuthService) HashPassword(password string) (string, error) {
hashedBytes, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) hashedBytes, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
@@ -755,6 +876,198 @@ func (s *AuthService) ResetPassword(ctx context.Context, email, token, newPasswo
return ErrServiceUnavailable return ErrServiceUnavailable
} }
// Also revoke all refresh tokens for this user
if err := s.RevokeAllUserSessions(ctx, user.ID); err != nil {
log.Printf("[Auth] Failed to revoke refresh tokens for user %d: %v", user.ID, err)
// Don't return error - password was already changed successfully
}
log.Printf("[Auth] Password reset successful for user: %s", email) log.Printf("[Auth] Password reset successful for user: %s", email)
return nil return nil
} }
// ==================== Refresh Token Methods ====================
// TokenPair 包含Access Token和Refresh Token
type TokenPair struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int `json:"expires_in"` // Access Token有效期
}
// GenerateTokenPair 生成Access Token和Refresh Token对
// familyID: 可选的Token家族ID用于Token轮转时保持家族关系
func (s *AuthService) GenerateTokenPair(ctx context.Context, user *User, familyID string) (*TokenPair, error) {
// 检查 refreshTokenCache 是否可用
if s.refreshTokenCache == nil {
return nil, errors.New("refresh token cache not configured")
}
// 生成Access Token
accessToken, err := s.GenerateToken(user)
if err != nil {
return nil, fmt.Errorf("generate access token: %w", err)
}
// 生成Refresh Token
refreshToken, err := s.generateRefreshToken(ctx, user, familyID)
if err != nil {
return nil, fmt.Errorf("generate refresh token: %w", err)
}
return &TokenPair{
AccessToken: accessToken,
RefreshToken: refreshToken,
ExpiresIn: s.GetAccessTokenExpiresIn(),
}, nil
}
// generateRefreshToken 生成并存储Refresh Token
func (s *AuthService) generateRefreshToken(ctx context.Context, user *User, familyID string) (string, error) {
// 生成随机Token
tokenBytes := make([]byte, 32)
if _, err := rand.Read(tokenBytes); err != nil {
return "", fmt.Errorf("generate random bytes: %w", err)
}
rawToken := refreshTokenPrefix + hex.EncodeToString(tokenBytes)
// 计算Token哈希存储哈希而非原始Token
tokenHash := hashToken(rawToken)
// 如果没有提供familyID生成新的
if familyID == "" {
familyBytes := make([]byte, 16)
if _, err := rand.Read(familyBytes); err != nil {
return "", fmt.Errorf("generate family id: %w", err)
}
familyID = hex.EncodeToString(familyBytes)
}
now := time.Now()
ttl := time.Duration(s.cfg.JWT.RefreshTokenExpireDays) * 24 * time.Hour
data := &RefreshTokenData{
UserID: user.ID,
TokenVersion: user.TokenVersion,
FamilyID: familyID,
CreatedAt: now,
ExpiresAt: now.Add(ttl),
}
// 存储Token数据
if err := s.refreshTokenCache.StoreRefreshToken(ctx, tokenHash, data, ttl); err != nil {
return "", fmt.Errorf("store refresh token: %w", err)
}
// 添加到用户Token集合
if err := s.refreshTokenCache.AddToUserTokenSet(ctx, user.ID, tokenHash, ttl); err != nil {
log.Printf("[Auth] Failed to add token to user set: %v", err)
// 不影响主流程
}
// 添加到家族Token集合
if err := s.refreshTokenCache.AddToFamilyTokenSet(ctx, familyID, tokenHash, ttl); err != nil {
log.Printf("[Auth] Failed to add token to family set: %v", err)
// 不影响主流程
}
return rawToken, nil
}
// RefreshTokenPair 使用Refresh Token刷新Token对
// 实现Token轮转每次刷新都会生成新的Refresh Token旧Token立即失效
func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string) (*TokenPair, error) {
// 检查 refreshTokenCache 是否可用
if s.refreshTokenCache == nil {
return nil, ErrRefreshTokenInvalid
}
// 验证Token格式
if !strings.HasPrefix(refreshToken, refreshTokenPrefix) {
return nil, ErrRefreshTokenInvalid
}
tokenHash := hashToken(refreshToken)
// 获取Token数据
data, err := s.refreshTokenCache.GetRefreshToken(ctx, tokenHash)
if err != nil {
if errors.Is(err, ErrRefreshTokenNotFound) {
// Token不存在可能是已被使用Token轮转或已过期
log.Printf("[Auth] Refresh token not found, possible reuse attack")
return nil, ErrRefreshTokenInvalid
}
log.Printf("[Auth] Error getting refresh token: %v", err)
return nil, ErrServiceUnavailable
}
// 检查Token是否过期
if time.Now().After(data.ExpiresAt) {
// 删除过期Token
_ = s.refreshTokenCache.DeleteRefreshToken(ctx, tokenHash)
return nil, ErrRefreshTokenExpired
}
// 获取用户信息
user, err := s.userRepo.GetByID(ctx, data.UserID)
if err != nil {
if errors.Is(err, ErrUserNotFound) {
// 用户已删除撤销整个Token家族
_ = s.refreshTokenCache.DeleteTokenFamily(ctx, data.FamilyID)
return nil, ErrRefreshTokenInvalid
}
log.Printf("[Auth] Database error getting user for token refresh: %v", err)
return nil, ErrServiceUnavailable
}
// 检查用户状态
if !user.IsActive() {
// 用户被禁用撤销整个Token家族
_ = s.refreshTokenCache.DeleteTokenFamily(ctx, data.FamilyID)
return nil, ErrUserNotActive
}
// 检查TokenVersion密码更改后所有Token失效
if data.TokenVersion != user.TokenVersion {
// TokenVersion不匹配撤销整个Token家族
_ = s.refreshTokenCache.DeleteTokenFamily(ctx, data.FamilyID)
return nil, ErrTokenRevoked
}
// Token轮转立即使旧Token失效
if err := s.refreshTokenCache.DeleteRefreshToken(ctx, tokenHash); err != nil {
log.Printf("[Auth] Failed to delete old refresh token: %v", err)
// 继续处理,不影响主流程
}
// 生成新的Token对保持同一个家族ID
return s.GenerateTokenPair(ctx, user, data.FamilyID)
}
// RevokeRefreshToken 撤销单个Refresh Token
func (s *AuthService) RevokeRefreshToken(ctx context.Context, refreshToken string) error {
if s.refreshTokenCache == nil {
return nil // No-op if cache not configured
}
if !strings.HasPrefix(refreshToken, refreshTokenPrefix) {
return ErrRefreshTokenInvalid
}
tokenHash := hashToken(refreshToken)
return s.refreshTokenCache.DeleteRefreshToken(ctx, tokenHash)
}
// RevokeAllUserSessions 撤销用户的所有会话所有Refresh Token
// 用于密码更改或用户主动登出所有设备
func (s *AuthService) RevokeAllUserSessions(ctx context.Context, userID int64) error {
if s.refreshTokenCache == nil {
return nil // No-op if cache not configured
}
return s.refreshTokenCache.DeleteUserRefreshTokens(ctx, userID)
}
// hashToken 计算Token的SHA256哈希
func hashToken(token string) string {
hash := sha256.Sum256([]byte(token))
return hex.EncodeToString(hash[:])
}

View File

@@ -116,6 +116,7 @@ func newAuthService(repo *userRepoStub, settings map[string]string, emailCache E
return NewAuthService( return NewAuthService(
repo, repo,
nil, // redeemRepo nil, // redeemRepo
nil, // refreshTokenCache
cfg, cfg,
settingService, settingService,
emailService, emailService,

View File

@@ -0,0 +1,300 @@
package service
import (
"context"
"log"
"sort"
"strings"
"sync"
"github.com/Wei-Shaw/sub2api/internal/model"
)
// ErrorPassthroughRepository 定义错误透传规则的数据访问接口
type ErrorPassthroughRepository interface {
// List 获取所有规则
List(ctx context.Context) ([]*model.ErrorPassthroughRule, error)
// GetByID 根据 ID 获取规则
GetByID(ctx context.Context, id int64) (*model.ErrorPassthroughRule, error)
// Create 创建规则
Create(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error)
// Update 更新规则
Update(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error)
// Delete 删除规则
Delete(ctx context.Context, id int64) error
}
// ErrorPassthroughCache 定义错误透传规则的缓存接口
type ErrorPassthroughCache interface {
// Get 从缓存获取规则列表
Get(ctx context.Context) ([]*model.ErrorPassthroughRule, bool)
// Set 设置缓存
Set(ctx context.Context, rules []*model.ErrorPassthroughRule) error
// Invalidate 使缓存失效
Invalidate(ctx context.Context) error
// NotifyUpdate 通知其他实例刷新缓存
NotifyUpdate(ctx context.Context) error
// SubscribeUpdates 订阅缓存更新通知
SubscribeUpdates(ctx context.Context, handler func())
}
// ErrorPassthroughService 错误透传规则服务
type ErrorPassthroughService struct {
repo ErrorPassthroughRepository
cache ErrorPassthroughCache
// 本地内存缓存,用于快速匹配
localCache []*model.ErrorPassthroughRule
localCacheMu sync.RWMutex
}
// NewErrorPassthroughService 创建错误透传规则服务
func NewErrorPassthroughService(
repo ErrorPassthroughRepository,
cache ErrorPassthroughCache,
) *ErrorPassthroughService {
svc := &ErrorPassthroughService{
repo: repo,
cache: cache,
}
// 启动时加载规则到本地缓存
ctx := context.Background()
if err := svc.refreshLocalCache(ctx); err != nil {
log.Printf("[ErrorPassthroughService] Failed to load rules on startup: %v", err)
}
// 订阅缓存更新通知
if cache != nil {
cache.SubscribeUpdates(ctx, func() {
if err := svc.refreshLocalCache(context.Background()); err != nil {
log.Printf("[ErrorPassthroughService] Failed to refresh cache on notification: %v", err)
}
})
}
return svc
}
// List 获取所有规则
func (s *ErrorPassthroughService) List(ctx context.Context) ([]*model.ErrorPassthroughRule, error) {
return s.repo.List(ctx)
}
// GetByID 根据 ID 获取规则
func (s *ErrorPassthroughService) GetByID(ctx context.Context, id int64) (*model.ErrorPassthroughRule, error) {
return s.repo.GetByID(ctx, id)
}
// Create 创建规则
func (s *ErrorPassthroughService) Create(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) {
if err := rule.Validate(); err != nil {
return nil, err
}
created, err := s.repo.Create(ctx, rule)
if err != nil {
return nil, err
}
// 刷新缓存
s.invalidateAndNotify(ctx)
return created, nil
}
// Update 更新规则
func (s *ErrorPassthroughService) Update(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) {
if err := rule.Validate(); err != nil {
return nil, err
}
updated, err := s.repo.Update(ctx, rule)
if err != nil {
return nil, err
}
// 刷新缓存
s.invalidateAndNotify(ctx)
return updated, nil
}
// Delete 删除规则
func (s *ErrorPassthroughService) Delete(ctx context.Context, id int64) error {
if err := s.repo.Delete(ctx, id); err != nil {
return err
}
// 刷新缓存
s.invalidateAndNotify(ctx)
return nil
}
// MatchRule 匹配透传规则
// 返回第一个匹配的规则,如果没有匹配则返回 nil
func (s *ErrorPassthroughService) MatchRule(platform string, statusCode int, body []byte) *model.ErrorPassthroughRule {
rules := s.getCachedRules()
if len(rules) == 0 {
return nil
}
bodyStr := strings.ToLower(string(body))
for _, rule := range rules {
if !rule.Enabled {
continue
}
if !s.platformMatches(rule, platform) {
continue
}
if s.ruleMatches(rule, statusCode, bodyStr) {
return rule
}
}
return nil
}
// getCachedRules 获取缓存的规则列表(按优先级排序)
func (s *ErrorPassthroughService) getCachedRules() []*model.ErrorPassthroughRule {
s.localCacheMu.RLock()
rules := s.localCache
s.localCacheMu.RUnlock()
if rules != nil {
return rules
}
// 如果本地缓存为空,尝试刷新
ctx := context.Background()
if err := s.refreshLocalCache(ctx); err != nil {
log.Printf("[ErrorPassthroughService] Failed to refresh cache: %v", err)
return nil
}
s.localCacheMu.RLock()
defer s.localCacheMu.RUnlock()
return s.localCache
}
// refreshLocalCache 刷新本地缓存
func (s *ErrorPassthroughService) refreshLocalCache(ctx context.Context) error {
// 先尝试从 Redis 缓存获取
if s.cache != nil {
if rules, ok := s.cache.Get(ctx); ok {
s.setLocalCache(rules)
return nil
}
}
// 从数据库加载repo.List 已按 priority 排序)
rules, err := s.repo.List(ctx)
if err != nil {
return err
}
// 更新 Redis 缓存
if s.cache != nil {
if err := s.cache.Set(ctx, rules); err != nil {
log.Printf("[ErrorPassthroughService] Failed to set cache: %v", err)
}
}
// 更新本地缓存setLocalCache 内部会确保排序)
s.setLocalCache(rules)
return nil
}
// setLocalCache 设置本地缓存
func (s *ErrorPassthroughService) setLocalCache(rules []*model.ErrorPassthroughRule) {
// 按优先级排序
sorted := make([]*model.ErrorPassthroughRule, len(rules))
copy(sorted, rules)
sort.Slice(sorted, func(i, j int) bool {
return sorted[i].Priority < sorted[j].Priority
})
s.localCacheMu.Lock()
s.localCache = sorted
s.localCacheMu.Unlock()
}
// invalidateAndNotify 使缓存失效并通知其他实例
func (s *ErrorPassthroughService) invalidateAndNotify(ctx context.Context) {
// 刷新本地缓存
if err := s.refreshLocalCache(ctx); err != nil {
log.Printf("[ErrorPassthroughService] Failed to refresh local cache: %v", err)
}
// 通知其他实例
if s.cache != nil {
if err := s.cache.NotifyUpdate(ctx); err != nil {
log.Printf("[ErrorPassthroughService] Failed to notify cache update: %v", err)
}
}
}
// platformMatches 检查平台是否匹配
func (s *ErrorPassthroughService) platformMatches(rule *model.ErrorPassthroughRule, platform string) bool {
// 如果没有配置平台限制,则匹配所有平台
if len(rule.Platforms) == 0 {
return true
}
platform = strings.ToLower(platform)
for _, p := range rule.Platforms {
if strings.ToLower(p) == platform {
return true
}
}
return false
}
// ruleMatches 检查规则是否匹配
func (s *ErrorPassthroughService) ruleMatches(rule *model.ErrorPassthroughRule, statusCode int, bodyLower string) bool {
hasErrorCodes := len(rule.ErrorCodes) > 0
hasKeywords := len(rule.Keywords) > 0
// 如果没有配置任何条件,不匹配
if !hasErrorCodes && !hasKeywords {
return false
}
codeMatch := !hasErrorCodes || s.containsInt(rule.ErrorCodes, statusCode)
keywordMatch := !hasKeywords || s.containsAnyKeyword(bodyLower, rule.Keywords)
if rule.MatchMode == model.MatchModeAll {
// "all" 模式:所有配置的条件都必须满足
return codeMatch && keywordMatch
}
// "any" 模式:任一条件满足即可
if hasErrorCodes && hasKeywords {
return codeMatch || keywordMatch
}
return codeMatch && keywordMatch
}
// containsInt 检查切片是否包含指定整数
func (s *ErrorPassthroughService) containsInt(slice []int, val int) bool {
for _, v := range slice {
if v == val {
return true
}
}
return false
}
// containsAnyKeyword 检查字符串是否包含任一关键词(不区分大小写)
func (s *ErrorPassthroughService) containsAnyKeyword(bodyLower string, keywords []string) bool {
for _, kw := range keywords {
if strings.Contains(bodyLower, strings.ToLower(kw)) {
return true
}
}
return false
}

View File

@@ -0,0 +1,755 @@
//go:build unit
package service
import (
"context"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// mockErrorPassthroughRepo 用于测试的 mock repository
type mockErrorPassthroughRepo struct {
rules []*model.ErrorPassthroughRule
}
func (m *mockErrorPassthroughRepo) List(ctx context.Context) ([]*model.ErrorPassthroughRule, error) {
return m.rules, nil
}
func (m *mockErrorPassthroughRepo) GetByID(ctx context.Context, id int64) (*model.ErrorPassthroughRule, error) {
for _, r := range m.rules {
if r.ID == id {
return r, nil
}
}
return nil, nil
}
func (m *mockErrorPassthroughRepo) Create(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) {
rule.ID = int64(len(m.rules) + 1)
m.rules = append(m.rules, rule)
return rule, nil
}
func (m *mockErrorPassthroughRepo) Update(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) {
for i, r := range m.rules {
if r.ID == rule.ID {
m.rules[i] = rule
return rule, nil
}
}
return rule, nil
}
func (m *mockErrorPassthroughRepo) Delete(ctx context.Context, id int64) error {
for i, r := range m.rules {
if r.ID == id {
m.rules = append(m.rules[:i], m.rules[i+1:]...)
return nil
}
}
return nil
}
// newTestService 创建测试用的服务实例
func newTestService(rules []*model.ErrorPassthroughRule) *ErrorPassthroughService {
repo := &mockErrorPassthroughRepo{rules: rules}
svc := &ErrorPassthroughService{
repo: repo,
cache: nil, // 不使用缓存
}
// 直接设置本地缓存,避免调用 refreshLocalCache
svc.setLocalCache(rules)
return svc
}
// =============================================================================
// 测试 ruleMatches 核心匹配逻辑
// =============================================================================
func TestRuleMatches_NoConditions(t *testing.T) {
// 没有配置任何条件时,不应该匹配
svc := newTestService(nil)
rule := &model.ErrorPassthroughRule{
Enabled: true,
ErrorCodes: []int{},
Keywords: []string{},
MatchMode: model.MatchModeAny,
}
assert.False(t, svc.ruleMatches(rule, 422, "some error message"),
"没有配置条件时不应该匹配")
}
func TestRuleMatches_OnlyErrorCodes_AnyMode(t *testing.T) {
svc := newTestService(nil)
rule := &model.ErrorPassthroughRule{
Enabled: true,
ErrorCodes: []int{422, 400},
Keywords: []string{},
MatchMode: model.MatchModeAny,
}
tests := []struct {
name string
statusCode int
body string
expected bool
}{
{"状态码匹配 422", 422, "any message", true},
{"状态码匹配 400", 400, "any message", true},
{"状态码不匹配 500", 500, "any message", false},
{"状态码不匹配 429", 429, "any message", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := svc.ruleMatches(rule, tt.statusCode, tt.body)
assert.Equal(t, tt.expected, result)
})
}
}
func TestRuleMatches_OnlyKeywords_AnyMode(t *testing.T) {
svc := newTestService(nil)
rule := &model.ErrorPassthroughRule{
Enabled: true,
ErrorCodes: []int{},
Keywords: []string{"context limit", "model not supported"},
MatchMode: model.MatchModeAny,
}
tests := []struct {
name string
statusCode int
body string
expected bool
}{
{"关键词匹配 context limit", 500, "error: context limit reached", true},
{"关键词匹配 model not supported", 400, "the model not supported here", true},
{"关键词不匹配", 422, "some other error", false},
// 注意ruleMatches 接收的 body 参数应该是已经转换为小写的
// 实际使用时MatchRule 会先将 body 转换为小写再传给 ruleMatches
{"关键词大小写 - 输入已小写", 500, "context limit exceeded", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 模拟 MatchRule 的行为:先转换为小写
bodyLower := strings.ToLower(tt.body)
result := svc.ruleMatches(rule, tt.statusCode, bodyLower)
assert.Equal(t, tt.expected, result)
})
}
}
func TestRuleMatches_BothConditions_AnyMode(t *testing.T) {
// any 模式:错误码 OR 关键词
svc := newTestService(nil)
rule := &model.ErrorPassthroughRule{
Enabled: true,
ErrorCodes: []int{422, 400},
Keywords: []string{"context limit"},
MatchMode: model.MatchModeAny,
}
tests := []struct {
name string
statusCode int
body string
expected bool
reason string
}{
{
name: "状态码和关键词都匹配",
statusCode: 422,
body: "context limit reached",
expected: true,
reason: "both match",
},
{
name: "只有状态码匹配",
statusCode: 422,
body: "some other error",
expected: true,
reason: "code matches, keyword doesn't - OR mode should match",
},
{
name: "只有关键词匹配",
statusCode: 500,
body: "context limit exceeded",
expected: true,
reason: "keyword matches, code doesn't - OR mode should match",
},
{
name: "都不匹配",
statusCode: 500,
body: "some other error",
expected: false,
reason: "neither matches",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := svc.ruleMatches(rule, tt.statusCode, tt.body)
assert.Equal(t, tt.expected, result, tt.reason)
})
}
}
func TestRuleMatches_BothConditions_AllMode(t *testing.T) {
// all 模式:错误码 AND 关键词
svc := newTestService(nil)
rule := &model.ErrorPassthroughRule{
Enabled: true,
ErrorCodes: []int{422, 400},
Keywords: []string{"context limit"},
MatchMode: model.MatchModeAll,
}
tests := []struct {
name string
statusCode int
body string
expected bool
reason string
}{
{
name: "状态码和关键词都匹配",
statusCode: 422,
body: "context limit reached",
expected: true,
reason: "both match - AND mode should match",
},
{
name: "只有状态码匹配",
statusCode: 422,
body: "some other error",
expected: false,
reason: "code matches but keyword doesn't - AND mode should NOT match",
},
{
name: "只有关键词匹配",
statusCode: 500,
body: "context limit exceeded",
expected: false,
reason: "keyword matches but code doesn't - AND mode should NOT match",
},
{
name: "都不匹配",
statusCode: 500,
body: "some other error",
expected: false,
reason: "neither matches",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := svc.ruleMatches(rule, tt.statusCode, tt.body)
assert.Equal(t, tt.expected, result, tt.reason)
})
}
}
// =============================================================================
// 测试 platformMatches 平台匹配逻辑
// =============================================================================
func TestPlatformMatches(t *testing.T) {
svc := newTestService(nil)
tests := []struct {
name string
rulePlatforms []string
requestPlatform string
expected bool
}{
{
name: "空平台列表匹配所有",
rulePlatforms: []string{},
requestPlatform: "anthropic",
expected: true,
},
{
name: "nil平台列表匹配所有",
rulePlatforms: nil,
requestPlatform: "openai",
expected: true,
},
{
name: "精确匹配 anthropic",
rulePlatforms: []string{"anthropic", "openai"},
requestPlatform: "anthropic",
expected: true,
},
{
name: "精确匹配 openai",
rulePlatforms: []string{"anthropic", "openai"},
requestPlatform: "openai",
expected: true,
},
{
name: "不匹配 gemini",
rulePlatforms: []string{"anthropic", "openai"},
requestPlatform: "gemini",
expected: false,
},
{
name: "大小写不敏感",
rulePlatforms: []string{"Anthropic", "OpenAI"},
requestPlatform: "anthropic",
expected: true,
},
{
name: "匹配 antigravity",
rulePlatforms: []string{"antigravity"},
requestPlatform: "antigravity",
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rule := &model.ErrorPassthroughRule{
Platforms: tt.rulePlatforms,
}
result := svc.platformMatches(rule, tt.requestPlatform)
assert.Equal(t, tt.expected, result)
})
}
}
// =============================================================================
// 测试 MatchRule 完整匹配流程
// =============================================================================
func TestMatchRule_Priority(t *testing.T) {
// 测试规则按优先级排序,优先级小的先匹配
rules := []*model.ErrorPassthroughRule{
{
ID: 1,
Name: "Low Priority",
Enabled: true,
Priority: 10,
ErrorCodes: []int{422},
MatchMode: model.MatchModeAny,
},
{
ID: 2,
Name: "High Priority",
Enabled: true,
Priority: 1,
ErrorCodes: []int{422},
MatchMode: model.MatchModeAny,
},
}
svc := newTestService(rules)
matched := svc.MatchRule("anthropic", 422, []byte("error"))
require.NotNil(t, matched)
assert.Equal(t, int64(2), matched.ID, "应该匹配优先级更高(数值更小)的规则")
assert.Equal(t, "High Priority", matched.Name)
}
func TestMatchRule_DisabledRule(t *testing.T) {
rules := []*model.ErrorPassthroughRule{
{
ID: 1,
Name: "Disabled Rule",
Enabled: false,
Priority: 1,
ErrorCodes: []int{422},
MatchMode: model.MatchModeAny,
},
{
ID: 2,
Name: "Enabled Rule",
Enabled: true,
Priority: 10,
ErrorCodes: []int{422},
MatchMode: model.MatchModeAny,
},
}
svc := newTestService(rules)
matched := svc.MatchRule("anthropic", 422, []byte("error"))
require.NotNil(t, matched)
assert.Equal(t, int64(2), matched.ID, "应该跳过禁用的规则")
}
func TestMatchRule_PlatformFilter(t *testing.T) {
rules := []*model.ErrorPassthroughRule{
{
ID: 1,
Name: "Anthropic Only",
Enabled: true,
Priority: 1,
ErrorCodes: []int{422},
Platforms: []string{"anthropic"},
MatchMode: model.MatchModeAny,
},
{
ID: 2,
Name: "OpenAI Only",
Enabled: true,
Priority: 2,
ErrorCodes: []int{422},
Platforms: []string{"openai"},
MatchMode: model.MatchModeAny,
},
{
ID: 3,
Name: "All Platforms",
Enabled: true,
Priority: 3,
ErrorCodes: []int{422},
Platforms: []string{},
MatchMode: model.MatchModeAny,
},
}
svc := newTestService(rules)
t.Run("Anthropic 请求匹配 Anthropic 规则", func(t *testing.T) {
matched := svc.MatchRule("anthropic", 422, []byte("error"))
require.NotNil(t, matched)
assert.Equal(t, int64(1), matched.ID)
})
t.Run("OpenAI 请求匹配 OpenAI 规则", func(t *testing.T) {
matched := svc.MatchRule("openai", 422, []byte("error"))
require.NotNil(t, matched)
assert.Equal(t, int64(2), matched.ID)
})
t.Run("Gemini 请求匹配全平台规则", func(t *testing.T) {
matched := svc.MatchRule("gemini", 422, []byte("error"))
require.NotNil(t, matched)
assert.Equal(t, int64(3), matched.ID)
})
t.Run("Antigravity 请求匹配全平台规则", func(t *testing.T) {
matched := svc.MatchRule("antigravity", 422, []byte("error"))
require.NotNil(t, matched)
assert.Equal(t, int64(3), matched.ID)
})
}
func TestMatchRule_NoMatch(t *testing.T) {
rules := []*model.ErrorPassthroughRule{
{
ID: 1,
Name: "Rule for 422",
Enabled: true,
Priority: 1,
ErrorCodes: []int{422},
MatchMode: model.MatchModeAny,
},
}
svc := newTestService(rules)
matched := svc.MatchRule("anthropic", 500, []byte("error"))
assert.Nil(t, matched, "不匹配任何规则时应返回 nil")
}
func TestMatchRule_EmptyRules(t *testing.T) {
svc := newTestService([]*model.ErrorPassthroughRule{})
matched := svc.MatchRule("anthropic", 422, []byte("error"))
assert.Nil(t, matched, "没有规则时应返回 nil")
}
func TestMatchRule_CaseInsensitiveKeyword(t *testing.T) {
rules := []*model.ErrorPassthroughRule{
{
ID: 1,
Name: "Context Limit",
Enabled: true,
Priority: 1,
Keywords: []string{"Context Limit"},
MatchMode: model.MatchModeAny,
},
}
svc := newTestService(rules)
tests := []struct {
name string
body string
expected bool
}{
{"完全匹配", "Context Limit reached", true},
{"小写匹配", "context limit reached", true},
{"大写匹配", "CONTEXT LIMIT REACHED", true},
{"混合大小写", "ConTeXt LiMiT error", true},
{"不匹配", "some other error", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
matched := svc.MatchRule("anthropic", 500, []byte(tt.body))
if tt.expected {
assert.NotNil(t, matched)
} else {
assert.Nil(t, matched)
}
})
}
}
// =============================================================================
// 测试真实场景
// =============================================================================
func TestMatchRule_RealWorldScenario_ContextLimitPassthrough(t *testing.T) {
// 场景:上游返回 422 + "context limit has been reached",需要透传给客户端
rules := []*model.ErrorPassthroughRule{
{
ID: 1,
Name: "Context Limit Passthrough",
Enabled: true,
Priority: 1,
ErrorCodes: []int{422},
Keywords: []string{"context limit"},
MatchMode: model.MatchModeAll, // 必须同时满足
Platforms: []string{"anthropic", "antigravity"},
PassthroughCode: true,
PassthroughBody: true,
},
}
svc := newTestService(rules)
// 测试 Anthropic 平台
t.Run("Anthropic 422 with context limit", func(t *testing.T) {
body := []byte(`{"type":"error","error":{"type":"invalid_request","message":"The context limit has been reached"}}`)
matched := svc.MatchRule("anthropic", 422, body)
require.NotNil(t, matched)
assert.True(t, matched.PassthroughCode)
assert.True(t, matched.PassthroughBody)
})
// 测试 Antigravity 平台
t.Run("Antigravity 422 with context limit", func(t *testing.T) {
body := []byte(`{"error":"context limit exceeded"}`)
matched := svc.MatchRule("antigravity", 422, body)
require.NotNil(t, matched)
})
// 测试 OpenAI 平台(不在规则的平台列表中)
t.Run("OpenAI should not match", func(t *testing.T) {
body := []byte(`{"error":"context limit exceeded"}`)
matched := svc.MatchRule("openai", 422, body)
assert.Nil(t, matched, "OpenAI 不在规则的平台列表中")
})
// 测试状态码不匹配
t.Run("Wrong status code", func(t *testing.T) {
body := []byte(`{"error":"context limit exceeded"}`)
matched := svc.MatchRule("anthropic", 400, body)
assert.Nil(t, matched, "状态码不匹配")
})
// 测试关键词不匹配
t.Run("Wrong keyword", func(t *testing.T) {
body := []byte(`{"error":"rate limit exceeded"}`)
matched := svc.MatchRule("anthropic", 422, body)
assert.Nil(t, matched, "关键词不匹配")
})
}
func TestMatchRule_RealWorldScenario_CustomErrorMessage(t *testing.T) {
// 场景:某些错误需要返回自定义消息,隐藏上游详细信息
customMsg := "Service temporarily unavailable, please try again later"
responseCode := 503
rules := []*model.ErrorPassthroughRule{
{
ID: 1,
Name: "Hide Internal Errors",
Enabled: true,
Priority: 1,
ErrorCodes: []int{500, 502, 503},
MatchMode: model.MatchModeAny,
PassthroughCode: false,
ResponseCode: &responseCode,
PassthroughBody: false,
CustomMessage: &customMsg,
},
}
svc := newTestService(rules)
matched := svc.MatchRule("anthropic", 500, []byte("internal server error"))
require.NotNil(t, matched)
assert.False(t, matched.PassthroughCode)
assert.Equal(t, 503, *matched.ResponseCode)
assert.False(t, matched.PassthroughBody)
assert.Equal(t, customMsg, *matched.CustomMessage)
}
// =============================================================================
// 测试 model.Validate
// =============================================================================
func TestErrorPassthroughRule_Validate(t *testing.T) {
tests := []struct {
name string
rule *model.ErrorPassthroughRule
expectError bool
errorField string
}{
{
name: "有效规则 - 透传模式(含错误码)",
rule: &model.ErrorPassthroughRule{
Name: "Valid Rule",
MatchMode: model.MatchModeAny,
ErrorCodes: []int{422},
PassthroughCode: true,
PassthroughBody: true,
},
expectError: false,
},
{
name: "有效规则 - 透传模式(含关键词)",
rule: &model.ErrorPassthroughRule{
Name: "Valid Rule",
MatchMode: model.MatchModeAny,
Keywords: []string{"context limit"},
PassthroughCode: true,
PassthroughBody: true,
},
expectError: false,
},
{
name: "有效规则 - 自定义响应",
rule: &model.ErrorPassthroughRule{
Name: "Valid Rule",
MatchMode: model.MatchModeAll,
ErrorCodes: []int{500},
Keywords: []string{"internal error"},
PassthroughCode: false,
ResponseCode: testIntPtr(503),
PassthroughBody: false,
CustomMessage: testStrPtr("Custom error"),
},
expectError: false,
},
{
name: "缺少名称",
rule: &model.ErrorPassthroughRule{
Name: "",
MatchMode: model.MatchModeAny,
ErrorCodes: []int{422},
PassthroughCode: true,
PassthroughBody: true,
},
expectError: true,
errorField: "name",
},
{
name: "无效的匹配模式",
rule: &model.ErrorPassthroughRule{
Name: "Invalid Mode",
MatchMode: "invalid",
ErrorCodes: []int{422},
PassthroughCode: true,
PassthroughBody: true,
},
expectError: true,
errorField: "match_mode",
},
{
name: "缺少匹配条件(错误码和关键词都为空)",
rule: &model.ErrorPassthroughRule{
Name: "No Conditions",
MatchMode: model.MatchModeAny,
ErrorCodes: []int{},
Keywords: []string{},
PassthroughCode: true,
PassthroughBody: true,
},
expectError: true,
errorField: "conditions",
},
{
name: "缺少匹配条件nil切片",
rule: &model.ErrorPassthroughRule{
Name: "Nil Conditions",
MatchMode: model.MatchModeAny,
ErrorCodes: nil,
Keywords: nil,
PassthroughCode: true,
PassthroughBody: true,
},
expectError: true,
errorField: "conditions",
},
{
name: "自定义状态码但未提供值",
rule: &model.ErrorPassthroughRule{
Name: "Missing Code",
MatchMode: model.MatchModeAny,
ErrorCodes: []int{422},
PassthroughCode: false,
ResponseCode: nil,
PassthroughBody: true,
},
expectError: true,
errorField: "response_code",
},
{
name: "自定义消息但未提供值",
rule: &model.ErrorPassthroughRule{
Name: "Missing Message",
MatchMode: model.MatchModeAny,
ErrorCodes: []int{422},
PassthroughCode: true,
PassthroughBody: false,
CustomMessage: nil,
},
expectError: true,
errorField: "custom_message",
},
{
name: "自定义消息为空字符串",
rule: &model.ErrorPassthroughRule{
Name: "Empty Message",
MatchMode: model.MatchModeAny,
ErrorCodes: []int{422},
PassthroughCode: true,
PassthroughBody: false,
CustomMessage: testStrPtr(""),
},
expectError: true,
errorField: "custom_message",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.rule.Validate()
if tt.expectError {
require.Error(t, err)
validationErr, ok := err.(*model.ValidationError)
require.True(t, ok, "应该返回 ValidationError")
assert.Equal(t, tt.errorField, validationErr.Field)
} else {
assert.NoError(t, err)
}
})
}
}
// Helper functions
func testIntPtr(i int) *int { return &i }
func testStrPtr(s string) *string { return &s }

View File

@@ -20,7 +20,6 @@ import (
"strings" "strings"
"sync/atomic" "sync/atomic"
"time" "time"
"unicode"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/claude"
@@ -375,7 +374,8 @@ type ForwardResult struct {
// UpstreamFailoverError indicates an upstream error that should trigger account failover. // UpstreamFailoverError indicates an upstream error that should trigger account failover.
type UpstreamFailoverError struct { type UpstreamFailoverError struct {
StatusCode int StatusCode int
ResponseBody []byte // 上游响应体,用于错误透传规则匹配
} }
func (e *UpstreamFailoverError) Error() string { func (e *UpstreamFailoverError) Error() string {
@@ -389,6 +389,7 @@ type GatewayService struct {
usageLogRepo UsageLogRepository usageLogRepo UsageLogRepository
userRepo UserRepository userRepo UserRepository
userSubRepo UserSubscriptionRepository userSubRepo UserSubscriptionRepository
userGroupRateRepo UserGroupRateRepository
cache GatewayCache cache GatewayCache
cfg *config.Config cfg *config.Config
schedulerSnapshot *SchedulerSnapshotService schedulerSnapshot *SchedulerSnapshotService
@@ -410,6 +411,7 @@ func NewGatewayService(
usageLogRepo UsageLogRepository, usageLogRepo UsageLogRepository,
userRepo UserRepository, userRepo UserRepository,
userSubRepo UserSubscriptionRepository, userSubRepo UserSubscriptionRepository,
userGroupRateRepo UserGroupRateRepository,
cache GatewayCache, cache GatewayCache,
cfg *config.Config, cfg *config.Config,
schedulerSnapshot *SchedulerSnapshotService, schedulerSnapshot *SchedulerSnapshotService,
@@ -429,6 +431,7 @@ func NewGatewayService(
usageLogRepo: usageLogRepo, usageLogRepo: usageLogRepo,
userRepo: userRepo, userRepo: userRepo,
userSubRepo: userSubRepo, userSubRepo: userSubRepo,
userGroupRateRepo: userGroupRateRepo,
cache: cache, cache: cache,
cfg: cfg, cfg: cfg,
schedulerSnapshot: schedulerSnapshot, schedulerSnapshot: schedulerSnapshot,
@@ -624,35 +627,6 @@ func stripToolPrefix(value string) string {
return toolPrefixRe.ReplaceAllString(value, "") return toolPrefixRe.ReplaceAllString(value, "")
} }
func toPascalCase(value string) string {
if value == "" {
return value
}
normalized := toolNameBoundaryRe.ReplaceAllString(value, " ")
tokens := make([]string, 0)
for _, token := range strings.Fields(normalized) {
expanded := toolNameCamelRe.ReplaceAllString(token, "$1 $2")
parts := strings.Fields(expanded)
if len(parts) > 0 {
tokens = append(tokens, parts...)
}
}
if len(tokens) == 0 {
return value
}
var builder strings.Builder
for _, token := range tokens {
lower := strings.ToLower(token)
if lower == "" {
continue
}
runes := []rune(lower)
runes[0] = unicode.ToUpper(runes[0])
_, _ = builder.WriteString(string(runes))
}
return builder.String()
}
func toSnakeCase(value string) string { func toSnakeCase(value string) string {
if value == "" { if value == "" {
return value return value
@@ -668,16 +642,15 @@ func normalizeToolNameForClaude(name string, cache map[string]string) string {
return name return name
} }
stripped := stripToolPrefix(name) stripped := stripToolPrefix(name)
// 只对已知的工具名进行映射,未知工具名保持原样
// 避免破坏 Anthropic 特殊工具(如 text_editor_20250728
mapped, ok := claudeToolNameOverrides[strings.ToLower(stripped)] mapped, ok := claudeToolNameOverrides[strings.ToLower(stripped)]
if !ok { if !ok {
mapped = toPascalCase(stripped)
}
if mapped != "" && cache != nil && mapped != stripped {
cache[mapped] = stripped
}
if mapped == "" {
return stripped return stripped
} }
if cache != nil && mapped != stripped {
cache[mapped] = stripped
}
return mapped return mapped
} }
@@ -686,15 +659,18 @@ func normalizeToolNameForOpenCode(name string, cache map[string]string) string {
return name return name
} }
stripped := stripToolPrefix(name) stripped := stripToolPrefix(name)
// 优先从请求时建立的映射中查找
if cache != nil { if cache != nil {
if mapped, ok := cache[stripped]; ok { if mapped, ok := cache[stripped]; ok {
return mapped return mapped
} }
} }
// 已知工具名的硬编码映射
if mapped, ok := openCodeToolOverrides[stripped]; ok { if mapped, ok := openCodeToolOverrides[stripped]; ok {
return mapped return mapped
} }
return toSnakeCase(stripped) // 未知工具名保持原样,避免破坏 Anthropic 特殊工具
return stripped
} }
func normalizeParamNameForOpenCode(name string, cache map[string]string) string { func normalizeParamNameForOpenCode(name string, cache map[string]string) string {
@@ -3313,7 +3289,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
return "" return ""
}(), }(),
}) })
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
} }
return s.handleRetryExhaustedError(ctx, resp, c, account) return s.handleRetryExhaustedError(ctx, resp, c, account)
} }
@@ -3343,10 +3319,8 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
return "" return ""
}(), }(),
}) })
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
} }
// 处理错误响应(不可重试的错误)
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
// 可选:对部分 400 触发 failover默认关闭以保持语义 // 可选:对部分 400 触发 failover默认关闭以保持语义
if resp.StatusCode == 400 && s.cfg != nil && s.cfg.Gateway.FailoverOn400 { if resp.StatusCode == 400 && s.cfg != nil && s.cfg.Gateway.FailoverOn400 {
@@ -3390,7 +3364,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
log.Printf("Account %d: 400 error, attempting failover", account.ID) log.Printf("Account %d: 400 error, attempting failover", account.ID)
} }
s.handleFailoverSideEffects(ctx, resp, account) s.handleFailoverSideEffects(ctx, resp, account)
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
} }
} }
return s.handleErrorResponse(ctx, resp, c, account) return s.handleErrorResponse(ctx, resp, c, account)
@@ -3787,6 +3761,12 @@ func (s *GatewayService) shouldFailoverOn400(respBody []byte) bool {
return false return false
} }
// ExtractUpstreamErrorMessage 从上游响应体中提取错误消息
// 支持 Claude 风格的错误格式:{"type":"error","error":{"type":"...","message":"..."}}
func ExtractUpstreamErrorMessage(body []byte) string {
return extractUpstreamErrorMessage(body)
}
func extractUpstreamErrorMessage(body []byte) string { func extractUpstreamErrorMessage(body []byte) string {
// Claude 风格:{"type":"error","error":{"type":"...","message":"..."}} // Claude 风格:{"type":"error","error":{"type":"...","message":"..."}}
if m := gjson.GetBytes(body, "error.message").String(); strings.TrimSpace(m) != "" { if m := gjson.GetBytes(body, "error.message").String(); strings.TrimSpace(m) != "" {
@@ -3854,7 +3834,7 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res
shouldDisable = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) shouldDisable = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
} }
if shouldDisable { if shouldDisable {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: body}
} }
// 记录上游错误响应体摘要便于排障(可选:由配置控制;不回显到客户端) // 记录上游错误响应体摘要便于排障(可选:由配置控制;不回显到客户端)
@@ -4641,10 +4621,17 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
account := input.Account account := input.Account
subscription := input.Subscription subscription := input.Subscription
// 获取费率倍数 // 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认)
multiplier := s.cfg.Default.RateMultiplier multiplier := s.cfg.Default.RateMultiplier
if apiKey.GroupID != nil && apiKey.Group != nil { if apiKey.GroupID != nil && apiKey.Group != nil {
multiplier = apiKey.Group.RateMultiplier multiplier = apiKey.Group.RateMultiplier
// 检查用户专属倍率
if s.userGroupRateRepo != nil {
if userRate, err := s.userGroupRateRepo.GetByUserAndGroup(ctx, user.ID, *apiKey.GroupID); err == nil && userRate != nil {
multiplier = *userRate
}
}
} }
var cost *CostBreakdown var cost *CostBreakdown
@@ -4827,10 +4814,17 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
account := input.Account account := input.Account
subscription := input.Subscription subscription := input.Subscription
// 获取费率倍数 // 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认)
multiplier := s.cfg.Default.RateMultiplier multiplier := s.cfg.Default.RateMultiplier
if apiKey.GroupID != nil && apiKey.Group != nil { if apiKey.GroupID != nil && apiKey.Group != nil {
multiplier = apiKey.Group.RateMultiplier multiplier = apiKey.Group.RateMultiplier
// 检查用户专属倍率
if s.userGroupRateRepo != nil {
if userRate, err := s.userGroupRateRepo.GetByUserAndGroup(ctx, user.ID, *apiKey.GroupID); err == nil && userRate != nil {
multiplier = *userRate
}
}
} }
var cost *CostBreakdown var cost *CostBreakdown

View File

@@ -864,7 +864,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
Message: upstreamMsg, Message: upstreamMsg,
Detail: upstreamDetail, Detail: upstreamDetail,
}) })
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
} }
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) { if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
upstreamReqID := resp.Header.Get(requestIDHeader) upstreamReqID := resp.Header.Get(requestIDHeader)
@@ -891,7 +891,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
Message: upstreamMsg, Message: upstreamMsg,
Detail: upstreamDetail, Detail: upstreamDetail,
}) })
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
} }
upstreamReqID := resp.Header.Get(requestIDHeader) upstreamReqID := resp.Header.Get(requestIDHeader)
if upstreamReqID == "" { if upstreamReqID == "" {
@@ -1301,7 +1301,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
Message: upstreamMsg, Message: upstreamMsg,
Detail: upstreamDetail, Detail: upstreamDetail,
}) })
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
} }
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) { if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
evBody := unwrapIfNeeded(isOAuth, respBody) evBody := unwrapIfNeeded(isOAuth, respBody)
@@ -1325,7 +1325,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
Message: upstreamMsg, Message: upstreamMsg,
Detail: upstreamDetail, Detail: upstreamDetail,
}) })
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: evBody}
} }
respBody = unwrapIfNeeded(isOAuth, respBody) respBody = unwrapIfNeeded(isOAuth, respBody)

View File

@@ -944,6 +944,32 @@ func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, pr
return strings.TrimSpace(loadResp.CloudAICompanionProject), tierID, nil return strings.TrimSpace(loadResp.CloudAICompanionProject), tierID, nil
} }
// 关键逻辑:对齐 Gemini CLI 对“已注册用户”的处理方式。
// 当 LoadCodeAssist 返回了 currentTier / paidTier表示账号已注册但没有返回 cloudaicompanionProject 时:
// - 不要再调用 onboardUser通常不会再分配 project_id且可能触发 INVALID_ARGUMENT
// - 先尝试从 Cloud Resource Manager 获取可用项目;仍失败则提示用户手动填写 project_id
if loadResp != nil {
registeredTierID := strings.TrimSpace(loadResp.GetTier())
if registeredTierID != "" {
// 已注册但未返回 cloudaicompanionProject这在 Google One 用户中较常见:需要用户自行提供 project_id。
log.Printf("[GeminiOAuth] User has tier (%s) but no cloudaicompanionProject, trying Cloud Resource Manager...", registeredTierID)
// Try to get project from Cloud Resource Manager
fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL)
if fbErr == nil && strings.TrimSpace(fallback) != "" {
log.Printf("[GeminiOAuth] Found project from Cloud Resource Manager: %s", fallback)
return strings.TrimSpace(fallback), tierID, nil
}
// No project found - user must provide project_id manually
log.Printf("[GeminiOAuth] No project found from Cloud Resource Manager, user must provide project_id manually")
return "", tierID, fmt.Errorf("user is registered (tier: %s) but no project_id available. Please provide Project ID manually in the authorization form, or create a project at https://console.cloud.google.com", registeredTierID)
}
}
// 未检测到 currentTier/paidTier视为新用户继续调用 onboardUser
log.Printf("[GeminiOAuth] No currentTier/paidTier found, proceeding with onboardUser (tierID: %s)", tierID)
req := &geminicli.OnboardUserRequest{ req := &geminicli.OnboardUserRequest{
TierID: tierID, TierID: tierID,
Metadata: geminicli.LoadCodeAssistMetadata{ Metadata: geminicli.LoadCodeAssistMetadata{

View File

@@ -846,10 +846,12 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
} }
} }
// Remove prompt_cache_retention (not supported by upstream OpenAI API) // Remove unsupported fields (not supported by upstream OpenAI API)
if _, has := reqBody["prompt_cache_retention"]; has { for _, unsupportedField := range []string{"prompt_cache_retention", "safety_identifier", "previous_response_id"} {
delete(reqBody, "prompt_cache_retention") if _, has := reqBody[unsupportedField]; has {
bodyModified = true delete(reqBody, unsupportedField)
bodyModified = true
}
} }
} }
@@ -938,7 +940,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
}) })
s.handleFailoverSideEffects(ctx, resp, account) s.handleFailoverSideEffects(ctx, resp, account)
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
} }
return s.handleErrorResponse(ctx, resp, c, account) return s.handleErrorResponse(ctx, resp, c, account)
} }
@@ -1129,7 +1131,7 @@ func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *ht
Detail: upstreamDetail, Detail: upstreamDetail,
}) })
if shouldDisable { if shouldDisable {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: body}
} }
// Return appropriate error response // Return appropriate error response

View File

@@ -0,0 +1,73 @@
package service
import (
"context"
"errors"
"time"
)
// ErrRefreshTokenNotFound is returned when a refresh token is not found in cache.
// This is used to abstract away the underlying cache implementation (e.g., redis.Nil).
var ErrRefreshTokenNotFound = errors.New("refresh token not found")
// RefreshTokenData 存储在Redis中的Refresh Token数据
type RefreshTokenData struct {
UserID int64 `json:"user_id"`
TokenVersion int64 `json:"token_version"` // 用于检测密码更改后的Token失效
FamilyID string `json:"family_id"` // Token家族ID用于防重放攻击
CreatedAt time.Time `json:"created_at"`
ExpiresAt time.Time `json:"expires_at"`
}
// RefreshTokenCache 管理Refresh Token的Redis缓存
// 用于JWT Token刷新机制支持Token轮转和防重放攻击
//
// Key 格式:
// - refresh_token:{token_hash} -> RefreshTokenData (JSON)
// - user_refresh_tokens:{user_id} -> Set<token_hash>
// - token_family:{family_id} -> Set<token_hash>
type RefreshTokenCache interface {
// StoreRefreshToken 存储Refresh Token
// tokenHash: Token的SHA256哈希值不存储原始Token
// data: Token关联的数据
// ttl: Token过期时间
StoreRefreshToken(ctx context.Context, tokenHash string, data *RefreshTokenData, ttl time.Duration) error
// GetRefreshToken 获取Refresh Token数据
// 返回 (data, nil) 如果Token存在
// 返回 (nil, ErrRefreshTokenNotFound) 如果Token不存在
// 返回 (nil, err) 如果发生其他错误
GetRefreshToken(ctx context.Context, tokenHash string) (*RefreshTokenData, error)
// DeleteRefreshToken 删除单个Refresh Token
// 用于Token轮转时使旧Token失效
DeleteRefreshToken(ctx context.Context, tokenHash string) error
// DeleteUserRefreshTokens 删除用户的所有Refresh Token
// 用于密码更改或用户主动登出所有设备
DeleteUserRefreshTokens(ctx context.Context, userID int64) error
// DeleteTokenFamily 删除整个Token家族
// 用于检测到Token重放攻击时撤销整个会话链
DeleteTokenFamily(ctx context.Context, familyID string) error
// AddToUserTokenSet 将Token添加到用户的Token集合
// 用于跟踪用户的所有活跃Refresh Token
AddToUserTokenSet(ctx context.Context, userID int64, tokenHash string, ttl time.Duration) error
// AddToFamilyTokenSet 将Token添加到家族Token集合
// 用于跟踪同一登录会话的所有Token
AddToFamilyTokenSet(ctx context.Context, familyID string, tokenHash string, ttl time.Duration) error
// GetUserTokenHashes 获取用户的所有Token哈希
// 用于批量删除用户Token
GetUserTokenHashes(ctx context.Context, userID int64) ([]string, error)
// GetFamilyTokenHashes 获取家族的所有Token哈希
// 用于批量删除家族Token
GetFamilyTokenHashes(ctx context.Context, familyID string) ([]string, error)
// IsTokenInFamily 检查Token是否属于指定家族
// 用于验证Token家族关系
IsTokenInFamily(ctx context.Context, familyID string, tokenHash string) (bool, error)
}

View File

@@ -288,6 +288,15 @@ func (s *UsageService) GetUserDashboardStats(ctx context.Context, userID int64)
return stats, nil return stats, nil
} }
// GetAPIKeyDashboardStats returns dashboard summary stats filtered by API Key.
func (s *UsageService) GetAPIKeyDashboardStats(ctx context.Context, apiKeyID int64) (*usagestats.UserDashboardStats, error) {
stats, err := s.usageRepo.GetAPIKeyDashboardStats(ctx, apiKeyID)
if err != nil {
return nil, fmt.Errorf("get api key dashboard stats: %w", err)
}
return stats, nil
}
// GetUserUsageTrendByUserID returns per-user usage trend. // GetUserUsageTrendByUserID returns per-user usage trend.
func (s *UsageService) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) ([]usagestats.TrendDataPoint, error) { func (s *UsageService) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) ([]usagestats.TrendDataPoint, error) {
trend, err := s.usageRepo.GetUserUsageTrendByUserID(ctx, userID, startTime, endTime, granularity) trend, err := s.usageRepo.GetUserUsageTrendByUserID(ctx, userID, startTime, endTime, granularity)

View File

@@ -21,6 +21,10 @@ type User struct {
CreatedAt time.Time CreatedAt time.Time
UpdatedAt time.Time UpdatedAt time.Time
// GroupRates 用户专属分组倍率配置
// map[groupID]rateMultiplier
GroupRates map[int64]float64
// TOTP 双因素认证字段 // TOTP 双因素认证字段
TotpSecretEncrypted *string // AES-256-GCM 加密的 TOTP 密钥 TotpSecretEncrypted *string // AES-256-GCM 加密的 TOTP 密钥
TotpEnabled bool // 是否启用 TOTP TotpEnabled bool // 是否启用 TOTP
@@ -40,18 +44,20 @@ func (u *User) IsActive() bool {
// CanBindGroup checks whether a user can bind to a given group. // CanBindGroup checks whether a user can bind to a given group.
// For standard groups: // For standard groups:
// - If AllowedGroups is non-empty, only allow binding to IDs in that list. // - Public groups (non-exclusive): all users can bind
// - If AllowedGroups is empty (nil or length 0), allow binding to any non-exclusive group. // - Exclusive groups: only users with the group in AllowedGroups can bind
func (u *User) CanBindGroup(groupID int64, isExclusive bool) bool { func (u *User) CanBindGroup(groupID int64, isExclusive bool) bool {
if len(u.AllowedGroups) > 0 { // 公开分组(非专属):所有用户都可以绑定
for _, id := range u.AllowedGroups { if !isExclusive {
if id == groupID { return true
return true
}
}
return false
} }
return !isExclusive // 专属分组:需要在 AllowedGroups 中
for _, id := range u.AllowedGroups {
if id == groupID {
return true
}
}
return false
} }
func (u *User) SetPassword(password string) error { func (u *User) SetPassword(password string) error {

View File

@@ -0,0 +1,25 @@
package service
import "context"
// UserGroupRateRepository 用户专属分组倍率仓储接口
// 允许管理员为特定用户设置分组的专属计费倍率,覆盖分组默认倍率
type UserGroupRateRepository interface {
// GetByUserID 获取用户的所有专属分组倍率
// 返回 map[groupID]rateMultiplier
GetByUserID(ctx context.Context, userID int64) (map[int64]float64, error)
// GetByUserAndGroup 获取用户在特定分组的专属倍率
// 如果未设置专属倍率,返回 nil
GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error)
// SyncUserGroupRates 同步用户的分组专属倍率
// rates: map[groupID]*rateMultipliernil 表示删除该分组的专属倍率
SyncUserGroupRates(ctx context.Context, userID int64, rates map[int64]*float64) error
// DeleteByGroupID 删除指定分组的所有用户专属倍率(分组删除时调用)
DeleteByGroupID(ctx context.Context, groupID int64) error
// DeleteByUserID 删除指定用户的所有专属倍率(用户删除时调用)
DeleteByUserID(ctx context.Context, userID int64) error
}

View File

@@ -294,4 +294,5 @@ var ProviderSet = wire.NewSet(
NewUserAttributeService, NewUserAttributeService,
NewUsageCache, NewUsageCache,
NewTotpService, NewTotpService,
NewErrorPassthroughService,
) )

View File

@@ -0,0 +1,19 @@
-- 用户专属分组倍率表
-- 允许管理员为特定用户设置分组的专属计费倍率,覆盖分组默认倍率
CREATE TABLE IF NOT EXISTS user_group_rate_multipliers (
user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
group_id BIGINT NOT NULL REFERENCES groups(id) ON DELETE CASCADE,
rate_multiplier DECIMAL(10,4) NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
PRIMARY KEY (user_id, group_id)
);
-- 按 group_id 查询索引(删除分组时清理关联记录)
CREATE INDEX IF NOT EXISTS idx_user_group_rate_multipliers_group_id
ON user_group_rate_multipliers(group_id);
COMMENT ON TABLE user_group_rate_multipliers IS '用户专属分组倍率配置';
COMMENT ON COLUMN user_group_rate_multipliers.user_id IS '用户ID';
COMMENT ON COLUMN user_group_rate_multipliers.group_id IS '分组ID';
COMMENT ON COLUMN user_group_rate_multipliers.rate_multiplier IS '专属计费倍率(覆盖分组默认倍率)';

View File

@@ -0,0 +1,24 @@
-- Error Passthrough Rules table
-- Allows administrators to configure how upstream errors are passed through to clients
CREATE TABLE IF NOT EXISTS error_passthrough_rules (
id BIGSERIAL PRIMARY KEY,
name VARCHAR(100) NOT NULL,
enabled BOOLEAN NOT NULL DEFAULT true,
priority INTEGER NOT NULL DEFAULT 0,
error_codes JSONB DEFAULT '[]',
keywords JSONB DEFAULT '[]',
match_mode VARCHAR(10) NOT NULL DEFAULT 'any',
platforms JSONB DEFAULT '[]',
passthrough_code BOOLEAN NOT NULL DEFAULT true,
response_code INTEGER,
passthrough_body BOOLEAN NOT NULL DEFAULT true,
custom_message TEXT,
description TEXT,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
-- Indexes for efficient queries
CREATE INDEX IF NOT EXISTS idx_error_passthrough_rules_enabled ON error_passthrough_rules (enabled);
CREATE INDEX IF NOT EXISTS idx_error_passthrough_rules_priority ON error_passthrough_rules (priority);

View File

@@ -20,6 +20,31 @@ SERVER_PORT=8080
# Server mode: release or debug # Server mode: release or debug
SERVER_MODE=release SERVER_MODE=release
# Global max request body size in bytes (default: 100MB)
# 全局最大请求体大小(字节,默认 100MB
# Applies to all requests, especially important for h2c first request memory protection
# 适用于所有请求,对 h2c 第一请求的内存保护尤为重要
SERVER_MAX_REQUEST_BODY_SIZE=104857600
# Enable HTTP/2 Cleartext (h2c) for client connections
# 启用 HTTP/2 Cleartext (h2c) 客户端连接
SERVER_H2C_ENABLED=true
# H2C max concurrent streams (default: 50)
# H2C 最大并发流数量(默认 50
SERVER_H2C_MAX_CONCURRENT_STREAMS=50
# H2C idle timeout in seconds (default: 75)
# H2C 空闲超时时间(秒,默认 75
SERVER_H2C_IDLE_TIMEOUT=75
# H2C max read frame size in bytes (default: 1048576 = 1MB)
# H2C 最大帧大小(字节,默认 1048576 = 1MB
SERVER_H2C_MAX_READ_FRAME_SIZE=1048576
# H2C max upload buffer per connection in bytes (default: 2097152 = 2MB)
# H2C 每个连接的最大上传缓冲区(字节,默认 2097152 = 2MB
SERVER_H2C_MAX_UPLOAD_BUFFER_PER_CONNECTION=2097152
# H2C max upload buffer per stream in bytes (default: 524288 = 512KB)
# H2C 每个流的最大上传缓冲区(字节,默认 524288 = 512KB
SERVER_H2C_MAX_UPLOAD_BUFFER_PER_STREAM=524288
# 运行模式: standard (默认) 或 simple (内部自用) # 运行模式: standard (默认) 或 simple (内部自用)
# standard: 完整 SaaS 功能,包含计费/余额校验simple: 隐藏 SaaS 功能并跳过计费/余额校验 # standard: 完整 SaaS 功能,包含计费/余额校验simple: 隐藏 SaaS 功能并跳过计费/余额校验
RUN_MODE=standard RUN_MODE=standard

View File

@@ -23,6 +23,32 @@ server:
# Trusted proxies for X-Forwarded-For parsing (CIDR/IP). Empty disables trusted proxies. # Trusted proxies for X-Forwarded-For parsing (CIDR/IP). Empty disables trusted proxies.
# 信任的代理地址CIDR/IP 格式),用于解析 X-Forwarded-For 头。留空则禁用代理信任。 # 信任的代理地址CIDR/IP 格式),用于解析 X-Forwarded-For 头。留空则禁用代理信任。
trusted_proxies: [] trusted_proxies: []
# Global max request body size in bytes (default: 100MB)
# 全局最大请求体大小(字节,默认 100MB
# Applies to all requests, especially important for h2c first request memory protection
# 适用于所有请求,对 h2c 第一请求的内存保护尤为重要
max_request_body_size: 104857600
# HTTP/2 Cleartext (h2c) configuration
# HTTP/2 Cleartext (h2c) 配置
h2c:
# Enable HTTP/2 Cleartext for client connections
# 启用 HTTP/2 Cleartext 客户端连接
enabled: true
# Max concurrent streams per connection
# 每个连接的最大并发流数量
max_concurrent_streams: 50
# Idle timeout for connections (seconds)
# 连接空闲超时时间(秒)
idle_timeout: 75
# Max frame size in bytes (default: 1MB)
# 最大帧大小(字节,默认 1MB
max_read_frame_size: 1048576
# Max upload buffer per connection in bytes (default: 2MB)
# 每个连接的最大上传缓冲区(字节,默认 2MB
max_upload_buffer_per_connection: 2097152
# Max upload buffer per stream in bytes (default: 512KB)
# 每个流的最大上传缓冲区(字节,默认 512KB
max_upload_buffer_per_stream: 524288
# ============================================================================= # =============================================================================
# Run Mode Configuration # Run Mode Configuration

View File

@@ -0,0 +1,134 @@
/**
* Admin Error Passthrough Rules API endpoints
* Handles error passthrough rule management for administrators
*/
import { apiClient } from '../client'
/**
* Error passthrough rule interface
*/
export interface ErrorPassthroughRule {
id: number
name: string
enabled: boolean
priority: number
error_codes: number[]
keywords: string[]
match_mode: 'any' | 'all'
platforms: string[]
passthrough_code: boolean
response_code: number | null
passthrough_body: boolean
custom_message: string | null
description: string | null
created_at: string
updated_at: string
}
/**
* Create rule request
*/
export interface CreateRuleRequest {
name: string
enabled?: boolean
priority?: number
error_codes?: number[]
keywords?: string[]
match_mode?: 'any' | 'all'
platforms?: string[]
passthrough_code?: boolean
response_code?: number | null
passthrough_body?: boolean
custom_message?: string | null
description?: string | null
}
/**
* Update rule request
*/
export interface UpdateRuleRequest {
name?: string
enabled?: boolean
priority?: number
error_codes?: number[]
keywords?: string[]
match_mode?: 'any' | 'all'
platforms?: string[]
passthrough_code?: boolean
response_code?: number | null
passthrough_body?: boolean
custom_message?: string | null
description?: string | null
}
/**
* List all error passthrough rules
* @returns List of all rules sorted by priority
*/
export async function list(): Promise<ErrorPassthroughRule[]> {
const { data } = await apiClient.get<ErrorPassthroughRule[]>('/admin/error-passthrough-rules')
return data
}
/**
* Get rule by ID
* @param id - Rule ID
* @returns Rule details
*/
export async function getById(id: number): Promise<ErrorPassthroughRule> {
const { data } = await apiClient.get<ErrorPassthroughRule>(`/admin/error-passthrough-rules/${id}`)
return data
}
/**
* Create new rule
* @param ruleData - Rule data
* @returns Created rule
*/
export async function create(ruleData: CreateRuleRequest): Promise<ErrorPassthroughRule> {
const { data } = await apiClient.post<ErrorPassthroughRule>('/admin/error-passthrough-rules', ruleData)
return data
}
/**
* Update rule
* @param id - Rule ID
* @param updates - Fields to update
* @returns Updated rule
*/
export async function update(id: number, updates: UpdateRuleRequest): Promise<ErrorPassthroughRule> {
const { data } = await apiClient.put<ErrorPassthroughRule>(`/admin/error-passthrough-rules/${id}`, updates)
return data
}
/**
* Delete rule
* @param id - Rule ID
* @returns Success confirmation
*/
export async function deleteRule(id: number): Promise<{ message: string }> {
const { data } = await apiClient.delete<{ message: string }>(`/admin/error-passthrough-rules/${id}`)
return data
}
/**
* Toggle rule enabled status
* @param id - Rule ID
* @param enabled - New enabled status
* @returns Updated rule
*/
export async function toggleEnabled(id: number, enabled: boolean): Promise<ErrorPassthroughRule> {
return update(id, { enabled })
}
export const errorPassthroughAPI = {
list,
getById,
create,
update,
delete: deleteRule,
toggleEnabled
}
export default errorPassthroughAPI

View File

@@ -19,6 +19,7 @@ import geminiAPI from './gemini'
import antigravityAPI from './antigravity' import antigravityAPI from './antigravity'
import userAttributesAPI from './userAttributes' import userAttributesAPI from './userAttributes'
import opsAPI from './ops' import opsAPI from './ops'
import errorPassthroughAPI from './errorPassthrough'
/** /**
* Unified admin API object for convenient access * Unified admin API object for convenient access
@@ -39,7 +40,8 @@ export const adminAPI = {
gemini: geminiAPI, gemini: geminiAPI,
antigravity: antigravityAPI, antigravity: antigravityAPI,
userAttributes: userAttributesAPI, userAttributes: userAttributesAPI,
ops: opsAPI ops: opsAPI,
errorPassthrough: errorPassthroughAPI
} }
export { export {
@@ -58,10 +60,12 @@ export {
geminiAPI, geminiAPI,
antigravityAPI, antigravityAPI,
userAttributesAPI, userAttributesAPI,
opsAPI opsAPI,
errorPassthroughAPI
} }
export default adminAPI export default adminAPI
// Re-export types used by components // Re-export types used by components
export type { BalanceHistoryItem } from './users' export type { BalanceHistoryItem } from './users'
export type { ErrorPassthroughRule, CreateRuleRequest, UpdateRuleRequest } from './errorPassthrough'

View File

@@ -35,6 +35,22 @@ export function setAuthToken(token: string): void {
localStorage.setItem('auth_token', token) localStorage.setItem('auth_token', token)
} }
/**
* Store refresh token in localStorage
*/
export function setRefreshToken(token: string): void {
localStorage.setItem('refresh_token', token)
}
/**
* Store token expiration timestamp in localStorage
* Converts expires_in (seconds) to absolute timestamp (milliseconds)
*/
export function setTokenExpiresAt(expiresIn: number): void {
const expiresAt = Date.now() + expiresIn * 1000
localStorage.setItem('token_expires_at', String(expiresAt))
}
/** /**
* Get authentication token from localStorage * Get authentication token from localStorage
*/ */
@@ -42,12 +58,29 @@ export function getAuthToken(): string | null {
return localStorage.getItem('auth_token') return localStorage.getItem('auth_token')
} }
/**
* Get refresh token from localStorage
*/
export function getRefreshToken(): string | null {
return localStorage.getItem('refresh_token')
}
/**
* Get token expiration timestamp from localStorage
*/
export function getTokenExpiresAt(): number | null {
const value = localStorage.getItem('token_expires_at')
return value ? parseInt(value, 10) : null
}
/** /**
* Clear authentication token from localStorage * Clear authentication token from localStorage
*/ */
export function clearAuthToken(): void { export function clearAuthToken(): void {
localStorage.removeItem('auth_token') localStorage.removeItem('auth_token')
localStorage.removeItem('refresh_token')
localStorage.removeItem('auth_user') localStorage.removeItem('auth_user')
localStorage.removeItem('token_expires_at')
} }
/** /**
@@ -61,6 +94,12 @@ export async function login(credentials: LoginRequest): Promise<LoginResponse> {
// Only store token if 2FA is not required // Only store token if 2FA is not required
if (!isTotp2FARequired(data)) { if (!isTotp2FARequired(data)) {
setAuthToken(data.access_token) setAuthToken(data.access_token)
if (data.refresh_token) {
setRefreshToken(data.refresh_token)
}
if (data.expires_in) {
setTokenExpiresAt(data.expires_in)
}
localStorage.setItem('auth_user', JSON.stringify(data.user)) localStorage.setItem('auth_user', JSON.stringify(data.user))
} }
@@ -77,6 +116,12 @@ export async function login2FA(request: TotpLogin2FARequest): Promise<AuthRespon
// Store token and user data // Store token and user data
setAuthToken(data.access_token) setAuthToken(data.access_token)
if (data.refresh_token) {
setRefreshToken(data.refresh_token)
}
if (data.expires_in) {
setTokenExpiresAt(data.expires_in)
}
localStorage.setItem('auth_user', JSON.stringify(data.user)) localStorage.setItem('auth_user', JSON.stringify(data.user))
return data return data
@@ -92,6 +137,12 @@ export async function register(userData: RegisterRequest): Promise<AuthResponse>
// Store token and user data // Store token and user data
setAuthToken(data.access_token) setAuthToken(data.access_token)
if (data.refresh_token) {
setRefreshToken(data.refresh_token)
}
if (data.expires_in) {
setTokenExpiresAt(data.expires_in)
}
localStorage.setItem('auth_user', JSON.stringify(data.user)) localStorage.setItem('auth_user', JSON.stringify(data.user))
return data return data
@@ -108,11 +159,62 @@ export async function getCurrentUser() {
/** /**
* User logout * User logout
* Clears authentication token and user data from localStorage * Clears authentication token and user data from localStorage
* Optionally revokes the refresh token on the server
*/ */
export function logout(): void { export async function logout(): Promise<void> {
const refreshToken = getRefreshToken()
// Try to revoke the refresh token on the server
if (refreshToken) {
try {
await apiClient.post('/auth/logout', { refresh_token: refreshToken })
} catch {
// Ignore errors - we still want to clear local state
}
}
clearAuthToken() clearAuthToken()
// Optionally redirect to login page }
// window.location.href = '/login';
/**
* Refresh token response
*/
export interface RefreshTokenResponse {
access_token: string
refresh_token: string
expires_in: number
token_type: string
}
/**
* Refresh the access token using the refresh token
* @returns New token pair
*/
export async function refreshToken(): Promise<RefreshTokenResponse> {
const currentRefreshToken = getRefreshToken()
if (!currentRefreshToken) {
throw new Error('No refresh token available')
}
const { data } = await apiClient.post<RefreshTokenResponse>('/auth/refresh', {
refresh_token: currentRefreshToken
})
// Update tokens in localStorage
setAuthToken(data.access_token)
setRefreshToken(data.refresh_token)
setTokenExpiresAt(data.expires_in)
return data
}
/**
* Revoke all sessions for the current user
* @returns Response with message
*/
export async function revokeAllSessions(): Promise<{ message: string }> {
const { data } = await apiClient.post<{ message: string }>('/auth/revoke-all-sessions')
return data
} }
/** /**
@@ -242,14 +344,20 @@ export const authAPI = {
logout, logout,
isAuthenticated, isAuthenticated,
setAuthToken, setAuthToken,
setRefreshToken,
setTokenExpiresAt,
getAuthToken, getAuthToken,
getRefreshToken,
getTokenExpiresAt,
clearAuthToken, clearAuthToken,
getPublicSettings, getPublicSettings,
sendVerifyCode, sendVerifyCode,
validatePromoCode, validatePromoCode,
validateInvitationCode, validateInvitationCode,
forgotPassword, forgotPassword,
resetPassword resetPassword,
refreshToken,
revokeAllSessions
} }
export default authAPI export default authAPI

View File

@@ -1,9 +1,9 @@
/** /**
* Axios HTTP Client Configuration * Axios HTTP Client Configuration
* Base client with interceptors for authentication and error handling * Base client with interceptors for authentication, token refresh, and error handling
*/ */
import axios, { AxiosInstance, AxiosError, InternalAxiosRequestConfig } from 'axios' import axios, { AxiosInstance, AxiosError, InternalAxiosRequestConfig, AxiosResponse } from 'axios'
import type { ApiResponse } from '@/types' import type { ApiResponse } from '@/types'
import { getLocale } from '@/i18n' import { getLocale } from '@/i18n'
@@ -19,6 +19,28 @@ export const apiClient: AxiosInstance = axios.create({
} }
}) })
// ==================== Token Refresh State ====================
// Track if a token refresh is in progress to prevent multiple simultaneous refresh requests
let isRefreshing = false
// Queue of requests waiting for token refresh
let refreshSubscribers: Array<(token: string) => void> = []
/**
* Subscribe to token refresh completion
*/
function subscribeTokenRefresh(callback: (token: string) => void): void {
refreshSubscribers.push(callback)
}
/**
* Notify all subscribers that token has been refreshed
*/
function onTokenRefreshed(token: string): void {
refreshSubscribers.forEach((callback) => callback(token))
refreshSubscribers = []
}
// ==================== Request Interceptor ==================== // ==================== Request Interceptor ====================
// Get user's timezone // Get user's timezone
@@ -61,7 +83,7 @@ apiClient.interceptors.request.use(
// ==================== Response Interceptor ==================== // ==================== Response Interceptor ====================
apiClient.interceptors.response.use( apiClient.interceptors.response.use(
(response) => { (response: AxiosResponse) => {
// Unwrap standard API response format { code, message, data } // Unwrap standard API response format { code, message, data }
const apiResponse = response.data as ApiResponse<unknown> const apiResponse = response.data as ApiResponse<unknown>
if (apiResponse && typeof apiResponse === 'object' && 'code' in apiResponse) { if (apiResponse && typeof apiResponse === 'object' && 'code' in apiResponse) {
@@ -79,13 +101,15 @@ apiClient.interceptors.response.use(
} }
return response return response
}, },
(error: AxiosError<ApiResponse<unknown>>) => { async (error: AxiosError<ApiResponse<unknown>>) => {
// Request cancellation: keep the original axios cancellation error so callers can ignore it. // Request cancellation: keep the original axios cancellation error so callers can ignore it.
// Otherwise we'd misclassify it as a generic "network error". // Otherwise we'd misclassify it as a generic "network error".
if (error.code === 'ERR_CANCELED' || axios.isCancel(error)) { if (error.code === 'ERR_CANCELED' || axios.isCancel(error)) {
return Promise.reject(error) return Promise.reject(error)
} }
const originalRequest = error.config as InternalAxiosRequestConfig & { _retry?: boolean }
// Handle common errors // Handle common errors
if (error.response) { if (error.response) {
const { status, data } = error.response const { status, data } = error.response
@@ -120,23 +144,116 @@ apiClient.interceptors.response.use(
}) })
} }
// 401: Unauthorized - clear token and redirect to login // 401: Try to refresh the token if we have a refresh token
if (status === 401) { // This handles TOKEN_EXPIRED, INVALID_TOKEN, TOKEN_REVOKED, etc.
const hasToken = !!localStorage.getItem('auth_token') if (status === 401 && !originalRequest._retry) {
const url = error.config?.url || '' const refreshToken = localStorage.getItem('refresh_token')
const isAuthEndpoint = const isAuthEndpoint =
url.includes('/auth/login') || url.includes('/auth/register') || url.includes('/auth/refresh') url.includes('/auth/login') || url.includes('/auth/register') || url.includes('/auth/refresh')
// If we have a refresh token and this is not an auth endpoint, try to refresh
if (refreshToken && !isAuthEndpoint) {
if (isRefreshing) {
// Wait for the ongoing refresh to complete
return new Promise((resolve, reject) => {
subscribeTokenRefresh((newToken: string) => {
if (newToken) {
// Mark as retried to prevent infinite loop if retry also returns 401
originalRequest._retry = true
if (originalRequest.headers) {
originalRequest.headers.Authorization = `Bearer ${newToken}`
}
resolve(apiClient(originalRequest))
} else {
// Refresh failed, reject with original error
reject({
status,
code: apiData.code,
message: apiData.message || apiData.detail || error.message
})
}
})
})
}
originalRequest._retry = true
isRefreshing = true
try {
// Call refresh endpoint directly to avoid circular dependency
const refreshResponse = await axios.post(
`${API_BASE_URL}/auth/refresh`,
{ refresh_token: refreshToken },
{ headers: { 'Content-Type': 'application/json' } }
)
const refreshData = refreshResponse.data as ApiResponse<{
access_token: string
refresh_token: string
expires_in: number
}>
if (refreshData.code === 0 && refreshData.data) {
const { access_token, refresh_token: newRefreshToken, expires_in } = refreshData.data
// Update tokens in localStorage (convert expires_in to timestamp)
localStorage.setItem('auth_token', access_token)
localStorage.setItem('refresh_token', newRefreshToken)
localStorage.setItem('token_expires_at', String(Date.now() + expires_in * 1000))
// Notify subscribers with new token
onTokenRefreshed(access_token)
// Retry the original request with new token
if (originalRequest.headers) {
originalRequest.headers.Authorization = `Bearer ${access_token}`
}
isRefreshing = false
return apiClient(originalRequest)
}
// Refresh response was not successful, fall through to clear auth
throw new Error('Token refresh failed')
} catch (refreshError) {
// Refresh failed - notify subscribers with empty token
onTokenRefreshed('')
isRefreshing = false
// Clear tokens and redirect to login
localStorage.removeItem('auth_token')
localStorage.removeItem('refresh_token')
localStorage.removeItem('auth_user')
localStorage.removeItem('token_expires_at')
sessionStorage.setItem('auth_expired', '1')
if (!window.location.pathname.includes('/login')) {
window.location.href = '/login'
}
return Promise.reject({
status: 401,
code: 'TOKEN_REFRESH_FAILED',
message: 'Session expired. Please log in again.'
})
}
}
// No refresh token or is auth endpoint - clear auth and redirect
const hasToken = !!localStorage.getItem('auth_token')
const headers = error.config?.headers as Record<string, unknown> | undefined const headers = error.config?.headers as Record<string, unknown> | undefined
const authHeader = headers?.Authorization ?? headers?.authorization const authHeader = headers?.Authorization ?? headers?.authorization
const sentAuth = const sentAuth =
typeof authHeader === 'string' typeof authHeader === 'string'
? authHeader.trim() !== '' ? authHeader.trim() !== ''
: Array.isArray(authHeader) : Array.isArray(authHeader)
? authHeader.length > 0 ? authHeader.length > 0
: !!authHeader : !!authHeader
localStorage.removeItem('auth_token') localStorage.removeItem('auth_token')
localStorage.removeItem('refresh_token')
localStorage.removeItem('auth_user') localStorage.removeItem('auth_user')
localStorage.removeItem('token_expires_at')
if ((hasToken || sentAuth) && !isAuthEndpoint) { if ((hasToken || sentAuth) && !isAuthEndpoint) {
sessionStorage.setItem('auth_expired', '1') sessionStorage.setItem('auth_expired', '1')
} }

View File

@@ -18,8 +18,18 @@ export async function getAvailable(): Promise<Group[]> {
return data return data
} }
/**
* Get current user's custom group rate multipliers
* @returns Map of group_id to custom rate_multiplier
*/
export async function getUserGroupRates(): Promise<Record<number, number>> {
const { data } = await apiClient.get<Record<number, number> | null>('/groups/rates')
return data || {}
}
export const userGroupsAPI = { export const userGroupsAPI = {
getAvailable getAvailable,
getUserGroupRates
} }
export default userGroupsAPI export default userGroupsAPI

View File

@@ -0,0 +1,623 @@
<template>
<BaseDialog
:show="show"
:title="t('admin.errorPassthrough.title')"
width="extra-wide"
@close="$emit('close')"
>
<div class="space-y-4">
<!-- Header -->
<div class="flex items-center justify-between">
<p class="text-sm text-gray-500 dark:text-gray-400">
{{ t('admin.errorPassthrough.description') }}
</p>
<button @click="showCreateModal = true" class="btn btn-primary btn-sm">
<Icon name="plus" size="sm" class="mr-1" />
{{ t('admin.errorPassthrough.createRule') }}
</button>
</div>
<!-- Rules Table -->
<div v-if="loading" class="flex items-center justify-center py-8">
<Icon name="refresh" size="lg" class="animate-spin text-gray-400" />
</div>
<div v-else-if="rules.length === 0" class="py-8 text-center">
<div class="mx-auto mb-4 flex h-12 w-12 items-center justify-center rounded-full bg-gray-100 dark:bg-dark-700">
<Icon name="shield" size="lg" class="text-gray-400" />
</div>
<h4 class="mb-1 text-sm font-medium text-gray-900 dark:text-white">
{{ t('admin.errorPassthrough.noRules') }}
</h4>
<p class="text-sm text-gray-500 dark:text-gray-400">
{{ t('admin.errorPassthrough.createFirstRule') }}
</p>
</div>
<div v-else class="max-h-96 overflow-auto rounded-lg border border-gray-200 dark:border-dark-600">
<table class="min-w-full divide-y divide-gray-200 dark:divide-dark-700">
<thead class="sticky top-0 bg-gray-50 dark:bg-dark-700">
<tr>
<th class="px-3 py-2 text-left text-xs font-medium uppercase text-gray-500 dark:text-gray-400">
{{ t('admin.errorPassthrough.columns.priority') }}
</th>
<th class="px-3 py-2 text-left text-xs font-medium uppercase text-gray-500 dark:text-gray-400">
{{ t('admin.errorPassthrough.columns.name') }}
</th>
<th class="px-3 py-2 text-left text-xs font-medium uppercase text-gray-500 dark:text-gray-400">
{{ t('admin.errorPassthrough.columns.conditions') }}
</th>
<th class="px-3 py-2 text-left text-xs font-medium uppercase text-gray-500 dark:text-gray-400">
{{ t('admin.errorPassthrough.columns.platforms') }}
</th>
<th class="px-3 py-2 text-left text-xs font-medium uppercase text-gray-500 dark:text-gray-400">
{{ t('admin.errorPassthrough.columns.behavior') }}
</th>
<th class="px-3 py-2 text-left text-xs font-medium uppercase text-gray-500 dark:text-gray-400">
{{ t('admin.errorPassthrough.columns.status') }}
</th>
<th class="px-3 py-2 text-left text-xs font-medium uppercase text-gray-500 dark:text-gray-400">
{{ t('admin.errorPassthrough.columns.actions') }}
</th>
</tr>
</thead>
<tbody class="divide-y divide-gray-200 bg-white dark:divide-dark-700 dark:bg-dark-800">
<tr v-for="rule in rules" :key="rule.id" class="hover:bg-gray-50 dark:hover:bg-dark-700">
<td class="whitespace-nowrap px-3 py-2">
<span class="inline-flex h-5 w-5 items-center justify-center rounded bg-gray-100 text-xs font-medium text-gray-700 dark:bg-dark-600 dark:text-gray-300">
{{ rule.priority }}
</span>
</td>
<td class="px-3 py-2">
<div class="font-medium text-gray-900 dark:text-white text-sm">{{ rule.name }}</div>
<div v-if="rule.description" class="mt-0.5 text-xs text-gray-500 dark:text-gray-400 max-w-xs truncate">
{{ rule.description }}
</div>
</td>
<td class="px-3 py-2">
<div class="flex flex-wrap gap-1 max-w-48">
<span
v-for="code in rule.error_codes.slice(0, 3)"
:key="code"
class="badge badge-danger text-xs"
>
{{ code }}
</span>
<span
v-if="rule.error_codes.length > 3"
class="text-xs text-gray-500"
>
+{{ rule.error_codes.length - 3 }}
</span>
<span
v-for="keyword in rule.keywords.slice(0, 1)"
:key="keyword"
class="badge badge-gray text-xs"
>
"{{ keyword.length > 10 ? keyword.substring(0, 10) + '...' : keyword }}"
</span>
<span
v-if="rule.keywords.length > 1"
class="text-xs text-gray-500"
>
+{{ rule.keywords.length - 1 }}
</span>
</div>
<div class="mt-0.5 text-xs text-gray-500 dark:text-gray-400">
{{ t('admin.errorPassthrough.matchMode.' + rule.match_mode) }}
</div>
</td>
<td class="px-3 py-2">
<div v-if="rule.platforms.length === 0" class="text-xs text-gray-500 dark:text-gray-400">
{{ t('admin.errorPassthrough.allPlatforms') }}
</div>
<div v-else class="flex flex-wrap gap-1">
<span
v-for="platform in rule.platforms.slice(0, 2)"
:key="platform"
class="badge badge-primary text-xs"
>
{{ platform }}
</span>
<span v-if="rule.platforms.length > 2" class="text-xs text-gray-500">
+{{ rule.platforms.length - 2 }}
</span>
</div>
</td>
<td class="px-3 py-2">
<div class="text-xs space-y-0.5">
<div class="flex items-center gap-1">
<Icon
:name="rule.passthrough_code ? 'checkCircle' : 'xCircle'"
size="xs"
:class="rule.passthrough_code ? 'text-green-500' : 'text-gray-400'"
/>
<span class="text-gray-600 dark:text-gray-400">
{{ t('admin.errorPassthrough.code') }}:
{{ rule.passthrough_code ? t('admin.errorPassthrough.passthrough') : (rule.response_code || '-') }}
</span>
</div>
<div class="flex items-center gap-1">
<Icon
:name="rule.passthrough_body ? 'checkCircle' : 'xCircle'"
size="xs"
:class="rule.passthrough_body ? 'text-green-500' : 'text-gray-400'"
/>
<span class="text-gray-600 dark:text-gray-400">
{{ t('admin.errorPassthrough.body') }}:
{{ rule.passthrough_body ? t('admin.errorPassthrough.passthrough') : t('admin.errorPassthrough.custom') }}
</span>
</div>
</div>
</td>
<td class="px-3 py-2">
<button
@click="toggleEnabled(rule)"
:class="[
'relative inline-flex h-4 w-7 flex-shrink-0 cursor-pointer rounded-full border-2 border-transparent transition-colors duration-200 ease-in-out focus:outline-none focus:ring-2 focus:ring-primary-500 focus:ring-offset-2',
rule.enabled ? 'bg-primary-600' : 'bg-gray-200 dark:bg-dark-600'
]"
>
<span
:class="[
'pointer-events-none inline-block h-3 w-3 transform rounded-full bg-white shadow ring-0 transition duration-200 ease-in-out',
rule.enabled ? 'translate-x-3' : 'translate-x-0'
]"
/>
</button>
</td>
<td class="px-3 py-2">
<div class="flex items-center gap-1">
<button
@click="handleEdit(rule)"
class="p-1 text-gray-500 hover:text-primary-600 dark:hover:text-primary-400"
:title="t('common.edit')"
>
<Icon name="edit" size="sm" />
</button>
<button
@click="handleDelete(rule)"
class="p-1 text-gray-500 hover:text-red-600 dark:hover:text-red-400"
:title="t('common.delete')"
>
<Icon name="trash" size="sm" />
</button>
</div>
</td>
</tr>
</tbody>
</table>
</div>
</div>
<template #footer>
<div class="flex justify-end">
<button @click="$emit('close')" class="btn btn-secondary">
{{ t('common.close') }}
</button>
</div>
</template>
<!-- Create/Edit Modal -->
<BaseDialog
:show="showCreateModal || showEditModal"
:title="showEditModal ? t('admin.errorPassthrough.editRule') : t('admin.errorPassthrough.createRule')"
width="wide"
@close="closeFormModal"
>
<form @submit.prevent="handleSubmit" class="space-y-4">
<!-- Basic Info -->
<div class="grid grid-cols-2 gap-4">
<div>
<label class="input-label">{{ t('admin.errorPassthrough.form.name') }}</label>
<input
v-model="form.name"
type="text"
required
class="input"
:placeholder="t('admin.errorPassthrough.form.namePlaceholder')"
/>
</div>
<div>
<label class="input-label">{{ t('admin.errorPassthrough.form.priority') }}</label>
<input
v-model.number="form.priority"
type="number"
min="0"
class="input"
/>
<p class="input-hint">{{ t('admin.errorPassthrough.form.priorityHint') }}</p>
</div>
</div>
<div>
<label class="input-label">{{ t('admin.errorPassthrough.form.description') }}</label>
<input
v-model="form.description"
type="text"
class="input"
:placeholder="t('admin.errorPassthrough.form.descriptionPlaceholder')"
/>
</div>
<!-- Match Conditions -->
<div class="rounded-lg border border-gray-200 p-3 dark:border-dark-600">
<h4 class="mb-2 text-sm font-medium text-gray-900 dark:text-white">
{{ t('admin.errorPassthrough.form.matchConditions') }}
</h4>
<div class="grid grid-cols-2 gap-3">
<div>
<label class="input-label text-xs">{{ t('admin.errorPassthrough.form.errorCodes') }}</label>
<input
v-model="errorCodesInput"
type="text"
class="input text-sm"
:placeholder="t('admin.errorPassthrough.form.errorCodesPlaceholder')"
/>
<p class="input-hint text-xs">{{ t('admin.errorPassthrough.form.errorCodesHint') }}</p>
</div>
<div>
<label class="input-label text-xs">{{ t('admin.errorPassthrough.form.keywords') }}</label>
<textarea
v-model="keywordsInput"
rows="2"
class="input font-mono text-xs"
:placeholder="t('admin.errorPassthrough.form.keywordsPlaceholder')"
/>
<p class="input-hint text-xs">{{ t('admin.errorPassthrough.form.keywordsHint') }}</p>
</div>
</div>
<div class="mt-3">
<label class="input-label text-xs">{{ t('admin.errorPassthrough.form.matchMode') }}</label>
<div class="mt-1 space-y-2">
<label
v-for="option in matchModeOptions"
:key="option.value"
class="flex items-start gap-2 cursor-pointer"
>
<input
type="radio"
:value="option.value"
v-model="form.match_mode"
class="mt-0.5 h-3.5 w-3.5 border-gray-300 text-primary-600 focus:ring-primary-500"
/>
<div class="flex-1">
<span class="text-xs font-medium text-gray-700 dark:text-gray-300">{{ option.label }}</span>
<p class="text-xs text-gray-500 dark:text-gray-400">{{ option.description }}</p>
</div>
</label>
</div>
</div>
<div class="mt-3">
<label class="input-label text-xs">{{ t('admin.errorPassthrough.form.platforms') }}</label>
<div class="flex flex-wrap gap-3">
<label
v-for="platform in platformOptions"
:key="platform.value"
class="inline-flex items-center gap-1.5"
>
<input
type="checkbox"
:value="platform.value"
v-model="form.platforms"
class="h-3.5 w-3.5 rounded border-gray-300 text-primary-600 focus:ring-primary-500"
/>
<span class="text-xs text-gray-700 dark:text-gray-300">{{ platform.label }}</span>
</label>
</div>
<p class="input-hint text-xs mt-1">{{ t('admin.errorPassthrough.form.platformsHint') }}</p>
</div>
</div>
<!-- Response Behavior -->
<div class="rounded-lg border border-gray-200 p-3 dark:border-dark-600">
<h4 class="mb-2 text-sm font-medium text-gray-900 dark:text-white">
{{ t('admin.errorPassthrough.form.responseBehavior') }}
</h4>
<div class="grid grid-cols-2 gap-3">
<div>
<label class="flex items-center gap-1.5">
<input
type="checkbox"
v-model="form.passthrough_code"
class="h-3.5 w-3.5 rounded border-gray-300 text-primary-600 focus:ring-primary-500"
/>
<span class="text-xs font-medium text-gray-700 dark:text-gray-300">
{{ t('admin.errorPassthrough.form.passthroughCode') }}
</span>
</label>
<div v-if="!form.passthrough_code" class="mt-2">
<label class="input-label text-xs">{{ t('admin.errorPassthrough.form.responseCode') }}</label>
<input
v-model.number="form.response_code"
type="number"
min="100"
max="599"
class="input text-sm"
placeholder="422"
/>
</div>
</div>
<div>
<label class="flex items-center gap-1.5">
<input
type="checkbox"
v-model="form.passthrough_body"
class="h-3.5 w-3.5 rounded border-gray-300 text-primary-600 focus:ring-primary-500"
/>
<span class="text-xs font-medium text-gray-700 dark:text-gray-300">
{{ t('admin.errorPassthrough.form.passthroughBody') }}
</span>
</label>
<div v-if="!form.passthrough_body" class="mt-2">
<label class="input-label text-xs">{{ t('admin.errorPassthrough.form.customMessage') }}</label>
<input
v-model="form.custom_message"
type="text"
class="input text-sm"
:placeholder="t('admin.errorPassthrough.form.customMessagePlaceholder')"
/>
</div>
</div>
</div>
</div>
<!-- Enabled -->
<div class="flex items-center gap-1.5">
<input
type="checkbox"
v-model="form.enabled"
class="h-3.5 w-3.5 rounded border-gray-300 text-primary-600 focus:ring-primary-500"
/>
<span class="text-xs font-medium text-gray-700 dark:text-gray-300">
{{ t('admin.errorPassthrough.form.enabled') }}
</span>
</div>
</form>
<template #footer>
<div class="flex justify-end gap-3">
<button @click="closeFormModal" type="button" class="btn btn-secondary">
{{ t('common.cancel') }}
</button>
<button @click="handleSubmit" :disabled="submitting" class="btn btn-primary">
<Icon v-if="submitting" name="refresh" size="sm" class="mr-1 animate-spin" />
{{ showEditModal ? t('common.update') : t('common.create') }}
</button>
</div>
</template>
</BaseDialog>
<!-- Delete Confirmation -->
<ConfirmDialog
:show="showDeleteDialog"
:title="t('admin.errorPassthrough.deleteRule')"
:message="t('admin.errorPassthrough.deleteConfirm', { name: deletingRule?.name })"
:confirm-text="t('common.delete')"
:cancel-text="t('common.cancel')"
:danger="true"
@confirm="confirmDelete"
@cancel="showDeleteDialog = false"
/>
</BaseDialog>
</template>
<script setup lang="ts">
import { ref, reactive, computed, watch } from 'vue'
import { useI18n } from 'vue-i18n'
import { useAppStore } from '@/stores/app'
import { adminAPI } from '@/api/admin'
import type { ErrorPassthroughRule } from '@/api/admin/errorPassthrough'
import BaseDialog from '@/components/common/BaseDialog.vue'
import ConfirmDialog from '@/components/common/ConfirmDialog.vue'
import Icon from '@/components/icons/Icon.vue'
const props = defineProps<{
show: boolean
}>()
const emit = defineEmits<{
close: []
}>()
// eslint-disable-next-line @typescript-eslint/no-unused-vars
void emit // suppress unused warning - emit is used via $emit in template
const { t } = useI18n()
const appStore = useAppStore()
const rules = ref<ErrorPassthroughRule[]>([])
const loading = ref(false)
const submitting = ref(false)
const showCreateModal = ref(false)
const showEditModal = ref(false)
const showDeleteDialog = ref(false)
const editingRule = ref<ErrorPassthroughRule | null>(null)
const deletingRule = ref<ErrorPassthroughRule | null>(null)
// Form inputs for arrays
const errorCodesInput = ref('')
const keywordsInput = ref('')
const form = reactive({
name: '',
enabled: true,
priority: 0,
match_mode: 'any' as 'any' | 'all',
platforms: [] as string[],
passthrough_code: true,
response_code: null as number | null,
passthrough_body: true,
custom_message: null as string | null,
description: null as string | null
})
const matchModeOptions = computed(() => [
{ value: 'any', label: t('admin.errorPassthrough.matchMode.any'), description: t('admin.errorPassthrough.matchMode.anyHint') },
{ value: 'all', label: t('admin.errorPassthrough.matchMode.all'), description: t('admin.errorPassthrough.matchMode.allHint') }
])
const platformOptions = [
{ value: 'anthropic', label: 'Anthropic' },
{ value: 'openai', label: 'OpenAI' },
{ value: 'gemini', label: 'Gemini' },
{ value: 'antigravity', label: 'Antigravity' }
]
// Load rules when dialog opens
watch(() => props.show, (newVal) => {
if (newVal) {
loadRules()
}
})
const loadRules = async () => {
loading.value = true
try {
rules.value = await adminAPI.errorPassthrough.list()
} catch (error) {
appStore.showError(t('admin.errorPassthrough.failedToLoad'))
console.error('Error loading rules:', error)
} finally {
loading.value = false
}
}
const resetForm = () => {
form.name = ''
form.enabled = true
form.priority = 0
form.match_mode = 'any'
form.platforms = []
form.passthrough_code = true
form.response_code = null
form.passthrough_body = true
form.custom_message = null
form.description = null
errorCodesInput.value = ''
keywordsInput.value = ''
}
const closeFormModal = () => {
showCreateModal.value = false
showEditModal.value = false
editingRule.value = null
resetForm()
}
const handleEdit = (rule: ErrorPassthroughRule) => {
editingRule.value = rule
form.name = rule.name
form.enabled = rule.enabled
form.priority = rule.priority
form.match_mode = rule.match_mode
form.platforms = [...rule.platforms]
form.passthrough_code = rule.passthrough_code
form.response_code = rule.response_code
form.passthrough_body = rule.passthrough_body
form.custom_message = rule.custom_message
form.description = rule.description
errorCodesInput.value = rule.error_codes.join(', ')
keywordsInput.value = rule.keywords.join('\n')
showEditModal.value = true
}
const handleDelete = (rule: ErrorPassthroughRule) => {
deletingRule.value = rule
showDeleteDialog.value = true
}
const parseErrorCodes = (): number[] => {
if (!errorCodesInput.value.trim()) return []
return errorCodesInput.value
.split(/[,\s]+/)
.map(s => parseInt(s.trim(), 10))
.filter(n => !isNaN(n) && n > 0)
}
const parseKeywords = (): string[] => {
if (!keywordsInput.value.trim()) return []
return keywordsInput.value
.split('\n')
.map(s => s.trim())
.filter(s => s.length > 0)
}
const handleSubmit = async () => {
if (!form.name.trim()) {
appStore.showError(t('admin.errorPassthrough.nameRequired'))
return
}
const errorCodes = parseErrorCodes()
const keywords = parseKeywords()
if (errorCodes.length === 0 && keywords.length === 0) {
appStore.showError(t('admin.errorPassthrough.conditionsRequired'))
return
}
submitting.value = true
try {
const data = {
name: form.name.trim(),
enabled: form.enabled,
priority: form.priority,
error_codes: errorCodes,
keywords: keywords,
match_mode: form.match_mode,
platforms: form.platforms,
passthrough_code: form.passthrough_code,
response_code: form.passthrough_code ? null : form.response_code,
passthrough_body: form.passthrough_body,
custom_message: form.passthrough_body ? null : form.custom_message,
description: form.description?.trim() || null
}
if (showEditModal.value && editingRule.value) {
await adminAPI.errorPassthrough.update(editingRule.value.id, data)
appStore.showSuccess(t('admin.errorPassthrough.ruleUpdated'))
} else {
await adminAPI.errorPassthrough.create(data)
appStore.showSuccess(t('admin.errorPassthrough.ruleCreated'))
}
closeFormModal()
loadRules()
} catch (error: any) {
appStore.showError(error.response?.data?.detail || t('admin.errorPassthrough.failedToSave'))
console.error('Error saving rule:', error)
} finally {
submitting.value = false
}
}
const toggleEnabled = async (rule: ErrorPassthroughRule) => {
try {
await adminAPI.errorPassthrough.toggleEnabled(rule.id, !rule.enabled)
rule.enabled = !rule.enabled
} catch (error: any) {
appStore.showError(error.response?.data?.detail || t('admin.errorPassthrough.failedToToggle'))
console.error('Error toggling rule:', error)
}
}
const confirmDelete = async () => {
if (!deletingRule.value) return
try {
await adminAPI.errorPassthrough.delete(deletingRule.value.id)
appStore.showSuccess(t('admin.errorPassthrough.ruleDeleted'))
showDeleteDialog.value = false
deletingRule.value = null
loadRules()
} catch (error: any) {
appStore.showError(error.response?.data?.detail || t('admin.errorPassthrough.failedToDelete'))
console.error('Error deleting rule:', error)
}
}
</script>

View File

@@ -1,59 +1,328 @@
<template> <template>
<BaseDialog :show="show" :title="t('admin.users.setAllowedGroups')" width="normal" @close="$emit('close')"> <BaseDialog :show="show" :title="t('admin.users.groupConfig')" width="wide" @close="$emit('close')">
<div v-if="user" class="space-y-4"> <div v-if="user" class="space-y-6">
<div class="flex items-center gap-3 rounded-xl bg-gray-50 p-4 dark:bg-dark-700"> <!-- 用户信息头部 -->
<div class="flex h-10 w-10 items-center justify-center rounded-full bg-primary-100"> <div class="flex items-center gap-4 rounded-2xl bg-gradient-to-r from-primary-50 to-primary-100 p-5 dark:from-primary-900/30 dark:to-primary-800/20">
<span class="text-lg font-medium text-primary-700">{{ user.email.charAt(0).toUpperCase() }}</span> <div class="flex h-14 w-14 items-center justify-center rounded-full bg-white shadow-sm dark:bg-dark-700">
<span class="text-2xl font-semibold text-primary-600 dark:text-primary-400">{{ user.email.charAt(0).toUpperCase() }}</span>
</div>
<div class="flex-1">
<p class="text-lg font-semibold text-gray-900 dark:text-white">{{ user.email }}</p>
<p class="mt-1 text-sm text-gray-600 dark:text-gray-400">{{ t('admin.users.groupConfigHint', { email: user.email }) }}</p>
</div> </div>
<p class="font-medium text-gray-900 dark:text-white">{{ user.email }}</p>
</div> </div>
<div v-if="loading" class="flex justify-center py-8"><svg class="h-8 w-8 animate-spin text-primary-500" fill="none" viewBox="0 0 24 24"><circle class="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" stroke-width="4"></circle><path class="opacity-75" fill="currentColor" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"></path></svg></div>
<div v-else> <!-- 加载状态 -->
<p class="mb-3 text-sm text-gray-600">{{ t('admin.users.allowedGroupsHint') }}</p> <div v-if="loading" class="flex justify-center py-12">
<div class="max-h-64 space-y-2 overflow-y-auto"> <svg class="h-10 w-10 animate-spin text-primary-500" fill="none" viewBox="0 0 24 24">
<label v-for="group in groups" :key="group.id" class="flex cursor-pointer items-center gap-3 rounded-lg border border-gray-200 p-3 hover:bg-gray-50" :class="{'border-primary-300 bg-primary-50': selectedIds.includes(group.id)}"> <circle class="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" stroke-width="4"></circle>
<input type="checkbox" :value="group.id" v-model="selectedIds" class="h-4 w-4 rounded border-gray-300 text-primary-600" /> <path class="opacity-75" fill="currentColor" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"></path>
<div class="flex-1"><p class="font-medium text-gray-900">{{ group.name }}</p><p v-if="group.description" class="truncate text-sm text-gray-500">{{ group.description }}</p></div> </svg>
<div class="flex items-center gap-2"><span class="badge badge-gray text-xs">{{ group.platform }}</span><span v-if="group.is_exclusive" class="badge badge-purple text-xs">{{ t('admin.groups.exclusive') }}</span></div> </div>
</label>
<div v-else class="space-y-6">
<!-- 专属分组区域 -->
<div v-if="exclusiveGroups.length > 0">
<div class="mb-3 flex items-center gap-2">
<div class="h-1.5 w-1.5 rounded-full bg-purple-500"></div>
<h4 class="text-sm font-semibold text-gray-700 dark:text-gray-300">{{ t('admin.users.exclusiveGroups') }}</h4>
<span class="text-xs text-gray-400">({{ exclusiveGroupConfigs.filter(c => c.isSelected).length }}/{{ exclusiveGroupConfigs.length }})</span>
</div>
<div class="grid gap-3">
<div
v-for="config in exclusiveGroupConfigs"
:key="config.groupId"
class="group relative overflow-hidden rounded-xl border-2 p-4 transition-all duration-200"
:class="config.isSelected
? 'border-primary-400 bg-primary-50/50 shadow-sm dark:border-primary-500 dark:bg-primary-900/20'
: 'border-gray-200 bg-white hover:border-gray-300 dark:border-dark-600 dark:bg-dark-800 dark:hover:border-dark-500'"
>
<div class="flex items-center gap-4">
<!-- 复选框 -->
<div class="flex-shrink-0">
<label class="relative flex h-6 w-6 cursor-pointer items-center justify-center">
<input
type="checkbox"
:checked="config.isSelected"
@change="toggleExclusiveGroup(config.groupId)"
class="peer sr-only"
/>
<div class="h-5 w-5 rounded-md border-2 border-gray-300 transition-all peer-checked:border-primary-500 peer-checked:bg-primary-500 dark:border-dark-500 peer-checked:dark:border-primary-500">
<svg v-if="config.isSelected" class="h-full w-full text-white" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="3">
<path stroke-linecap="round" stroke-linejoin="round" d="M5 13l4 4L19 7" />
</svg>
</div>
</label>
</div>
<!-- 分组信息 -->
<div class="min-w-0 flex-1">
<div class="flex items-center gap-2">
<span class="text-base font-semibold text-gray-900 dark:text-white">{{ config.groupName }}</span>
<span class="inline-flex items-center rounded-full bg-purple-100 px-2 py-0.5 text-xs font-medium text-purple-700 dark:bg-purple-900/40 dark:text-purple-300">
{{ t('admin.groups.exclusive') }}
</span>
</div>
<div class="mt-1.5 flex items-center gap-3 text-sm">
<span class="inline-flex items-center gap-1 text-gray-500 dark:text-gray-400">
<PlatformIcon :platform="config.platform" size="xs" />
<span>{{ config.platform }}</span>
</span>
<span class="text-gray-300 dark:text-dark-500"></span>
<span class="text-gray-500 dark:text-gray-400">
{{ t('admin.users.defaultRate') }}: <span class="font-medium text-gray-700 dark:text-gray-300">{{ config.defaultRate }}x</span>
</span>
</div>
</div>
<!-- 专属倍率输入 -->
<div class="flex flex-shrink-0 items-center gap-3">
<label class="text-sm font-medium text-gray-600 dark:text-gray-400">{{ t('admin.users.customRate') }}</label>
<input
type="number"
step="0.001"
min="0"
:value="config.customRate ?? ''"
@input="updateCustomRate(config.groupId, ($event.target as HTMLInputElement).value)"
:placeholder="String(config.defaultRate)"
class="hide-spinner w-24 rounded-lg border border-gray-300 bg-white px-3 py-2 text-sm font-medium transition-colors focus:border-primary-500 focus:outline-none focus:ring-2 focus:ring-primary-500/20 dark:border-dark-500 dark:bg-dark-700 dark:focus:border-primary-500"
/>
</div>
</div>
</div>
</div>
</div> </div>
<div class="mt-4 border-t border-gray-200 pt-4">
<label class="flex cursor-pointer items-center gap-3 rounded-lg border border-gray-200 p-3 hover:bg-gray-50" :class="{'border-green-300 bg-green-50': selectedIds.length === 0}"> <!-- 公开分组区域 -->
<input type="radio" :checked="selectedIds.length === 0" @change="selectedIds = []" class="h-4 w-4 border-gray-300 text-green-600" /> <div v-if="publicGroups.length > 0">
<div class="flex-1"><p class="font-medium text-gray-900">{{ t('admin.users.allowAllGroups') }}</p><p class="text-sm text-gray-500">{{ t('admin.users.allowAllGroupsHint') }}</p></div> <div class="mb-3 flex items-center gap-2">
</label> <div class="h-1.5 w-1.5 rounded-full bg-green-500"></div>
<h4 class="text-sm font-semibold text-gray-700 dark:text-gray-300">{{ t('admin.users.publicGroups') }}</h4>
<span class="text-xs text-gray-400">({{ publicGroupConfigs.length }})</span>
</div>
<div class="grid gap-3">
<div
v-for="config in publicGroupConfigs"
:key="config.groupId"
class="relative overflow-hidden rounded-xl border-2 border-green-200 bg-green-50/50 p-4 dark:border-green-800/50 dark:bg-green-900/10"
>
<div class="flex items-center gap-4">
<!-- 复选框禁用状态 -->
<div class="flex-shrink-0">
<div class="flex h-5 w-5 items-center justify-center rounded-md border-2 border-green-400 bg-green-500 dark:border-green-600 dark:bg-green-600">
<svg class="h-full w-full text-white" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="3">
<path stroke-linecap="round" stroke-linejoin="round" d="M5 13l4 4L19 7" />
</svg>
</div>
</div>
<!-- 分组信息 -->
<div class="min-w-0 flex-1">
<div class="flex items-center gap-2">
<span class="text-base font-semibold text-gray-900 dark:text-white">{{ config.groupName }}</span>
</div>
<div class="mt-1.5 flex items-center gap-3 text-sm">
<span class="inline-flex items-center gap-1 text-gray-500 dark:text-gray-400">
<PlatformIcon :platform="config.platform" size="xs" />
<span>{{ config.platform }}</span>
</span>
<span class="text-gray-300 dark:text-dark-500"></span>
<span class="text-gray-500 dark:text-gray-400">
{{ t('admin.users.defaultRate') }}: <span class="font-medium text-gray-700 dark:text-gray-300">{{ config.defaultRate }}x</span>
</span>
</div>
</div>
<!-- 专属倍率输入 -->
<div class="flex flex-shrink-0 items-center gap-3">
<label class="text-sm font-medium text-gray-600 dark:text-gray-400">{{ t('admin.users.customRate') }}</label>
<input
type="number"
step="0.001"
min="0"
:value="config.customRate ?? ''"
@input="updateCustomRate(config.groupId, ($event.target as HTMLInputElement).value)"
:placeholder="String(config.defaultRate)"
class="hide-spinner w-24 rounded-lg border border-gray-300 bg-white px-3 py-2 text-sm font-medium transition-colors focus:border-primary-500 focus:outline-none focus:ring-2 focus:ring-primary-500/20 dark:border-dark-500 dark:bg-dark-700 dark:focus:border-primary-500"
/>
</div>
</div>
</div>
</div>
</div>
<!-- 无分组提示 -->
<div v-if="groups.length === 0" class="flex flex-col items-center justify-center py-12 text-center">
<div class="mb-4 flex h-16 w-16 items-center justify-center rounded-full bg-gray-100 dark:bg-dark-700">
<svg class="h-8 w-8 text-gray-400" fill="none" viewBox="0 0 24 24" stroke="currentColor">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M19 11H5m14 0a2 2 0 012 2v6a2 2 0 01-2 2H5a2 2 0 01-2-2v-6a2 2 0 012-2m14 0V9a2 2 0 00-2-2M5 11V9a2 2 0 012-2m0 0V5a2 2 0 012-2h6a2 2 0 012 2v2M7 7h10" />
</svg>
</div>
<p class="text-gray-500 dark:text-gray-400">{{ t('common.noGroupsAvailable') }}</p>
</div> </div>
</div> </div>
</div> </div>
<template #footer> <template #footer>
<div class="flex justify-end gap-3"> <div class="flex justify-end gap-3">
<button @click="$emit('close')" class="btn btn-secondary">{{ t('common.cancel') }}</button> <button @click="$emit('close')" class="btn btn-secondary px-5">{{ t('common.cancel') }}</button>
<button @click="handleSave" :disabled="submitting" class="btn btn-primary">{{ submitting ? t('common.saving') : t('common.save') }}</button> <button @click="handleSave" :disabled="submitting" class="btn btn-primary px-6">
<svg v-if="submitting" class="-ml-1 mr-2 h-4 w-4 animate-spin" fill="none" viewBox="0 0 24 24">
<circle class="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" stroke-width="4"></circle>
<path class="opacity-75" fill="currentColor" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"></path>
</svg>
{{ submitting ? t('common.saving') : t('common.save') }}
</button>
</div> </div>
</template> </template>
</BaseDialog> </BaseDialog>
</template> </template>
<script setup lang="ts"> <script setup lang="ts">
import { ref, watch } from 'vue' import { ref, watch, computed } from 'vue'
import { useI18n } from 'vue-i18n' import { useI18n } from 'vue-i18n'
import { useAppStore } from '@/stores/app' import { useAppStore } from '@/stores/app'
import { adminAPI } from '@/api/admin' import { adminAPI } from '@/api/admin'
import type { AdminUser, Group } from '@/types' import type { AdminUser, Group, GroupPlatform } from '@/types'
import BaseDialog from '@/components/common/BaseDialog.vue' import BaseDialog from '@/components/common/BaseDialog.vue'
import PlatformIcon from '@/components/common/PlatformIcon.vue'
const props = defineProps<{ show: boolean, user: AdminUser | null }>() interface GroupRateConfig {
const emit = defineEmits(['close', 'success']); const { t } = useI18n(); const appStore = useAppStore() groupId: number
groupName: string
platform: GroupPlatform
isExclusive: boolean
defaultRate: number
customRate: number | null
isSelected: boolean
}
const groups = ref<Group[]>([]); const selectedIds = ref<number[]>([]); const loading = ref(false); const submitting = ref(false) const props = defineProps<{ show: boolean; user: AdminUser | null }>()
const emit = defineEmits(['close', 'success'])
const { t } = useI18n()
const appStore = useAppStore()
watch(() => props.show, (v) => { if(v && props.user) { selectedIds.value = props.user.allowed_groups || []; load() } }) const groups = ref<Group[]>([])
const load = async () => { loading.value = true; try { const res = await adminAPI.groups.list(1, 1000); groups.value = res.items.filter(g => g.subscription_type === 'standard' && g.status === 'active') } catch (error) { console.error('Failed to load groups:', error) } finally { loading.value = false } } const groupConfigs = ref<GroupRateConfig[]>([])
const handleSave = async () => { const originalGroupRates = ref<Record<number, number>>({}) // 记录原始专属倍率,用于检测删除
if (!props.user) return; submitting.value = true const loading = ref(false)
const submitting = ref(false)
// 分离专属分组和公开分组
const exclusiveGroups = computed(() => groups.value.filter((g) => g.is_exclusive))
const publicGroups = computed(() => groups.value.filter((g) => !g.is_exclusive))
const exclusiveGroupConfigs = computed(() => groupConfigs.value.filter((c) => c.isExclusive))
const publicGroupConfigs = computed(() => groupConfigs.value.filter((c) => !c.isExclusive))
watch(
() => props.show,
(v) => {
if (v && props.user) {
load()
}
}
)
const load = async () => {
loading.value = true
try { try {
await adminAPI.users.update(props.user.id, { allowed_groups: selectedIds.value }) const res = await adminAPI.groups.list(1, 1000)
appStore.showSuccess(t('admin.users.allowedGroupsUpdated')); emit('success'); emit('close') // 只显示标准类型且活跃的分组
} catch (error) { console.error('Failed to update allowed groups:', error) } finally { submitting.value = false } groups.value = res.items.filter((g) => g.subscription_type === 'standard' && g.status === 'active')
// 初始化配置
const userAllowedGroups = props.user?.allowed_groups || []
const userGroupRates = props.user?.group_rates || {}
// 保存原始专属倍率,用于检测删除操作
originalGroupRates.value = { ...userGroupRates }
groupConfigs.value = groups.value.map((g) => ({
groupId: g.id,
groupName: g.name,
platform: g.platform,
isExclusive: g.is_exclusive,
defaultRate: g.rate_multiplier,
customRate: userGroupRates[g.id] ?? null,
// 专属分组:检查是否在 allowed_groups 中
// 公开分组:始终选中
isSelected: g.is_exclusive ? userAllowedGroups.includes(g.id) : true,
}))
} catch (error) {
console.error('Failed to load groups:', error)
} finally {
loading.value = false
}
}
const toggleExclusiveGroup = (groupId: number) => {
const config = groupConfigs.value.find((c) => c.groupId === groupId)
if (config && config.isExclusive) {
config.isSelected = !config.isSelected
}
}
const updateCustomRate = (groupId: number, value: string) => {
const config = groupConfigs.value.find((c) => c.groupId === groupId)
if (config) {
if (value === '' || value === null || value === undefined) {
config.customRate = null
} else {
const numValue = parseFloat(value)
config.customRate = isNaN(numValue) ? null : numValue
}
}
}
const handleSave = async () => {
if (!props.user) return
submitting.value = true
try {
// 构建 allowed_groups仅包含专属分组中被勾选的
const allowedGroups = groupConfigs.value.filter((c) => c.isExclusive && c.isSelected).map((c) => c.groupId)
// 构建 group_rates
// - 有新专属倍率: 设置为该值
// - 原本有专属倍率但现在被清空: 设置为 null表示删除
const groupRates: Record<number, number | null> = {}
for (const c of groupConfigs.value) {
const hadOriginalRate = originalGroupRates.value[c.groupId] !== undefined
if (c.customRate !== null) {
// 有专属倍率
groupRates[c.groupId] = c.customRate
} else if (hadOriginalRate) {
// 原本有专属倍率,现在被清空,需要显式删除
groupRates[c.groupId] = null
}
}
await adminAPI.users.update(props.user.id, {
allowed_groups: allowedGroups,
group_rates: Object.keys(groupRates).length > 0 ? groupRates : undefined,
})
appStore.showSuccess(t('admin.users.groupConfigUpdated'))
emit('success')
emit('close')
} catch (error) {
console.error('Failed to update user group config:', error)
} finally {
submitting.value = false
}
} }
</script> </script>
<style scoped>
/* 隐藏数字输入框的箭头按钮 */
.hide-spinner::-webkit-outer-spin-button,
.hide-spinner::-webkit-inner-spin-button {
-webkit-appearance: none;
margin: 0;
}
.hide-spinner {
-moz-appearance: textfield;
}
</style>

View File

@@ -11,7 +11,14 @@
<span class="truncate">{{ name }}</span> <span class="truncate">{{ name }}</span>
<!-- Right side label --> <!-- Right side label -->
<span v-if="showLabel" :class="labelClass"> <span v-if="showLabel" :class="labelClass">
{{ labelText }} <template v-if="hasCustomRate">
<!-- 原倍率删除线 + 专属倍率高亮 -->
<span class="line-through opacity-50 mr-0.5">{{ rateMultiplier }}x</span>
<span class="font-bold">{{ userRateMultiplier }}x</span>
</template>
<template v-else>
{{ labelText }}
</template>
</span> </span>
</span> </span>
</template> </template>
@@ -27,6 +34,7 @@ interface Props {
platform?: GroupPlatform platform?: GroupPlatform
subscriptionType?: SubscriptionType subscriptionType?: SubscriptionType
rateMultiplier?: number rateMultiplier?: number
userRateMultiplier?: number | null // 用户专属倍率
showRate?: boolean showRate?: boolean
daysRemaining?: number | null // 剩余天数(订阅类型时使用) daysRemaining?: number | null // 剩余天数(订阅类型时使用)
} }
@@ -34,20 +42,31 @@ interface Props {
const props = withDefaults(defineProps<Props>(), { const props = withDefaults(defineProps<Props>(), {
subscriptionType: 'standard', subscriptionType: 'standard',
showRate: true, showRate: true,
daysRemaining: null daysRemaining: null,
userRateMultiplier: null
}) })
const { t } = useI18n() const { t } = useI18n()
const isSubscription = computed(() => props.subscriptionType === 'subscription') const isSubscription = computed(() => props.subscriptionType === 'subscription')
// 是否有专属倍率(且与默认倍率不同)
const hasCustomRate = computed(() => {
return (
props.userRateMultiplier !== null &&
props.userRateMultiplier !== undefined &&
props.rateMultiplier !== undefined &&
props.userRateMultiplier !== props.rateMultiplier
)
})
// 是否显示右侧标签 // 是否显示右侧标签
const showLabel = computed(() => { const showLabel = computed(() => {
if (!props.showRate) return false if (!props.showRate) return false
// 订阅类型:显示天数或"订阅" // 订阅类型:显示天数或"订阅"
if (isSubscription.value) return true if (isSubscription.value) return true
// 标准类型:显示倍率 // 标准类型:显示倍率(包括专属倍率)
return props.rateMultiplier !== undefined return props.rateMultiplier !== undefined || hasCustomRate.value
}) })
// Label text // Label text
@@ -71,7 +90,7 @@ const labelClass = computed(() => {
const base = 'px-1.5 py-0.5 rounded text-[10px] font-semibold' const base = 'px-1.5 py-0.5 rounded text-[10px] font-semibold'
if (!isSubscription.value) { if (!isSubscription.value) {
// Standard: subtle background // Standard: subtle background (不再为专属倍率使用不同的背景色)
return `${base} bg-black/10 dark:bg-white/10` return `${base} bg-black/10 dark:bg-white/10`
} }

View File

@@ -9,6 +9,7 @@
:platform="platform" :platform="platform"
:subscription-type="subscriptionType" :subscription-type="subscriptionType"
:rate-multiplier="rateMultiplier" :rate-multiplier="rateMultiplier"
:user-rate-multiplier="userRateMultiplier"
/> />
<span <span
v-if="description" v-if="description"
@@ -39,6 +40,7 @@ interface Props {
platform: GroupPlatform platform: GroupPlatform
subscriptionType?: SubscriptionType subscriptionType?: SubscriptionType
rateMultiplier?: number rateMultiplier?: number
userRateMultiplier?: number | null
description?: string | null description?: string | null
selected?: boolean selected?: boolean
showCheckmark?: boolean showCheckmark?: boolean
@@ -47,6 +49,7 @@ interface Props {
withDefaults(defineProps<Props>(), { withDefaults(defineProps<Props>(), {
subscriptionType: 'standard', subscriptionType: 'standard',
selected: false, selected: false,
showCheckmark: true showCheckmark: true,
userRateMultiplier: null
}) })
</script> </script>

View File

@@ -283,7 +283,12 @@ function closeDropdown() {
async function handleLogout() { async function handleLogout() {
closeDropdown() closeDropdown()
authStore.logout() try {
await authStore.logout()
} catch (error) {
// Ignore logout errors - still redirect to login
console.error('Logout error:', error)
}
await router.push('/login') await router.push('/login')
} }

View File

@@ -849,6 +849,16 @@ export default {
allowedGroupsUpdated: 'Allowed groups updated successfully', allowedGroupsUpdated: 'Allowed groups updated successfully',
failedToLoadGroups: 'Failed to load groups', failedToLoadGroups: 'Failed to load groups',
failedToUpdateAllowedGroups: 'Failed to update allowed groups', failedToUpdateAllowedGroups: 'Failed to update allowed groups',
// User Group Configuration
groupConfig: 'User Group Configuration',
groupConfigHint: 'Configure custom rate multipliers for user {email} (overrides group defaults)',
exclusiveGroups: 'Exclusive Groups',
publicGroups: 'Public Groups (Default Available)',
defaultRate: 'Default Rate',
customRate: 'Custom Rate',
useDefaultRate: 'Use Default',
customRatePlaceholder: 'Leave empty for default',
groupConfigUpdated: 'Group configuration updated successfully',
deposit: 'Deposit', deposit: 'Deposit',
withdraw: 'Withdraw', withdraw: 'Withdraw',
depositAmount: 'Deposit Amount', depositAmount: 'Deposit Amount',
@@ -3198,6 +3208,80 @@ export default {
failedToSave: 'Failed to save settings', failedToSave: 'Failed to save settings',
failedToTestSmtp: 'SMTP connection test failed', failedToTestSmtp: 'SMTP connection test failed',
failedToSendTestEmail: 'Failed to send test email' failedToSendTestEmail: 'Failed to send test email'
},
// Error Passthrough Rules
errorPassthrough: {
title: 'Error Passthrough Rules',
description: 'Configure how upstream errors are returned to clients',
createRule: 'Create Rule',
editRule: 'Edit Rule',
deleteRule: 'Delete Rule',
noRules: 'No rules configured',
createFirstRule: 'Create your first error passthrough rule',
allPlatforms: 'All Platforms',
passthrough: 'Passthrough',
custom: 'Custom',
code: 'Code',
body: 'Body',
// Columns
columns: {
priority: 'Priority',
name: 'Name',
conditions: 'Conditions',
platforms: 'Platforms',
behavior: 'Behavior',
status: 'Status',
actions: 'Actions'
},
// Match Mode
matchMode: {
any: 'Code OR Keyword',
all: 'Code AND Keyword',
anyHint: 'Status code matches any error code, OR message contains any keyword',
allHint: 'Status code matches any error code, AND message contains any keyword'
},
// Form
form: {
name: 'Rule Name',
namePlaceholder: 'e.g., Context Limit Passthrough',
priority: 'Priority',
priorityHint: 'Lower values have higher priority',
description: 'Description',
descriptionPlaceholder: 'Describe the purpose of this rule...',
matchConditions: 'Match Conditions',
errorCodes: 'Error Codes',
errorCodesPlaceholder: '422, 400, 429',
errorCodesHint: 'Separate multiple codes with commas',
keywords: 'Keywords',
keywordsPlaceholder: 'One keyword per line\ncontext limit\nmodel not supported',
keywordsHint: 'One keyword per line, case-insensitive',
matchMode: 'Match Mode',
platforms: 'Platforms',
platformsHint: 'Leave empty to apply to all platforms',
responseBehavior: 'Response Behavior',
passthroughCode: 'Passthrough upstream status code',
responseCode: 'Custom status code',
passthroughBody: 'Passthrough upstream error message',
customMessage: 'Custom error message',
customMessagePlaceholder: 'Error message to return to client...',
enabled: 'Enable this rule'
},
// Messages
nameRequired: 'Please enter rule name',
conditionsRequired: 'Please configure at least one error code or keyword',
ruleCreated: 'Rule created successfully',
ruleUpdated: 'Rule updated successfully',
ruleDeleted: 'Rule deleted successfully',
deleteConfirm: 'Are you sure you want to delete rule "{name}"?',
failedToLoad: 'Failed to load rules',
failedToSave: 'Failed to save rule',
failedToDelete: 'Failed to delete rule',
failedToToggle: 'Failed to toggle status'
} }
}, },

View File

@@ -910,6 +910,16 @@ export default {
allowedGroupsUpdated: '允许分组更新成功', allowedGroupsUpdated: '允许分组更新成功',
failedToLoadGroups: '加载分组列表失败', failedToLoadGroups: '加载分组列表失败',
failedToUpdateAllowedGroups: '更新允许分组失败', failedToUpdateAllowedGroups: '更新允许分组失败',
// 用户分组配置
groupConfig: '用户分组配置',
groupConfigHint: '为用户 {email} 配置专属分组倍率(覆盖分组默认倍率)',
exclusiveGroups: '专属分组',
publicGroups: '公开分组(默认可用)',
defaultRate: '默认倍率',
customRate: '专属倍率',
useDefaultRate: '使用默认',
customRatePlaceholder: '留空使用默认',
groupConfigUpdated: '分组配置更新成功',
deposit: '充值', deposit: '充值',
withdraw: '退款', withdraw: '退款',
depositAmount: '充值金额', depositAmount: '充值金额',
@@ -3369,6 +3379,80 @@ export default {
failedToSave: '保存设置失败', failedToSave: '保存设置失败',
failedToTestSmtp: 'SMTP 连接测试失败', failedToTestSmtp: 'SMTP 连接测试失败',
failedToSendTestEmail: '发送测试邮件失败' failedToSendTestEmail: '发送测试邮件失败'
},
// Error Passthrough Rules
errorPassthrough: {
title: '错误透传规则',
description: '配置上游错误如何返回给客户端',
createRule: '创建规则',
editRule: '编辑规则',
deleteRule: '删除规则',
noRules: '暂无规则',
createFirstRule: '创建第一条错误透传规则',
allPlatforms: '所有平台',
passthrough: '透传',
custom: '自定义',
code: '状态码',
body: '消息体',
// Columns
columns: {
priority: '优先级',
name: '名称',
conditions: '匹配条件',
platforms: '平台',
behavior: '响应行为',
status: '状态',
actions: '操作'
},
// Match Mode
matchMode: {
any: '错误码 或 关键词',
all: '错误码 且 关键词',
anyHint: '状态码匹配任一错误码,或消息包含任一关键词',
allHint: '状态码匹配任一错误码,且消息包含任一关键词'
},
// Form
form: {
name: '规则名称',
namePlaceholder: '例如:上下文超限透传',
priority: '优先级',
priorityHint: '数值越小优先级越高,优先匹配',
description: '规则描述',
descriptionPlaceholder: '描述此规则的用途...',
matchConditions: '匹配条件',
errorCodes: '错误码',
errorCodesPlaceholder: '422, 400, 429',
errorCodesHint: '多个错误码用逗号分隔',
keywords: '关键词',
keywordsPlaceholder: '每行一个关键词\ncontext limit\nmodel not supported',
keywordsHint: '每行一个关键词,不区分大小写',
matchMode: '匹配模式',
platforms: '适用平台',
platformsHint: '不选择表示适用于所有平台',
responseBehavior: '响应行为',
passthroughCode: '透传上游状态码',
responseCode: '自定义状态码',
passthroughBody: '透传上游错误信息',
customMessage: '自定义错误信息',
customMessagePlaceholder: '返回给客户端的错误信息...',
enabled: '启用此规则'
},
// Messages
nameRequired: '请输入规则名称',
conditionsRequired: '请至少配置一个错误码或关键词',
ruleCreated: '规则创建成功',
ruleUpdated: '规则更新成功',
ruleDeleted: '规则删除成功',
deleteConfirm: '确定要删除规则 "{name}" 吗?',
failedToLoad: '加载规则失败',
failedToSave: '保存规则失败',
failedToDelete: '删除规则失败',
failedToToggle: '切换状态失败'
} }
}, },

View File

@@ -1,6 +1,6 @@
/** /**
* Authentication Store * Authentication Store
* Manages user authentication state, login/logout, and token persistence * Manages user authentication state, login/logout, token refresh, and token persistence
*/ */
import { defineStore } from 'pinia' import { defineStore } from 'pinia'
@@ -10,15 +10,21 @@ import type { User, LoginRequest, RegisterRequest, AuthResponse } from '@/types'
const AUTH_TOKEN_KEY = 'auth_token' const AUTH_TOKEN_KEY = 'auth_token'
const AUTH_USER_KEY = 'auth_user' const AUTH_USER_KEY = 'auth_user'
const AUTO_REFRESH_INTERVAL = 60 * 1000 // 60 seconds const REFRESH_TOKEN_KEY = 'refresh_token'
const TOKEN_EXPIRES_AT_KEY = 'token_expires_at' // 存储过期时间戳而非有效期
const AUTO_REFRESH_INTERVAL = 60 * 1000 // 60 seconds for user data refresh
const TOKEN_REFRESH_BUFFER = 120 * 1000 // 120 seconds before expiry to refresh token
export const useAuthStore = defineStore('auth', () => { export const useAuthStore = defineStore('auth', () => {
// ==================== State ==================== // ==================== State ====================
const user = ref<User | null>(null) const user = ref<User | null>(null)
const token = ref<string | null>(null) const token = ref<string | null>(null)
const refreshTokenValue = ref<string | null>(null)
const tokenExpiresAt = ref<number | null>(null) // 过期时间戳(毫秒)
const runMode = ref<'standard' | 'simple'>('standard') const runMode = ref<'standard' | 'simple'>('standard')
let refreshIntervalId: ReturnType<typeof setInterval> | null = null let refreshIntervalId: ReturnType<typeof setInterval> | null = null
let tokenRefreshTimeoutId: ReturnType<typeof setTimeout> | null = null
// ==================== Computed ==================== // ==================== Computed ====================
@@ -42,19 +48,29 @@ export const useAuthStore = defineStore('auth', () => {
function checkAuth(): void { function checkAuth(): void {
const savedToken = localStorage.getItem(AUTH_TOKEN_KEY) const savedToken = localStorage.getItem(AUTH_TOKEN_KEY)
const savedUser = localStorage.getItem(AUTH_USER_KEY) const savedUser = localStorage.getItem(AUTH_USER_KEY)
const savedRefreshToken = localStorage.getItem(REFRESH_TOKEN_KEY)
const savedExpiresAt = localStorage.getItem(TOKEN_EXPIRES_AT_KEY)
if (savedToken && savedUser) { if (savedToken && savedUser) {
try { try {
token.value = savedToken token.value = savedToken
user.value = JSON.parse(savedUser) user.value = JSON.parse(savedUser)
refreshTokenValue.value = savedRefreshToken
tokenExpiresAt.value = savedExpiresAt ? parseInt(savedExpiresAt, 10) : null
// Immediately refresh user data from backend (async, don't block) // Immediately refresh user data from backend (async, don't block)
refreshUser().catch((error) => { refreshUser().catch((error) => {
console.error('Failed to refresh user on init:', error) console.error('Failed to refresh user on init:', error)
}) })
// Start auto-refresh interval // Start auto-refresh interval for user data
startAutoRefresh() startAutoRefresh()
// Start proactive token refresh if we have refresh token and expiry info
// Note: use !== null to handle case when tokenExpiresAt.value is 0 (expired)
if (savedRefreshToken && tokenExpiresAt.value !== null) {
scheduleTokenRefreshAt(tokenExpiresAt.value)
}
} catch (error) { } catch (error) {
console.error('Failed to parse saved user data:', error) console.error('Failed to parse saved user data:', error)
clearAuth() clearAuth()
@@ -89,6 +105,76 @@ export const useAuthStore = defineStore('auth', () => {
} }
} }
/**
* Schedule proactive token refresh before expiry (based on expiry timestamp)
* @param expiresAtMs - Token expiry timestamp in milliseconds
*/
function scheduleTokenRefreshAt(expiresAtMs: number): void {
// Clear any existing timeout
if (tokenRefreshTimeoutId) {
clearTimeout(tokenRefreshTimeoutId)
tokenRefreshTimeoutId = null
}
// Calculate remaining time until refresh (buffer time before expiry)
const now = Date.now()
const refreshInMs = Math.max(0, expiresAtMs - now - TOKEN_REFRESH_BUFFER)
if (refreshInMs <= 0) {
// Token is about to expire or already expired, refresh immediately
performTokenRefresh()
return
}
tokenRefreshTimeoutId = setTimeout(() => {
performTokenRefresh()
}, refreshInMs)
}
/**
* Schedule proactive token refresh before expiry (based on expires_in seconds)
* @param expiresInSeconds - Token expiry time in seconds from now
*/
function scheduleTokenRefresh(expiresInSeconds: number): void {
const expiresAtMs = Date.now() + expiresInSeconds * 1000
tokenExpiresAt.value = expiresAtMs
localStorage.setItem(TOKEN_EXPIRES_AT_KEY, String(expiresAtMs))
scheduleTokenRefreshAt(expiresAtMs)
}
/**
* Perform the actual token refresh
*/
async function performTokenRefresh(): Promise<void> {
if (!refreshTokenValue.value) {
return
}
try {
const response = await authAPI.refreshToken()
// Update state
token.value = response.access_token
refreshTokenValue.value = response.refresh_token
// Schedule next refresh (this also updates tokenExpiresAt and localStorage)
scheduleTokenRefresh(response.expires_in)
} catch (error) {
console.error('Token refresh failed:', error)
// Don't clear auth here - the interceptor will handle 401 errors
}
}
/**
* Stop token refresh timeout
*/
function stopTokenRefresh(): void {
if (tokenRefreshTimeoutId) {
clearTimeout(tokenRefreshTimeoutId)
tokenRefreshTimeoutId = null
}
}
/** /**
* User login * User login
* @param credentials - Login credentials (email and password) * @param credentials - Login credentials (email and password)
@@ -141,6 +227,12 @@ export const useAuthStore = defineStore('auth', () => {
// Store token and user // Store token and user
token.value = response.access_token token.value = response.access_token
// Store refresh token if present
if (response.refresh_token) {
refreshTokenValue.value = response.refresh_token
localStorage.setItem(REFRESH_TOKEN_KEY, response.refresh_token)
}
// Extract run_mode if present // Extract run_mode if present
if (response.user.run_mode) { if (response.user.run_mode) {
runMode.value = response.user.run_mode runMode.value = response.user.run_mode
@@ -152,8 +244,14 @@ export const useAuthStore = defineStore('auth', () => {
localStorage.setItem(AUTH_TOKEN_KEY, response.access_token) localStorage.setItem(AUTH_TOKEN_KEY, response.access_token)
localStorage.setItem(AUTH_USER_KEY, JSON.stringify(userData)) localStorage.setItem(AUTH_USER_KEY, JSON.stringify(userData))
// Start auto-refresh interval // Start auto-refresh interval for user data
startAutoRefresh() startAutoRefresh()
// Start proactive token refresh if we have refresh token and expiry info
// scheduleTokenRefresh will also store the expiry timestamp
if (response.refresh_token && response.expires_in) {
scheduleTokenRefresh(response.expires_in)
}
} }
/** /**
@@ -166,24 +264,10 @@ export const useAuthStore = defineStore('auth', () => {
try { try {
const response = await authAPI.register(userData) const response = await authAPI.register(userData)
// Store token and user // Use the common helper to set auth state
token.value = response.access_token setAuthFromResponse(response)
// Extract run_mode if present return user.value!
if (response.user.run_mode) {
runMode.value = response.user.run_mode
}
const { run_mode: _run_mode, ...userDataWithoutRunMode } = response.user
user.value = userDataWithoutRunMode
// Persist to localStorage
localStorage.setItem(AUTH_TOKEN_KEY, response.access_token)
localStorage.setItem(AUTH_USER_KEY, JSON.stringify(userDataWithoutRunMode))
// Start auto-refresh interval
startAutoRefresh()
return userDataWithoutRunMode
} catch (error) { } catch (error) {
// Clear any partial state on error // Clear any partial state on error
clearAuth() clearAuth()
@@ -193,18 +277,41 @@ export const useAuthStore = defineStore('auth', () => {
/** /**
* 直接设置 token用于 OAuth/SSO 回调),并加载当前用户信息。 * 直接设置 token用于 OAuth/SSO 回调),并加载当前用户信息。
* 会自动读取 localStorage 中已设置的 refresh_token 和 token_expires_in
* @param newToken - 后端签发的 JWT access token * @param newToken - 后端签发的 JWT access token
*/ */
async function setToken(newToken: string): Promise<User> { async function setToken(newToken: string): Promise<User> {
// Clear any previous state first (avoid mixing sessions) // Clear any previous state first (avoid mixing sessions)
clearAuth() // Note: Don't clear localStorage here as OAuth callback may have set refresh_token
stopAutoRefresh()
stopTokenRefresh()
token.value = null
user.value = null
token.value = newToken token.value = newToken
localStorage.setItem(AUTH_TOKEN_KEY, newToken) localStorage.setItem(AUTH_TOKEN_KEY, newToken)
// Read refresh token and expires_at from localStorage if set by OAuth callback
const savedRefreshToken = localStorage.getItem(REFRESH_TOKEN_KEY)
const savedExpiresAt = localStorage.getItem(TOKEN_EXPIRES_AT_KEY)
if (savedRefreshToken) {
refreshTokenValue.value = savedRefreshToken
}
if (savedExpiresAt) {
tokenExpiresAt.value = parseInt(savedExpiresAt, 10)
}
try { try {
const userData = await refreshUser() const userData = await refreshUser()
startAutoRefresh() startAutoRefresh()
// Start proactive token refresh if we have refresh token and expiry info
// Note: use !== null to handle case when tokenExpiresAt.value is 0 (expired)
if (savedRefreshToken && tokenExpiresAt.value !== null) {
scheduleTokenRefreshAt(tokenExpiresAt.value)
}
return userData return userData
} catch (error) { } catch (error) {
clearAuth() clearAuth()
@@ -216,9 +323,9 @@ export const useAuthStore = defineStore('auth', () => {
* User logout * User logout
* Clears all authentication state and persisted data * Clears all authentication state and persisted data
*/ */
function logout(): void { async function logout(): Promise<void> {
// Call API logout (client-side cleanup) // Call API logout (revokes refresh token on server)
authAPI.logout() await authAPI.logout()
// Clear state // Clear state
clearAuth() clearAuth()
@@ -263,11 +370,17 @@ export const useAuthStore = defineStore('auth', () => {
function clearAuth(): void { function clearAuth(): void {
// Stop auto-refresh // Stop auto-refresh
stopAutoRefresh() stopAutoRefresh()
// Stop token refresh
stopTokenRefresh()
token.value = null token.value = null
refreshTokenValue.value = null
tokenExpiresAt.value = null
user.value = null user.value = null
localStorage.removeItem(AUTH_TOKEN_KEY) localStorage.removeItem(AUTH_TOKEN_KEY)
localStorage.removeItem(AUTH_USER_KEY) localStorage.removeItem(AUTH_USER_KEY)
localStorage.removeItem(REFRESH_TOKEN_KEY)
localStorage.removeItem(TOKEN_EXPIRES_AT_KEY)
} }
// ==================== Return Store API ==================== // ==================== Return Store API ====================

View File

@@ -41,6 +41,8 @@ export interface User {
export interface AdminUser extends User { export interface AdminUser extends User {
// 管理员备注(普通用户接口不返回) // 管理员备注(普通用户接口不返回)
notes: string notes: string
// 用户专属分组倍率配置 (group_id -> rate_multiplier)
group_rates?: Record<number, number>
} }
export interface LoginRequest { export interface LoginRequest {
@@ -92,6 +94,8 @@ export interface PublicSettings {
export interface AuthResponse { export interface AuthResponse {
access_token: string access_token: string
refresh_token?: string // New: Refresh Token for token renewal
expires_in?: number // New: Access Token expiry time in seconds
token_type: string token_type: string
user: User & { run_mode?: 'standard' | 'simple' } user: User & { run_mode?: 'standard' | 'simple' }
} }
@@ -977,6 +981,9 @@ export interface UpdateUserRequest {
concurrency?: number concurrency?: number
status?: 'active' | 'disabled' status?: 'active' | 'disabled'
allowed_groups?: number[] | null allowed_groups?: number[] | null
// 用户专属分组倍率配置 (group_id -> rate_multiplier | null)
// null 表示删除该分组的专属倍率
group_rates?: Record<number, number | null>
} }
export interface ChangePasswordRequest { export interface ChangePasswordRequest {

View File

@@ -16,6 +16,16 @@
@sync="showSync = true" @sync="showSync = true"
@create="showCreate = true" @create="showCreate = true"
> >
<template #before>
<button
@click="showErrorPassthrough = true"
class="btn btn-secondary"
:title="t('admin.errorPassthrough.title')"
>
<Icon name="shield" size="md" class="mr-1.5" />
<span class="hidden md:inline">{{ t('admin.errorPassthrough.title') }}</span>
</button>
</template>
<template #after> <template #after>
<!-- Auto Refresh Dropdown --> <!-- Auto Refresh Dropdown -->
<div class="relative" ref="autoRefreshDropdownRef"> <div class="relative" ref="autoRefreshDropdownRef">
@@ -221,6 +231,7 @@
<BulkEditAccountModal :show="showBulkEdit" :account-ids="selIds" :proxies="proxies" :groups="groups" @close="showBulkEdit = false" @updated="handleBulkUpdated" /> <BulkEditAccountModal :show="showBulkEdit" :account-ids="selIds" :proxies="proxies" :groups="groups" @close="showBulkEdit = false" @updated="handleBulkUpdated" />
<TempUnschedStatusModal :show="showTempUnsched" :account="tempUnschedAcc" @close="showTempUnsched = false" @reset="handleTempUnschedReset" /> <TempUnschedStatusModal :show="showTempUnsched" :account="tempUnschedAcc" @close="showTempUnsched = false" @reset="handleTempUnschedReset" />
<ConfirmDialog :show="showDeleteDialog" :title="t('admin.accounts.deleteAccount')" :message="t('admin.accounts.deleteConfirm', { name: deletingAcc?.name })" :confirm-text="t('common.delete')" :cancel-text="t('common.cancel')" :danger="true" @confirm="confirmDelete" @cancel="showDeleteDialog = false" /> <ConfirmDialog :show="showDeleteDialog" :title="t('admin.accounts.deleteAccount')" :message="t('admin.accounts.deleteConfirm', { name: deletingAcc?.name })" :confirm-text="t('common.delete')" :cancel-text="t('common.cancel')" :danger="true" @confirm="confirmDelete" @cancel="showDeleteDialog = false" />
<ErrorPassthroughRulesModal :show="showErrorPassthrough" @close="showErrorPassthrough = false" />
</AppLayout> </AppLayout>
</template> </template>
@@ -252,6 +263,7 @@ import AccountGroupsCell from '@/components/account/AccountGroupsCell.vue'
import AccountCapacityCell from '@/components/account/AccountCapacityCell.vue' import AccountCapacityCell from '@/components/account/AccountCapacityCell.vue'
import PlatformTypeBadge from '@/components/common/PlatformTypeBadge.vue' import PlatformTypeBadge from '@/components/common/PlatformTypeBadge.vue'
import Icon from '@/components/icons/Icon.vue' import Icon from '@/components/icons/Icon.vue'
import ErrorPassthroughRulesModal from '@/components/admin/ErrorPassthroughRulesModal.vue'
import { formatDateTime, formatRelativeTime } from '@/utils/format' import { formatDateTime, formatRelativeTime } from '@/utils/format'
import type { Account, Proxy, AdminGroup } from '@/types' import type { Account, Proxy, AdminGroup } from '@/types'
@@ -271,6 +283,7 @@ const showDeleteDialog = ref(false)
const showReAuth = ref(false) const showReAuth = ref(false)
const showTest = ref(false) const showTest = ref(false)
const showStats = ref(false) const showStats = ref(false)
const showErrorPassthrough = ref(false)
const edAcc = ref<Account | null>(null) const edAcc = ref<Account | null>(null)
const tempUnschedAcc = ref<Account | null>(null) const tempUnschedAcc = ref<Account | null>(null)
const deletingAcc = ref<Account | null>(null) const deletingAcc = ref<Account | null>(null)

View File

@@ -71,6 +71,8 @@ onMounted(async () => {
const params = parseFragmentParams() const params = parseFragmentParams()
const token = params.get('access_token') || '' const token = params.get('access_token') || ''
const refreshToken = params.get('refresh_token') || ''
const expiresInStr = params.get('expires_in') || ''
const redirect = sanitizeRedirectPath( const redirect = sanitizeRedirectPath(
params.get('redirect') || (route.query.redirect as string | undefined) || '/dashboard' params.get('redirect') || (route.query.redirect as string | undefined) || '/dashboard'
) )
@@ -92,6 +94,17 @@ onMounted(async () => {
} }
try { try {
// Store refresh token and expires_at (convert to timestamp) if provided
if (refreshToken) {
localStorage.setItem('refresh_token', refreshToken)
}
if (expiresInStr) {
const expiresIn = parseInt(expiresInStr, 10)
if (!isNaN(expiresIn)) {
localStorage.setItem('token_expires_at', String(Date.now() + expiresIn * 1000))
}
}
await authStore.setToken(token) await authStore.setToken(token)
appStore.showSuccess(t('auth.loginSuccess')) appStore.showSuccess(t('auth.loginSuccess'))
await router.replace(redirect) await router.replace(redirect)

View File

@@ -73,6 +73,7 @@
:platform="row.group.platform" :platform="row.group.platform"
:subscription-type="row.group.subscription_type" :subscription-type="row.group.subscription_type"
:rate-multiplier="row.group.rate_multiplier" :rate-multiplier="row.group.rate_multiplier"
:user-rate-multiplier="userGroupRates[row.group.id]"
/> />
<span v-else class="text-sm text-gray-400 dark:text-dark-500">{{ <span v-else class="text-sm text-gray-400 dark:text-dark-500">{{
t('keys.noGroup') t('keys.noGroup')
@@ -272,6 +273,7 @@
:platform="(option as unknown as GroupOption).platform" :platform="(option as unknown as GroupOption).platform"
:subscription-type="(option as unknown as GroupOption).subscriptionType" :subscription-type="(option as unknown as GroupOption).subscriptionType"
:rate-multiplier="(option as unknown as GroupOption).rate" :rate-multiplier="(option as unknown as GroupOption).rate"
:user-rate-multiplier="(option as unknown as GroupOption).userRate"
/> />
<span v-else class="text-gray-400">{{ t('keys.selectGroup') }}</span> <span v-else class="text-gray-400">{{ t('keys.selectGroup') }}</span>
</template> </template>
@@ -281,6 +283,7 @@
:platform="(option as unknown as GroupOption).platform" :platform="(option as unknown as GroupOption).platform"
:subscription-type="(option as unknown as GroupOption).subscriptionType" :subscription-type="(option as unknown as GroupOption).subscriptionType"
:rate-multiplier="(option as unknown as GroupOption).rate" :rate-multiplier="(option as unknown as GroupOption).rate"
:user-rate-multiplier="(option as unknown as GroupOption).userRate"
:description="(option as unknown as GroupOption).description" :description="(option as unknown as GroupOption).description"
:selected="selected" :selected="selected"
/> />
@@ -667,6 +670,7 @@
:platform="option.platform" :platform="option.platform"
:subscription-type="option.subscriptionType" :subscription-type="option.subscriptionType"
:rate-multiplier="option.rate" :rate-multiplier="option.rate"
:user-rate-multiplier="option.userRate"
:description="option.description" :description="option.description"
:selected=" :selected="
selectedKeyForGroup?.group_id === option.value || selectedKeyForGroup?.group_id === option.value ||
@@ -718,6 +722,7 @@ interface GroupOption {
label: string label: string
description: string | null description: string | null
rate: number rate: number
userRate: number | null
subscriptionType: SubscriptionType subscriptionType: SubscriptionType
platform: GroupPlatform platform: GroupPlatform
} }
@@ -742,6 +747,7 @@ const groups = ref<Group[]>([])
const loading = ref(false) const loading = ref(false)
const submitting = ref(false) const submitting = ref(false)
const usageStats = ref<Record<string, BatchApiKeyUsageStats>>({}) const usageStats = ref<Record<string, BatchApiKeyUsageStats>>({})
const userGroupRates = ref<Record<number, number>>({})
const pagination = ref({ const pagination = ref({
page: 1, page: 1,
@@ -825,6 +831,7 @@ const groupOptions = computed(() =>
label: group.name, label: group.name,
description: group.description, description: group.description,
rate: group.rate_multiplier, rate: group.rate_multiplier,
userRate: userGroupRates.value[group.id] ?? null,
subscriptionType: group.subscription_type, subscriptionType: group.subscription_type,
platform: group.platform platform: group.platform
})) }))
@@ -899,6 +906,14 @@ const loadGroups = async () => {
} }
} }
const loadUserGroupRates = async () => {
try {
userGroupRates.value = await userGroupsAPI.getUserGroupRates()
} catch (error) {
console.error('Failed to load user group rates:', error)
}
}
const loadPublicSettings = async () => { const loadPublicSettings = async () => {
try { try {
publicSettings.value = await authAPI.getPublicSettings() publicSettings.value = await authAPI.getPublicSettings()
@@ -1268,6 +1283,7 @@ const closeCcsClientSelect = () => {
onMounted(() => { onMounted(() => {
loadApiKeys() loadApiKeys()
loadGroups() loadGroups()
loadUserGroupRates()
loadPublicSettings() loadPublicSettings()
document.addEventListener('click', closeGroupSelector) document.addEventListener('click', closeGroupSelector)
}) })