mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-03 06:52:13 +08:00
Merge up/main
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -14,6 +14,9 @@ backend/server
|
||||
backend/sub2api
|
||||
backend/main
|
||||
|
||||
# Go 测试二进制
|
||||
*.test
|
||||
|
||||
# 测试覆盖率
|
||||
*.out
|
||||
coverage.html
|
||||
|
||||
368
Linux DO Connect.md
Normal file
368
Linux DO Connect.md
Normal file
@@ -0,0 +1,368 @@
|
||||
# Linux DO Connect
|
||||
|
||||
OAuth(Open Authorization)是一个开放的网络授权标准,目前最新版本为 OAuth 2.0。我们日常使用的第三方登录(如 Google 账号登录)就采用了该标准。OAuth 允许用户授权第三方应用访问存储在其他服务提供商(如 Google)上的信息,无需在不同平台上重复填写注册信息。用户授权后,平台可以直接访问用户的账户信息进行身份验证,而用户无需向第三方应用提供密码。
|
||||
|
||||
目前系统已实现完整的 OAuth2 授权码(code)方式鉴权,但界面等配套功能还在持续完善中。让我们一起打造一个更完善的共享方案。
|
||||
|
||||
## 基本介绍
|
||||
|
||||
这是一套标准的 OAuth2 鉴权系统,可以让开发者共享论坛的用户基本信息。
|
||||
|
||||
- 可获取字段:
|
||||
|
||||
| 参数 | 说明 |
|
||||
| ----------------- | ------------------------------- |
|
||||
| `id` | 用户唯一标识(不可变) |
|
||||
| `username` | 论坛用户名 |
|
||||
| `name` | 论坛用户昵称(可变) |
|
||||
| `avatar_template` | 用户头像模板URL(支持多种尺寸) |
|
||||
| `active` | 账号活跃状态 |
|
||||
| `trust_level` | 信任等级(0-4) |
|
||||
| `silenced` | 禁言状态 |
|
||||
| `external_ids` | 外部ID关联信息 |
|
||||
| `api_key` | API访问密钥 |
|
||||
|
||||
通过这些信息,公益网站/接口可以实现:
|
||||
|
||||
1. 基于 `id` 的服务频率限制
|
||||
2. 基于 `trust_level` 的服务额度分配
|
||||
3. 基于用户信息的滥用举报机制
|
||||
|
||||
## 相关端点
|
||||
|
||||
- Authorize 端点: `https://connect.linux.do/oauth2/authorize`
|
||||
- Token 端点:`https://connect.linux.do/oauth2/token`
|
||||
- 用户信息 端点:`https://connect.linux.do/api/user`
|
||||
|
||||
## 申请使用
|
||||
|
||||
- 访问 [Connect.Linux.Do](https://connect.linux.do/) 申请接入你的应用。
|
||||
|
||||

|
||||
|
||||
- 点击 **`我的应用接入`** - **`申请新接入`**,填写相关信息。其中 **`回调地址`** 是你的应用接收用户信息的地址。
|
||||
|
||||

|
||||
|
||||
- 申请成功后,你将获得 **`Client Id`** 和 **`Client Secret`**,这是你应用的唯一身份凭证。
|
||||
|
||||

|
||||
|
||||
## 接入 Linux Do
|
||||
|
||||
JavaScript
|
||||
```JavaScript
|
||||
// 安装第三方请求库(或使用原生的 Fetch API),本例中使用 axios
|
||||
// npm install axios
|
||||
|
||||
// 通过 OAuth2 获取 Linux Do 用户信息的参考流程
|
||||
const axios = require('axios');
|
||||
const readline = require('readline');
|
||||
|
||||
// 配置信息(建议通过环境变量配置,避免使用硬编码)
|
||||
const CLIENT_ID = '你的 Client ID';
|
||||
const CLIENT_SECRET = '你的 Client Secret';
|
||||
const REDIRECT_URI = '你的回调地址';
|
||||
const AUTH_URL = 'https://connect.linux.do/oauth2/authorize';
|
||||
const TOKEN_URL = 'https://connect.linux.do/oauth2/token';
|
||||
const USER_INFO_URL = 'https://connect.linux.do/api/user';
|
||||
|
||||
// 第一步:生成授权 URL
|
||||
function getAuthUrl() {
|
||||
const params = new URLSearchParams({
|
||||
client_id: CLIENT_ID,
|
||||
redirect_uri: REDIRECT_URI,
|
||||
response_type: 'code',
|
||||
scope: 'user'
|
||||
});
|
||||
|
||||
return `${AUTH_URL}?${params.toString()}`;
|
||||
}
|
||||
|
||||
// 第二步:获取 code 参数
|
||||
function getCode() {
|
||||
return new Promise((resolve) => {
|
||||
// 本例中使用终端输入来模拟流程,仅供本地测试
|
||||
// 请在实际应用中替换为真实的处理逻辑
|
||||
const rl = readline.createInterface({ input: process.stdin, output: process.stdout });
|
||||
rl.question('从回调 URL 中提取出 code,粘贴到此处并按回车:', (answer) => {
|
||||
rl.close();
|
||||
resolve(answer.trim());
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// 第三步:使用 code 参数获取访问令牌
|
||||
async function getAccessToken(code) {
|
||||
try {
|
||||
const form = new URLSearchParams({
|
||||
client_id: CLIENT_ID,
|
||||
client_secret: CLIENT_SECRET,
|
||||
code: code,
|
||||
redirect_uri: REDIRECT_URI,
|
||||
grant_type: 'authorization_code'
|
||||
}).toString();
|
||||
|
||||
const response = await axios.post(TOKEN_URL, form, {
|
||||
// 提醒:需正确配置请求头,否则无法正常获取访问令牌
|
||||
headers: {
|
||||
'Content-Type': 'application/x-www-form-urlencoded',
|
||||
'Accept': 'application/json'
|
||||
}
|
||||
});
|
||||
|
||||
return response.data;
|
||||
} catch (error) {
|
||||
console.error(`获取访问令牌失败:${error.response ? JSON.stringify(error.response.data) : error.message}`);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
// 第四步:使用访问令牌获取用户信息
|
||||
async function getUserInfo(accessToken) {
|
||||
try {
|
||||
const response = await axios.get(USER_INFO_URL, {
|
||||
headers: {
|
||||
Authorization: `Bearer ${accessToken}`
|
||||
}
|
||||
});
|
||||
|
||||
return response.data;
|
||||
} catch (error) {
|
||||
console.error(`获取用户信息失败:${error.response ? JSON.stringify(error.response.data) : error.message}`);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
// 主流程
|
||||
async function main() {
|
||||
// 1. 生成授权 URL,前端引导用户访问授权页
|
||||
const authUrl = getAuthUrl();
|
||||
console.log(`请访问此 URL 授权:${authUrl}
|
||||
`);
|
||||
|
||||
// 2. 用户授权后,从回调 URL 获取 code 参数
|
||||
const code = await getCode();
|
||||
|
||||
try {
|
||||
// 3. 使用 code 参数获取访问令牌
|
||||
const tokenData = await getAccessToken(code);
|
||||
const accessToken = tokenData.access_token;
|
||||
|
||||
// 4. 使用访问令牌获取用户信息
|
||||
if (accessToken) {
|
||||
const userInfo = await getUserInfo(accessToken);
|
||||
console.log(`
|
||||
获取用户信息成功:${JSON.stringify(userInfo, null, 2)}`);
|
||||
} else {
|
||||
console.log(`
|
||||
获取访问令牌失败:${JSON.stringify(tokenData)}`);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('发生错误:', error);
|
||||
}
|
||||
}
|
||||
```
|
||||
Python
|
||||
```python
|
||||
# 安装第三方请求库,本例中使用 requests
|
||||
# pip install requests
|
||||
|
||||
# 通过 OAuth2 获取 Linux Do 用户信息的参考流程
|
||||
import requests
|
||||
import json
|
||||
|
||||
# 配置信息(建议通过环境变量配置,避免使用硬编码)
|
||||
CLIENT_ID = '你的 Client ID'
|
||||
CLIENT_SECRET = '你的 Client Secret'
|
||||
REDIRECT_URI = '你的回调地址'
|
||||
AUTH_URL = 'https://connect.linux.do/oauth2/authorize'
|
||||
TOKEN_URL = 'https://connect.linux.do/oauth2/token'
|
||||
USER_INFO_URL = 'https://connect.linux.do/api/user'
|
||||
|
||||
# 第一步:生成授权 URL
|
||||
def get_auth_url():
|
||||
params = {
|
||||
'client_id': CLIENT_ID,
|
||||
'redirect_uri': REDIRECT_URI,
|
||||
'response_type': 'code',
|
||||
'scope': 'user'
|
||||
}
|
||||
auth_url = f"{AUTH_URL}?{'&'.join(f'{k}={v}' for k, v in params.items())}"
|
||||
return auth_url
|
||||
|
||||
# 第二步:获取 code 参数
|
||||
def get_code():
|
||||
# 本例中使用终端输入来模拟流程,仅供本地测试
|
||||
# 请在实际应用中替换为真实的处理逻辑
|
||||
return input('从回调 URL 中提取出 code,粘贴到此处并按回车:').strip()
|
||||
|
||||
# 第三步:使用 code 参数获取访问令牌
|
||||
def get_access_token(code):
|
||||
try:
|
||||
data = {
|
||||
'client_id': CLIENT_ID,
|
||||
'client_secret': CLIENT_SECRET,
|
||||
'code': code,
|
||||
'redirect_uri': REDIRECT_URI,
|
||||
'grant_type': 'authorization_code'
|
||||
}
|
||||
# 提醒:需正确配置请求头,否则无法正常获取访问令牌
|
||||
headers = {
|
||||
'Content-Type': 'application/x-www-form-urlencoded',
|
||||
'Accept': 'application/json'
|
||||
}
|
||||
response = requests.post(TOKEN_URL, data=data, headers=headers)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except requests.exceptions.RequestException as e:
|
||||
print(f"获取访问令牌失败:{e}")
|
||||
return None
|
||||
|
||||
# 第四步:使用访问令牌获取用户信息
|
||||
def get_user_info(access_token):
|
||||
try:
|
||||
headers = {
|
||||
'Authorization': f'Bearer {access_token}'
|
||||
}
|
||||
response = requests.get(USER_INFO_URL, headers=headers)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except requests.exceptions.RequestException as e:
|
||||
print(f"获取用户信息失败:{e}")
|
||||
return None
|
||||
|
||||
# 主流程
|
||||
if __name__ == '__main__':
|
||||
# 1. 生成授权 URL,前端引导用户访问授权页
|
||||
auth_url = get_auth_url()
|
||||
print(f'请访问此 URL 授权:{auth_url}
|
||||
')
|
||||
|
||||
# 2. 用户授权后,从回调 URL 获取 code 参数
|
||||
code = get_code()
|
||||
|
||||
# 3. 使用 code 参数获取访问令牌
|
||||
token_data = get_access_token(code)
|
||||
if token_data:
|
||||
access_token = token_data.get('access_token')
|
||||
|
||||
# 4. 使用访问令牌获取用户信息
|
||||
if access_token:
|
||||
user_info = get_user_info(access_token)
|
||||
if user_info:
|
||||
print(f"
|
||||
获取用户信息成功:{json.dumps(user_info, indent=2)}")
|
||||
else:
|
||||
print("
|
||||
获取用户信息失败")
|
||||
else:
|
||||
print(f"
|
||||
获取访问令牌失败:{json.dumps(token_data, indent=2)}")
|
||||
else:
|
||||
print("
|
||||
获取访问令牌失败")
|
||||
```
|
||||
PHP
|
||||
```php
|
||||
// 通过 OAuth2 获取 Linux Do 用户信息的参考流程
|
||||
|
||||
// 配置信息
|
||||
$CLIENT_ID = '你的 Client ID';
|
||||
$CLIENT_SECRET = '你的 Client Secret';
|
||||
$REDIRECT_URI = '你的回调地址';
|
||||
$AUTH_URL = 'https://connect.linux.do/oauth2/authorize';
|
||||
$TOKEN_URL = 'https://connect.linux.do/oauth2/token';
|
||||
$USER_INFO_URL = 'https://connect.linux.do/api/user';
|
||||
|
||||
// 生成授权 URL
|
||||
function getAuthUrl($clientId, $redirectUri) {
|
||||
global $AUTH_URL;
|
||||
return $AUTH_URL . '?' . http_build_query([
|
||||
'client_id' => $clientId,
|
||||
'redirect_uri' => $redirectUri,
|
||||
'response_type' => 'code',
|
||||
'scope' => 'user'
|
||||
]);
|
||||
}
|
||||
|
||||
// 使用 code 参数获取用户信息(合并获取令牌和获取用户信息的步骤)
|
||||
function getUserInfoWithCode($code, $clientId, $clientSecret, $redirectUri) {
|
||||
global $TOKEN_URL, $USER_INFO_URL;
|
||||
|
||||
// 1. 获取访问令牌
|
||||
$ch = curl_init($TOKEN_URL);
|
||||
curl_setopt($ch, CURLOPT_RETURNTRANSFER, true);
|
||||
curl_setopt($ch, CURLOPT_POST, true);
|
||||
curl_setopt($ch, CURLOPT_POSTFIELDS, http_build_query([
|
||||
'client_id' => $clientId,
|
||||
'client_secret' => $clientSecret,
|
||||
'code' => $code,
|
||||
'redirect_uri' => $redirectUri,
|
||||
'grant_type' => 'authorization_code'
|
||||
]));
|
||||
curl_setopt($ch, CURLOPT_HTTPHEADER, [
|
||||
'Content-Type: application/x-www-form-urlencoded',
|
||||
'Accept: application/json'
|
||||
]);
|
||||
|
||||
$tokenResponse = curl_exec($ch);
|
||||
curl_close($ch);
|
||||
|
||||
$tokenData = json_decode($tokenResponse, true);
|
||||
if (!isset($tokenData['access_token'])) {
|
||||
return ['error' => '获取访问令牌失败', 'details' => $tokenData];
|
||||
}
|
||||
|
||||
// 2. 获取用户信息
|
||||
$ch = curl_init($USER_INFO_URL);
|
||||
curl_setopt($ch, CURLOPT_RETURNTRANSFER, true);
|
||||
curl_setopt($ch, CURLOPT_HTTPHEADER, [
|
||||
'Authorization: Bearer ' . $tokenData['access_token']
|
||||
]);
|
||||
|
||||
$userResponse = curl_exec($ch);
|
||||
curl_close($ch);
|
||||
|
||||
return json_decode($userResponse, true);
|
||||
}
|
||||
|
||||
// 主流程
|
||||
// 1. 生成授权 URL
|
||||
$authUrl = getAuthUrl($CLIENT_ID, $REDIRECT_URI);
|
||||
echo "<a href='$authUrl'>使用 Linux Do 登录</a>";
|
||||
|
||||
// 2. 处理回调并获取用户信息
|
||||
if (isset($_GET['code'])) {
|
||||
$userInfo = getUserInfoWithCode(
|
||||
$_GET['code'],
|
||||
$CLIENT_ID,
|
||||
$CLIENT_SECRET,
|
||||
$REDIRECT_URI
|
||||
);
|
||||
|
||||
if (isset($userInfo['error'])) {
|
||||
echo '错误: ' . $userInfo['error'];
|
||||
} else {
|
||||
echo '欢迎, ' . $userInfo['name'] . '!';
|
||||
// 处理用户登录逻辑...
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## 使用说明
|
||||
|
||||
### 授权流程
|
||||
|
||||
1. 用户点击应用中的’使用 Linux Do 登录’按钮
|
||||
2. 系统将用户重定向至 Linux Do 的授权页面
|
||||
3. 用户完成授权后,系统自动重定向回应用并携带授权码
|
||||
4. 应用使用授权码获取访问令牌
|
||||
5. 使用访问令牌获取用户信息
|
||||
|
||||
### 安全建议
|
||||
|
||||
- 切勿在前端代码中暴露 Client Secret
|
||||
- 对所有用户输入数据进行严格验证
|
||||
- 确保使用 HTTPS 协议传输数据
|
||||
- 定期更新并妥善保管 Client Secret
|
||||
@@ -33,7 +33,7 @@ func main() {
|
||||
}()
|
||||
|
||||
userRepo := repository.NewUserRepository(client, sqlDB)
|
||||
authService := service.NewAuthService(userRepo, cfg, nil, nil, nil, nil)
|
||||
authService := service.NewAuthService(userRepo, cfg, nil, nil, nil, nil, nil)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
@@ -1 +1 @@
|
||||
0.1.1
|
||||
0.1.46
|
||||
|
||||
@@ -51,13 +51,17 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
turnstileVerifier := repository.NewTurnstileVerifier()
|
||||
turnstileService := service.NewTurnstileService(settingService, turnstileVerifier)
|
||||
emailQueueService := service.ProvideEmailQueueService(emailService)
|
||||
authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService)
|
||||
promoCodeRepository := repository.NewPromoCodeRepository(client)
|
||||
billingCache := repository.NewBillingCache(redisClient)
|
||||
userSubscriptionRepository := repository.NewUserSubscriptionRepository(client)
|
||||
billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository, configConfig)
|
||||
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client)
|
||||
authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService)
|
||||
userService := service.NewUserService(userRepository)
|
||||
authHandler := handler.NewAuthHandler(configConfig, authService, userService)
|
||||
authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService)
|
||||
userHandler := handler.NewUserHandler(userService)
|
||||
apiKeyRepository := repository.NewAPIKeyRepository(client)
|
||||
groupRepository := repository.NewGroupRepository(client, db)
|
||||
userSubscriptionRepository := repository.NewUserSubscriptionRepository(client)
|
||||
apiKeyCache := repository.NewAPIKeyCache(redisClient)
|
||||
apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, apiKeyCache, configConfig)
|
||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||
@@ -65,8 +69,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
usageService := service.NewUsageService(usageLogRepository, userRepository, client)
|
||||
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
||||
redeemCodeRepository := repository.NewRedeemCodeRepository(client)
|
||||
billingCache := repository.NewBillingCache(redisClient)
|
||||
billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository, configConfig)
|
||||
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService)
|
||||
redeemCache := repository.NewRedeemCache(redisClient)
|
||||
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client)
|
||||
@@ -112,6 +114,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
antigravityOAuthHandler := admin.NewAntigravityOAuthHandler(antigravityOAuthService)
|
||||
proxyHandler := admin.NewProxyHandler(adminService)
|
||||
adminRedeemHandler := admin.NewRedeemHandler(adminService)
|
||||
promoHandler := admin.NewPromoHandler(promoService)
|
||||
settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService)
|
||||
updateCache := repository.NewUpdateCache(redisClient)
|
||||
gitHubReleaseClient := repository.ProvideGitHubReleaseClient(configConfig)
|
||||
@@ -124,7 +127,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
userAttributeValueRepository := repository.NewUserAttributeValueRepository(client)
|
||||
userAttributeService := service.NewUserAttributeService(userAttributeDefinitionRepository, userAttributeValueRepository)
|
||||
userAttributeHandler := admin.NewUserAttributeHandler(userAttributeService)
|
||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, settingHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler)
|
||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler)
|
||||
pricingRemoteClient := repository.ProvidePricingRemoteClient(configConfig)
|
||||
pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient)
|
||||
if err != nil {
|
||||
@@ -145,7 +148,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService)
|
||||
adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService)
|
||||
apiKeyAuthMiddleware := middleware.NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig)
|
||||
engine := server.ProvideRouter(configConfig, handlers, jwtAuthMiddleware, adminAuthMiddleware, apiKeyAuthMiddleware, apiKeyService, subscriptionService)
|
||||
engine := server.ProvideRouter(configConfig, handlers, jwtAuthMiddleware, adminAuthMiddleware, apiKeyAuthMiddleware, apiKeyService, subscriptionService, redisClient)
|
||||
httpServer := server.ProvideHTTPServer(configConfig, engine)
|
||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, configConfig)
|
||||
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"math"
|
||||
|
||||
"entgo.io/ent"
|
||||
"entgo.io/ent/dialect"
|
||||
"entgo.io/ent/dialect/sql"
|
||||
"entgo.io/ent/dialect/sql/sqlgraph"
|
||||
"entgo.io/ent/schema/field"
|
||||
@@ -31,6 +32,7 @@ type AccountQuery struct {
|
||||
withProxy *ProxyQuery
|
||||
withUsageLogs *UsageLogQuery
|
||||
withAccountGroups *AccountGroupQuery
|
||||
modifiers []func(*sql.Selector)
|
||||
// intermediate query (i.e. traversal path).
|
||||
sql *sql.Selector
|
||||
path func(context.Context) (*sql.Selector, error)
|
||||
@@ -495,6 +497,9 @@ func (_q *AccountQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Acco
|
||||
node.Edges.loadedTypes = loadedTypes
|
||||
return node.assignValues(columns, values)
|
||||
}
|
||||
if len(_q.modifiers) > 0 {
|
||||
_spec.Modifiers = _q.modifiers
|
||||
}
|
||||
for i := range hooks {
|
||||
hooks[i](ctx, _spec)
|
||||
}
|
||||
@@ -690,6 +695,9 @@ func (_q *AccountQuery) loadAccountGroups(ctx context.Context, query *AccountGro
|
||||
|
||||
func (_q *AccountQuery) sqlCount(ctx context.Context) (int, error) {
|
||||
_spec := _q.querySpec()
|
||||
if len(_q.modifiers) > 0 {
|
||||
_spec.Modifiers = _q.modifiers
|
||||
}
|
||||
_spec.Node.Columns = _q.ctx.Fields
|
||||
if len(_q.ctx.Fields) > 0 {
|
||||
_spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
|
||||
@@ -755,6 +763,9 @@ func (_q *AccountQuery) sqlQuery(ctx context.Context) *sql.Selector {
|
||||
if _q.ctx.Unique != nil && *_q.ctx.Unique {
|
||||
selector.Distinct()
|
||||
}
|
||||
for _, m := range _q.modifiers {
|
||||
m(selector)
|
||||
}
|
||||
for _, p := range _q.predicates {
|
||||
p(selector)
|
||||
}
|
||||
@@ -772,6 +783,32 @@ func (_q *AccountQuery) sqlQuery(ctx context.Context) *sql.Selector {
|
||||
return selector
|
||||
}
|
||||
|
||||
// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
|
||||
// updated, deleted or "selected ... for update" by other sessions, until the transaction is
|
||||
// either committed or rolled-back.
|
||||
func (_q *AccountQuery) ForUpdate(opts ...sql.LockOption) *AccountQuery {
|
||||
if _q.driver.Dialect() == dialect.Postgres {
|
||||
_q.Unique(false)
|
||||
}
|
||||
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
|
||||
s.ForUpdate(opts...)
|
||||
})
|
||||
return _q
|
||||
}
|
||||
|
||||
// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
|
||||
// on any rows that are read. Other sessions can read the rows, but cannot modify them
|
||||
// until your transaction commits.
|
||||
func (_q *AccountQuery) ForShare(opts ...sql.LockOption) *AccountQuery {
|
||||
if _q.driver.Dialect() == dialect.Postgres {
|
||||
_q.Unique(false)
|
||||
}
|
||||
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
|
||||
s.ForShare(opts...)
|
||||
})
|
||||
return _q
|
||||
}
|
||||
|
||||
// AccountGroupBy is the group-by builder for Account entities.
|
||||
type AccountGroupBy struct {
|
||||
selector
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"math"
|
||||
|
||||
"entgo.io/ent"
|
||||
"entgo.io/ent/dialect"
|
||||
"entgo.io/ent/dialect/sql"
|
||||
"entgo.io/ent/dialect/sql/sqlgraph"
|
||||
"github.com/Wei-Shaw/sub2api/ent/account"
|
||||
@@ -25,6 +26,7 @@ type AccountGroupQuery struct {
|
||||
predicates []predicate.AccountGroup
|
||||
withAccount *AccountQuery
|
||||
withGroup *GroupQuery
|
||||
modifiers []func(*sql.Selector)
|
||||
// intermediate query (i.e. traversal path).
|
||||
sql *sql.Selector
|
||||
path func(context.Context) (*sql.Selector, error)
|
||||
@@ -347,6 +349,9 @@ func (_q *AccountGroupQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]
|
||||
node.Edges.loadedTypes = loadedTypes
|
||||
return node.assignValues(columns, values)
|
||||
}
|
||||
if len(_q.modifiers) > 0 {
|
||||
_spec.Modifiers = _q.modifiers
|
||||
}
|
||||
for i := range hooks {
|
||||
hooks[i](ctx, _spec)
|
||||
}
|
||||
@@ -432,6 +437,9 @@ func (_q *AccountGroupQuery) loadGroup(ctx context.Context, query *GroupQuery, n
|
||||
|
||||
func (_q *AccountGroupQuery) sqlCount(ctx context.Context) (int, error) {
|
||||
_spec := _q.querySpec()
|
||||
if len(_q.modifiers) > 0 {
|
||||
_spec.Modifiers = _q.modifiers
|
||||
}
|
||||
_spec.Unique = false
|
||||
_spec.Node.Columns = nil
|
||||
return sqlgraph.CountNodes(ctx, _q.driver, _spec)
|
||||
@@ -495,6 +503,9 @@ func (_q *AccountGroupQuery) sqlQuery(ctx context.Context) *sql.Selector {
|
||||
if _q.ctx.Unique != nil && *_q.ctx.Unique {
|
||||
selector.Distinct()
|
||||
}
|
||||
for _, m := range _q.modifiers {
|
||||
m(selector)
|
||||
}
|
||||
for _, p := range _q.predicates {
|
||||
p(selector)
|
||||
}
|
||||
@@ -512,6 +523,32 @@ func (_q *AccountGroupQuery) sqlQuery(ctx context.Context) *sql.Selector {
|
||||
return selector
|
||||
}
|
||||
|
||||
// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
|
||||
// updated, deleted or "selected ... for update" by other sessions, until the transaction is
|
||||
// either committed or rolled-back.
|
||||
func (_q *AccountGroupQuery) ForUpdate(opts ...sql.LockOption) *AccountGroupQuery {
|
||||
if _q.driver.Dialect() == dialect.Postgres {
|
||||
_q.Unique(false)
|
||||
}
|
||||
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
|
||||
s.ForUpdate(opts...)
|
||||
})
|
||||
return _q
|
||||
}
|
||||
|
||||
// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
|
||||
// on any rows that are read. Other sessions can read the rows, but cannot modify them
|
||||
// until your transaction commits.
|
||||
func (_q *AccountGroupQuery) ForShare(opts ...sql.LockOption) *AccountGroupQuery {
|
||||
if _q.driver.Dialect() == dialect.Postgres {
|
||||
_q.Unique(false)
|
||||
}
|
||||
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
|
||||
s.ForShare(opts...)
|
||||
})
|
||||
return _q
|
||||
}
|
||||
|
||||
// AccountGroupGroupBy is the group-by builder for AccountGroup entities.
|
||||
type AccountGroupGroupBy struct {
|
||||
selector
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
package ent
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -35,6 +36,10 @@ type APIKey struct {
|
||||
GroupID *int64 `json:"group_id,omitempty"`
|
||||
// Status holds the value of the "status" field.
|
||||
Status string `json:"status,omitempty"`
|
||||
// Allowed IPs/CIDRs, e.g. ["192.168.1.100", "10.0.0.0/8"]
|
||||
IPWhitelist []string `json:"ip_whitelist,omitempty"`
|
||||
// Blocked IPs/CIDRs
|
||||
IPBlacklist []string `json:"ip_blacklist,omitempty"`
|
||||
// Edges holds the relations/edges for other nodes in the graph.
|
||||
// The values are being populated by the APIKeyQuery when eager-loading is set.
|
||||
Edges APIKeyEdges `json:"edges"`
|
||||
@@ -90,6 +95,8 @@ func (*APIKey) scanValues(columns []string) ([]any, error) {
|
||||
values := make([]any, len(columns))
|
||||
for i := range columns {
|
||||
switch columns[i] {
|
||||
case apikey.FieldIPWhitelist, apikey.FieldIPBlacklist:
|
||||
values[i] = new([]byte)
|
||||
case apikey.FieldID, apikey.FieldUserID, apikey.FieldGroupID:
|
||||
values[i] = new(sql.NullInt64)
|
||||
case apikey.FieldKey, apikey.FieldName, apikey.FieldStatus:
|
||||
@@ -167,6 +174,22 @@ func (_m *APIKey) assignValues(columns []string, values []any) error {
|
||||
} else if value.Valid {
|
||||
_m.Status = value.String
|
||||
}
|
||||
case apikey.FieldIPWhitelist:
|
||||
if value, ok := values[i].(*[]byte); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field ip_whitelist", values[i])
|
||||
} else if value != nil && len(*value) > 0 {
|
||||
if err := json.Unmarshal(*value, &_m.IPWhitelist); err != nil {
|
||||
return fmt.Errorf("unmarshal field ip_whitelist: %w", err)
|
||||
}
|
||||
}
|
||||
case apikey.FieldIPBlacklist:
|
||||
if value, ok := values[i].(*[]byte); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field ip_blacklist", values[i])
|
||||
} else if value != nil && len(*value) > 0 {
|
||||
if err := json.Unmarshal(*value, &_m.IPBlacklist); err != nil {
|
||||
return fmt.Errorf("unmarshal field ip_blacklist: %w", err)
|
||||
}
|
||||
}
|
||||
default:
|
||||
_m.selectValues.Set(columns[i], values[i])
|
||||
}
|
||||
@@ -245,6 +268,12 @@ func (_m *APIKey) String() string {
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("status=")
|
||||
builder.WriteString(_m.Status)
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("ip_whitelist=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.IPWhitelist))
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("ip_blacklist=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.IPBlacklist))
|
||||
builder.WriteByte(')')
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
@@ -31,6 +31,10 @@ const (
|
||||
FieldGroupID = "group_id"
|
||||
// FieldStatus holds the string denoting the status field in the database.
|
||||
FieldStatus = "status"
|
||||
// FieldIPWhitelist holds the string denoting the ip_whitelist field in the database.
|
||||
FieldIPWhitelist = "ip_whitelist"
|
||||
// FieldIPBlacklist holds the string denoting the ip_blacklist field in the database.
|
||||
FieldIPBlacklist = "ip_blacklist"
|
||||
// EdgeUser holds the string denoting the user edge name in mutations.
|
||||
EdgeUser = "user"
|
||||
// EdgeGroup holds the string denoting the group edge name in mutations.
|
||||
@@ -73,6 +77,8 @@ var Columns = []string{
|
||||
FieldName,
|
||||
FieldGroupID,
|
||||
FieldStatus,
|
||||
FieldIPWhitelist,
|
||||
FieldIPBlacklist,
|
||||
}
|
||||
|
||||
// ValidColumn reports if the column name is valid (part of the table columns).
|
||||
|
||||
@@ -470,6 +470,26 @@ func StatusContainsFold(v string) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldContainsFold(FieldStatus, v))
|
||||
}
|
||||
|
||||
// IPWhitelistIsNil applies the IsNil predicate on the "ip_whitelist" field.
|
||||
func IPWhitelistIsNil() predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldIsNull(FieldIPWhitelist))
|
||||
}
|
||||
|
||||
// IPWhitelistNotNil applies the NotNil predicate on the "ip_whitelist" field.
|
||||
func IPWhitelistNotNil() predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldNotNull(FieldIPWhitelist))
|
||||
}
|
||||
|
||||
// IPBlacklistIsNil applies the IsNil predicate on the "ip_blacklist" field.
|
||||
func IPBlacklistIsNil() predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldIsNull(FieldIPBlacklist))
|
||||
}
|
||||
|
||||
// IPBlacklistNotNil applies the NotNil predicate on the "ip_blacklist" field.
|
||||
func IPBlacklistNotNil() predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldNotNull(FieldIPBlacklist))
|
||||
}
|
||||
|
||||
// HasUser applies the HasEdge predicate on the "user" edge.
|
||||
func HasUser() predicate.APIKey {
|
||||
return predicate.APIKey(func(s *sql.Selector) {
|
||||
|
||||
@@ -113,6 +113,18 @@ func (_c *APIKeyCreate) SetNillableStatus(v *string) *APIKeyCreate {
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetIPWhitelist sets the "ip_whitelist" field.
|
||||
func (_c *APIKeyCreate) SetIPWhitelist(v []string) *APIKeyCreate {
|
||||
_c.mutation.SetIPWhitelist(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetIPBlacklist sets the "ip_blacklist" field.
|
||||
func (_c *APIKeyCreate) SetIPBlacklist(v []string) *APIKeyCreate {
|
||||
_c.mutation.SetIPBlacklist(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetUser sets the "user" edge to the User entity.
|
||||
func (_c *APIKeyCreate) SetUser(v *User) *APIKeyCreate {
|
||||
return _c.SetUserID(v.ID)
|
||||
@@ -285,6 +297,14 @@ func (_c *APIKeyCreate) createSpec() (*APIKey, *sqlgraph.CreateSpec) {
|
||||
_spec.SetField(apikey.FieldStatus, field.TypeString, value)
|
||||
_node.Status = value
|
||||
}
|
||||
if value, ok := _c.mutation.IPWhitelist(); ok {
|
||||
_spec.SetField(apikey.FieldIPWhitelist, field.TypeJSON, value)
|
||||
_node.IPWhitelist = value
|
||||
}
|
||||
if value, ok := _c.mutation.IPBlacklist(); ok {
|
||||
_spec.SetField(apikey.FieldIPBlacklist, field.TypeJSON, value)
|
||||
_node.IPBlacklist = value
|
||||
}
|
||||
if nodes := _c.mutation.UserIDs(); len(nodes) > 0 {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.M2O,
|
||||
@@ -483,6 +503,42 @@ func (u *APIKeyUpsert) UpdateStatus() *APIKeyUpsert {
|
||||
return u
|
||||
}
|
||||
|
||||
// SetIPWhitelist sets the "ip_whitelist" field.
|
||||
func (u *APIKeyUpsert) SetIPWhitelist(v []string) *APIKeyUpsert {
|
||||
u.Set(apikey.FieldIPWhitelist, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateIPWhitelist sets the "ip_whitelist" field to the value that was provided on create.
|
||||
func (u *APIKeyUpsert) UpdateIPWhitelist() *APIKeyUpsert {
|
||||
u.SetExcluded(apikey.FieldIPWhitelist)
|
||||
return u
|
||||
}
|
||||
|
||||
// ClearIPWhitelist clears the value of the "ip_whitelist" field.
|
||||
func (u *APIKeyUpsert) ClearIPWhitelist() *APIKeyUpsert {
|
||||
u.SetNull(apikey.FieldIPWhitelist)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetIPBlacklist sets the "ip_blacklist" field.
|
||||
func (u *APIKeyUpsert) SetIPBlacklist(v []string) *APIKeyUpsert {
|
||||
u.Set(apikey.FieldIPBlacklist, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateIPBlacklist sets the "ip_blacklist" field to the value that was provided on create.
|
||||
func (u *APIKeyUpsert) UpdateIPBlacklist() *APIKeyUpsert {
|
||||
u.SetExcluded(apikey.FieldIPBlacklist)
|
||||
return u
|
||||
}
|
||||
|
||||
// ClearIPBlacklist clears the value of the "ip_blacklist" field.
|
||||
func (u *APIKeyUpsert) ClearIPBlacklist() *APIKeyUpsert {
|
||||
u.SetNull(apikey.FieldIPBlacklist)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateNewValues updates the mutable fields using the new values that were set on create.
|
||||
// Using this option is equivalent to using:
|
||||
//
|
||||
@@ -640,6 +696,48 @@ func (u *APIKeyUpsertOne) UpdateStatus() *APIKeyUpsertOne {
|
||||
})
|
||||
}
|
||||
|
||||
// SetIPWhitelist sets the "ip_whitelist" field.
|
||||
func (u *APIKeyUpsertOne) SetIPWhitelist(v []string) *APIKeyUpsertOne {
|
||||
return u.Update(func(s *APIKeyUpsert) {
|
||||
s.SetIPWhitelist(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateIPWhitelist sets the "ip_whitelist" field to the value that was provided on create.
|
||||
func (u *APIKeyUpsertOne) UpdateIPWhitelist() *APIKeyUpsertOne {
|
||||
return u.Update(func(s *APIKeyUpsert) {
|
||||
s.UpdateIPWhitelist()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearIPWhitelist clears the value of the "ip_whitelist" field.
|
||||
func (u *APIKeyUpsertOne) ClearIPWhitelist() *APIKeyUpsertOne {
|
||||
return u.Update(func(s *APIKeyUpsert) {
|
||||
s.ClearIPWhitelist()
|
||||
})
|
||||
}
|
||||
|
||||
// SetIPBlacklist sets the "ip_blacklist" field.
|
||||
func (u *APIKeyUpsertOne) SetIPBlacklist(v []string) *APIKeyUpsertOne {
|
||||
return u.Update(func(s *APIKeyUpsert) {
|
||||
s.SetIPBlacklist(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateIPBlacklist sets the "ip_blacklist" field to the value that was provided on create.
|
||||
func (u *APIKeyUpsertOne) UpdateIPBlacklist() *APIKeyUpsertOne {
|
||||
return u.Update(func(s *APIKeyUpsert) {
|
||||
s.UpdateIPBlacklist()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearIPBlacklist clears the value of the "ip_blacklist" field.
|
||||
func (u *APIKeyUpsertOne) ClearIPBlacklist() *APIKeyUpsertOne {
|
||||
return u.Update(func(s *APIKeyUpsert) {
|
||||
s.ClearIPBlacklist()
|
||||
})
|
||||
}
|
||||
|
||||
// Exec executes the query.
|
||||
func (u *APIKeyUpsertOne) Exec(ctx context.Context) error {
|
||||
if len(u.create.conflict) == 0 {
|
||||
@@ -963,6 +1061,48 @@ func (u *APIKeyUpsertBulk) UpdateStatus() *APIKeyUpsertBulk {
|
||||
})
|
||||
}
|
||||
|
||||
// SetIPWhitelist sets the "ip_whitelist" field.
|
||||
func (u *APIKeyUpsertBulk) SetIPWhitelist(v []string) *APIKeyUpsertBulk {
|
||||
return u.Update(func(s *APIKeyUpsert) {
|
||||
s.SetIPWhitelist(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateIPWhitelist sets the "ip_whitelist" field to the value that was provided on create.
|
||||
func (u *APIKeyUpsertBulk) UpdateIPWhitelist() *APIKeyUpsertBulk {
|
||||
return u.Update(func(s *APIKeyUpsert) {
|
||||
s.UpdateIPWhitelist()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearIPWhitelist clears the value of the "ip_whitelist" field.
|
||||
func (u *APIKeyUpsertBulk) ClearIPWhitelist() *APIKeyUpsertBulk {
|
||||
return u.Update(func(s *APIKeyUpsert) {
|
||||
s.ClearIPWhitelist()
|
||||
})
|
||||
}
|
||||
|
||||
// SetIPBlacklist sets the "ip_blacklist" field.
|
||||
func (u *APIKeyUpsertBulk) SetIPBlacklist(v []string) *APIKeyUpsertBulk {
|
||||
return u.Update(func(s *APIKeyUpsert) {
|
||||
s.SetIPBlacklist(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateIPBlacklist sets the "ip_blacklist" field to the value that was provided on create.
|
||||
func (u *APIKeyUpsertBulk) UpdateIPBlacklist() *APIKeyUpsertBulk {
|
||||
return u.Update(func(s *APIKeyUpsert) {
|
||||
s.UpdateIPBlacklist()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearIPBlacklist clears the value of the "ip_blacklist" field.
|
||||
func (u *APIKeyUpsertBulk) ClearIPBlacklist() *APIKeyUpsertBulk {
|
||||
return u.Update(func(s *APIKeyUpsert) {
|
||||
s.ClearIPBlacklist()
|
||||
})
|
||||
}
|
||||
|
||||
// Exec executes the query.
|
||||
func (u *APIKeyUpsertBulk) Exec(ctx context.Context) error {
|
||||
if u.create.err != nil {
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"math"
|
||||
|
||||
"entgo.io/ent"
|
||||
"entgo.io/ent/dialect"
|
||||
"entgo.io/ent/dialect/sql"
|
||||
"entgo.io/ent/dialect/sql/sqlgraph"
|
||||
"entgo.io/ent/schema/field"
|
||||
@@ -29,6 +30,7 @@ type APIKeyQuery struct {
|
||||
withUser *UserQuery
|
||||
withGroup *GroupQuery
|
||||
withUsageLogs *UsageLogQuery
|
||||
modifiers []func(*sql.Selector)
|
||||
// intermediate query (i.e. traversal path).
|
||||
sql *sql.Selector
|
||||
path func(context.Context) (*sql.Selector, error)
|
||||
@@ -458,6 +460,9 @@ func (_q *APIKeyQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*APIKe
|
||||
node.Edges.loadedTypes = loadedTypes
|
||||
return node.assignValues(columns, values)
|
||||
}
|
||||
if len(_q.modifiers) > 0 {
|
||||
_spec.Modifiers = _q.modifiers
|
||||
}
|
||||
for i := range hooks {
|
||||
hooks[i](ctx, _spec)
|
||||
}
|
||||
@@ -583,6 +588,9 @@ func (_q *APIKeyQuery) loadUsageLogs(ctx context.Context, query *UsageLogQuery,
|
||||
|
||||
func (_q *APIKeyQuery) sqlCount(ctx context.Context) (int, error) {
|
||||
_spec := _q.querySpec()
|
||||
if len(_q.modifiers) > 0 {
|
||||
_spec.Modifiers = _q.modifiers
|
||||
}
|
||||
_spec.Node.Columns = _q.ctx.Fields
|
||||
if len(_q.ctx.Fields) > 0 {
|
||||
_spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
|
||||
@@ -651,6 +659,9 @@ func (_q *APIKeyQuery) sqlQuery(ctx context.Context) *sql.Selector {
|
||||
if _q.ctx.Unique != nil && *_q.ctx.Unique {
|
||||
selector.Distinct()
|
||||
}
|
||||
for _, m := range _q.modifiers {
|
||||
m(selector)
|
||||
}
|
||||
for _, p := range _q.predicates {
|
||||
p(selector)
|
||||
}
|
||||
@@ -668,6 +679,32 @@ func (_q *APIKeyQuery) sqlQuery(ctx context.Context) *sql.Selector {
|
||||
return selector
|
||||
}
|
||||
|
||||
// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
|
||||
// updated, deleted or "selected ... for update" by other sessions, until the transaction is
|
||||
// either committed or rolled-back.
|
||||
func (_q *APIKeyQuery) ForUpdate(opts ...sql.LockOption) *APIKeyQuery {
|
||||
if _q.driver.Dialect() == dialect.Postgres {
|
||||
_q.Unique(false)
|
||||
}
|
||||
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
|
||||
s.ForUpdate(opts...)
|
||||
})
|
||||
return _q
|
||||
}
|
||||
|
||||
// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
|
||||
// on any rows that are read. Other sessions can read the rows, but cannot modify them
|
||||
// until your transaction commits.
|
||||
func (_q *APIKeyQuery) ForShare(opts ...sql.LockOption) *APIKeyQuery {
|
||||
if _q.driver.Dialect() == dialect.Postgres {
|
||||
_q.Unique(false)
|
||||
}
|
||||
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
|
||||
s.ForShare(opts...)
|
||||
})
|
||||
return _q
|
||||
}
|
||||
|
||||
// APIKeyGroupBy is the group-by builder for APIKey entities.
|
||||
type APIKeyGroupBy struct {
|
||||
selector
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
|
||||
"entgo.io/ent/dialect/sql"
|
||||
"entgo.io/ent/dialect/sql/sqlgraph"
|
||||
"entgo.io/ent/dialect/sql/sqljson"
|
||||
"entgo.io/ent/schema/field"
|
||||
"github.com/Wei-Shaw/sub2api/ent/apikey"
|
||||
"github.com/Wei-Shaw/sub2api/ent/group"
|
||||
@@ -133,6 +134,42 @@ func (_u *APIKeyUpdate) SetNillableStatus(v *string) *APIKeyUpdate {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetIPWhitelist sets the "ip_whitelist" field.
|
||||
func (_u *APIKeyUpdate) SetIPWhitelist(v []string) *APIKeyUpdate {
|
||||
_u.mutation.SetIPWhitelist(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// AppendIPWhitelist appends value to the "ip_whitelist" field.
|
||||
func (_u *APIKeyUpdate) AppendIPWhitelist(v []string) *APIKeyUpdate {
|
||||
_u.mutation.AppendIPWhitelist(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearIPWhitelist clears the value of the "ip_whitelist" field.
|
||||
func (_u *APIKeyUpdate) ClearIPWhitelist() *APIKeyUpdate {
|
||||
_u.mutation.ClearIPWhitelist()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetIPBlacklist sets the "ip_blacklist" field.
|
||||
func (_u *APIKeyUpdate) SetIPBlacklist(v []string) *APIKeyUpdate {
|
||||
_u.mutation.SetIPBlacklist(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// AppendIPBlacklist appends value to the "ip_blacklist" field.
|
||||
func (_u *APIKeyUpdate) AppendIPBlacklist(v []string) *APIKeyUpdate {
|
||||
_u.mutation.AppendIPBlacklist(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearIPBlacklist clears the value of the "ip_blacklist" field.
|
||||
func (_u *APIKeyUpdate) ClearIPBlacklist() *APIKeyUpdate {
|
||||
_u.mutation.ClearIPBlacklist()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetUser sets the "user" edge to the User entity.
|
||||
func (_u *APIKeyUpdate) SetUser(v *User) *APIKeyUpdate {
|
||||
return _u.SetUserID(v.ID)
|
||||
@@ -291,6 +328,28 @@ func (_u *APIKeyUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
||||
if value, ok := _u.mutation.Status(); ok {
|
||||
_spec.SetField(apikey.FieldStatus, field.TypeString, value)
|
||||
}
|
||||
if value, ok := _u.mutation.IPWhitelist(); ok {
|
||||
_spec.SetField(apikey.FieldIPWhitelist, field.TypeJSON, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AppendedIPWhitelist(); ok {
|
||||
_spec.AddModifier(func(u *sql.UpdateBuilder) {
|
||||
sqljson.Append(u, apikey.FieldIPWhitelist, value)
|
||||
})
|
||||
}
|
||||
if _u.mutation.IPWhitelistCleared() {
|
||||
_spec.ClearField(apikey.FieldIPWhitelist, field.TypeJSON)
|
||||
}
|
||||
if value, ok := _u.mutation.IPBlacklist(); ok {
|
||||
_spec.SetField(apikey.FieldIPBlacklist, field.TypeJSON, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AppendedIPBlacklist(); ok {
|
||||
_spec.AddModifier(func(u *sql.UpdateBuilder) {
|
||||
sqljson.Append(u, apikey.FieldIPBlacklist, value)
|
||||
})
|
||||
}
|
||||
if _u.mutation.IPBlacklistCleared() {
|
||||
_spec.ClearField(apikey.FieldIPBlacklist, field.TypeJSON)
|
||||
}
|
||||
if _u.mutation.UserCleared() {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.M2O,
|
||||
@@ -516,6 +575,42 @@ func (_u *APIKeyUpdateOne) SetNillableStatus(v *string) *APIKeyUpdateOne {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetIPWhitelist sets the "ip_whitelist" field.
|
||||
func (_u *APIKeyUpdateOne) SetIPWhitelist(v []string) *APIKeyUpdateOne {
|
||||
_u.mutation.SetIPWhitelist(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// AppendIPWhitelist appends value to the "ip_whitelist" field.
|
||||
func (_u *APIKeyUpdateOne) AppendIPWhitelist(v []string) *APIKeyUpdateOne {
|
||||
_u.mutation.AppendIPWhitelist(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearIPWhitelist clears the value of the "ip_whitelist" field.
|
||||
func (_u *APIKeyUpdateOne) ClearIPWhitelist() *APIKeyUpdateOne {
|
||||
_u.mutation.ClearIPWhitelist()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetIPBlacklist sets the "ip_blacklist" field.
|
||||
func (_u *APIKeyUpdateOne) SetIPBlacklist(v []string) *APIKeyUpdateOne {
|
||||
_u.mutation.SetIPBlacklist(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// AppendIPBlacklist appends value to the "ip_blacklist" field.
|
||||
func (_u *APIKeyUpdateOne) AppendIPBlacklist(v []string) *APIKeyUpdateOne {
|
||||
_u.mutation.AppendIPBlacklist(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearIPBlacklist clears the value of the "ip_blacklist" field.
|
||||
func (_u *APIKeyUpdateOne) ClearIPBlacklist() *APIKeyUpdateOne {
|
||||
_u.mutation.ClearIPBlacklist()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetUser sets the "user" edge to the User entity.
|
||||
func (_u *APIKeyUpdateOne) SetUser(v *User) *APIKeyUpdateOne {
|
||||
return _u.SetUserID(v.ID)
|
||||
@@ -704,6 +799,28 @@ func (_u *APIKeyUpdateOne) sqlSave(ctx context.Context) (_node *APIKey, err erro
|
||||
if value, ok := _u.mutation.Status(); ok {
|
||||
_spec.SetField(apikey.FieldStatus, field.TypeString, value)
|
||||
}
|
||||
if value, ok := _u.mutation.IPWhitelist(); ok {
|
||||
_spec.SetField(apikey.FieldIPWhitelist, field.TypeJSON, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AppendedIPWhitelist(); ok {
|
||||
_spec.AddModifier(func(u *sql.UpdateBuilder) {
|
||||
sqljson.Append(u, apikey.FieldIPWhitelist, value)
|
||||
})
|
||||
}
|
||||
if _u.mutation.IPWhitelistCleared() {
|
||||
_spec.ClearField(apikey.FieldIPWhitelist, field.TypeJSON)
|
||||
}
|
||||
if value, ok := _u.mutation.IPBlacklist(); ok {
|
||||
_spec.SetField(apikey.FieldIPBlacklist, field.TypeJSON, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AppendedIPBlacklist(); ok {
|
||||
_spec.AddModifier(func(u *sql.UpdateBuilder) {
|
||||
sqljson.Append(u, apikey.FieldIPBlacklist, value)
|
||||
})
|
||||
}
|
||||
if _u.mutation.IPBlacklistCleared() {
|
||||
_spec.ClearField(apikey.FieldIPBlacklist, field.TypeJSON)
|
||||
}
|
||||
if _u.mutation.UserCleared() {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.M2O,
|
||||
|
||||
@@ -19,6 +19,8 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/ent/accountgroup"
|
||||
"github.com/Wei-Shaw/sub2api/ent/apikey"
|
||||
"github.com/Wei-Shaw/sub2api/ent/group"
|
||||
"github.com/Wei-Shaw/sub2api/ent/promocode"
|
||||
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
|
||||
"github.com/Wei-Shaw/sub2api/ent/proxy"
|
||||
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
|
||||
"github.com/Wei-Shaw/sub2api/ent/setting"
|
||||
@@ -45,6 +47,10 @@ type Client struct {
|
||||
AccountGroup *AccountGroupClient
|
||||
// Group is the client for interacting with the Group builders.
|
||||
Group *GroupClient
|
||||
// PromoCode is the client for interacting with the PromoCode builders.
|
||||
PromoCode *PromoCodeClient
|
||||
// PromoCodeUsage is the client for interacting with the PromoCodeUsage builders.
|
||||
PromoCodeUsage *PromoCodeUsageClient
|
||||
// Proxy is the client for interacting with the Proxy builders.
|
||||
Proxy *ProxyClient
|
||||
// RedeemCode is the client for interacting with the RedeemCode builders.
|
||||
@@ -78,6 +84,8 @@ func (c *Client) init() {
|
||||
c.Account = NewAccountClient(c.config)
|
||||
c.AccountGroup = NewAccountGroupClient(c.config)
|
||||
c.Group = NewGroupClient(c.config)
|
||||
c.PromoCode = NewPromoCodeClient(c.config)
|
||||
c.PromoCodeUsage = NewPromoCodeUsageClient(c.config)
|
||||
c.Proxy = NewProxyClient(c.config)
|
||||
c.RedeemCode = NewRedeemCodeClient(c.config)
|
||||
c.Setting = NewSettingClient(c.config)
|
||||
@@ -183,6 +191,8 @@ func (c *Client) Tx(ctx context.Context) (*Tx, error) {
|
||||
Account: NewAccountClient(cfg),
|
||||
AccountGroup: NewAccountGroupClient(cfg),
|
||||
Group: NewGroupClient(cfg),
|
||||
PromoCode: NewPromoCodeClient(cfg),
|
||||
PromoCodeUsage: NewPromoCodeUsageClient(cfg),
|
||||
Proxy: NewProxyClient(cfg),
|
||||
RedeemCode: NewRedeemCodeClient(cfg),
|
||||
Setting: NewSettingClient(cfg),
|
||||
@@ -215,6 +225,8 @@ func (c *Client) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error)
|
||||
Account: NewAccountClient(cfg),
|
||||
AccountGroup: NewAccountGroupClient(cfg),
|
||||
Group: NewGroupClient(cfg),
|
||||
PromoCode: NewPromoCodeClient(cfg),
|
||||
PromoCodeUsage: NewPromoCodeUsageClient(cfg),
|
||||
Proxy: NewProxyClient(cfg),
|
||||
RedeemCode: NewRedeemCodeClient(cfg),
|
||||
Setting: NewSettingClient(cfg),
|
||||
@@ -253,9 +265,9 @@ func (c *Client) Close() error {
|
||||
// In order to add hooks to a specific client, call: `client.Node.Use(...)`.
|
||||
func (c *Client) Use(hooks ...Hook) {
|
||||
for _, n := range []interface{ Use(...Hook) }{
|
||||
c.APIKey, c.Account, c.AccountGroup, c.Group, c.Proxy, c.RedeemCode, c.Setting,
|
||||
c.UsageLog, c.User, c.UserAllowedGroup, c.UserAttributeDefinition,
|
||||
c.UserAttributeValue, c.UserSubscription,
|
||||
c.APIKey, c.Account, c.AccountGroup, c.Group, c.PromoCode, c.PromoCodeUsage,
|
||||
c.Proxy, c.RedeemCode, c.Setting, c.UsageLog, c.User, c.UserAllowedGroup,
|
||||
c.UserAttributeDefinition, c.UserAttributeValue, c.UserSubscription,
|
||||
} {
|
||||
n.Use(hooks...)
|
||||
}
|
||||
@@ -265,9 +277,9 @@ func (c *Client) Use(hooks ...Hook) {
|
||||
// In order to add interceptors to a specific client, call: `client.Node.Intercept(...)`.
|
||||
func (c *Client) Intercept(interceptors ...Interceptor) {
|
||||
for _, n := range []interface{ Intercept(...Interceptor) }{
|
||||
c.APIKey, c.Account, c.AccountGroup, c.Group, c.Proxy, c.RedeemCode, c.Setting,
|
||||
c.UsageLog, c.User, c.UserAllowedGroup, c.UserAttributeDefinition,
|
||||
c.UserAttributeValue, c.UserSubscription,
|
||||
c.APIKey, c.Account, c.AccountGroup, c.Group, c.PromoCode, c.PromoCodeUsage,
|
||||
c.Proxy, c.RedeemCode, c.Setting, c.UsageLog, c.User, c.UserAllowedGroup,
|
||||
c.UserAttributeDefinition, c.UserAttributeValue, c.UserSubscription,
|
||||
} {
|
||||
n.Intercept(interceptors...)
|
||||
}
|
||||
@@ -284,6 +296,10 @@ func (c *Client) Mutate(ctx context.Context, m Mutation) (Value, error) {
|
||||
return c.AccountGroup.mutate(ctx, m)
|
||||
case *GroupMutation:
|
||||
return c.Group.mutate(ctx, m)
|
||||
case *PromoCodeMutation:
|
||||
return c.PromoCode.mutate(ctx, m)
|
||||
case *PromoCodeUsageMutation:
|
||||
return c.PromoCodeUsage.mutate(ctx, m)
|
||||
case *ProxyMutation:
|
||||
return c.Proxy.mutate(ctx, m)
|
||||
case *RedeemCodeMutation:
|
||||
@@ -1068,6 +1084,320 @@ func (c *GroupClient) mutate(ctx context.Context, m *GroupMutation) (Value, erro
|
||||
}
|
||||
}
|
||||
|
||||
// PromoCodeClient is a client for the PromoCode schema.
|
||||
type PromoCodeClient struct {
|
||||
config
|
||||
}
|
||||
|
||||
// NewPromoCodeClient returns a client for the PromoCode from the given config.
|
||||
func NewPromoCodeClient(c config) *PromoCodeClient {
|
||||
return &PromoCodeClient{config: c}
|
||||
}
|
||||
|
||||
// Use adds a list of mutation hooks to the hooks stack.
|
||||
// A call to `Use(f, g, h)` equals to `promocode.Hooks(f(g(h())))`.
|
||||
func (c *PromoCodeClient) Use(hooks ...Hook) {
|
||||
c.hooks.PromoCode = append(c.hooks.PromoCode, hooks...)
|
||||
}
|
||||
|
||||
// Intercept adds a list of query interceptors to the interceptors stack.
|
||||
// A call to `Intercept(f, g, h)` equals to `promocode.Intercept(f(g(h())))`.
|
||||
func (c *PromoCodeClient) Intercept(interceptors ...Interceptor) {
|
||||
c.inters.PromoCode = append(c.inters.PromoCode, interceptors...)
|
||||
}
|
||||
|
||||
// Create returns a builder for creating a PromoCode entity.
|
||||
func (c *PromoCodeClient) Create() *PromoCodeCreate {
|
||||
mutation := newPromoCodeMutation(c.config, OpCreate)
|
||||
return &PromoCodeCreate{config: c.config, hooks: c.Hooks(), mutation: mutation}
|
||||
}
|
||||
|
||||
// CreateBulk returns a builder for creating a bulk of PromoCode entities.
|
||||
func (c *PromoCodeClient) CreateBulk(builders ...*PromoCodeCreate) *PromoCodeCreateBulk {
|
||||
return &PromoCodeCreateBulk{config: c.config, builders: builders}
|
||||
}
|
||||
|
||||
// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
|
||||
// a builder and applies setFunc on it.
|
||||
func (c *PromoCodeClient) MapCreateBulk(slice any, setFunc func(*PromoCodeCreate, int)) *PromoCodeCreateBulk {
|
||||
rv := reflect.ValueOf(slice)
|
||||
if rv.Kind() != reflect.Slice {
|
||||
return &PromoCodeCreateBulk{err: fmt.Errorf("calling to PromoCodeClient.MapCreateBulk with wrong type %T, need slice", slice)}
|
||||
}
|
||||
builders := make([]*PromoCodeCreate, rv.Len())
|
||||
for i := 0; i < rv.Len(); i++ {
|
||||
builders[i] = c.Create()
|
||||
setFunc(builders[i], i)
|
||||
}
|
||||
return &PromoCodeCreateBulk{config: c.config, builders: builders}
|
||||
}
|
||||
|
||||
// Update returns an update builder for PromoCode.
|
||||
func (c *PromoCodeClient) Update() *PromoCodeUpdate {
|
||||
mutation := newPromoCodeMutation(c.config, OpUpdate)
|
||||
return &PromoCodeUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation}
|
||||
}
|
||||
|
||||
// UpdateOne returns an update builder for the given entity.
|
||||
func (c *PromoCodeClient) UpdateOne(_m *PromoCode) *PromoCodeUpdateOne {
|
||||
mutation := newPromoCodeMutation(c.config, OpUpdateOne, withPromoCode(_m))
|
||||
return &PromoCodeUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
|
||||
}
|
||||
|
||||
// UpdateOneID returns an update builder for the given id.
|
||||
func (c *PromoCodeClient) UpdateOneID(id int64) *PromoCodeUpdateOne {
|
||||
mutation := newPromoCodeMutation(c.config, OpUpdateOne, withPromoCodeID(id))
|
||||
return &PromoCodeUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
|
||||
}
|
||||
|
||||
// Delete returns a delete builder for PromoCode.
|
||||
func (c *PromoCodeClient) Delete() *PromoCodeDelete {
|
||||
mutation := newPromoCodeMutation(c.config, OpDelete)
|
||||
return &PromoCodeDelete{config: c.config, hooks: c.Hooks(), mutation: mutation}
|
||||
}
|
||||
|
||||
// DeleteOne returns a builder for deleting the given entity.
|
||||
func (c *PromoCodeClient) DeleteOne(_m *PromoCode) *PromoCodeDeleteOne {
|
||||
return c.DeleteOneID(_m.ID)
|
||||
}
|
||||
|
||||
// DeleteOneID returns a builder for deleting the given entity by its id.
|
||||
func (c *PromoCodeClient) DeleteOneID(id int64) *PromoCodeDeleteOne {
|
||||
builder := c.Delete().Where(promocode.ID(id))
|
||||
builder.mutation.id = &id
|
||||
builder.mutation.op = OpDeleteOne
|
||||
return &PromoCodeDeleteOne{builder}
|
||||
}
|
||||
|
||||
// Query returns a query builder for PromoCode.
|
||||
func (c *PromoCodeClient) Query() *PromoCodeQuery {
|
||||
return &PromoCodeQuery{
|
||||
config: c.config,
|
||||
ctx: &QueryContext{Type: TypePromoCode},
|
||||
inters: c.Interceptors(),
|
||||
}
|
||||
}
|
||||
|
||||
// Get returns a PromoCode entity by its id.
|
||||
func (c *PromoCodeClient) Get(ctx context.Context, id int64) (*PromoCode, error) {
|
||||
return c.Query().Where(promocode.ID(id)).Only(ctx)
|
||||
}
|
||||
|
||||
// GetX is like Get, but panics if an error occurs.
|
||||
func (c *PromoCodeClient) GetX(ctx context.Context, id int64) *PromoCode {
|
||||
obj, err := c.Get(ctx, id)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return obj
|
||||
}
|
||||
|
||||
// QueryUsageRecords queries the usage_records edge of a PromoCode.
|
||||
func (c *PromoCodeClient) QueryUsageRecords(_m *PromoCode) *PromoCodeUsageQuery {
|
||||
query := (&PromoCodeUsageClient{config: c.config}).Query()
|
||||
query.path = func(context.Context) (fromV *sql.Selector, _ error) {
|
||||
id := _m.ID
|
||||
step := sqlgraph.NewStep(
|
||||
sqlgraph.From(promocode.Table, promocode.FieldID, id),
|
||||
sqlgraph.To(promocodeusage.Table, promocodeusage.FieldID),
|
||||
sqlgraph.Edge(sqlgraph.O2M, false, promocode.UsageRecordsTable, promocode.UsageRecordsColumn),
|
||||
)
|
||||
fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
|
||||
return fromV, nil
|
||||
}
|
||||
return query
|
||||
}
|
||||
|
||||
// Hooks returns the client hooks.
|
||||
func (c *PromoCodeClient) Hooks() []Hook {
|
||||
return c.hooks.PromoCode
|
||||
}
|
||||
|
||||
// Interceptors returns the client interceptors.
|
||||
func (c *PromoCodeClient) Interceptors() []Interceptor {
|
||||
return c.inters.PromoCode
|
||||
}
|
||||
|
||||
func (c *PromoCodeClient) mutate(ctx context.Context, m *PromoCodeMutation) (Value, error) {
|
||||
switch m.Op() {
|
||||
case OpCreate:
|
||||
return (&PromoCodeCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
|
||||
case OpUpdate:
|
||||
return (&PromoCodeUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
|
||||
case OpUpdateOne:
|
||||
return (&PromoCodeUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
|
||||
case OpDelete, OpDeleteOne:
|
||||
return (&PromoCodeDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
|
||||
default:
|
||||
return nil, fmt.Errorf("ent: unknown PromoCode mutation op: %q", m.Op())
|
||||
}
|
||||
}
|
||||
|
||||
// PromoCodeUsageClient is a client for the PromoCodeUsage schema.
|
||||
type PromoCodeUsageClient struct {
|
||||
config
|
||||
}
|
||||
|
||||
// NewPromoCodeUsageClient returns a client for the PromoCodeUsage from the given config.
|
||||
func NewPromoCodeUsageClient(c config) *PromoCodeUsageClient {
|
||||
return &PromoCodeUsageClient{config: c}
|
||||
}
|
||||
|
||||
// Use adds a list of mutation hooks to the hooks stack.
|
||||
// A call to `Use(f, g, h)` equals to `promocodeusage.Hooks(f(g(h())))`.
|
||||
func (c *PromoCodeUsageClient) Use(hooks ...Hook) {
|
||||
c.hooks.PromoCodeUsage = append(c.hooks.PromoCodeUsage, hooks...)
|
||||
}
|
||||
|
||||
// Intercept adds a list of query interceptors to the interceptors stack.
|
||||
// A call to `Intercept(f, g, h)` equals to `promocodeusage.Intercept(f(g(h())))`.
|
||||
func (c *PromoCodeUsageClient) Intercept(interceptors ...Interceptor) {
|
||||
c.inters.PromoCodeUsage = append(c.inters.PromoCodeUsage, interceptors...)
|
||||
}
|
||||
|
||||
// Create returns a builder for creating a PromoCodeUsage entity.
|
||||
func (c *PromoCodeUsageClient) Create() *PromoCodeUsageCreate {
|
||||
mutation := newPromoCodeUsageMutation(c.config, OpCreate)
|
||||
return &PromoCodeUsageCreate{config: c.config, hooks: c.Hooks(), mutation: mutation}
|
||||
}
|
||||
|
||||
// CreateBulk returns a builder for creating a bulk of PromoCodeUsage entities.
|
||||
func (c *PromoCodeUsageClient) CreateBulk(builders ...*PromoCodeUsageCreate) *PromoCodeUsageCreateBulk {
|
||||
return &PromoCodeUsageCreateBulk{config: c.config, builders: builders}
|
||||
}
|
||||
|
||||
// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
|
||||
// a builder and applies setFunc on it.
|
||||
func (c *PromoCodeUsageClient) MapCreateBulk(slice any, setFunc func(*PromoCodeUsageCreate, int)) *PromoCodeUsageCreateBulk {
|
||||
rv := reflect.ValueOf(slice)
|
||||
if rv.Kind() != reflect.Slice {
|
||||
return &PromoCodeUsageCreateBulk{err: fmt.Errorf("calling to PromoCodeUsageClient.MapCreateBulk with wrong type %T, need slice", slice)}
|
||||
}
|
||||
builders := make([]*PromoCodeUsageCreate, rv.Len())
|
||||
for i := 0; i < rv.Len(); i++ {
|
||||
builders[i] = c.Create()
|
||||
setFunc(builders[i], i)
|
||||
}
|
||||
return &PromoCodeUsageCreateBulk{config: c.config, builders: builders}
|
||||
}
|
||||
|
||||
// Update returns an update builder for PromoCodeUsage.
|
||||
func (c *PromoCodeUsageClient) Update() *PromoCodeUsageUpdate {
|
||||
mutation := newPromoCodeUsageMutation(c.config, OpUpdate)
|
||||
return &PromoCodeUsageUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation}
|
||||
}
|
||||
|
||||
// UpdateOne returns an update builder for the given entity.
|
||||
func (c *PromoCodeUsageClient) UpdateOne(_m *PromoCodeUsage) *PromoCodeUsageUpdateOne {
|
||||
mutation := newPromoCodeUsageMutation(c.config, OpUpdateOne, withPromoCodeUsage(_m))
|
||||
return &PromoCodeUsageUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
|
||||
}
|
||||
|
||||
// UpdateOneID returns an update builder for the given id.
|
||||
func (c *PromoCodeUsageClient) UpdateOneID(id int64) *PromoCodeUsageUpdateOne {
|
||||
mutation := newPromoCodeUsageMutation(c.config, OpUpdateOne, withPromoCodeUsageID(id))
|
||||
return &PromoCodeUsageUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
|
||||
}
|
||||
|
||||
// Delete returns a delete builder for PromoCodeUsage.
|
||||
func (c *PromoCodeUsageClient) Delete() *PromoCodeUsageDelete {
|
||||
mutation := newPromoCodeUsageMutation(c.config, OpDelete)
|
||||
return &PromoCodeUsageDelete{config: c.config, hooks: c.Hooks(), mutation: mutation}
|
||||
}
|
||||
|
||||
// DeleteOne returns a builder for deleting the given entity.
|
||||
func (c *PromoCodeUsageClient) DeleteOne(_m *PromoCodeUsage) *PromoCodeUsageDeleteOne {
|
||||
return c.DeleteOneID(_m.ID)
|
||||
}
|
||||
|
||||
// DeleteOneID returns a builder for deleting the given entity by its id.
|
||||
func (c *PromoCodeUsageClient) DeleteOneID(id int64) *PromoCodeUsageDeleteOne {
|
||||
builder := c.Delete().Where(promocodeusage.ID(id))
|
||||
builder.mutation.id = &id
|
||||
builder.mutation.op = OpDeleteOne
|
||||
return &PromoCodeUsageDeleteOne{builder}
|
||||
}
|
||||
|
||||
// Query returns a query builder for PromoCodeUsage.
|
||||
func (c *PromoCodeUsageClient) Query() *PromoCodeUsageQuery {
|
||||
return &PromoCodeUsageQuery{
|
||||
config: c.config,
|
||||
ctx: &QueryContext{Type: TypePromoCodeUsage},
|
||||
inters: c.Interceptors(),
|
||||
}
|
||||
}
|
||||
|
||||
// Get returns a PromoCodeUsage entity by its id.
|
||||
func (c *PromoCodeUsageClient) Get(ctx context.Context, id int64) (*PromoCodeUsage, error) {
|
||||
return c.Query().Where(promocodeusage.ID(id)).Only(ctx)
|
||||
}
|
||||
|
||||
// GetX is like Get, but panics if an error occurs.
|
||||
func (c *PromoCodeUsageClient) GetX(ctx context.Context, id int64) *PromoCodeUsage {
|
||||
obj, err := c.Get(ctx, id)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return obj
|
||||
}
|
||||
|
||||
// QueryPromoCode queries the promo_code edge of a PromoCodeUsage.
|
||||
func (c *PromoCodeUsageClient) QueryPromoCode(_m *PromoCodeUsage) *PromoCodeQuery {
|
||||
query := (&PromoCodeClient{config: c.config}).Query()
|
||||
query.path = func(context.Context) (fromV *sql.Selector, _ error) {
|
||||
id := _m.ID
|
||||
step := sqlgraph.NewStep(
|
||||
sqlgraph.From(promocodeusage.Table, promocodeusage.FieldID, id),
|
||||
sqlgraph.To(promocode.Table, promocode.FieldID),
|
||||
sqlgraph.Edge(sqlgraph.M2O, true, promocodeusage.PromoCodeTable, promocodeusage.PromoCodeColumn),
|
||||
)
|
||||
fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
|
||||
return fromV, nil
|
||||
}
|
||||
return query
|
||||
}
|
||||
|
||||
// QueryUser queries the user edge of a PromoCodeUsage.
|
||||
func (c *PromoCodeUsageClient) QueryUser(_m *PromoCodeUsage) *UserQuery {
|
||||
query := (&UserClient{config: c.config}).Query()
|
||||
query.path = func(context.Context) (fromV *sql.Selector, _ error) {
|
||||
id := _m.ID
|
||||
step := sqlgraph.NewStep(
|
||||
sqlgraph.From(promocodeusage.Table, promocodeusage.FieldID, id),
|
||||
sqlgraph.To(user.Table, user.FieldID),
|
||||
sqlgraph.Edge(sqlgraph.M2O, true, promocodeusage.UserTable, promocodeusage.UserColumn),
|
||||
)
|
||||
fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
|
||||
return fromV, nil
|
||||
}
|
||||
return query
|
||||
}
|
||||
|
||||
// Hooks returns the client hooks.
|
||||
func (c *PromoCodeUsageClient) Hooks() []Hook {
|
||||
return c.hooks.PromoCodeUsage
|
||||
}
|
||||
|
||||
// Interceptors returns the client interceptors.
|
||||
func (c *PromoCodeUsageClient) Interceptors() []Interceptor {
|
||||
return c.inters.PromoCodeUsage
|
||||
}
|
||||
|
||||
func (c *PromoCodeUsageClient) mutate(ctx context.Context, m *PromoCodeUsageMutation) (Value, error) {
|
||||
switch m.Op() {
|
||||
case OpCreate:
|
||||
return (&PromoCodeUsageCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
|
||||
case OpUpdate:
|
||||
return (&PromoCodeUsageUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
|
||||
case OpUpdateOne:
|
||||
return (&PromoCodeUsageUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
|
||||
case OpDelete, OpDeleteOne:
|
||||
return (&PromoCodeUsageDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
|
||||
default:
|
||||
return nil, fmt.Errorf("ent: unknown PromoCodeUsage mutation op: %q", m.Op())
|
||||
}
|
||||
}
|
||||
|
||||
// ProxyClient is a client for the Proxy schema.
|
||||
type ProxyClient struct {
|
||||
config
|
||||
@@ -1950,6 +2280,22 @@ func (c *UserClient) QueryAttributeValues(_m *User) *UserAttributeValueQuery {
|
||||
return query
|
||||
}
|
||||
|
||||
// QueryPromoCodeUsages queries the promo_code_usages edge of a User.
|
||||
func (c *UserClient) QueryPromoCodeUsages(_m *User) *PromoCodeUsageQuery {
|
||||
query := (&PromoCodeUsageClient{config: c.config}).Query()
|
||||
query.path = func(context.Context) (fromV *sql.Selector, _ error) {
|
||||
id := _m.ID
|
||||
step := sqlgraph.NewStep(
|
||||
sqlgraph.From(user.Table, user.FieldID, id),
|
||||
sqlgraph.To(promocodeusage.Table, promocodeusage.FieldID),
|
||||
sqlgraph.Edge(sqlgraph.O2M, false, user.PromoCodeUsagesTable, user.PromoCodeUsagesColumn),
|
||||
)
|
||||
fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
|
||||
return fromV, nil
|
||||
}
|
||||
return query
|
||||
}
|
||||
|
||||
// QueryUserAllowedGroups queries the user_allowed_groups edge of a User.
|
||||
func (c *UserClient) QueryUserAllowedGroups(_m *User) *UserAllowedGroupQuery {
|
||||
query := (&UserAllowedGroupClient{config: c.config}).Query()
|
||||
@@ -2627,14 +2973,14 @@ func (c *UserSubscriptionClient) mutate(ctx context.Context, m *UserSubscription
|
||||
// hooks and interceptors per client, for fast access.
|
||||
type (
|
||||
hooks struct {
|
||||
APIKey, Account, AccountGroup, Group, Proxy, RedeemCode, Setting, UsageLog,
|
||||
User, UserAllowedGroup, UserAttributeDefinition, UserAttributeValue,
|
||||
UserSubscription []ent.Hook
|
||||
APIKey, Account, AccountGroup, Group, PromoCode, PromoCodeUsage, Proxy,
|
||||
RedeemCode, Setting, UsageLog, User, UserAllowedGroup, UserAttributeDefinition,
|
||||
UserAttributeValue, UserSubscription []ent.Hook
|
||||
}
|
||||
inters struct {
|
||||
APIKey, Account, AccountGroup, Group, Proxy, RedeemCode, Setting, UsageLog,
|
||||
User, UserAllowedGroup, UserAttributeDefinition, UserAttributeValue,
|
||||
UserSubscription []ent.Interceptor
|
||||
APIKey, Account, AccountGroup, Group, PromoCode, PromoCodeUsage, Proxy,
|
||||
RedeemCode, Setting, UsageLog, User, UserAllowedGroup, UserAttributeDefinition,
|
||||
UserAttributeValue, UserSubscription []ent.Interceptor
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -16,6 +16,8 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/ent/accountgroup"
|
||||
"github.com/Wei-Shaw/sub2api/ent/apikey"
|
||||
"github.com/Wei-Shaw/sub2api/ent/group"
|
||||
"github.com/Wei-Shaw/sub2api/ent/promocode"
|
||||
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
|
||||
"github.com/Wei-Shaw/sub2api/ent/proxy"
|
||||
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
|
||||
"github.com/Wei-Shaw/sub2api/ent/setting"
|
||||
@@ -89,6 +91,8 @@ func checkColumn(t, c string) error {
|
||||
account.Table: account.ValidColumn,
|
||||
accountgroup.Table: accountgroup.ValidColumn,
|
||||
group.Table: group.ValidColumn,
|
||||
promocode.Table: promocode.ValidColumn,
|
||||
promocodeusage.Table: promocodeusage.ValidColumn,
|
||||
proxy.Table: proxy.ValidColumn,
|
||||
redeemcode.Table: redeemcode.ValidColumn,
|
||||
setting.Table: setting.ValidColumn,
|
||||
|
||||
@@ -2,4 +2,5 @@
|
||||
package ent
|
||||
|
||||
// 启用 sql/execquery 以生成 ExecContext/QueryContext 的透传接口,便于事务内执行原生 SQL。
|
||||
//go:generate go run -mod=mod entgo.io/ent/cmd/ent generate --feature sql/upsert,intercept,sql/execquery --idtype int64 ./schema
|
||||
// 启用 sql/lock 以支持 FOR UPDATE 行锁。
|
||||
//go:generate go run -mod=mod entgo.io/ent/cmd/ent generate --feature sql/upsert,intercept,sql/execquery,sql/lock --idtype int64 ./schema
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"math"
|
||||
|
||||
"entgo.io/ent"
|
||||
"entgo.io/ent/dialect"
|
||||
"entgo.io/ent/dialect/sql"
|
||||
"entgo.io/ent/dialect/sql/sqlgraph"
|
||||
"entgo.io/ent/schema/field"
|
||||
@@ -39,6 +40,7 @@ type GroupQuery struct {
|
||||
withAllowedUsers *UserQuery
|
||||
withAccountGroups *AccountGroupQuery
|
||||
withUserAllowedGroups *UserAllowedGroupQuery
|
||||
modifiers []func(*sql.Selector)
|
||||
// intermediate query (i.e. traversal path).
|
||||
sql *sql.Selector
|
||||
path func(context.Context) (*sql.Selector, error)
|
||||
@@ -643,6 +645,9 @@ func (_q *GroupQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Group,
|
||||
node.Edges.loadedTypes = loadedTypes
|
||||
return node.assignValues(columns, values)
|
||||
}
|
||||
if len(_q.modifiers) > 0 {
|
||||
_spec.Modifiers = _q.modifiers
|
||||
}
|
||||
for i := range hooks {
|
||||
hooks[i](ctx, _spec)
|
||||
}
|
||||
@@ -1025,6 +1030,9 @@ func (_q *GroupQuery) loadUserAllowedGroups(ctx context.Context, query *UserAllo
|
||||
|
||||
func (_q *GroupQuery) sqlCount(ctx context.Context) (int, error) {
|
||||
_spec := _q.querySpec()
|
||||
if len(_q.modifiers) > 0 {
|
||||
_spec.Modifiers = _q.modifiers
|
||||
}
|
||||
_spec.Node.Columns = _q.ctx.Fields
|
||||
if len(_q.ctx.Fields) > 0 {
|
||||
_spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
|
||||
@@ -1087,6 +1095,9 @@ func (_q *GroupQuery) sqlQuery(ctx context.Context) *sql.Selector {
|
||||
if _q.ctx.Unique != nil && *_q.ctx.Unique {
|
||||
selector.Distinct()
|
||||
}
|
||||
for _, m := range _q.modifiers {
|
||||
m(selector)
|
||||
}
|
||||
for _, p := range _q.predicates {
|
||||
p(selector)
|
||||
}
|
||||
@@ -1104,6 +1115,32 @@ func (_q *GroupQuery) sqlQuery(ctx context.Context) *sql.Selector {
|
||||
return selector
|
||||
}
|
||||
|
||||
// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
|
||||
// updated, deleted or "selected ... for update" by other sessions, until the transaction is
|
||||
// either committed or rolled-back.
|
||||
func (_q *GroupQuery) ForUpdate(opts ...sql.LockOption) *GroupQuery {
|
||||
if _q.driver.Dialect() == dialect.Postgres {
|
||||
_q.Unique(false)
|
||||
}
|
||||
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
|
||||
s.ForUpdate(opts...)
|
||||
})
|
||||
return _q
|
||||
}
|
||||
|
||||
// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
|
||||
// on any rows that are read. Other sessions can read the rows, but cannot modify them
|
||||
// until your transaction commits.
|
||||
func (_q *GroupQuery) ForShare(opts ...sql.LockOption) *GroupQuery {
|
||||
if _q.driver.Dialect() == dialect.Postgres {
|
||||
_q.Unique(false)
|
||||
}
|
||||
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
|
||||
s.ForShare(opts...)
|
||||
})
|
||||
return _q
|
||||
}
|
||||
|
||||
// GroupGroupBy is the group-by builder for Group entities.
|
||||
type GroupGroupBy struct {
|
||||
selector
|
||||
|
||||
@@ -57,6 +57,30 @@ func (f GroupFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error
|
||||
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.GroupMutation", m)
|
||||
}
|
||||
|
||||
// The PromoCodeFunc type is an adapter to allow the use of ordinary
|
||||
// function as PromoCode mutator.
|
||||
type PromoCodeFunc func(context.Context, *ent.PromoCodeMutation) (ent.Value, error)
|
||||
|
||||
// Mutate calls f(ctx, m).
|
||||
func (f PromoCodeFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
|
||||
if mv, ok := m.(*ent.PromoCodeMutation); ok {
|
||||
return f(ctx, mv)
|
||||
}
|
||||
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.PromoCodeMutation", m)
|
||||
}
|
||||
|
||||
// The PromoCodeUsageFunc type is an adapter to allow the use of ordinary
|
||||
// function as PromoCodeUsage mutator.
|
||||
type PromoCodeUsageFunc func(context.Context, *ent.PromoCodeUsageMutation) (ent.Value, error)
|
||||
|
||||
// Mutate calls f(ctx, m).
|
||||
func (f PromoCodeUsageFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
|
||||
if mv, ok := m.(*ent.PromoCodeUsageMutation); ok {
|
||||
return f(ctx, mv)
|
||||
}
|
||||
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.PromoCodeUsageMutation", m)
|
||||
}
|
||||
|
||||
// The ProxyFunc type is an adapter to allow the use of ordinary
|
||||
// function as Proxy mutator.
|
||||
type ProxyFunc func(context.Context, *ent.ProxyMutation) (ent.Value, error)
|
||||
|
||||
@@ -13,6 +13,8 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/ent/apikey"
|
||||
"github.com/Wei-Shaw/sub2api/ent/group"
|
||||
"github.com/Wei-Shaw/sub2api/ent/predicate"
|
||||
"github.com/Wei-Shaw/sub2api/ent/promocode"
|
||||
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
|
||||
"github.com/Wei-Shaw/sub2api/ent/proxy"
|
||||
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
|
||||
"github.com/Wei-Shaw/sub2api/ent/setting"
|
||||
@@ -188,6 +190,60 @@ func (f TraverseGroup) Traverse(ctx context.Context, q ent.Query) error {
|
||||
return fmt.Errorf("unexpected query type %T. expect *ent.GroupQuery", q)
|
||||
}
|
||||
|
||||
// The PromoCodeFunc type is an adapter to allow the use of ordinary function as a Querier.
|
||||
type PromoCodeFunc func(context.Context, *ent.PromoCodeQuery) (ent.Value, error)
|
||||
|
||||
// Query calls f(ctx, q).
|
||||
func (f PromoCodeFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
|
||||
if q, ok := q.(*ent.PromoCodeQuery); ok {
|
||||
return f(ctx, q)
|
||||
}
|
||||
return nil, fmt.Errorf("unexpected query type %T. expect *ent.PromoCodeQuery", q)
|
||||
}
|
||||
|
||||
// The TraversePromoCode type is an adapter to allow the use of ordinary function as Traverser.
|
||||
type TraversePromoCode func(context.Context, *ent.PromoCodeQuery) error
|
||||
|
||||
// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline.
|
||||
func (f TraversePromoCode) Intercept(next ent.Querier) ent.Querier {
|
||||
return next
|
||||
}
|
||||
|
||||
// Traverse calls f(ctx, q).
|
||||
func (f TraversePromoCode) Traverse(ctx context.Context, q ent.Query) error {
|
||||
if q, ok := q.(*ent.PromoCodeQuery); ok {
|
||||
return f(ctx, q)
|
||||
}
|
||||
return fmt.Errorf("unexpected query type %T. expect *ent.PromoCodeQuery", q)
|
||||
}
|
||||
|
||||
// The PromoCodeUsageFunc type is an adapter to allow the use of ordinary function as a Querier.
|
||||
type PromoCodeUsageFunc func(context.Context, *ent.PromoCodeUsageQuery) (ent.Value, error)
|
||||
|
||||
// Query calls f(ctx, q).
|
||||
func (f PromoCodeUsageFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
|
||||
if q, ok := q.(*ent.PromoCodeUsageQuery); ok {
|
||||
return f(ctx, q)
|
||||
}
|
||||
return nil, fmt.Errorf("unexpected query type %T. expect *ent.PromoCodeUsageQuery", q)
|
||||
}
|
||||
|
||||
// The TraversePromoCodeUsage type is an adapter to allow the use of ordinary function as Traverser.
|
||||
type TraversePromoCodeUsage func(context.Context, *ent.PromoCodeUsageQuery) error
|
||||
|
||||
// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline.
|
||||
func (f TraversePromoCodeUsage) Intercept(next ent.Querier) ent.Querier {
|
||||
return next
|
||||
}
|
||||
|
||||
// Traverse calls f(ctx, q).
|
||||
func (f TraversePromoCodeUsage) Traverse(ctx context.Context, q ent.Query) error {
|
||||
if q, ok := q.(*ent.PromoCodeUsageQuery); ok {
|
||||
return f(ctx, q)
|
||||
}
|
||||
return fmt.Errorf("unexpected query type %T. expect *ent.PromoCodeUsageQuery", q)
|
||||
}
|
||||
|
||||
// The ProxyFunc type is an adapter to allow the use of ordinary function as a Querier.
|
||||
type ProxyFunc func(context.Context, *ent.ProxyQuery) (ent.Value, error)
|
||||
|
||||
@@ -442,6 +498,10 @@ func NewQuery(q ent.Query) (Query, error) {
|
||||
return &query[*ent.AccountGroupQuery, predicate.AccountGroup, accountgroup.OrderOption]{typ: ent.TypeAccountGroup, tq: q}, nil
|
||||
case *ent.GroupQuery:
|
||||
return &query[*ent.GroupQuery, predicate.Group, group.OrderOption]{typ: ent.TypeGroup, tq: q}, nil
|
||||
case *ent.PromoCodeQuery:
|
||||
return &query[*ent.PromoCodeQuery, predicate.PromoCode, promocode.OrderOption]{typ: ent.TypePromoCode, tq: q}, nil
|
||||
case *ent.PromoCodeUsageQuery:
|
||||
return &query[*ent.PromoCodeUsageQuery, predicate.PromoCodeUsage, promocodeusage.OrderOption]{typ: ent.TypePromoCodeUsage, tq: q}, nil
|
||||
case *ent.ProxyQuery:
|
||||
return &query[*ent.ProxyQuery, predicate.Proxy, proxy.OrderOption]{typ: ent.TypeProxy, tq: q}, nil
|
||||
case *ent.RedeemCodeQuery:
|
||||
|
||||
@@ -18,6 +18,8 @@ var (
|
||||
{Name: "key", Type: field.TypeString, Unique: true, Size: 128},
|
||||
{Name: "name", Type: field.TypeString, Size: 100},
|
||||
{Name: "status", Type: field.TypeString, Size: 20, Default: "active"},
|
||||
{Name: "ip_whitelist", Type: field.TypeJSON, Nullable: true},
|
||||
{Name: "ip_blacklist", Type: field.TypeJSON, Nullable: true},
|
||||
{Name: "group_id", Type: field.TypeInt64, Nullable: true},
|
||||
{Name: "user_id", Type: field.TypeInt64},
|
||||
}
|
||||
@@ -29,13 +31,13 @@ var (
|
||||
ForeignKeys: []*schema.ForeignKey{
|
||||
{
|
||||
Symbol: "api_keys_groups_api_keys",
|
||||
Columns: []*schema.Column{APIKeysColumns[7]},
|
||||
Columns: []*schema.Column{APIKeysColumns[9]},
|
||||
RefColumns: []*schema.Column{GroupsColumns[0]},
|
||||
OnDelete: schema.SetNull,
|
||||
},
|
||||
{
|
||||
Symbol: "api_keys_users_api_keys",
|
||||
Columns: []*schema.Column{APIKeysColumns[8]},
|
||||
Columns: []*schema.Column{APIKeysColumns[10]},
|
||||
RefColumns: []*schema.Column{UsersColumns[0]},
|
||||
OnDelete: schema.NoAction,
|
||||
},
|
||||
@@ -44,12 +46,12 @@ var (
|
||||
{
|
||||
Name: "apikey_user_id",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{APIKeysColumns[8]},
|
||||
Columns: []*schema.Column{APIKeysColumns[10]},
|
||||
},
|
||||
{
|
||||
Name: "apikey_group_id",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{APIKeysColumns[7]},
|
||||
Columns: []*schema.Column{APIKeysColumns[9]},
|
||||
},
|
||||
{
|
||||
Name: "apikey_status",
|
||||
@@ -257,6 +259,82 @@ var (
|
||||
},
|
||||
},
|
||||
}
|
||||
// PromoCodesColumns holds the columns for the "promo_codes" table.
|
||||
PromoCodesColumns = []*schema.Column{
|
||||
{Name: "id", Type: field.TypeInt64, Increment: true},
|
||||
{Name: "code", Type: field.TypeString, Unique: true, Size: 32},
|
||||
{Name: "bonus_amount", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
|
||||
{Name: "max_uses", Type: field.TypeInt, Default: 0},
|
||||
{Name: "used_count", Type: field.TypeInt, Default: 0},
|
||||
{Name: "status", Type: field.TypeString, Size: 20, Default: "active"},
|
||||
{Name: "expires_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
||||
{Name: "notes", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}},
|
||||
{Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
||||
{Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
||||
}
|
||||
// PromoCodesTable holds the schema information for the "promo_codes" table.
|
||||
PromoCodesTable = &schema.Table{
|
||||
Name: "promo_codes",
|
||||
Columns: PromoCodesColumns,
|
||||
PrimaryKey: []*schema.Column{PromoCodesColumns[0]},
|
||||
Indexes: []*schema.Index{
|
||||
{
|
||||
Name: "promocode_status",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{PromoCodesColumns[5]},
|
||||
},
|
||||
{
|
||||
Name: "promocode_expires_at",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{PromoCodesColumns[6]},
|
||||
},
|
||||
},
|
||||
}
|
||||
// PromoCodeUsagesColumns holds the columns for the "promo_code_usages" table.
|
||||
PromoCodeUsagesColumns = []*schema.Column{
|
||||
{Name: "id", Type: field.TypeInt64, Increment: true},
|
||||
{Name: "bonus_amount", Type: field.TypeFloat64, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
|
||||
{Name: "used_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
||||
{Name: "promo_code_id", Type: field.TypeInt64},
|
||||
{Name: "user_id", Type: field.TypeInt64},
|
||||
}
|
||||
// PromoCodeUsagesTable holds the schema information for the "promo_code_usages" table.
|
||||
PromoCodeUsagesTable = &schema.Table{
|
||||
Name: "promo_code_usages",
|
||||
Columns: PromoCodeUsagesColumns,
|
||||
PrimaryKey: []*schema.Column{PromoCodeUsagesColumns[0]},
|
||||
ForeignKeys: []*schema.ForeignKey{
|
||||
{
|
||||
Symbol: "promo_code_usages_promo_codes_usage_records",
|
||||
Columns: []*schema.Column{PromoCodeUsagesColumns[3]},
|
||||
RefColumns: []*schema.Column{PromoCodesColumns[0]},
|
||||
OnDelete: schema.NoAction,
|
||||
},
|
||||
{
|
||||
Symbol: "promo_code_usages_users_promo_code_usages",
|
||||
Columns: []*schema.Column{PromoCodeUsagesColumns[4]},
|
||||
RefColumns: []*schema.Column{UsersColumns[0]},
|
||||
OnDelete: schema.NoAction,
|
||||
},
|
||||
},
|
||||
Indexes: []*schema.Index{
|
||||
{
|
||||
Name: "promocodeusage_promo_code_id",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{PromoCodeUsagesColumns[3]},
|
||||
},
|
||||
{
|
||||
Name: "promocodeusage_user_id",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{PromoCodeUsagesColumns[4]},
|
||||
},
|
||||
{
|
||||
Name: "promocodeusage_promo_code_id_user_id",
|
||||
Unique: true,
|
||||
Columns: []*schema.Column{PromoCodeUsagesColumns[3], PromoCodeUsagesColumns[4]},
|
||||
},
|
||||
},
|
||||
}
|
||||
// ProxiesColumns holds the columns for the "proxies" table.
|
||||
ProxiesColumns = []*schema.Column{
|
||||
{Name: "id", Type: field.TypeInt64, Increment: true},
|
||||
@@ -376,6 +454,7 @@ var (
|
||||
{Name: "duration_ms", Type: field.TypeInt, Nullable: true},
|
||||
{Name: "first_token_ms", Type: field.TypeInt, Nullable: true},
|
||||
{Name: "user_agent", Type: field.TypeString, Nullable: true, Size: 512},
|
||||
{Name: "ip_address", Type: field.TypeString, Nullable: true, Size: 45},
|
||||
{Name: "image_count", Type: field.TypeInt, Default: 0},
|
||||
{Name: "image_size", Type: field.TypeString, Nullable: true, Size: 10},
|
||||
{Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
||||
@@ -393,31 +472,31 @@ var (
|
||||
ForeignKeys: []*schema.ForeignKey{
|
||||
{
|
||||
Symbol: "usage_logs_api_keys_usage_logs",
|
||||
Columns: []*schema.Column{UsageLogsColumns[24]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[25]},
|
||||
RefColumns: []*schema.Column{APIKeysColumns[0]},
|
||||
OnDelete: schema.NoAction,
|
||||
},
|
||||
{
|
||||
Symbol: "usage_logs_accounts_usage_logs",
|
||||
Columns: []*schema.Column{UsageLogsColumns[25]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[26]},
|
||||
RefColumns: []*schema.Column{AccountsColumns[0]},
|
||||
OnDelete: schema.NoAction,
|
||||
},
|
||||
{
|
||||
Symbol: "usage_logs_groups_usage_logs",
|
||||
Columns: []*schema.Column{UsageLogsColumns[26]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[27]},
|
||||
RefColumns: []*schema.Column{GroupsColumns[0]},
|
||||
OnDelete: schema.SetNull,
|
||||
},
|
||||
{
|
||||
Symbol: "usage_logs_users_usage_logs",
|
||||
Columns: []*schema.Column{UsageLogsColumns[27]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[28]},
|
||||
RefColumns: []*schema.Column{UsersColumns[0]},
|
||||
OnDelete: schema.NoAction,
|
||||
},
|
||||
{
|
||||
Symbol: "usage_logs_user_subscriptions_usage_logs",
|
||||
Columns: []*schema.Column{UsageLogsColumns[28]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[29]},
|
||||
RefColumns: []*schema.Column{UserSubscriptionsColumns[0]},
|
||||
OnDelete: schema.SetNull,
|
||||
},
|
||||
@@ -426,32 +505,32 @@ var (
|
||||
{
|
||||
Name: "usagelog_user_id",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[27]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[28]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_api_key_id",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[24]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[25]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_account_id",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[25]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[26]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_group_id",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[26]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[27]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_subscription_id",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[28]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[29]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_created_at",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[23]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[24]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_model",
|
||||
@@ -466,12 +545,12 @@ var (
|
||||
{
|
||||
Name: "usagelog_user_id_created_at",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[27], UsageLogsColumns[23]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[28], UsageLogsColumns[24]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_api_key_id_created_at",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[24], UsageLogsColumns[23]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[25], UsageLogsColumns[24]},
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -717,6 +796,8 @@ var (
|
||||
AccountsTable,
|
||||
AccountGroupsTable,
|
||||
GroupsTable,
|
||||
PromoCodesTable,
|
||||
PromoCodeUsagesTable,
|
||||
ProxiesTable,
|
||||
RedeemCodesTable,
|
||||
SettingsTable,
|
||||
@@ -747,6 +828,14 @@ func init() {
|
||||
GroupsTable.Annotation = &entsql.Annotation{
|
||||
Table: "groups",
|
||||
}
|
||||
PromoCodesTable.Annotation = &entsql.Annotation{
|
||||
Table: "promo_codes",
|
||||
}
|
||||
PromoCodeUsagesTable.ForeignKeys[0].RefTable = PromoCodesTable
|
||||
PromoCodeUsagesTable.ForeignKeys[1].RefTable = UsersTable
|
||||
PromoCodeUsagesTable.Annotation = &entsql.Annotation{
|
||||
Table: "promo_code_usages",
|
||||
}
|
||||
ProxiesTable.Annotation = &entsql.Annotation{
|
||||
Table: "proxies",
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -18,6 +18,12 @@ type AccountGroup func(*sql.Selector)
|
||||
// Group is the predicate function for group builders.
|
||||
type Group func(*sql.Selector)
|
||||
|
||||
// PromoCode is the predicate function for promocode builders.
|
||||
type PromoCode func(*sql.Selector)
|
||||
|
||||
// PromoCodeUsage is the predicate function for promocodeusage builders.
|
||||
type PromoCodeUsage func(*sql.Selector)
|
||||
|
||||
// Proxy is the predicate function for proxy builders.
|
||||
type Proxy func(*sql.Selector)
|
||||
|
||||
|
||||
228
backend/ent/promocode.go
Normal file
228
backend/ent/promocode.go
Normal file
@@ -0,0 +1,228 @@
|
||||
// Code generated by ent, DO NOT EDIT.
|
||||
|
||||
package ent
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"entgo.io/ent"
|
||||
"entgo.io/ent/dialect/sql"
|
||||
"github.com/Wei-Shaw/sub2api/ent/promocode"
|
||||
)
|
||||
|
||||
// PromoCode is the model entity for the PromoCode schema.
|
||||
type PromoCode struct {
|
||||
config `json:"-"`
|
||||
// ID of the ent.
|
||||
ID int64 `json:"id,omitempty"`
|
||||
// 优惠码
|
||||
Code string `json:"code,omitempty"`
|
||||
// 赠送余额金额
|
||||
BonusAmount float64 `json:"bonus_amount,omitempty"`
|
||||
// 最大使用次数,0表示无限制
|
||||
MaxUses int `json:"max_uses,omitempty"`
|
||||
// 已使用次数
|
||||
UsedCount int `json:"used_count,omitempty"`
|
||||
// 状态: active, disabled
|
||||
Status string `json:"status,omitempty"`
|
||||
// 过期时间,null表示永不过期
|
||||
ExpiresAt *time.Time `json:"expires_at,omitempty"`
|
||||
// 备注
|
||||
Notes *string `json:"notes,omitempty"`
|
||||
// CreatedAt holds the value of the "created_at" field.
|
||||
CreatedAt time.Time `json:"created_at,omitempty"`
|
||||
// UpdatedAt holds the value of the "updated_at" field.
|
||||
UpdatedAt time.Time `json:"updated_at,omitempty"`
|
||||
// Edges holds the relations/edges for other nodes in the graph.
|
||||
// The values are being populated by the PromoCodeQuery when eager-loading is set.
|
||||
Edges PromoCodeEdges `json:"edges"`
|
||||
selectValues sql.SelectValues
|
||||
}
|
||||
|
||||
// PromoCodeEdges holds the relations/edges for other nodes in the graph.
|
||||
type PromoCodeEdges struct {
|
||||
// UsageRecords holds the value of the usage_records edge.
|
||||
UsageRecords []*PromoCodeUsage `json:"usage_records,omitempty"`
|
||||
// loadedTypes holds the information for reporting if a
|
||||
// type was loaded (or requested) in eager-loading or not.
|
||||
loadedTypes [1]bool
|
||||
}
|
||||
|
||||
// UsageRecordsOrErr returns the UsageRecords value or an error if the edge
|
||||
// was not loaded in eager-loading.
|
||||
func (e PromoCodeEdges) UsageRecordsOrErr() ([]*PromoCodeUsage, error) {
|
||||
if e.loadedTypes[0] {
|
||||
return e.UsageRecords, nil
|
||||
}
|
||||
return nil, &NotLoadedError{edge: "usage_records"}
|
||||
}
|
||||
|
||||
// scanValues returns the types for scanning values from sql.Rows.
|
||||
func (*PromoCode) scanValues(columns []string) ([]any, error) {
|
||||
values := make([]any, len(columns))
|
||||
for i := range columns {
|
||||
switch columns[i] {
|
||||
case promocode.FieldBonusAmount:
|
||||
values[i] = new(sql.NullFloat64)
|
||||
case promocode.FieldID, promocode.FieldMaxUses, promocode.FieldUsedCount:
|
||||
values[i] = new(sql.NullInt64)
|
||||
case promocode.FieldCode, promocode.FieldStatus, promocode.FieldNotes:
|
||||
values[i] = new(sql.NullString)
|
||||
case promocode.FieldExpiresAt, promocode.FieldCreatedAt, promocode.FieldUpdatedAt:
|
||||
values[i] = new(sql.NullTime)
|
||||
default:
|
||||
values[i] = new(sql.UnknownType)
|
||||
}
|
||||
}
|
||||
return values, nil
|
||||
}
|
||||
|
||||
// assignValues assigns the values that were returned from sql.Rows (after scanning)
|
||||
// to the PromoCode fields.
|
||||
func (_m *PromoCode) assignValues(columns []string, values []any) error {
|
||||
if m, n := len(values), len(columns); m < n {
|
||||
return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
|
||||
}
|
||||
for i := range columns {
|
||||
switch columns[i] {
|
||||
case promocode.FieldID:
|
||||
value, ok := values[i].(*sql.NullInt64)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field id", value)
|
||||
}
|
||||
_m.ID = int64(value.Int64)
|
||||
case promocode.FieldCode:
|
||||
if value, ok := values[i].(*sql.NullString); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field code", values[i])
|
||||
} else if value.Valid {
|
||||
_m.Code = value.String
|
||||
}
|
||||
case promocode.FieldBonusAmount:
|
||||
if value, ok := values[i].(*sql.NullFloat64); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field bonus_amount", values[i])
|
||||
} else if value.Valid {
|
||||
_m.BonusAmount = value.Float64
|
||||
}
|
||||
case promocode.FieldMaxUses:
|
||||
if value, ok := values[i].(*sql.NullInt64); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field max_uses", values[i])
|
||||
} else if value.Valid {
|
||||
_m.MaxUses = int(value.Int64)
|
||||
}
|
||||
case promocode.FieldUsedCount:
|
||||
if value, ok := values[i].(*sql.NullInt64); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field used_count", values[i])
|
||||
} else if value.Valid {
|
||||
_m.UsedCount = int(value.Int64)
|
||||
}
|
||||
case promocode.FieldStatus:
|
||||
if value, ok := values[i].(*sql.NullString); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field status", values[i])
|
||||
} else if value.Valid {
|
||||
_m.Status = value.String
|
||||
}
|
||||
case promocode.FieldExpiresAt:
|
||||
if value, ok := values[i].(*sql.NullTime); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field expires_at", values[i])
|
||||
} else if value.Valid {
|
||||
_m.ExpiresAt = new(time.Time)
|
||||
*_m.ExpiresAt = value.Time
|
||||
}
|
||||
case promocode.FieldNotes:
|
||||
if value, ok := values[i].(*sql.NullString); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field notes", values[i])
|
||||
} else if value.Valid {
|
||||
_m.Notes = new(string)
|
||||
*_m.Notes = value.String
|
||||
}
|
||||
case promocode.FieldCreatedAt:
|
||||
if value, ok := values[i].(*sql.NullTime); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field created_at", values[i])
|
||||
} else if value.Valid {
|
||||
_m.CreatedAt = value.Time
|
||||
}
|
||||
case promocode.FieldUpdatedAt:
|
||||
if value, ok := values[i].(*sql.NullTime); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field updated_at", values[i])
|
||||
} else if value.Valid {
|
||||
_m.UpdatedAt = value.Time
|
||||
}
|
||||
default:
|
||||
_m.selectValues.Set(columns[i], values[i])
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value returns the ent.Value that was dynamically selected and assigned to the PromoCode.
|
||||
// This includes values selected through modifiers, order, etc.
|
||||
func (_m *PromoCode) Value(name string) (ent.Value, error) {
|
||||
return _m.selectValues.Get(name)
|
||||
}
|
||||
|
||||
// QueryUsageRecords queries the "usage_records" edge of the PromoCode entity.
|
||||
func (_m *PromoCode) QueryUsageRecords() *PromoCodeUsageQuery {
|
||||
return NewPromoCodeClient(_m.config).QueryUsageRecords(_m)
|
||||
}
|
||||
|
||||
// Update returns a builder for updating this PromoCode.
|
||||
// Note that you need to call PromoCode.Unwrap() before calling this method if this PromoCode
|
||||
// was returned from a transaction, and the transaction was committed or rolled back.
|
||||
func (_m *PromoCode) Update() *PromoCodeUpdateOne {
|
||||
return NewPromoCodeClient(_m.config).UpdateOne(_m)
|
||||
}
|
||||
|
||||
// Unwrap unwraps the PromoCode entity that was returned from a transaction after it was closed,
|
||||
// so that all future queries will be executed through the driver which created the transaction.
|
||||
func (_m *PromoCode) Unwrap() *PromoCode {
|
||||
_tx, ok := _m.config.driver.(*txDriver)
|
||||
if !ok {
|
||||
panic("ent: PromoCode is not a transactional entity")
|
||||
}
|
||||
_m.config.driver = _tx.drv
|
||||
return _m
|
||||
}
|
||||
|
||||
// String implements the fmt.Stringer.
|
||||
func (_m *PromoCode) String() string {
|
||||
var builder strings.Builder
|
||||
builder.WriteString("PromoCode(")
|
||||
builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID))
|
||||
builder.WriteString("code=")
|
||||
builder.WriteString(_m.Code)
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("bonus_amount=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.BonusAmount))
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("max_uses=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.MaxUses))
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("used_count=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.UsedCount))
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("status=")
|
||||
builder.WriteString(_m.Status)
|
||||
builder.WriteString(", ")
|
||||
if v := _m.ExpiresAt; v != nil {
|
||||
builder.WriteString("expires_at=")
|
||||
builder.WriteString(v.Format(time.ANSIC))
|
||||
}
|
||||
builder.WriteString(", ")
|
||||
if v := _m.Notes; v != nil {
|
||||
builder.WriteString("notes=")
|
||||
builder.WriteString(*v)
|
||||
}
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("created_at=")
|
||||
builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("updated_at=")
|
||||
builder.WriteString(_m.UpdatedAt.Format(time.ANSIC))
|
||||
builder.WriteByte(')')
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
// PromoCodes is a parsable slice of PromoCode.
|
||||
type PromoCodes []*PromoCode
|
||||
165
backend/ent/promocode/promocode.go
Normal file
165
backend/ent/promocode/promocode.go
Normal file
@@ -0,0 +1,165 @@
|
||||
// Code generated by ent, DO NOT EDIT.
|
||||
|
||||
package promocode
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"entgo.io/ent/dialect/sql"
|
||||
"entgo.io/ent/dialect/sql/sqlgraph"
|
||||
)
|
||||
|
||||
const (
|
||||
// Label holds the string label denoting the promocode type in the database.
|
||||
Label = "promo_code"
|
||||
// FieldID holds the string denoting the id field in the database.
|
||||
FieldID = "id"
|
||||
// FieldCode holds the string denoting the code field in the database.
|
||||
FieldCode = "code"
|
||||
// FieldBonusAmount holds the string denoting the bonus_amount field in the database.
|
||||
FieldBonusAmount = "bonus_amount"
|
||||
// FieldMaxUses holds the string denoting the max_uses field in the database.
|
||||
FieldMaxUses = "max_uses"
|
||||
// FieldUsedCount holds the string denoting the used_count field in the database.
|
||||
FieldUsedCount = "used_count"
|
||||
// FieldStatus holds the string denoting the status field in the database.
|
||||
FieldStatus = "status"
|
||||
// FieldExpiresAt holds the string denoting the expires_at field in the database.
|
||||
FieldExpiresAt = "expires_at"
|
||||
// FieldNotes holds the string denoting the notes field in the database.
|
||||
FieldNotes = "notes"
|
||||
// FieldCreatedAt holds the string denoting the created_at field in the database.
|
||||
FieldCreatedAt = "created_at"
|
||||
// FieldUpdatedAt holds the string denoting the updated_at field in the database.
|
||||
FieldUpdatedAt = "updated_at"
|
||||
// EdgeUsageRecords holds the string denoting the usage_records edge name in mutations.
|
||||
EdgeUsageRecords = "usage_records"
|
||||
// Table holds the table name of the promocode in the database.
|
||||
Table = "promo_codes"
|
||||
// UsageRecordsTable is the table that holds the usage_records relation/edge.
|
||||
UsageRecordsTable = "promo_code_usages"
|
||||
// UsageRecordsInverseTable is the table name for the PromoCodeUsage entity.
|
||||
// It exists in this package in order to avoid circular dependency with the "promocodeusage" package.
|
||||
UsageRecordsInverseTable = "promo_code_usages"
|
||||
// UsageRecordsColumn is the table column denoting the usage_records relation/edge.
|
||||
UsageRecordsColumn = "promo_code_id"
|
||||
)
|
||||
|
||||
// Columns holds all SQL columns for promocode fields.
|
||||
var Columns = []string{
|
||||
FieldID,
|
||||
FieldCode,
|
||||
FieldBonusAmount,
|
||||
FieldMaxUses,
|
||||
FieldUsedCount,
|
||||
FieldStatus,
|
||||
FieldExpiresAt,
|
||||
FieldNotes,
|
||||
FieldCreatedAt,
|
||||
FieldUpdatedAt,
|
||||
}
|
||||
|
||||
// ValidColumn reports if the column name is valid (part of the table columns).
|
||||
func ValidColumn(column string) bool {
|
||||
for i := range Columns {
|
||||
if column == Columns[i] {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
var (
|
||||
// CodeValidator is a validator for the "code" field. It is called by the builders before save.
|
||||
CodeValidator func(string) error
|
||||
// DefaultBonusAmount holds the default value on creation for the "bonus_amount" field.
|
||||
DefaultBonusAmount float64
|
||||
// DefaultMaxUses holds the default value on creation for the "max_uses" field.
|
||||
DefaultMaxUses int
|
||||
// DefaultUsedCount holds the default value on creation for the "used_count" field.
|
||||
DefaultUsedCount int
|
||||
// DefaultStatus holds the default value on creation for the "status" field.
|
||||
DefaultStatus string
|
||||
// StatusValidator is a validator for the "status" field. It is called by the builders before save.
|
||||
StatusValidator func(string) error
|
||||
// DefaultCreatedAt holds the default value on creation for the "created_at" field.
|
||||
DefaultCreatedAt func() time.Time
|
||||
// DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
|
||||
DefaultUpdatedAt func() time.Time
|
||||
// UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field.
|
||||
UpdateDefaultUpdatedAt func() time.Time
|
||||
)
|
||||
|
||||
// OrderOption defines the ordering options for the PromoCode queries.
|
||||
type OrderOption func(*sql.Selector)
|
||||
|
||||
// ByID orders the results by the id field.
|
||||
func ByID(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldID, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByCode orders the results by the code field.
|
||||
func ByCode(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldCode, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByBonusAmount orders the results by the bonus_amount field.
|
||||
func ByBonusAmount(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldBonusAmount, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByMaxUses orders the results by the max_uses field.
|
||||
func ByMaxUses(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldMaxUses, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByUsedCount orders the results by the used_count field.
|
||||
func ByUsedCount(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldUsedCount, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByStatus orders the results by the status field.
|
||||
func ByStatus(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldStatus, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByExpiresAt orders the results by the expires_at field.
|
||||
func ByExpiresAt(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldExpiresAt, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByNotes orders the results by the notes field.
|
||||
func ByNotes(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldNotes, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByCreatedAt orders the results by the created_at field.
|
||||
func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByUpdatedAt orders the results by the updated_at field.
|
||||
func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByUsageRecordsCount orders the results by usage_records count.
|
||||
func ByUsageRecordsCount(opts ...sql.OrderTermOption) OrderOption {
|
||||
return func(s *sql.Selector) {
|
||||
sqlgraph.OrderByNeighborsCount(s, newUsageRecordsStep(), opts...)
|
||||
}
|
||||
}
|
||||
|
||||
// ByUsageRecords orders the results by usage_records terms.
|
||||
func ByUsageRecords(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
|
||||
return func(s *sql.Selector) {
|
||||
sqlgraph.OrderByNeighborTerms(s, newUsageRecordsStep(), append([]sql.OrderTerm{term}, terms...)...)
|
||||
}
|
||||
}
|
||||
func newUsageRecordsStep() *sqlgraph.Step {
|
||||
return sqlgraph.NewStep(
|
||||
sqlgraph.From(Table, FieldID),
|
||||
sqlgraph.To(UsageRecordsInverseTable, FieldID),
|
||||
sqlgraph.Edge(sqlgraph.O2M, false, UsageRecordsTable, UsageRecordsColumn),
|
||||
)
|
||||
}
|
||||
594
backend/ent/promocode/where.go
Normal file
594
backend/ent/promocode/where.go
Normal file
@@ -0,0 +1,594 @@
|
||||
// Code generated by ent, DO NOT EDIT.
|
||||
|
||||
package promocode
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"entgo.io/ent/dialect/sql"
|
||||
"entgo.io/ent/dialect/sql/sqlgraph"
|
||||
"github.com/Wei-Shaw/sub2api/ent/predicate"
|
||||
)
|
||||
|
||||
// ID filters vertices based on their ID field.
|
||||
func ID(id int64) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldEQ(FieldID, id))
|
||||
}
|
||||
|
||||
// IDEQ applies the EQ predicate on the ID field.
|
||||
func IDEQ(id int64) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldEQ(FieldID, id))
|
||||
}
|
||||
|
||||
// IDNEQ applies the NEQ predicate on the ID field.
|
||||
func IDNEQ(id int64) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldNEQ(FieldID, id))
|
||||
}
|
||||
|
||||
// IDIn applies the In predicate on the ID field.
|
||||
func IDIn(ids ...int64) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldIn(FieldID, ids...))
|
||||
}
|
||||
|
||||
// IDNotIn applies the NotIn predicate on the ID field.
|
||||
func IDNotIn(ids ...int64) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldNotIn(FieldID, ids...))
|
||||
}
|
||||
|
||||
// IDGT applies the GT predicate on the ID field.
|
||||
func IDGT(id int64) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldGT(FieldID, id))
|
||||
}
|
||||
|
||||
// IDGTE applies the GTE predicate on the ID field.
|
||||
func IDGTE(id int64) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldGTE(FieldID, id))
|
||||
}
|
||||
|
||||
// IDLT applies the LT predicate on the ID field.
|
||||
func IDLT(id int64) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldLT(FieldID, id))
|
||||
}
|
||||
|
||||
// IDLTE applies the LTE predicate on the ID field.
|
||||
func IDLTE(id int64) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldLTE(FieldID, id))
|
||||
}
|
||||
|
||||
// Code applies equality check predicate on the "code" field. It's identical to CodeEQ.
|
||||
func Code(v string) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldEQ(FieldCode, v))
|
||||
}
|
||||
|
||||
// BonusAmount applies equality check predicate on the "bonus_amount" field. It's identical to BonusAmountEQ.
|
||||
func BonusAmount(v float64) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldEQ(FieldBonusAmount, v))
|
||||
}
|
||||
|
||||
// MaxUses applies equality check predicate on the "max_uses" field. It's identical to MaxUsesEQ.
|
||||
func MaxUses(v int) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldEQ(FieldMaxUses, v))
|
||||
}
|
||||
|
||||
// UsedCount applies equality check predicate on the "used_count" field. It's identical to UsedCountEQ.
|
||||
func UsedCount(v int) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldEQ(FieldUsedCount, v))
|
||||
}
|
||||
|
||||
// Status applies equality check predicate on the "status" field. It's identical to StatusEQ.
|
||||
func Status(v string) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldEQ(FieldStatus, v))
|
||||
}
|
||||
|
||||
// ExpiresAt applies equality check predicate on the "expires_at" field. It's identical to ExpiresAtEQ.
|
||||
func ExpiresAt(v time.Time) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldEQ(FieldExpiresAt, v))
|
||||
}
|
||||
|
||||
// Notes applies equality check predicate on the "notes" field. It's identical to NotesEQ.
|
||||
func Notes(v string) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldEQ(FieldNotes, v))
|
||||
}
|
||||
|
||||
// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
|
||||
func CreatedAt(v time.Time) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldEQ(FieldCreatedAt, v))
|
||||
}
|
||||
|
||||
// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ.
|
||||
func UpdatedAt(v time.Time) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldEQ(FieldUpdatedAt, v))
|
||||
}
|
||||
|
||||
// CodeEQ applies the EQ predicate on the "code" field.
|
||||
func CodeEQ(v string) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldEQ(FieldCode, v))
|
||||
}
|
||||
|
||||
// CodeNEQ applies the NEQ predicate on the "code" field.
|
||||
func CodeNEQ(v string) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldNEQ(FieldCode, v))
|
||||
}
|
||||
|
||||
// CodeIn applies the In predicate on the "code" field.
|
||||
func CodeIn(vs ...string) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldIn(FieldCode, vs...))
|
||||
}
|
||||
|
||||
// CodeNotIn applies the NotIn predicate on the "code" field.
|
||||
func CodeNotIn(vs ...string) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldNotIn(FieldCode, vs...))
|
||||
}
|
||||
|
||||
// CodeGT applies the GT predicate on the "code" field.
|
||||
func CodeGT(v string) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldGT(FieldCode, v))
|
||||
}
|
||||
|
||||
// CodeGTE applies the GTE predicate on the "code" field.
|
||||
func CodeGTE(v string) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldGTE(FieldCode, v))
|
||||
}
|
||||
|
||||
// CodeLT applies the LT predicate on the "code" field.
|
||||
func CodeLT(v string) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldLT(FieldCode, v))
|
||||
}
|
||||
|
||||
// CodeLTE applies the LTE predicate on the "code" field.
|
||||
func CodeLTE(v string) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldLTE(FieldCode, v))
|
||||
}
|
||||
|
||||
// CodeContains applies the Contains predicate on the "code" field.
|
||||
func CodeContains(v string) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldContains(FieldCode, v))
|
||||
}
|
||||
|
||||
// CodeHasPrefix applies the HasPrefix predicate on the "code" field.
|
||||
func CodeHasPrefix(v string) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldHasPrefix(FieldCode, v))
|
||||
}
|
||||
|
||||
// CodeHasSuffix applies the HasSuffix predicate on the "code" field.
|
||||
func CodeHasSuffix(v string) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldHasSuffix(FieldCode, v))
|
||||
}
|
||||
|
||||
// CodeEqualFold applies the EqualFold predicate on the "code" field.
|
||||
func CodeEqualFold(v string) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldEqualFold(FieldCode, v))
|
||||
}
|
||||
|
||||
// CodeContainsFold applies the ContainsFold predicate on the "code" field.
|
||||
func CodeContainsFold(v string) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldContainsFold(FieldCode, v))
|
||||
}
|
||||
|
||||
// BonusAmountEQ applies the EQ predicate on the "bonus_amount" field.
|
||||
func BonusAmountEQ(v float64) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldEQ(FieldBonusAmount, v))
|
||||
}
|
||||
|
||||
// BonusAmountNEQ applies the NEQ predicate on the "bonus_amount" field.
|
||||
func BonusAmountNEQ(v float64) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldNEQ(FieldBonusAmount, v))
|
||||
}
|
||||
|
||||
// BonusAmountIn applies the In predicate on the "bonus_amount" field.
|
||||
func BonusAmountIn(vs ...float64) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldIn(FieldBonusAmount, vs...))
|
||||
}
|
||||
|
||||
// BonusAmountNotIn applies the NotIn predicate on the "bonus_amount" field.
|
||||
func BonusAmountNotIn(vs ...float64) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldNotIn(FieldBonusAmount, vs...))
|
||||
}
|
||||
|
||||
// BonusAmountGT applies the GT predicate on the "bonus_amount" field.
|
||||
func BonusAmountGT(v float64) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldGT(FieldBonusAmount, v))
|
||||
}
|
||||
|
||||
// BonusAmountGTE applies the GTE predicate on the "bonus_amount" field.
|
||||
func BonusAmountGTE(v float64) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldGTE(FieldBonusAmount, v))
|
||||
}
|
||||
|
||||
// BonusAmountLT applies the LT predicate on the "bonus_amount" field.
|
||||
func BonusAmountLT(v float64) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldLT(FieldBonusAmount, v))
|
||||
}
|
||||
|
||||
// BonusAmountLTE applies the LTE predicate on the "bonus_amount" field.
|
||||
func BonusAmountLTE(v float64) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldLTE(FieldBonusAmount, v))
|
||||
}
|
||||
|
||||
// MaxUsesEQ applies the EQ predicate on the "max_uses" field.
|
||||
func MaxUsesEQ(v int) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldEQ(FieldMaxUses, v))
|
||||
}
|
||||
|
||||
// MaxUsesNEQ applies the NEQ predicate on the "max_uses" field.
|
||||
func MaxUsesNEQ(v int) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldNEQ(FieldMaxUses, v))
|
||||
}
|
||||
|
||||
// MaxUsesIn applies the In predicate on the "max_uses" field.
|
||||
func MaxUsesIn(vs ...int) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldIn(FieldMaxUses, vs...))
|
||||
}
|
||||
|
||||
// MaxUsesNotIn applies the NotIn predicate on the "max_uses" field.
|
||||
func MaxUsesNotIn(vs ...int) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldNotIn(FieldMaxUses, vs...))
|
||||
}
|
||||
|
||||
// MaxUsesGT applies the GT predicate on the "max_uses" field.
|
||||
func MaxUsesGT(v int) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldGT(FieldMaxUses, v))
|
||||
}
|
||||
|
||||
// MaxUsesGTE applies the GTE predicate on the "max_uses" field.
|
||||
func MaxUsesGTE(v int) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldGTE(FieldMaxUses, v))
|
||||
}
|
||||
|
||||
// MaxUsesLT applies the LT predicate on the "max_uses" field.
|
||||
func MaxUsesLT(v int) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldLT(FieldMaxUses, v))
|
||||
}
|
||||
|
||||
// MaxUsesLTE applies the LTE predicate on the "max_uses" field.
|
||||
func MaxUsesLTE(v int) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldLTE(FieldMaxUses, v))
|
||||
}
|
||||
|
||||
// UsedCountEQ applies the EQ predicate on the "used_count" field.
|
||||
func UsedCountEQ(v int) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldEQ(FieldUsedCount, v))
|
||||
}
|
||||
|
||||
// UsedCountNEQ applies the NEQ predicate on the "used_count" field.
|
||||
func UsedCountNEQ(v int) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldNEQ(FieldUsedCount, v))
|
||||
}
|
||||
|
||||
// UsedCountIn applies the In predicate on the "used_count" field.
|
||||
func UsedCountIn(vs ...int) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldIn(FieldUsedCount, vs...))
|
||||
}
|
||||
|
||||
// UsedCountNotIn applies the NotIn predicate on the "used_count" field.
|
||||
func UsedCountNotIn(vs ...int) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldNotIn(FieldUsedCount, vs...))
|
||||
}
|
||||
|
||||
// UsedCountGT applies the GT predicate on the "used_count" field.
|
||||
func UsedCountGT(v int) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldGT(FieldUsedCount, v))
|
||||
}
|
||||
|
||||
// UsedCountGTE applies the GTE predicate on the "used_count" field.
|
||||
func UsedCountGTE(v int) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldGTE(FieldUsedCount, v))
|
||||
}
|
||||
|
||||
// UsedCountLT applies the LT predicate on the "used_count" field.
|
||||
func UsedCountLT(v int) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldLT(FieldUsedCount, v))
|
||||
}
|
||||
|
||||
// UsedCountLTE applies the LTE predicate on the "used_count" field.
|
||||
func UsedCountLTE(v int) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldLTE(FieldUsedCount, v))
|
||||
}
|
||||
|
||||
// StatusEQ applies the EQ predicate on the "status" field.
|
||||
func StatusEQ(v string) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldEQ(FieldStatus, v))
|
||||
}
|
||||
|
||||
// StatusNEQ applies the NEQ predicate on the "status" field.
|
||||
func StatusNEQ(v string) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldNEQ(FieldStatus, v))
|
||||
}
|
||||
|
||||
// StatusIn applies the In predicate on the "status" field.
|
||||
func StatusIn(vs ...string) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldIn(FieldStatus, vs...))
|
||||
}
|
||||
|
||||
// StatusNotIn applies the NotIn predicate on the "status" field.
|
||||
func StatusNotIn(vs ...string) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldNotIn(FieldStatus, vs...))
|
||||
}
|
||||
|
||||
// StatusGT applies the GT predicate on the "status" field.
|
||||
func StatusGT(v string) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldGT(FieldStatus, v))
|
||||
}
|
||||
|
||||
// StatusGTE applies the GTE predicate on the "status" field.
|
||||
func StatusGTE(v string) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldGTE(FieldStatus, v))
|
||||
}
|
||||
|
||||
// StatusLT applies the LT predicate on the "status" field.
|
||||
func StatusLT(v string) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldLT(FieldStatus, v))
|
||||
}
|
||||
|
||||
// StatusLTE applies the LTE predicate on the "status" field.
|
||||
func StatusLTE(v string) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldLTE(FieldStatus, v))
|
||||
}
|
||||
|
||||
// StatusContains applies the Contains predicate on the "status" field.
|
||||
func StatusContains(v string) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldContains(FieldStatus, v))
|
||||
}
|
||||
|
||||
// StatusHasPrefix applies the HasPrefix predicate on the "status" field.
|
||||
func StatusHasPrefix(v string) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldHasPrefix(FieldStatus, v))
|
||||
}
|
||||
|
||||
// StatusHasSuffix applies the HasSuffix predicate on the "status" field.
|
||||
func StatusHasSuffix(v string) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldHasSuffix(FieldStatus, v))
|
||||
}
|
||||
|
||||
// StatusEqualFold applies the EqualFold predicate on the "status" field.
|
||||
func StatusEqualFold(v string) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldEqualFold(FieldStatus, v))
|
||||
}
|
||||
|
||||
// StatusContainsFold applies the ContainsFold predicate on the "status" field.
|
||||
func StatusContainsFold(v string) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldContainsFold(FieldStatus, v))
|
||||
}
|
||||
|
||||
// ExpiresAtEQ applies the EQ predicate on the "expires_at" field.
|
||||
func ExpiresAtEQ(v time.Time) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldEQ(FieldExpiresAt, v))
|
||||
}
|
||||
|
||||
// ExpiresAtNEQ applies the NEQ predicate on the "expires_at" field.
|
||||
func ExpiresAtNEQ(v time.Time) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldNEQ(FieldExpiresAt, v))
|
||||
}
|
||||
|
||||
// ExpiresAtIn applies the In predicate on the "expires_at" field.
|
||||
func ExpiresAtIn(vs ...time.Time) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldIn(FieldExpiresAt, vs...))
|
||||
}
|
||||
|
||||
// ExpiresAtNotIn applies the NotIn predicate on the "expires_at" field.
|
||||
func ExpiresAtNotIn(vs ...time.Time) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldNotIn(FieldExpiresAt, vs...))
|
||||
}
|
||||
|
||||
// ExpiresAtGT applies the GT predicate on the "expires_at" field.
|
||||
func ExpiresAtGT(v time.Time) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldGT(FieldExpiresAt, v))
|
||||
}
|
||||
|
||||
// ExpiresAtGTE applies the GTE predicate on the "expires_at" field.
|
||||
func ExpiresAtGTE(v time.Time) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldGTE(FieldExpiresAt, v))
|
||||
}
|
||||
|
||||
// ExpiresAtLT applies the LT predicate on the "expires_at" field.
|
||||
func ExpiresAtLT(v time.Time) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldLT(FieldExpiresAt, v))
|
||||
}
|
||||
|
||||
// ExpiresAtLTE applies the LTE predicate on the "expires_at" field.
|
||||
func ExpiresAtLTE(v time.Time) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldLTE(FieldExpiresAt, v))
|
||||
}
|
||||
|
||||
// ExpiresAtIsNil applies the IsNil predicate on the "expires_at" field.
|
||||
func ExpiresAtIsNil() predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldIsNull(FieldExpiresAt))
|
||||
}
|
||||
|
||||
// ExpiresAtNotNil applies the NotNil predicate on the "expires_at" field.
|
||||
func ExpiresAtNotNil() predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldNotNull(FieldExpiresAt))
|
||||
}
|
||||
|
||||
// NotesEQ applies the EQ predicate on the "notes" field.
|
||||
func NotesEQ(v string) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldEQ(FieldNotes, v))
|
||||
}
|
||||
|
||||
// NotesNEQ applies the NEQ predicate on the "notes" field.
|
||||
func NotesNEQ(v string) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldNEQ(FieldNotes, v))
|
||||
}
|
||||
|
||||
// NotesIn applies the In predicate on the "notes" field.
|
||||
func NotesIn(vs ...string) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldIn(FieldNotes, vs...))
|
||||
}
|
||||
|
||||
// NotesNotIn applies the NotIn predicate on the "notes" field.
|
||||
func NotesNotIn(vs ...string) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldNotIn(FieldNotes, vs...))
|
||||
}
|
||||
|
||||
// NotesGT applies the GT predicate on the "notes" field.
|
||||
func NotesGT(v string) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldGT(FieldNotes, v))
|
||||
}
|
||||
|
||||
// NotesGTE applies the GTE predicate on the "notes" field.
|
||||
func NotesGTE(v string) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldGTE(FieldNotes, v))
|
||||
}
|
||||
|
||||
// NotesLT applies the LT predicate on the "notes" field.
|
||||
func NotesLT(v string) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldLT(FieldNotes, v))
|
||||
}
|
||||
|
||||
// NotesLTE applies the LTE predicate on the "notes" field.
|
||||
func NotesLTE(v string) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldLTE(FieldNotes, v))
|
||||
}
|
||||
|
||||
// NotesContains applies the Contains predicate on the "notes" field.
|
||||
func NotesContains(v string) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldContains(FieldNotes, v))
|
||||
}
|
||||
|
||||
// NotesHasPrefix applies the HasPrefix predicate on the "notes" field.
|
||||
func NotesHasPrefix(v string) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldHasPrefix(FieldNotes, v))
|
||||
}
|
||||
|
||||
// NotesHasSuffix applies the HasSuffix predicate on the "notes" field.
|
||||
func NotesHasSuffix(v string) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldHasSuffix(FieldNotes, v))
|
||||
}
|
||||
|
||||
// NotesIsNil applies the IsNil predicate on the "notes" field.
|
||||
func NotesIsNil() predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldIsNull(FieldNotes))
|
||||
}
|
||||
|
||||
// NotesNotNil applies the NotNil predicate on the "notes" field.
|
||||
func NotesNotNil() predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldNotNull(FieldNotes))
|
||||
}
|
||||
|
||||
// NotesEqualFold applies the EqualFold predicate on the "notes" field.
|
||||
func NotesEqualFold(v string) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldEqualFold(FieldNotes, v))
|
||||
}
|
||||
|
||||
// NotesContainsFold applies the ContainsFold predicate on the "notes" field.
|
||||
func NotesContainsFold(v string) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldContainsFold(FieldNotes, v))
|
||||
}
|
||||
|
||||
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
|
||||
func CreatedAtEQ(v time.Time) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldEQ(FieldCreatedAt, v))
|
||||
}
|
||||
|
||||
// CreatedAtNEQ applies the NEQ predicate on the "created_at" field.
|
||||
func CreatedAtNEQ(v time.Time) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldNEQ(FieldCreatedAt, v))
|
||||
}
|
||||
|
||||
// CreatedAtIn applies the In predicate on the "created_at" field.
|
||||
func CreatedAtIn(vs ...time.Time) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldIn(FieldCreatedAt, vs...))
|
||||
}
|
||||
|
||||
// CreatedAtNotIn applies the NotIn predicate on the "created_at" field.
|
||||
func CreatedAtNotIn(vs ...time.Time) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldNotIn(FieldCreatedAt, vs...))
|
||||
}
|
||||
|
||||
// CreatedAtGT applies the GT predicate on the "created_at" field.
|
||||
func CreatedAtGT(v time.Time) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldGT(FieldCreatedAt, v))
|
||||
}
|
||||
|
||||
// CreatedAtGTE applies the GTE predicate on the "created_at" field.
|
||||
func CreatedAtGTE(v time.Time) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldGTE(FieldCreatedAt, v))
|
||||
}
|
||||
|
||||
// CreatedAtLT applies the LT predicate on the "created_at" field.
|
||||
func CreatedAtLT(v time.Time) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldLT(FieldCreatedAt, v))
|
||||
}
|
||||
|
||||
// CreatedAtLTE applies the LTE predicate on the "created_at" field.
|
||||
func CreatedAtLTE(v time.Time) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldLTE(FieldCreatedAt, v))
|
||||
}
|
||||
|
||||
// UpdatedAtEQ applies the EQ predicate on the "updated_at" field.
|
||||
func UpdatedAtEQ(v time.Time) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldEQ(FieldUpdatedAt, v))
|
||||
}
|
||||
|
||||
// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field.
|
||||
func UpdatedAtNEQ(v time.Time) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldNEQ(FieldUpdatedAt, v))
|
||||
}
|
||||
|
||||
// UpdatedAtIn applies the In predicate on the "updated_at" field.
|
||||
func UpdatedAtIn(vs ...time.Time) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldIn(FieldUpdatedAt, vs...))
|
||||
}
|
||||
|
||||
// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field.
|
||||
func UpdatedAtNotIn(vs ...time.Time) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldNotIn(FieldUpdatedAt, vs...))
|
||||
}
|
||||
|
||||
// UpdatedAtGT applies the GT predicate on the "updated_at" field.
|
||||
func UpdatedAtGT(v time.Time) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldGT(FieldUpdatedAt, v))
|
||||
}
|
||||
|
||||
// UpdatedAtGTE applies the GTE predicate on the "updated_at" field.
|
||||
func UpdatedAtGTE(v time.Time) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldGTE(FieldUpdatedAt, v))
|
||||
}
|
||||
|
||||
// UpdatedAtLT applies the LT predicate on the "updated_at" field.
|
||||
func UpdatedAtLT(v time.Time) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldLT(FieldUpdatedAt, v))
|
||||
}
|
||||
|
||||
// UpdatedAtLTE applies the LTE predicate on the "updated_at" field.
|
||||
func UpdatedAtLTE(v time.Time) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.FieldLTE(FieldUpdatedAt, v))
|
||||
}
|
||||
|
||||
// HasUsageRecords applies the HasEdge predicate on the "usage_records" edge.
|
||||
func HasUsageRecords() predicate.PromoCode {
|
||||
return predicate.PromoCode(func(s *sql.Selector) {
|
||||
step := sqlgraph.NewStep(
|
||||
sqlgraph.From(Table, FieldID),
|
||||
sqlgraph.Edge(sqlgraph.O2M, false, UsageRecordsTable, UsageRecordsColumn),
|
||||
)
|
||||
sqlgraph.HasNeighbors(s, step)
|
||||
})
|
||||
}
|
||||
|
||||
// HasUsageRecordsWith applies the HasEdge predicate on the "usage_records" edge with a given conditions (other predicates).
|
||||
func HasUsageRecordsWith(preds ...predicate.PromoCodeUsage) predicate.PromoCode {
|
||||
return predicate.PromoCode(func(s *sql.Selector) {
|
||||
step := newUsageRecordsStep()
|
||||
sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
|
||||
for _, p := range preds {
|
||||
p(s)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// And groups predicates with the AND operator between them.
|
||||
func And(predicates ...predicate.PromoCode) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.AndPredicates(predicates...))
|
||||
}
|
||||
|
||||
// Or groups predicates with the OR operator between them.
|
||||
func Or(predicates ...predicate.PromoCode) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.OrPredicates(predicates...))
|
||||
}
|
||||
|
||||
// Not applies the not operator on the given predicate.
|
||||
func Not(p predicate.PromoCode) predicate.PromoCode {
|
||||
return predicate.PromoCode(sql.NotPredicates(p))
|
||||
}
|
||||
1081
backend/ent/promocode_create.go
Normal file
1081
backend/ent/promocode_create.go
Normal file
File diff suppressed because it is too large
Load Diff
88
backend/ent/promocode_delete.go
Normal file
88
backend/ent/promocode_delete.go
Normal file
@@ -0,0 +1,88 @@
|
||||
// Code generated by ent, DO NOT EDIT.
|
||||
|
||||
package ent
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"entgo.io/ent/dialect/sql"
|
||||
"entgo.io/ent/dialect/sql/sqlgraph"
|
||||
"entgo.io/ent/schema/field"
|
||||
"github.com/Wei-Shaw/sub2api/ent/predicate"
|
||||
"github.com/Wei-Shaw/sub2api/ent/promocode"
|
||||
)
|
||||
|
||||
// PromoCodeDelete is the builder for deleting a PromoCode entity.
|
||||
type PromoCodeDelete struct {
|
||||
config
|
||||
hooks []Hook
|
||||
mutation *PromoCodeMutation
|
||||
}
|
||||
|
||||
// Where appends a list predicates to the PromoCodeDelete builder.
|
||||
func (_d *PromoCodeDelete) Where(ps ...predicate.PromoCode) *PromoCodeDelete {
|
||||
_d.mutation.Where(ps...)
|
||||
return _d
|
||||
}
|
||||
|
||||
// Exec executes the deletion query and returns how many vertices were deleted.
|
||||
func (_d *PromoCodeDelete) Exec(ctx context.Context) (int, error) {
|
||||
return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks)
|
||||
}
|
||||
|
||||
// ExecX is like Exec, but panics if an error occurs.
|
||||
func (_d *PromoCodeDelete) ExecX(ctx context.Context) int {
|
||||
n, err := _d.Exec(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func (_d *PromoCodeDelete) sqlExec(ctx context.Context) (int, error) {
|
||||
_spec := sqlgraph.NewDeleteSpec(promocode.Table, sqlgraph.NewFieldSpec(promocode.FieldID, field.TypeInt64))
|
||||
if ps := _d.mutation.predicates; len(ps) > 0 {
|
||||
_spec.Predicate = func(selector *sql.Selector) {
|
||||
for i := range ps {
|
||||
ps[i](selector)
|
||||
}
|
||||
}
|
||||
}
|
||||
affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec)
|
||||
if err != nil && sqlgraph.IsConstraintError(err) {
|
||||
err = &ConstraintError{msg: err.Error(), wrap: err}
|
||||
}
|
||||
_d.mutation.done = true
|
||||
return affected, err
|
||||
}
|
||||
|
||||
// PromoCodeDeleteOne is the builder for deleting a single PromoCode entity.
|
||||
type PromoCodeDeleteOne struct {
|
||||
_d *PromoCodeDelete
|
||||
}
|
||||
|
||||
// Where appends a list predicates to the PromoCodeDelete builder.
|
||||
func (_d *PromoCodeDeleteOne) Where(ps ...predicate.PromoCode) *PromoCodeDeleteOne {
|
||||
_d._d.mutation.Where(ps...)
|
||||
return _d
|
||||
}
|
||||
|
||||
// Exec executes the deletion query.
|
||||
func (_d *PromoCodeDeleteOne) Exec(ctx context.Context) error {
|
||||
n, err := _d._d.Exec(ctx)
|
||||
switch {
|
||||
case err != nil:
|
||||
return err
|
||||
case n == 0:
|
||||
return &NotFoundError{promocode.Label}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// ExecX is like Exec, but panics if an error occurs.
|
||||
func (_d *PromoCodeDeleteOne) ExecX(ctx context.Context) {
|
||||
if err := _d.Exec(ctx); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
643
backend/ent/promocode_query.go
Normal file
643
backend/ent/promocode_query.go
Normal file
@@ -0,0 +1,643 @@
|
||||
// Code generated by ent, DO NOT EDIT.
|
||||
|
||||
package ent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"entgo.io/ent"
|
||||
"entgo.io/ent/dialect"
|
||||
"entgo.io/ent/dialect/sql"
|
||||
"entgo.io/ent/dialect/sql/sqlgraph"
|
||||
"entgo.io/ent/schema/field"
|
||||
"github.com/Wei-Shaw/sub2api/ent/predicate"
|
||||
"github.com/Wei-Shaw/sub2api/ent/promocode"
|
||||
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
|
||||
)
|
||||
|
||||
// PromoCodeQuery is the builder for querying PromoCode entities.
|
||||
type PromoCodeQuery struct {
|
||||
config
|
||||
ctx *QueryContext
|
||||
order []promocode.OrderOption
|
||||
inters []Interceptor
|
||||
predicates []predicate.PromoCode
|
||||
withUsageRecords *PromoCodeUsageQuery
|
||||
modifiers []func(*sql.Selector)
|
||||
// intermediate query (i.e. traversal path).
|
||||
sql *sql.Selector
|
||||
path func(context.Context) (*sql.Selector, error)
|
||||
}
|
||||
|
||||
// Where adds a new predicate for the PromoCodeQuery builder.
|
||||
func (_q *PromoCodeQuery) Where(ps ...predicate.PromoCode) *PromoCodeQuery {
|
||||
_q.predicates = append(_q.predicates, ps...)
|
||||
return _q
|
||||
}
|
||||
|
||||
// Limit the number of records to be returned by this query.
|
||||
func (_q *PromoCodeQuery) Limit(limit int) *PromoCodeQuery {
|
||||
_q.ctx.Limit = &limit
|
||||
return _q
|
||||
}
|
||||
|
||||
// Offset to start from.
|
||||
func (_q *PromoCodeQuery) Offset(offset int) *PromoCodeQuery {
|
||||
_q.ctx.Offset = &offset
|
||||
return _q
|
||||
}
|
||||
|
||||
// Unique configures the query builder to filter duplicate records on query.
|
||||
// By default, unique is set to true, and can be disabled using this method.
|
||||
func (_q *PromoCodeQuery) Unique(unique bool) *PromoCodeQuery {
|
||||
_q.ctx.Unique = &unique
|
||||
return _q
|
||||
}
|
||||
|
||||
// Order specifies how the records should be ordered.
|
||||
func (_q *PromoCodeQuery) Order(o ...promocode.OrderOption) *PromoCodeQuery {
|
||||
_q.order = append(_q.order, o...)
|
||||
return _q
|
||||
}
|
||||
|
||||
// QueryUsageRecords chains the current query on the "usage_records" edge.
|
||||
func (_q *PromoCodeQuery) QueryUsageRecords() *PromoCodeUsageQuery {
|
||||
query := (&PromoCodeUsageClient{config: _q.config}).Query()
|
||||
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
|
||||
if err := _q.prepareQuery(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
selector := _q.sqlQuery(ctx)
|
||||
if err := selector.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
step := sqlgraph.NewStep(
|
||||
sqlgraph.From(promocode.Table, promocode.FieldID, selector),
|
||||
sqlgraph.To(promocodeusage.Table, promocodeusage.FieldID),
|
||||
sqlgraph.Edge(sqlgraph.O2M, false, promocode.UsageRecordsTable, promocode.UsageRecordsColumn),
|
||||
)
|
||||
fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
|
||||
return fromU, nil
|
||||
}
|
||||
return query
|
||||
}
|
||||
|
||||
// First returns the first PromoCode entity from the query.
|
||||
// Returns a *NotFoundError when no PromoCode was found.
|
||||
func (_q *PromoCodeQuery) First(ctx context.Context) (*PromoCode, error) {
|
||||
nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(nodes) == 0 {
|
||||
return nil, &NotFoundError{promocode.Label}
|
||||
}
|
||||
return nodes[0], nil
|
||||
}
|
||||
|
||||
// FirstX is like First, but panics if an error occurs.
|
||||
func (_q *PromoCodeQuery) FirstX(ctx context.Context) *PromoCode {
|
||||
node, err := _q.First(ctx)
|
||||
if err != nil && !IsNotFound(err) {
|
||||
panic(err)
|
||||
}
|
||||
return node
|
||||
}
|
||||
|
||||
// FirstID returns the first PromoCode ID from the query.
|
||||
// Returns a *NotFoundError when no PromoCode ID was found.
|
||||
func (_q *PromoCodeQuery) FirstID(ctx context.Context) (id int64, err error) {
|
||||
var ids []int64
|
||||
if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil {
|
||||
return
|
||||
}
|
||||
if len(ids) == 0 {
|
||||
err = &NotFoundError{promocode.Label}
|
||||
return
|
||||
}
|
||||
return ids[0], nil
|
||||
}
|
||||
|
||||
// FirstIDX is like FirstID, but panics if an error occurs.
|
||||
func (_q *PromoCodeQuery) FirstIDX(ctx context.Context) int64 {
|
||||
id, err := _q.FirstID(ctx)
|
||||
if err != nil && !IsNotFound(err) {
|
||||
panic(err)
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
// Only returns a single PromoCode entity found by the query, ensuring it only returns one.
|
||||
// Returns a *NotSingularError when more than one PromoCode entity is found.
|
||||
// Returns a *NotFoundError when no PromoCode entities are found.
|
||||
func (_q *PromoCodeQuery) Only(ctx context.Context) (*PromoCode, error) {
|
||||
nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch len(nodes) {
|
||||
case 1:
|
||||
return nodes[0], nil
|
||||
case 0:
|
||||
return nil, &NotFoundError{promocode.Label}
|
||||
default:
|
||||
return nil, &NotSingularError{promocode.Label}
|
||||
}
|
||||
}
|
||||
|
||||
// OnlyX is like Only, but panics if an error occurs.
|
||||
func (_q *PromoCodeQuery) OnlyX(ctx context.Context) *PromoCode {
|
||||
node, err := _q.Only(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return node
|
||||
}
|
||||
|
||||
// OnlyID is like Only, but returns the only PromoCode ID in the query.
|
||||
// Returns a *NotSingularError when more than one PromoCode ID is found.
|
||||
// Returns a *NotFoundError when no entities are found.
|
||||
func (_q *PromoCodeQuery) OnlyID(ctx context.Context) (id int64, err error) {
|
||||
var ids []int64
|
||||
if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil {
|
||||
return
|
||||
}
|
||||
switch len(ids) {
|
||||
case 1:
|
||||
id = ids[0]
|
||||
case 0:
|
||||
err = &NotFoundError{promocode.Label}
|
||||
default:
|
||||
err = &NotSingularError{promocode.Label}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// OnlyIDX is like OnlyID, but panics if an error occurs.
|
||||
func (_q *PromoCodeQuery) OnlyIDX(ctx context.Context) int64 {
|
||||
id, err := _q.OnlyID(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
// All executes the query and returns a list of PromoCodes.
|
||||
func (_q *PromoCodeQuery) All(ctx context.Context) ([]*PromoCode, error) {
|
||||
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll)
|
||||
if err := _q.prepareQuery(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
qr := querierAll[[]*PromoCode, *PromoCodeQuery]()
|
||||
return withInterceptors[[]*PromoCode](ctx, _q, qr, _q.inters)
|
||||
}
|
||||
|
||||
// AllX is like All, but panics if an error occurs.
|
||||
func (_q *PromoCodeQuery) AllX(ctx context.Context) []*PromoCode {
|
||||
nodes, err := _q.All(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return nodes
|
||||
}
|
||||
|
||||
// IDs executes the query and returns a list of PromoCode IDs.
|
||||
func (_q *PromoCodeQuery) IDs(ctx context.Context) (ids []int64, err error) {
|
||||
if _q.ctx.Unique == nil && _q.path != nil {
|
||||
_q.Unique(true)
|
||||
}
|
||||
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs)
|
||||
if err = _q.Select(promocode.FieldID).Scan(ctx, &ids); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
// IDsX is like IDs, but panics if an error occurs.
|
||||
func (_q *PromoCodeQuery) IDsX(ctx context.Context) []int64 {
|
||||
ids, err := _q.IDs(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
// Count returns the count of the given query.
|
||||
func (_q *PromoCodeQuery) Count(ctx context.Context) (int, error) {
|
||||
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount)
|
||||
if err := _q.prepareQuery(ctx); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return withInterceptors[int](ctx, _q, querierCount[*PromoCodeQuery](), _q.inters)
|
||||
}
|
||||
|
||||
// CountX is like Count, but panics if an error occurs.
|
||||
func (_q *PromoCodeQuery) CountX(ctx context.Context) int {
|
||||
count, err := _q.Count(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
// Exist returns true if the query has elements in the graph.
|
||||
func (_q *PromoCodeQuery) Exist(ctx context.Context) (bool, error) {
|
||||
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist)
|
||||
switch _, err := _q.FirstID(ctx); {
|
||||
case IsNotFound(err):
|
||||
return false, nil
|
||||
case err != nil:
|
||||
return false, fmt.Errorf("ent: check existence: %w", err)
|
||||
default:
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
// ExistX is like Exist, but panics if an error occurs.
|
||||
func (_q *PromoCodeQuery) ExistX(ctx context.Context) bool {
|
||||
exist, err := _q.Exist(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return exist
|
||||
}
|
||||
|
||||
// Clone returns a duplicate of the PromoCodeQuery builder, including all associated steps. It can be
|
||||
// used to prepare common query builders and use them differently after the clone is made.
|
||||
func (_q *PromoCodeQuery) Clone() *PromoCodeQuery {
|
||||
if _q == nil {
|
||||
return nil
|
||||
}
|
||||
return &PromoCodeQuery{
|
||||
config: _q.config,
|
||||
ctx: _q.ctx.Clone(),
|
||||
order: append([]promocode.OrderOption{}, _q.order...),
|
||||
inters: append([]Interceptor{}, _q.inters...),
|
||||
predicates: append([]predicate.PromoCode{}, _q.predicates...),
|
||||
withUsageRecords: _q.withUsageRecords.Clone(),
|
||||
// clone intermediate query.
|
||||
sql: _q.sql.Clone(),
|
||||
path: _q.path,
|
||||
}
|
||||
}
|
||||
|
||||
// WithUsageRecords tells the query-builder to eager-load the nodes that are connected to
|
||||
// the "usage_records" edge. The optional arguments are used to configure the query builder of the edge.
|
||||
func (_q *PromoCodeQuery) WithUsageRecords(opts ...func(*PromoCodeUsageQuery)) *PromoCodeQuery {
|
||||
query := (&PromoCodeUsageClient{config: _q.config}).Query()
|
||||
for _, opt := range opts {
|
||||
opt(query)
|
||||
}
|
||||
_q.withUsageRecords = query
|
||||
return _q
|
||||
}
|
||||
|
||||
// GroupBy is used to group vertices by one or more fields/columns.
|
||||
// It is often used with aggregate functions, like: count, max, mean, min, sum.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// var v []struct {
|
||||
// Code string `json:"code,omitempty"`
|
||||
// Count int `json:"count,omitempty"`
|
||||
// }
|
||||
//
|
||||
// client.PromoCode.Query().
|
||||
// GroupBy(promocode.FieldCode).
|
||||
// Aggregate(ent.Count()).
|
||||
// Scan(ctx, &v)
|
||||
func (_q *PromoCodeQuery) GroupBy(field string, fields ...string) *PromoCodeGroupBy {
|
||||
_q.ctx.Fields = append([]string{field}, fields...)
|
||||
grbuild := &PromoCodeGroupBy{build: _q}
|
||||
grbuild.flds = &_q.ctx.Fields
|
||||
grbuild.label = promocode.Label
|
||||
grbuild.scan = grbuild.Scan
|
||||
return grbuild
|
||||
}
|
||||
|
||||
// Select allows the selection one or more fields/columns for the given query,
|
||||
// instead of selecting all fields in the entity.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// var v []struct {
|
||||
// Code string `json:"code,omitempty"`
|
||||
// }
|
||||
//
|
||||
// client.PromoCode.Query().
|
||||
// Select(promocode.FieldCode).
|
||||
// Scan(ctx, &v)
|
||||
func (_q *PromoCodeQuery) Select(fields ...string) *PromoCodeSelect {
|
||||
_q.ctx.Fields = append(_q.ctx.Fields, fields...)
|
||||
sbuild := &PromoCodeSelect{PromoCodeQuery: _q}
|
||||
sbuild.label = promocode.Label
|
||||
sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan
|
||||
return sbuild
|
||||
}
|
||||
|
||||
// Aggregate returns a PromoCodeSelect configured with the given aggregations.
|
||||
func (_q *PromoCodeQuery) Aggregate(fns ...AggregateFunc) *PromoCodeSelect {
|
||||
return _q.Select().Aggregate(fns...)
|
||||
}
|
||||
|
||||
func (_q *PromoCodeQuery) prepareQuery(ctx context.Context) error {
|
||||
for _, inter := range _q.inters {
|
||||
if inter == nil {
|
||||
return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
|
||||
}
|
||||
if trv, ok := inter.(Traverser); ok {
|
||||
if err := trv.Traverse(ctx, _q); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, f := range _q.ctx.Fields {
|
||||
if !promocode.ValidColumn(f) {
|
||||
return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
|
||||
}
|
||||
}
|
||||
if _q.path != nil {
|
||||
prev, err := _q.path(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_q.sql = prev
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (_q *PromoCodeQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*PromoCode, error) {
|
||||
var (
|
||||
nodes = []*PromoCode{}
|
||||
_spec = _q.querySpec()
|
||||
loadedTypes = [1]bool{
|
||||
_q.withUsageRecords != nil,
|
||||
}
|
||||
)
|
||||
_spec.ScanValues = func(columns []string) ([]any, error) {
|
||||
return (*PromoCode).scanValues(nil, columns)
|
||||
}
|
||||
_spec.Assign = func(columns []string, values []any) error {
|
||||
node := &PromoCode{config: _q.config}
|
||||
nodes = append(nodes, node)
|
||||
node.Edges.loadedTypes = loadedTypes
|
||||
return node.assignValues(columns, values)
|
||||
}
|
||||
if len(_q.modifiers) > 0 {
|
||||
_spec.Modifiers = _q.modifiers
|
||||
}
|
||||
for i := range hooks {
|
||||
hooks[i](ctx, _spec)
|
||||
}
|
||||
if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(nodes) == 0 {
|
||||
return nodes, nil
|
||||
}
|
||||
if query := _q.withUsageRecords; query != nil {
|
||||
if err := _q.loadUsageRecords(ctx, query, nodes,
|
||||
func(n *PromoCode) { n.Edges.UsageRecords = []*PromoCodeUsage{} },
|
||||
func(n *PromoCode, e *PromoCodeUsage) { n.Edges.UsageRecords = append(n.Edges.UsageRecords, e) }); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return nodes, nil
|
||||
}
|
||||
|
||||
func (_q *PromoCodeQuery) loadUsageRecords(ctx context.Context, query *PromoCodeUsageQuery, nodes []*PromoCode, init func(*PromoCode), assign func(*PromoCode, *PromoCodeUsage)) error {
|
||||
fks := make([]driver.Value, 0, len(nodes))
|
||||
nodeids := make(map[int64]*PromoCode)
|
||||
for i := range nodes {
|
||||
fks = append(fks, nodes[i].ID)
|
||||
nodeids[nodes[i].ID] = nodes[i]
|
||||
if init != nil {
|
||||
init(nodes[i])
|
||||
}
|
||||
}
|
||||
if len(query.ctx.Fields) > 0 {
|
||||
query.ctx.AppendFieldOnce(promocodeusage.FieldPromoCodeID)
|
||||
}
|
||||
query.Where(predicate.PromoCodeUsage(func(s *sql.Selector) {
|
||||
s.Where(sql.InValues(s.C(promocode.UsageRecordsColumn), fks...))
|
||||
}))
|
||||
neighbors, err := query.All(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, n := range neighbors {
|
||||
fk := n.PromoCodeID
|
||||
node, ok := nodeids[fk]
|
||||
if !ok {
|
||||
return fmt.Errorf(`unexpected referenced foreign-key "promo_code_id" returned %v for node %v`, fk, n.ID)
|
||||
}
|
||||
assign(node, n)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (_q *PromoCodeQuery) sqlCount(ctx context.Context) (int, error) {
|
||||
_spec := _q.querySpec()
|
||||
if len(_q.modifiers) > 0 {
|
||||
_spec.Modifiers = _q.modifiers
|
||||
}
|
||||
_spec.Node.Columns = _q.ctx.Fields
|
||||
if len(_q.ctx.Fields) > 0 {
|
||||
_spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
|
||||
}
|
||||
return sqlgraph.CountNodes(ctx, _q.driver, _spec)
|
||||
}
|
||||
|
||||
func (_q *PromoCodeQuery) querySpec() *sqlgraph.QuerySpec {
|
||||
_spec := sqlgraph.NewQuerySpec(promocode.Table, promocode.Columns, sqlgraph.NewFieldSpec(promocode.FieldID, field.TypeInt64))
|
||||
_spec.From = _q.sql
|
||||
if unique := _q.ctx.Unique; unique != nil {
|
||||
_spec.Unique = *unique
|
||||
} else if _q.path != nil {
|
||||
_spec.Unique = true
|
||||
}
|
||||
if fields := _q.ctx.Fields; len(fields) > 0 {
|
||||
_spec.Node.Columns = make([]string, 0, len(fields))
|
||||
_spec.Node.Columns = append(_spec.Node.Columns, promocode.FieldID)
|
||||
for i := range fields {
|
||||
if fields[i] != promocode.FieldID {
|
||||
_spec.Node.Columns = append(_spec.Node.Columns, fields[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
if ps := _q.predicates; len(ps) > 0 {
|
||||
_spec.Predicate = func(selector *sql.Selector) {
|
||||
for i := range ps {
|
||||
ps[i](selector)
|
||||
}
|
||||
}
|
||||
}
|
||||
if limit := _q.ctx.Limit; limit != nil {
|
||||
_spec.Limit = *limit
|
||||
}
|
||||
if offset := _q.ctx.Offset; offset != nil {
|
||||
_spec.Offset = *offset
|
||||
}
|
||||
if ps := _q.order; len(ps) > 0 {
|
||||
_spec.Order = func(selector *sql.Selector) {
|
||||
for i := range ps {
|
||||
ps[i](selector)
|
||||
}
|
||||
}
|
||||
}
|
||||
return _spec
|
||||
}
|
||||
|
||||
func (_q *PromoCodeQuery) sqlQuery(ctx context.Context) *sql.Selector {
|
||||
builder := sql.Dialect(_q.driver.Dialect())
|
||||
t1 := builder.Table(promocode.Table)
|
||||
columns := _q.ctx.Fields
|
||||
if len(columns) == 0 {
|
||||
columns = promocode.Columns
|
||||
}
|
||||
selector := builder.Select(t1.Columns(columns...)...).From(t1)
|
||||
if _q.sql != nil {
|
||||
selector = _q.sql
|
||||
selector.Select(selector.Columns(columns...)...)
|
||||
}
|
||||
if _q.ctx.Unique != nil && *_q.ctx.Unique {
|
||||
selector.Distinct()
|
||||
}
|
||||
for _, m := range _q.modifiers {
|
||||
m(selector)
|
||||
}
|
||||
for _, p := range _q.predicates {
|
||||
p(selector)
|
||||
}
|
||||
for _, p := range _q.order {
|
||||
p(selector)
|
||||
}
|
||||
if offset := _q.ctx.Offset; offset != nil {
|
||||
// limit is mandatory for offset clause. We start
|
||||
// with default value, and override it below if needed.
|
||||
selector.Offset(*offset).Limit(math.MaxInt32)
|
||||
}
|
||||
if limit := _q.ctx.Limit; limit != nil {
|
||||
selector.Limit(*limit)
|
||||
}
|
||||
return selector
|
||||
}
|
||||
|
||||
// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
|
||||
// updated, deleted or "selected ... for update" by other sessions, until the transaction is
|
||||
// either committed or rolled-back.
|
||||
func (_q *PromoCodeQuery) ForUpdate(opts ...sql.LockOption) *PromoCodeQuery {
|
||||
if _q.driver.Dialect() == dialect.Postgres {
|
||||
_q.Unique(false)
|
||||
}
|
||||
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
|
||||
s.ForUpdate(opts...)
|
||||
})
|
||||
return _q
|
||||
}
|
||||
|
||||
// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
|
||||
// on any rows that are read. Other sessions can read the rows, but cannot modify them
|
||||
// until your transaction commits.
|
||||
func (_q *PromoCodeQuery) ForShare(opts ...sql.LockOption) *PromoCodeQuery {
|
||||
if _q.driver.Dialect() == dialect.Postgres {
|
||||
_q.Unique(false)
|
||||
}
|
||||
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
|
||||
s.ForShare(opts...)
|
||||
})
|
||||
return _q
|
||||
}
|
||||
|
||||
// PromoCodeGroupBy is the group-by builder for PromoCode entities.
|
||||
type PromoCodeGroupBy struct {
|
||||
selector
|
||||
build *PromoCodeQuery
|
||||
}
|
||||
|
||||
// Aggregate adds the given aggregation functions to the group-by query.
|
||||
func (_g *PromoCodeGroupBy) Aggregate(fns ...AggregateFunc) *PromoCodeGroupBy {
|
||||
_g.fns = append(_g.fns, fns...)
|
||||
return _g
|
||||
}
|
||||
|
||||
// Scan applies the selector query and scans the result into the given value.
|
||||
func (_g *PromoCodeGroupBy) Scan(ctx context.Context, v any) error {
|
||||
ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy)
|
||||
if err := _g.build.prepareQuery(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
return scanWithInterceptors[*PromoCodeQuery, *PromoCodeGroupBy](ctx, _g.build, _g, _g.build.inters, v)
|
||||
}
|
||||
|
||||
func (_g *PromoCodeGroupBy) sqlScan(ctx context.Context, root *PromoCodeQuery, v any) error {
|
||||
selector := root.sqlQuery(ctx).Select()
|
||||
aggregation := make([]string, 0, len(_g.fns))
|
||||
for _, fn := range _g.fns {
|
||||
aggregation = append(aggregation, fn(selector))
|
||||
}
|
||||
if len(selector.SelectedColumns()) == 0 {
|
||||
columns := make([]string, 0, len(*_g.flds)+len(_g.fns))
|
||||
for _, f := range *_g.flds {
|
||||
columns = append(columns, selector.C(f))
|
||||
}
|
||||
columns = append(columns, aggregation...)
|
||||
selector.Select(columns...)
|
||||
}
|
||||
selector.GroupBy(selector.Columns(*_g.flds...)...)
|
||||
if err := selector.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
rows := &sql.Rows{}
|
||||
query, args := selector.Query()
|
||||
if err := _g.build.driver.Query(ctx, query, args, rows); err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
return sql.ScanSlice(rows, v)
|
||||
}
|
||||
|
||||
// PromoCodeSelect is the builder for selecting fields of PromoCode entities.
|
||||
type PromoCodeSelect struct {
|
||||
*PromoCodeQuery
|
||||
selector
|
||||
}
|
||||
|
||||
// Aggregate adds the given aggregation functions to the selector query.
|
||||
func (_s *PromoCodeSelect) Aggregate(fns ...AggregateFunc) *PromoCodeSelect {
|
||||
_s.fns = append(_s.fns, fns...)
|
||||
return _s
|
||||
}
|
||||
|
||||
// Scan applies the selector query and scans the result into the given value.
|
||||
func (_s *PromoCodeSelect) Scan(ctx context.Context, v any) error {
|
||||
ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect)
|
||||
if err := _s.prepareQuery(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
return scanWithInterceptors[*PromoCodeQuery, *PromoCodeSelect](ctx, _s.PromoCodeQuery, _s, _s.inters, v)
|
||||
}
|
||||
|
||||
func (_s *PromoCodeSelect) sqlScan(ctx context.Context, root *PromoCodeQuery, v any) error {
|
||||
selector := root.sqlQuery(ctx)
|
||||
aggregation := make([]string, 0, len(_s.fns))
|
||||
for _, fn := range _s.fns {
|
||||
aggregation = append(aggregation, fn(selector))
|
||||
}
|
||||
switch n := len(*_s.selector.flds); {
|
||||
case n == 0 && len(aggregation) > 0:
|
||||
selector.Select(aggregation...)
|
||||
case n != 0 && len(aggregation) > 0:
|
||||
selector.AppendSelect(aggregation...)
|
||||
}
|
||||
rows := &sql.Rows{}
|
||||
query, args := selector.Query()
|
||||
if err := _s.driver.Query(ctx, query, args, rows); err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
return sql.ScanSlice(rows, v)
|
||||
}
|
||||
745
backend/ent/promocode_update.go
Normal file
745
backend/ent/promocode_update.go
Normal file
@@ -0,0 +1,745 @@
|
||||
// Code generated by ent, DO NOT EDIT.
|
||||
|
||||
package ent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"entgo.io/ent/dialect/sql"
|
||||
"entgo.io/ent/dialect/sql/sqlgraph"
|
||||
"entgo.io/ent/schema/field"
|
||||
"github.com/Wei-Shaw/sub2api/ent/predicate"
|
||||
"github.com/Wei-Shaw/sub2api/ent/promocode"
|
||||
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
|
||||
)
|
||||
|
||||
// PromoCodeUpdate is the builder for updating PromoCode entities.
|
||||
type PromoCodeUpdate struct {
|
||||
config
|
||||
hooks []Hook
|
||||
mutation *PromoCodeMutation
|
||||
}
|
||||
|
||||
// Where appends a list predicates to the PromoCodeUpdate builder.
|
||||
func (_u *PromoCodeUpdate) Where(ps ...predicate.PromoCode) *PromoCodeUpdate {
|
||||
_u.mutation.Where(ps...)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetCode sets the "code" field.
|
||||
func (_u *PromoCodeUpdate) SetCode(v string) *PromoCodeUpdate {
|
||||
_u.mutation.SetCode(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableCode sets the "code" field if the given value is not nil.
|
||||
func (_u *PromoCodeUpdate) SetNillableCode(v *string) *PromoCodeUpdate {
|
||||
if v != nil {
|
||||
_u.SetCode(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetBonusAmount sets the "bonus_amount" field.
|
||||
func (_u *PromoCodeUpdate) SetBonusAmount(v float64) *PromoCodeUpdate {
|
||||
_u.mutation.ResetBonusAmount()
|
||||
_u.mutation.SetBonusAmount(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableBonusAmount sets the "bonus_amount" field if the given value is not nil.
|
||||
func (_u *PromoCodeUpdate) SetNillableBonusAmount(v *float64) *PromoCodeUpdate {
|
||||
if v != nil {
|
||||
_u.SetBonusAmount(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddBonusAmount adds value to the "bonus_amount" field.
|
||||
func (_u *PromoCodeUpdate) AddBonusAmount(v float64) *PromoCodeUpdate {
|
||||
_u.mutation.AddBonusAmount(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetMaxUses sets the "max_uses" field.
|
||||
func (_u *PromoCodeUpdate) SetMaxUses(v int) *PromoCodeUpdate {
|
||||
_u.mutation.ResetMaxUses()
|
||||
_u.mutation.SetMaxUses(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableMaxUses sets the "max_uses" field if the given value is not nil.
|
||||
func (_u *PromoCodeUpdate) SetNillableMaxUses(v *int) *PromoCodeUpdate {
|
||||
if v != nil {
|
||||
_u.SetMaxUses(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddMaxUses adds value to the "max_uses" field.
|
||||
func (_u *PromoCodeUpdate) AddMaxUses(v int) *PromoCodeUpdate {
|
||||
_u.mutation.AddMaxUses(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetUsedCount sets the "used_count" field.
|
||||
func (_u *PromoCodeUpdate) SetUsedCount(v int) *PromoCodeUpdate {
|
||||
_u.mutation.ResetUsedCount()
|
||||
_u.mutation.SetUsedCount(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableUsedCount sets the "used_count" field if the given value is not nil.
|
||||
func (_u *PromoCodeUpdate) SetNillableUsedCount(v *int) *PromoCodeUpdate {
|
||||
if v != nil {
|
||||
_u.SetUsedCount(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddUsedCount adds value to the "used_count" field.
|
||||
func (_u *PromoCodeUpdate) AddUsedCount(v int) *PromoCodeUpdate {
|
||||
_u.mutation.AddUsedCount(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetStatus sets the "status" field.
|
||||
func (_u *PromoCodeUpdate) SetStatus(v string) *PromoCodeUpdate {
|
||||
_u.mutation.SetStatus(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableStatus sets the "status" field if the given value is not nil.
|
||||
func (_u *PromoCodeUpdate) SetNillableStatus(v *string) *PromoCodeUpdate {
|
||||
if v != nil {
|
||||
_u.SetStatus(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetExpiresAt sets the "expires_at" field.
|
||||
func (_u *PromoCodeUpdate) SetExpiresAt(v time.Time) *PromoCodeUpdate {
|
||||
_u.mutation.SetExpiresAt(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil.
|
||||
func (_u *PromoCodeUpdate) SetNillableExpiresAt(v *time.Time) *PromoCodeUpdate {
|
||||
if v != nil {
|
||||
_u.SetExpiresAt(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearExpiresAt clears the value of the "expires_at" field.
|
||||
func (_u *PromoCodeUpdate) ClearExpiresAt() *PromoCodeUpdate {
|
||||
_u.mutation.ClearExpiresAt()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNotes sets the "notes" field.
|
||||
func (_u *PromoCodeUpdate) SetNotes(v string) *PromoCodeUpdate {
|
||||
_u.mutation.SetNotes(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableNotes sets the "notes" field if the given value is not nil.
|
||||
func (_u *PromoCodeUpdate) SetNillableNotes(v *string) *PromoCodeUpdate {
|
||||
if v != nil {
|
||||
_u.SetNotes(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearNotes clears the value of the "notes" field.
|
||||
func (_u *PromoCodeUpdate) ClearNotes() *PromoCodeUpdate {
|
||||
_u.mutation.ClearNotes()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetUpdatedAt sets the "updated_at" field.
|
||||
func (_u *PromoCodeUpdate) SetUpdatedAt(v time.Time) *PromoCodeUpdate {
|
||||
_u.mutation.SetUpdatedAt(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddUsageRecordIDs adds the "usage_records" edge to the PromoCodeUsage entity by IDs.
|
||||
func (_u *PromoCodeUpdate) AddUsageRecordIDs(ids ...int64) *PromoCodeUpdate {
|
||||
_u.mutation.AddUsageRecordIDs(ids...)
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddUsageRecords adds the "usage_records" edges to the PromoCodeUsage entity.
|
||||
func (_u *PromoCodeUpdate) AddUsageRecords(v ...*PromoCodeUsage) *PromoCodeUpdate {
|
||||
ids := make([]int64, len(v))
|
||||
for i := range v {
|
||||
ids[i] = v[i].ID
|
||||
}
|
||||
return _u.AddUsageRecordIDs(ids...)
|
||||
}
|
||||
|
||||
// Mutation returns the PromoCodeMutation object of the builder.
|
||||
func (_u *PromoCodeUpdate) Mutation() *PromoCodeMutation {
|
||||
return _u.mutation
|
||||
}
|
||||
|
||||
// ClearUsageRecords clears all "usage_records" edges to the PromoCodeUsage entity.
|
||||
func (_u *PromoCodeUpdate) ClearUsageRecords() *PromoCodeUpdate {
|
||||
_u.mutation.ClearUsageRecords()
|
||||
return _u
|
||||
}
|
||||
|
||||
// RemoveUsageRecordIDs removes the "usage_records" edge to PromoCodeUsage entities by IDs.
|
||||
func (_u *PromoCodeUpdate) RemoveUsageRecordIDs(ids ...int64) *PromoCodeUpdate {
|
||||
_u.mutation.RemoveUsageRecordIDs(ids...)
|
||||
return _u
|
||||
}
|
||||
|
||||
// RemoveUsageRecords removes "usage_records" edges to PromoCodeUsage entities.
|
||||
func (_u *PromoCodeUpdate) RemoveUsageRecords(v ...*PromoCodeUsage) *PromoCodeUpdate {
|
||||
ids := make([]int64, len(v))
|
||||
for i := range v {
|
||||
ids[i] = v[i].ID
|
||||
}
|
||||
return _u.RemoveUsageRecordIDs(ids...)
|
||||
}
|
||||
|
||||
// Save executes the query and returns the number of nodes affected by the update operation.
|
||||
func (_u *PromoCodeUpdate) Save(ctx context.Context) (int, error) {
|
||||
_u.defaults()
|
||||
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
|
||||
}
|
||||
|
||||
// SaveX is like Save, but panics if an error occurs.
|
||||
func (_u *PromoCodeUpdate) SaveX(ctx context.Context) int {
|
||||
affected, err := _u.Save(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return affected
|
||||
}
|
||||
|
||||
// Exec executes the query.
|
||||
func (_u *PromoCodeUpdate) Exec(ctx context.Context) error {
|
||||
_, err := _u.Save(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
// ExecX is like Exec, but panics if an error occurs.
|
||||
func (_u *PromoCodeUpdate) ExecX(ctx context.Context) {
|
||||
if err := _u.Exec(ctx); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
// defaults sets the default values of the builder before save.
|
||||
func (_u *PromoCodeUpdate) defaults() {
|
||||
if _, ok := _u.mutation.UpdatedAt(); !ok {
|
||||
v := promocode.UpdateDefaultUpdatedAt()
|
||||
_u.mutation.SetUpdatedAt(v)
|
||||
}
|
||||
}
|
||||
|
||||
// check runs all checks and user-defined validators on the builder.
|
||||
func (_u *PromoCodeUpdate) check() error {
|
||||
if v, ok := _u.mutation.Code(); ok {
|
||||
if err := promocode.CodeValidator(v); err != nil {
|
||||
return &ValidationError{Name: "code", err: fmt.Errorf(`ent: validator failed for field "PromoCode.code": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.Status(); ok {
|
||||
if err := promocode.StatusValidator(v); err != nil {
|
||||
return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "PromoCode.status": %w`, err)}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (_u *PromoCodeUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
||||
if err := _u.check(); err != nil {
|
||||
return _node, err
|
||||
}
|
||||
_spec := sqlgraph.NewUpdateSpec(promocode.Table, promocode.Columns, sqlgraph.NewFieldSpec(promocode.FieldID, field.TypeInt64))
|
||||
if ps := _u.mutation.predicates; len(ps) > 0 {
|
||||
_spec.Predicate = func(selector *sql.Selector) {
|
||||
for i := range ps {
|
||||
ps[i](selector)
|
||||
}
|
||||
}
|
||||
}
|
||||
if value, ok := _u.mutation.Code(); ok {
|
||||
_spec.SetField(promocode.FieldCode, field.TypeString, value)
|
||||
}
|
||||
if value, ok := _u.mutation.BonusAmount(); ok {
|
||||
_spec.SetField(promocode.FieldBonusAmount, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedBonusAmount(); ok {
|
||||
_spec.AddField(promocode.FieldBonusAmount, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.MaxUses(); ok {
|
||||
_spec.SetField(promocode.FieldMaxUses, field.TypeInt, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedMaxUses(); ok {
|
||||
_spec.AddField(promocode.FieldMaxUses, field.TypeInt, value)
|
||||
}
|
||||
if value, ok := _u.mutation.UsedCount(); ok {
|
||||
_spec.SetField(promocode.FieldUsedCount, field.TypeInt, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedUsedCount(); ok {
|
||||
_spec.AddField(promocode.FieldUsedCount, field.TypeInt, value)
|
||||
}
|
||||
if value, ok := _u.mutation.Status(); ok {
|
||||
_spec.SetField(promocode.FieldStatus, field.TypeString, value)
|
||||
}
|
||||
if value, ok := _u.mutation.ExpiresAt(); ok {
|
||||
_spec.SetField(promocode.FieldExpiresAt, field.TypeTime, value)
|
||||
}
|
||||
if _u.mutation.ExpiresAtCleared() {
|
||||
_spec.ClearField(promocode.FieldExpiresAt, field.TypeTime)
|
||||
}
|
||||
if value, ok := _u.mutation.Notes(); ok {
|
||||
_spec.SetField(promocode.FieldNotes, field.TypeString, value)
|
||||
}
|
||||
if _u.mutation.NotesCleared() {
|
||||
_spec.ClearField(promocode.FieldNotes, field.TypeString)
|
||||
}
|
||||
if value, ok := _u.mutation.UpdatedAt(); ok {
|
||||
_spec.SetField(promocode.FieldUpdatedAt, field.TypeTime, value)
|
||||
}
|
||||
if _u.mutation.UsageRecordsCleared() {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
Inverse: false,
|
||||
Table: promocode.UsageRecordsTable,
|
||||
Columns: []string{promocode.UsageRecordsColumn},
|
||||
Bidi: false,
|
||||
Target: &sqlgraph.EdgeTarget{
|
||||
IDSpec: sqlgraph.NewFieldSpec(promocodeusage.FieldID, field.TypeInt64),
|
||||
},
|
||||
}
|
||||
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
|
||||
}
|
||||
if nodes := _u.mutation.RemovedUsageRecordsIDs(); len(nodes) > 0 && !_u.mutation.UsageRecordsCleared() {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
Inverse: false,
|
||||
Table: promocode.UsageRecordsTable,
|
||||
Columns: []string{promocode.UsageRecordsColumn},
|
||||
Bidi: false,
|
||||
Target: &sqlgraph.EdgeTarget{
|
||||
IDSpec: sqlgraph.NewFieldSpec(promocodeusage.FieldID, field.TypeInt64),
|
||||
},
|
||||
}
|
||||
for _, k := range nodes {
|
||||
edge.Target.Nodes = append(edge.Target.Nodes, k)
|
||||
}
|
||||
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
|
||||
}
|
||||
if nodes := _u.mutation.UsageRecordsIDs(); len(nodes) > 0 {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
Inverse: false,
|
||||
Table: promocode.UsageRecordsTable,
|
||||
Columns: []string{promocode.UsageRecordsColumn},
|
||||
Bidi: false,
|
||||
Target: &sqlgraph.EdgeTarget{
|
||||
IDSpec: sqlgraph.NewFieldSpec(promocodeusage.FieldID, field.TypeInt64),
|
||||
},
|
||||
}
|
||||
for _, k := range nodes {
|
||||
edge.Target.Nodes = append(edge.Target.Nodes, k)
|
||||
}
|
||||
_spec.Edges.Add = append(_spec.Edges.Add, edge)
|
||||
}
|
||||
if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
|
||||
if _, ok := err.(*sqlgraph.NotFoundError); ok {
|
||||
err = &NotFoundError{promocode.Label}
|
||||
} else if sqlgraph.IsConstraintError(err) {
|
||||
err = &ConstraintError{msg: err.Error(), wrap: err}
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
_u.mutation.done = true
|
||||
return _node, nil
|
||||
}
|
||||
|
||||
// PromoCodeUpdateOne is the builder for updating a single PromoCode entity.
|
||||
type PromoCodeUpdateOne struct {
|
||||
config
|
||||
fields []string
|
||||
hooks []Hook
|
||||
mutation *PromoCodeMutation
|
||||
}
|
||||
|
||||
// SetCode sets the "code" field.
|
||||
func (_u *PromoCodeUpdateOne) SetCode(v string) *PromoCodeUpdateOne {
|
||||
_u.mutation.SetCode(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableCode sets the "code" field if the given value is not nil.
|
||||
func (_u *PromoCodeUpdateOne) SetNillableCode(v *string) *PromoCodeUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetCode(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetBonusAmount sets the "bonus_amount" field.
|
||||
func (_u *PromoCodeUpdateOne) SetBonusAmount(v float64) *PromoCodeUpdateOne {
|
||||
_u.mutation.ResetBonusAmount()
|
||||
_u.mutation.SetBonusAmount(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableBonusAmount sets the "bonus_amount" field if the given value is not nil.
|
||||
func (_u *PromoCodeUpdateOne) SetNillableBonusAmount(v *float64) *PromoCodeUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetBonusAmount(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddBonusAmount adds value to the "bonus_amount" field.
|
||||
func (_u *PromoCodeUpdateOne) AddBonusAmount(v float64) *PromoCodeUpdateOne {
|
||||
_u.mutation.AddBonusAmount(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetMaxUses sets the "max_uses" field.
|
||||
func (_u *PromoCodeUpdateOne) SetMaxUses(v int) *PromoCodeUpdateOne {
|
||||
_u.mutation.ResetMaxUses()
|
||||
_u.mutation.SetMaxUses(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableMaxUses sets the "max_uses" field if the given value is not nil.
|
||||
func (_u *PromoCodeUpdateOne) SetNillableMaxUses(v *int) *PromoCodeUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetMaxUses(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddMaxUses adds value to the "max_uses" field.
|
||||
func (_u *PromoCodeUpdateOne) AddMaxUses(v int) *PromoCodeUpdateOne {
|
||||
_u.mutation.AddMaxUses(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetUsedCount sets the "used_count" field.
|
||||
func (_u *PromoCodeUpdateOne) SetUsedCount(v int) *PromoCodeUpdateOne {
|
||||
_u.mutation.ResetUsedCount()
|
||||
_u.mutation.SetUsedCount(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableUsedCount sets the "used_count" field if the given value is not nil.
|
||||
func (_u *PromoCodeUpdateOne) SetNillableUsedCount(v *int) *PromoCodeUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetUsedCount(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddUsedCount adds value to the "used_count" field.
|
||||
func (_u *PromoCodeUpdateOne) AddUsedCount(v int) *PromoCodeUpdateOne {
|
||||
_u.mutation.AddUsedCount(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetStatus sets the "status" field.
|
||||
func (_u *PromoCodeUpdateOne) SetStatus(v string) *PromoCodeUpdateOne {
|
||||
_u.mutation.SetStatus(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableStatus sets the "status" field if the given value is not nil.
|
||||
func (_u *PromoCodeUpdateOne) SetNillableStatus(v *string) *PromoCodeUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetStatus(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetExpiresAt sets the "expires_at" field.
|
||||
func (_u *PromoCodeUpdateOne) SetExpiresAt(v time.Time) *PromoCodeUpdateOne {
|
||||
_u.mutation.SetExpiresAt(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil.
|
||||
func (_u *PromoCodeUpdateOne) SetNillableExpiresAt(v *time.Time) *PromoCodeUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetExpiresAt(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearExpiresAt clears the value of the "expires_at" field.
|
||||
func (_u *PromoCodeUpdateOne) ClearExpiresAt() *PromoCodeUpdateOne {
|
||||
_u.mutation.ClearExpiresAt()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNotes sets the "notes" field.
|
||||
func (_u *PromoCodeUpdateOne) SetNotes(v string) *PromoCodeUpdateOne {
|
||||
_u.mutation.SetNotes(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableNotes sets the "notes" field if the given value is not nil.
|
||||
func (_u *PromoCodeUpdateOne) SetNillableNotes(v *string) *PromoCodeUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetNotes(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearNotes clears the value of the "notes" field.
|
||||
func (_u *PromoCodeUpdateOne) ClearNotes() *PromoCodeUpdateOne {
|
||||
_u.mutation.ClearNotes()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetUpdatedAt sets the "updated_at" field.
|
||||
func (_u *PromoCodeUpdateOne) SetUpdatedAt(v time.Time) *PromoCodeUpdateOne {
|
||||
_u.mutation.SetUpdatedAt(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddUsageRecordIDs adds the "usage_records" edge to the PromoCodeUsage entity by IDs.
|
||||
func (_u *PromoCodeUpdateOne) AddUsageRecordIDs(ids ...int64) *PromoCodeUpdateOne {
|
||||
_u.mutation.AddUsageRecordIDs(ids...)
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddUsageRecords adds the "usage_records" edges to the PromoCodeUsage entity.
|
||||
func (_u *PromoCodeUpdateOne) AddUsageRecords(v ...*PromoCodeUsage) *PromoCodeUpdateOne {
|
||||
ids := make([]int64, len(v))
|
||||
for i := range v {
|
||||
ids[i] = v[i].ID
|
||||
}
|
||||
return _u.AddUsageRecordIDs(ids...)
|
||||
}
|
||||
|
||||
// Mutation returns the PromoCodeMutation object of the builder.
|
||||
func (_u *PromoCodeUpdateOne) Mutation() *PromoCodeMutation {
|
||||
return _u.mutation
|
||||
}
|
||||
|
||||
// ClearUsageRecords clears all "usage_records" edges to the PromoCodeUsage entity.
|
||||
func (_u *PromoCodeUpdateOne) ClearUsageRecords() *PromoCodeUpdateOne {
|
||||
_u.mutation.ClearUsageRecords()
|
||||
return _u
|
||||
}
|
||||
|
||||
// RemoveUsageRecordIDs removes the "usage_records" edge to PromoCodeUsage entities by IDs.
|
||||
func (_u *PromoCodeUpdateOne) RemoveUsageRecordIDs(ids ...int64) *PromoCodeUpdateOne {
|
||||
_u.mutation.RemoveUsageRecordIDs(ids...)
|
||||
return _u
|
||||
}
|
||||
|
||||
// RemoveUsageRecords removes "usage_records" edges to PromoCodeUsage entities.
|
||||
func (_u *PromoCodeUpdateOne) RemoveUsageRecords(v ...*PromoCodeUsage) *PromoCodeUpdateOne {
|
||||
ids := make([]int64, len(v))
|
||||
for i := range v {
|
||||
ids[i] = v[i].ID
|
||||
}
|
||||
return _u.RemoveUsageRecordIDs(ids...)
|
||||
}
|
||||
|
||||
// Where appends a list predicates to the PromoCodeUpdate builder.
|
||||
func (_u *PromoCodeUpdateOne) Where(ps ...predicate.PromoCode) *PromoCodeUpdateOne {
|
||||
_u.mutation.Where(ps...)
|
||||
return _u
|
||||
}
|
||||
|
||||
// Select allows selecting one or more fields (columns) of the returned entity.
|
||||
// The default is selecting all fields defined in the entity schema.
|
||||
func (_u *PromoCodeUpdateOne) Select(field string, fields ...string) *PromoCodeUpdateOne {
|
||||
_u.fields = append([]string{field}, fields...)
|
||||
return _u
|
||||
}
|
||||
|
||||
// Save executes the query and returns the updated PromoCode entity.
|
||||
func (_u *PromoCodeUpdateOne) Save(ctx context.Context) (*PromoCode, error) {
|
||||
_u.defaults()
|
||||
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
|
||||
}
|
||||
|
||||
// SaveX is like Save, but panics if an error occurs.
|
||||
func (_u *PromoCodeUpdateOne) SaveX(ctx context.Context) *PromoCode {
|
||||
node, err := _u.Save(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return node
|
||||
}
|
||||
|
||||
// Exec executes the query on the entity.
|
||||
func (_u *PromoCodeUpdateOne) Exec(ctx context.Context) error {
|
||||
_, err := _u.Save(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
// ExecX is like Exec, but panics if an error occurs.
|
||||
func (_u *PromoCodeUpdateOne) ExecX(ctx context.Context) {
|
||||
if err := _u.Exec(ctx); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
// defaults sets the default values of the builder before save.
|
||||
func (_u *PromoCodeUpdateOne) defaults() {
|
||||
if _, ok := _u.mutation.UpdatedAt(); !ok {
|
||||
v := promocode.UpdateDefaultUpdatedAt()
|
||||
_u.mutation.SetUpdatedAt(v)
|
||||
}
|
||||
}
|
||||
|
||||
// check runs all checks and user-defined validators on the builder.
|
||||
func (_u *PromoCodeUpdateOne) check() error {
|
||||
if v, ok := _u.mutation.Code(); ok {
|
||||
if err := promocode.CodeValidator(v); err != nil {
|
||||
return &ValidationError{Name: "code", err: fmt.Errorf(`ent: validator failed for field "PromoCode.code": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.Status(); ok {
|
||||
if err := promocode.StatusValidator(v); err != nil {
|
||||
return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "PromoCode.status": %w`, err)}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (_u *PromoCodeUpdateOne) sqlSave(ctx context.Context) (_node *PromoCode, err error) {
|
||||
if err := _u.check(); err != nil {
|
||||
return _node, err
|
||||
}
|
||||
_spec := sqlgraph.NewUpdateSpec(promocode.Table, promocode.Columns, sqlgraph.NewFieldSpec(promocode.FieldID, field.TypeInt64))
|
||||
id, ok := _u.mutation.ID()
|
||||
if !ok {
|
||||
return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "PromoCode.id" for update`)}
|
||||
}
|
||||
_spec.Node.ID.Value = id
|
||||
if fields := _u.fields; len(fields) > 0 {
|
||||
_spec.Node.Columns = make([]string, 0, len(fields))
|
||||
_spec.Node.Columns = append(_spec.Node.Columns, promocode.FieldID)
|
||||
for _, f := range fields {
|
||||
if !promocode.ValidColumn(f) {
|
||||
return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
|
||||
}
|
||||
if f != promocode.FieldID {
|
||||
_spec.Node.Columns = append(_spec.Node.Columns, f)
|
||||
}
|
||||
}
|
||||
}
|
||||
if ps := _u.mutation.predicates; len(ps) > 0 {
|
||||
_spec.Predicate = func(selector *sql.Selector) {
|
||||
for i := range ps {
|
||||
ps[i](selector)
|
||||
}
|
||||
}
|
||||
}
|
||||
if value, ok := _u.mutation.Code(); ok {
|
||||
_spec.SetField(promocode.FieldCode, field.TypeString, value)
|
||||
}
|
||||
if value, ok := _u.mutation.BonusAmount(); ok {
|
||||
_spec.SetField(promocode.FieldBonusAmount, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedBonusAmount(); ok {
|
||||
_spec.AddField(promocode.FieldBonusAmount, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.MaxUses(); ok {
|
||||
_spec.SetField(promocode.FieldMaxUses, field.TypeInt, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedMaxUses(); ok {
|
||||
_spec.AddField(promocode.FieldMaxUses, field.TypeInt, value)
|
||||
}
|
||||
if value, ok := _u.mutation.UsedCount(); ok {
|
||||
_spec.SetField(promocode.FieldUsedCount, field.TypeInt, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedUsedCount(); ok {
|
||||
_spec.AddField(promocode.FieldUsedCount, field.TypeInt, value)
|
||||
}
|
||||
if value, ok := _u.mutation.Status(); ok {
|
||||
_spec.SetField(promocode.FieldStatus, field.TypeString, value)
|
||||
}
|
||||
if value, ok := _u.mutation.ExpiresAt(); ok {
|
||||
_spec.SetField(promocode.FieldExpiresAt, field.TypeTime, value)
|
||||
}
|
||||
if _u.mutation.ExpiresAtCleared() {
|
||||
_spec.ClearField(promocode.FieldExpiresAt, field.TypeTime)
|
||||
}
|
||||
if value, ok := _u.mutation.Notes(); ok {
|
||||
_spec.SetField(promocode.FieldNotes, field.TypeString, value)
|
||||
}
|
||||
if _u.mutation.NotesCleared() {
|
||||
_spec.ClearField(promocode.FieldNotes, field.TypeString)
|
||||
}
|
||||
if value, ok := _u.mutation.UpdatedAt(); ok {
|
||||
_spec.SetField(promocode.FieldUpdatedAt, field.TypeTime, value)
|
||||
}
|
||||
if _u.mutation.UsageRecordsCleared() {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
Inverse: false,
|
||||
Table: promocode.UsageRecordsTable,
|
||||
Columns: []string{promocode.UsageRecordsColumn},
|
||||
Bidi: false,
|
||||
Target: &sqlgraph.EdgeTarget{
|
||||
IDSpec: sqlgraph.NewFieldSpec(promocodeusage.FieldID, field.TypeInt64),
|
||||
},
|
||||
}
|
||||
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
|
||||
}
|
||||
if nodes := _u.mutation.RemovedUsageRecordsIDs(); len(nodes) > 0 && !_u.mutation.UsageRecordsCleared() {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
Inverse: false,
|
||||
Table: promocode.UsageRecordsTable,
|
||||
Columns: []string{promocode.UsageRecordsColumn},
|
||||
Bidi: false,
|
||||
Target: &sqlgraph.EdgeTarget{
|
||||
IDSpec: sqlgraph.NewFieldSpec(promocodeusage.FieldID, field.TypeInt64),
|
||||
},
|
||||
}
|
||||
for _, k := range nodes {
|
||||
edge.Target.Nodes = append(edge.Target.Nodes, k)
|
||||
}
|
||||
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
|
||||
}
|
||||
if nodes := _u.mutation.UsageRecordsIDs(); len(nodes) > 0 {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
Inverse: false,
|
||||
Table: promocode.UsageRecordsTable,
|
||||
Columns: []string{promocode.UsageRecordsColumn},
|
||||
Bidi: false,
|
||||
Target: &sqlgraph.EdgeTarget{
|
||||
IDSpec: sqlgraph.NewFieldSpec(promocodeusage.FieldID, field.TypeInt64),
|
||||
},
|
||||
}
|
||||
for _, k := range nodes {
|
||||
edge.Target.Nodes = append(edge.Target.Nodes, k)
|
||||
}
|
||||
_spec.Edges.Add = append(_spec.Edges.Add, edge)
|
||||
}
|
||||
_node = &PromoCode{config: _u.config}
|
||||
_spec.Assign = _node.assignValues
|
||||
_spec.ScanValues = _node.scanValues
|
||||
if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil {
|
||||
if _, ok := err.(*sqlgraph.NotFoundError); ok {
|
||||
err = &NotFoundError{promocode.Label}
|
||||
} else if sqlgraph.IsConstraintError(err) {
|
||||
err = &ConstraintError{msg: err.Error(), wrap: err}
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
_u.mutation.done = true
|
||||
return _node, nil
|
||||
}
|
||||
187
backend/ent/promocodeusage.go
Normal file
187
backend/ent/promocodeusage.go
Normal file
@@ -0,0 +1,187 @@
|
||||
// Code generated by ent, DO NOT EDIT.
|
||||
|
||||
package ent
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"entgo.io/ent"
|
||||
"entgo.io/ent/dialect/sql"
|
||||
"github.com/Wei-Shaw/sub2api/ent/promocode"
|
||||
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
|
||||
"github.com/Wei-Shaw/sub2api/ent/user"
|
||||
)
|
||||
|
||||
// PromoCodeUsage is the model entity for the PromoCodeUsage schema.
|
||||
type PromoCodeUsage struct {
|
||||
config `json:"-"`
|
||||
// ID of the ent.
|
||||
ID int64 `json:"id,omitempty"`
|
||||
// 优惠码ID
|
||||
PromoCodeID int64 `json:"promo_code_id,omitempty"`
|
||||
// 使用用户ID
|
||||
UserID int64 `json:"user_id,omitempty"`
|
||||
// 实际赠送金额
|
||||
BonusAmount float64 `json:"bonus_amount,omitempty"`
|
||||
// 使用时间
|
||||
UsedAt time.Time `json:"used_at,omitempty"`
|
||||
// Edges holds the relations/edges for other nodes in the graph.
|
||||
// The values are being populated by the PromoCodeUsageQuery when eager-loading is set.
|
||||
Edges PromoCodeUsageEdges `json:"edges"`
|
||||
selectValues sql.SelectValues
|
||||
}
|
||||
|
||||
// PromoCodeUsageEdges holds the relations/edges for other nodes in the graph.
|
||||
type PromoCodeUsageEdges struct {
|
||||
// PromoCode holds the value of the promo_code edge.
|
||||
PromoCode *PromoCode `json:"promo_code,omitempty"`
|
||||
// User holds the value of the user edge.
|
||||
User *User `json:"user,omitempty"`
|
||||
// loadedTypes holds the information for reporting if a
|
||||
// type was loaded (or requested) in eager-loading or not.
|
||||
loadedTypes [2]bool
|
||||
}
|
||||
|
||||
// PromoCodeOrErr returns the PromoCode value or an error if the edge
|
||||
// was not loaded in eager-loading, or loaded but was not found.
|
||||
func (e PromoCodeUsageEdges) PromoCodeOrErr() (*PromoCode, error) {
|
||||
if e.PromoCode != nil {
|
||||
return e.PromoCode, nil
|
||||
} else if e.loadedTypes[0] {
|
||||
return nil, &NotFoundError{label: promocode.Label}
|
||||
}
|
||||
return nil, &NotLoadedError{edge: "promo_code"}
|
||||
}
|
||||
|
||||
// UserOrErr returns the User value or an error if the edge
|
||||
// was not loaded in eager-loading, or loaded but was not found.
|
||||
func (e PromoCodeUsageEdges) UserOrErr() (*User, error) {
|
||||
if e.User != nil {
|
||||
return e.User, nil
|
||||
} else if e.loadedTypes[1] {
|
||||
return nil, &NotFoundError{label: user.Label}
|
||||
}
|
||||
return nil, &NotLoadedError{edge: "user"}
|
||||
}
|
||||
|
||||
// scanValues returns the types for scanning values from sql.Rows.
|
||||
func (*PromoCodeUsage) scanValues(columns []string) ([]any, error) {
|
||||
values := make([]any, len(columns))
|
||||
for i := range columns {
|
||||
switch columns[i] {
|
||||
case promocodeusage.FieldBonusAmount:
|
||||
values[i] = new(sql.NullFloat64)
|
||||
case promocodeusage.FieldID, promocodeusage.FieldPromoCodeID, promocodeusage.FieldUserID:
|
||||
values[i] = new(sql.NullInt64)
|
||||
case promocodeusage.FieldUsedAt:
|
||||
values[i] = new(sql.NullTime)
|
||||
default:
|
||||
values[i] = new(sql.UnknownType)
|
||||
}
|
||||
}
|
||||
return values, nil
|
||||
}
|
||||
|
||||
// assignValues assigns the values that were returned from sql.Rows (after scanning)
|
||||
// to the PromoCodeUsage fields.
|
||||
func (_m *PromoCodeUsage) assignValues(columns []string, values []any) error {
|
||||
if m, n := len(values), len(columns); m < n {
|
||||
return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
|
||||
}
|
||||
for i := range columns {
|
||||
switch columns[i] {
|
||||
case promocodeusage.FieldID:
|
||||
value, ok := values[i].(*sql.NullInt64)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field id", value)
|
||||
}
|
||||
_m.ID = int64(value.Int64)
|
||||
case promocodeusage.FieldPromoCodeID:
|
||||
if value, ok := values[i].(*sql.NullInt64); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field promo_code_id", values[i])
|
||||
} else if value.Valid {
|
||||
_m.PromoCodeID = value.Int64
|
||||
}
|
||||
case promocodeusage.FieldUserID:
|
||||
if value, ok := values[i].(*sql.NullInt64); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field user_id", values[i])
|
||||
} else if value.Valid {
|
||||
_m.UserID = value.Int64
|
||||
}
|
||||
case promocodeusage.FieldBonusAmount:
|
||||
if value, ok := values[i].(*sql.NullFloat64); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field bonus_amount", values[i])
|
||||
} else if value.Valid {
|
||||
_m.BonusAmount = value.Float64
|
||||
}
|
||||
case promocodeusage.FieldUsedAt:
|
||||
if value, ok := values[i].(*sql.NullTime); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field used_at", values[i])
|
||||
} else if value.Valid {
|
||||
_m.UsedAt = value.Time
|
||||
}
|
||||
default:
|
||||
_m.selectValues.Set(columns[i], values[i])
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value returns the ent.Value that was dynamically selected and assigned to the PromoCodeUsage.
|
||||
// This includes values selected through modifiers, order, etc.
|
||||
func (_m *PromoCodeUsage) Value(name string) (ent.Value, error) {
|
||||
return _m.selectValues.Get(name)
|
||||
}
|
||||
|
||||
// QueryPromoCode queries the "promo_code" edge of the PromoCodeUsage entity.
|
||||
func (_m *PromoCodeUsage) QueryPromoCode() *PromoCodeQuery {
|
||||
return NewPromoCodeUsageClient(_m.config).QueryPromoCode(_m)
|
||||
}
|
||||
|
||||
// QueryUser queries the "user" edge of the PromoCodeUsage entity.
|
||||
func (_m *PromoCodeUsage) QueryUser() *UserQuery {
|
||||
return NewPromoCodeUsageClient(_m.config).QueryUser(_m)
|
||||
}
|
||||
|
||||
// Update returns a builder for updating this PromoCodeUsage.
|
||||
// Note that you need to call PromoCodeUsage.Unwrap() before calling this method if this PromoCodeUsage
|
||||
// was returned from a transaction, and the transaction was committed or rolled back.
|
||||
func (_m *PromoCodeUsage) Update() *PromoCodeUsageUpdateOne {
|
||||
return NewPromoCodeUsageClient(_m.config).UpdateOne(_m)
|
||||
}
|
||||
|
||||
// Unwrap unwraps the PromoCodeUsage entity that was returned from a transaction after it was closed,
|
||||
// so that all future queries will be executed through the driver which created the transaction.
|
||||
func (_m *PromoCodeUsage) Unwrap() *PromoCodeUsage {
|
||||
_tx, ok := _m.config.driver.(*txDriver)
|
||||
if !ok {
|
||||
panic("ent: PromoCodeUsage is not a transactional entity")
|
||||
}
|
||||
_m.config.driver = _tx.drv
|
||||
return _m
|
||||
}
|
||||
|
||||
// String implements the fmt.Stringer.
|
||||
func (_m *PromoCodeUsage) String() string {
|
||||
var builder strings.Builder
|
||||
builder.WriteString("PromoCodeUsage(")
|
||||
builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID))
|
||||
builder.WriteString("promo_code_id=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.PromoCodeID))
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("user_id=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.UserID))
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("bonus_amount=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.BonusAmount))
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("used_at=")
|
||||
builder.WriteString(_m.UsedAt.Format(time.ANSIC))
|
||||
builder.WriteByte(')')
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
// PromoCodeUsages is a parsable slice of PromoCodeUsage.
|
||||
type PromoCodeUsages []*PromoCodeUsage
|
||||
125
backend/ent/promocodeusage/promocodeusage.go
Normal file
125
backend/ent/promocodeusage/promocodeusage.go
Normal file
@@ -0,0 +1,125 @@
|
||||
// Code generated by ent, DO NOT EDIT.
|
||||
|
||||
package promocodeusage
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"entgo.io/ent/dialect/sql"
|
||||
"entgo.io/ent/dialect/sql/sqlgraph"
|
||||
)
|
||||
|
||||
const (
|
||||
// Label holds the string label denoting the promocodeusage type in the database.
|
||||
Label = "promo_code_usage"
|
||||
// FieldID holds the string denoting the id field in the database.
|
||||
FieldID = "id"
|
||||
// FieldPromoCodeID holds the string denoting the promo_code_id field in the database.
|
||||
FieldPromoCodeID = "promo_code_id"
|
||||
// FieldUserID holds the string denoting the user_id field in the database.
|
||||
FieldUserID = "user_id"
|
||||
// FieldBonusAmount holds the string denoting the bonus_amount field in the database.
|
||||
FieldBonusAmount = "bonus_amount"
|
||||
// FieldUsedAt holds the string denoting the used_at field in the database.
|
||||
FieldUsedAt = "used_at"
|
||||
// EdgePromoCode holds the string denoting the promo_code edge name in mutations.
|
||||
EdgePromoCode = "promo_code"
|
||||
// EdgeUser holds the string denoting the user edge name in mutations.
|
||||
EdgeUser = "user"
|
||||
// Table holds the table name of the promocodeusage in the database.
|
||||
Table = "promo_code_usages"
|
||||
// PromoCodeTable is the table that holds the promo_code relation/edge.
|
||||
PromoCodeTable = "promo_code_usages"
|
||||
// PromoCodeInverseTable is the table name for the PromoCode entity.
|
||||
// It exists in this package in order to avoid circular dependency with the "promocode" package.
|
||||
PromoCodeInverseTable = "promo_codes"
|
||||
// PromoCodeColumn is the table column denoting the promo_code relation/edge.
|
||||
PromoCodeColumn = "promo_code_id"
|
||||
// UserTable is the table that holds the user relation/edge.
|
||||
UserTable = "promo_code_usages"
|
||||
// UserInverseTable is the table name for the User entity.
|
||||
// It exists in this package in order to avoid circular dependency with the "user" package.
|
||||
UserInverseTable = "users"
|
||||
// UserColumn is the table column denoting the user relation/edge.
|
||||
UserColumn = "user_id"
|
||||
)
|
||||
|
||||
// Columns holds all SQL columns for promocodeusage fields.
|
||||
var Columns = []string{
|
||||
FieldID,
|
||||
FieldPromoCodeID,
|
||||
FieldUserID,
|
||||
FieldBonusAmount,
|
||||
FieldUsedAt,
|
||||
}
|
||||
|
||||
// ValidColumn reports if the column name is valid (part of the table columns).
|
||||
func ValidColumn(column string) bool {
|
||||
for i := range Columns {
|
||||
if column == Columns[i] {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
var (
|
||||
// DefaultUsedAt holds the default value on creation for the "used_at" field.
|
||||
DefaultUsedAt func() time.Time
|
||||
)
|
||||
|
||||
// OrderOption defines the ordering options for the PromoCodeUsage queries.
|
||||
type OrderOption func(*sql.Selector)
|
||||
|
||||
// ByID orders the results by the id field.
|
||||
func ByID(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldID, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByPromoCodeID orders the results by the promo_code_id field.
|
||||
func ByPromoCodeID(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldPromoCodeID, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByUserID orders the results by the user_id field.
|
||||
func ByUserID(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldUserID, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByBonusAmount orders the results by the bonus_amount field.
|
||||
func ByBonusAmount(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldBonusAmount, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByUsedAt orders the results by the used_at field.
|
||||
func ByUsedAt(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldUsedAt, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByPromoCodeField orders the results by promo_code field.
|
||||
func ByPromoCodeField(field string, opts ...sql.OrderTermOption) OrderOption {
|
||||
return func(s *sql.Selector) {
|
||||
sqlgraph.OrderByNeighborTerms(s, newPromoCodeStep(), sql.OrderByField(field, opts...))
|
||||
}
|
||||
}
|
||||
|
||||
// ByUserField orders the results by user field.
|
||||
func ByUserField(field string, opts ...sql.OrderTermOption) OrderOption {
|
||||
return func(s *sql.Selector) {
|
||||
sqlgraph.OrderByNeighborTerms(s, newUserStep(), sql.OrderByField(field, opts...))
|
||||
}
|
||||
}
|
||||
func newPromoCodeStep() *sqlgraph.Step {
|
||||
return sqlgraph.NewStep(
|
||||
sqlgraph.From(Table, FieldID),
|
||||
sqlgraph.To(PromoCodeInverseTable, FieldID),
|
||||
sqlgraph.Edge(sqlgraph.M2O, true, PromoCodeTable, PromoCodeColumn),
|
||||
)
|
||||
}
|
||||
func newUserStep() *sqlgraph.Step {
|
||||
return sqlgraph.NewStep(
|
||||
sqlgraph.From(Table, FieldID),
|
||||
sqlgraph.To(UserInverseTable, FieldID),
|
||||
sqlgraph.Edge(sqlgraph.M2O, true, UserTable, UserColumn),
|
||||
)
|
||||
}
|
||||
257
backend/ent/promocodeusage/where.go
Normal file
257
backend/ent/promocodeusage/where.go
Normal file
@@ -0,0 +1,257 @@
|
||||
// Code generated by ent, DO NOT EDIT.
|
||||
|
||||
package promocodeusage
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"entgo.io/ent/dialect/sql"
|
||||
"entgo.io/ent/dialect/sql/sqlgraph"
|
||||
"github.com/Wei-Shaw/sub2api/ent/predicate"
|
||||
)
|
||||
|
||||
// ID filters vertices based on their ID field.
|
||||
func ID(id int64) predicate.PromoCodeUsage {
|
||||
return predicate.PromoCodeUsage(sql.FieldEQ(FieldID, id))
|
||||
}
|
||||
|
||||
// IDEQ applies the EQ predicate on the ID field.
|
||||
func IDEQ(id int64) predicate.PromoCodeUsage {
|
||||
return predicate.PromoCodeUsage(sql.FieldEQ(FieldID, id))
|
||||
}
|
||||
|
||||
// IDNEQ applies the NEQ predicate on the ID field.
|
||||
func IDNEQ(id int64) predicate.PromoCodeUsage {
|
||||
return predicate.PromoCodeUsage(sql.FieldNEQ(FieldID, id))
|
||||
}
|
||||
|
||||
// IDIn applies the In predicate on the ID field.
|
||||
func IDIn(ids ...int64) predicate.PromoCodeUsage {
|
||||
return predicate.PromoCodeUsage(sql.FieldIn(FieldID, ids...))
|
||||
}
|
||||
|
||||
// IDNotIn applies the NotIn predicate on the ID field.
|
||||
func IDNotIn(ids ...int64) predicate.PromoCodeUsage {
|
||||
return predicate.PromoCodeUsage(sql.FieldNotIn(FieldID, ids...))
|
||||
}
|
||||
|
||||
// IDGT applies the GT predicate on the ID field.
|
||||
func IDGT(id int64) predicate.PromoCodeUsage {
|
||||
return predicate.PromoCodeUsage(sql.FieldGT(FieldID, id))
|
||||
}
|
||||
|
||||
// IDGTE applies the GTE predicate on the ID field.
|
||||
func IDGTE(id int64) predicate.PromoCodeUsage {
|
||||
return predicate.PromoCodeUsage(sql.FieldGTE(FieldID, id))
|
||||
}
|
||||
|
||||
// IDLT applies the LT predicate on the ID field.
|
||||
func IDLT(id int64) predicate.PromoCodeUsage {
|
||||
return predicate.PromoCodeUsage(sql.FieldLT(FieldID, id))
|
||||
}
|
||||
|
||||
// IDLTE applies the LTE predicate on the ID field.
|
||||
func IDLTE(id int64) predicate.PromoCodeUsage {
|
||||
return predicate.PromoCodeUsage(sql.FieldLTE(FieldID, id))
|
||||
}
|
||||
|
||||
// PromoCodeID applies equality check predicate on the "promo_code_id" field. It's identical to PromoCodeIDEQ.
|
||||
func PromoCodeID(v int64) predicate.PromoCodeUsage {
|
||||
return predicate.PromoCodeUsage(sql.FieldEQ(FieldPromoCodeID, v))
|
||||
}
|
||||
|
||||
// UserID applies equality check predicate on the "user_id" field. It's identical to UserIDEQ.
|
||||
func UserID(v int64) predicate.PromoCodeUsage {
|
||||
return predicate.PromoCodeUsage(sql.FieldEQ(FieldUserID, v))
|
||||
}
|
||||
|
||||
// BonusAmount applies equality check predicate on the "bonus_amount" field. It's identical to BonusAmountEQ.
|
||||
func BonusAmount(v float64) predicate.PromoCodeUsage {
|
||||
return predicate.PromoCodeUsage(sql.FieldEQ(FieldBonusAmount, v))
|
||||
}
|
||||
|
||||
// UsedAt applies equality check predicate on the "used_at" field. It's identical to UsedAtEQ.
|
||||
func UsedAt(v time.Time) predicate.PromoCodeUsage {
|
||||
return predicate.PromoCodeUsage(sql.FieldEQ(FieldUsedAt, v))
|
||||
}
|
||||
|
||||
// PromoCodeIDEQ applies the EQ predicate on the "promo_code_id" field.
|
||||
func PromoCodeIDEQ(v int64) predicate.PromoCodeUsage {
|
||||
return predicate.PromoCodeUsage(sql.FieldEQ(FieldPromoCodeID, v))
|
||||
}
|
||||
|
||||
// PromoCodeIDNEQ applies the NEQ predicate on the "promo_code_id" field.
|
||||
func PromoCodeIDNEQ(v int64) predicate.PromoCodeUsage {
|
||||
return predicate.PromoCodeUsage(sql.FieldNEQ(FieldPromoCodeID, v))
|
||||
}
|
||||
|
||||
// PromoCodeIDIn applies the In predicate on the "promo_code_id" field.
|
||||
func PromoCodeIDIn(vs ...int64) predicate.PromoCodeUsage {
|
||||
return predicate.PromoCodeUsage(sql.FieldIn(FieldPromoCodeID, vs...))
|
||||
}
|
||||
|
||||
// PromoCodeIDNotIn applies the NotIn predicate on the "promo_code_id" field.
|
||||
func PromoCodeIDNotIn(vs ...int64) predicate.PromoCodeUsage {
|
||||
return predicate.PromoCodeUsage(sql.FieldNotIn(FieldPromoCodeID, vs...))
|
||||
}
|
||||
|
||||
// UserIDEQ applies the EQ predicate on the "user_id" field.
|
||||
func UserIDEQ(v int64) predicate.PromoCodeUsage {
|
||||
return predicate.PromoCodeUsage(sql.FieldEQ(FieldUserID, v))
|
||||
}
|
||||
|
||||
// UserIDNEQ applies the NEQ predicate on the "user_id" field.
|
||||
func UserIDNEQ(v int64) predicate.PromoCodeUsage {
|
||||
return predicate.PromoCodeUsage(sql.FieldNEQ(FieldUserID, v))
|
||||
}
|
||||
|
||||
// UserIDIn applies the In predicate on the "user_id" field.
|
||||
func UserIDIn(vs ...int64) predicate.PromoCodeUsage {
|
||||
return predicate.PromoCodeUsage(sql.FieldIn(FieldUserID, vs...))
|
||||
}
|
||||
|
||||
// UserIDNotIn applies the NotIn predicate on the "user_id" field.
|
||||
func UserIDNotIn(vs ...int64) predicate.PromoCodeUsage {
|
||||
return predicate.PromoCodeUsage(sql.FieldNotIn(FieldUserID, vs...))
|
||||
}
|
||||
|
||||
// BonusAmountEQ applies the EQ predicate on the "bonus_amount" field.
|
||||
func BonusAmountEQ(v float64) predicate.PromoCodeUsage {
|
||||
return predicate.PromoCodeUsage(sql.FieldEQ(FieldBonusAmount, v))
|
||||
}
|
||||
|
||||
// BonusAmountNEQ applies the NEQ predicate on the "bonus_amount" field.
|
||||
func BonusAmountNEQ(v float64) predicate.PromoCodeUsage {
|
||||
return predicate.PromoCodeUsage(sql.FieldNEQ(FieldBonusAmount, v))
|
||||
}
|
||||
|
||||
// BonusAmountIn applies the In predicate on the "bonus_amount" field.
|
||||
func BonusAmountIn(vs ...float64) predicate.PromoCodeUsage {
|
||||
return predicate.PromoCodeUsage(sql.FieldIn(FieldBonusAmount, vs...))
|
||||
}
|
||||
|
||||
// BonusAmountNotIn applies the NotIn predicate on the "bonus_amount" field.
|
||||
func BonusAmountNotIn(vs ...float64) predicate.PromoCodeUsage {
|
||||
return predicate.PromoCodeUsage(sql.FieldNotIn(FieldBonusAmount, vs...))
|
||||
}
|
||||
|
||||
// BonusAmountGT applies the GT predicate on the "bonus_amount" field.
|
||||
func BonusAmountGT(v float64) predicate.PromoCodeUsage {
|
||||
return predicate.PromoCodeUsage(sql.FieldGT(FieldBonusAmount, v))
|
||||
}
|
||||
|
||||
// BonusAmountGTE applies the GTE predicate on the "bonus_amount" field.
|
||||
func BonusAmountGTE(v float64) predicate.PromoCodeUsage {
|
||||
return predicate.PromoCodeUsage(sql.FieldGTE(FieldBonusAmount, v))
|
||||
}
|
||||
|
||||
// BonusAmountLT applies the LT predicate on the "bonus_amount" field.
|
||||
func BonusAmountLT(v float64) predicate.PromoCodeUsage {
|
||||
return predicate.PromoCodeUsage(sql.FieldLT(FieldBonusAmount, v))
|
||||
}
|
||||
|
||||
// BonusAmountLTE applies the LTE predicate on the "bonus_amount" field.
|
||||
func BonusAmountLTE(v float64) predicate.PromoCodeUsage {
|
||||
return predicate.PromoCodeUsage(sql.FieldLTE(FieldBonusAmount, v))
|
||||
}
|
||||
|
||||
// UsedAtEQ applies the EQ predicate on the "used_at" field.
|
||||
func UsedAtEQ(v time.Time) predicate.PromoCodeUsage {
|
||||
return predicate.PromoCodeUsage(sql.FieldEQ(FieldUsedAt, v))
|
||||
}
|
||||
|
||||
// UsedAtNEQ applies the NEQ predicate on the "used_at" field.
|
||||
func UsedAtNEQ(v time.Time) predicate.PromoCodeUsage {
|
||||
return predicate.PromoCodeUsage(sql.FieldNEQ(FieldUsedAt, v))
|
||||
}
|
||||
|
||||
// UsedAtIn applies the In predicate on the "used_at" field.
|
||||
func UsedAtIn(vs ...time.Time) predicate.PromoCodeUsage {
|
||||
return predicate.PromoCodeUsage(sql.FieldIn(FieldUsedAt, vs...))
|
||||
}
|
||||
|
||||
// UsedAtNotIn applies the NotIn predicate on the "used_at" field.
|
||||
func UsedAtNotIn(vs ...time.Time) predicate.PromoCodeUsage {
|
||||
return predicate.PromoCodeUsage(sql.FieldNotIn(FieldUsedAt, vs...))
|
||||
}
|
||||
|
||||
// UsedAtGT applies the GT predicate on the "used_at" field.
|
||||
func UsedAtGT(v time.Time) predicate.PromoCodeUsage {
|
||||
return predicate.PromoCodeUsage(sql.FieldGT(FieldUsedAt, v))
|
||||
}
|
||||
|
||||
// UsedAtGTE applies the GTE predicate on the "used_at" field.
|
||||
func UsedAtGTE(v time.Time) predicate.PromoCodeUsage {
|
||||
return predicate.PromoCodeUsage(sql.FieldGTE(FieldUsedAt, v))
|
||||
}
|
||||
|
||||
// UsedAtLT applies the LT predicate on the "used_at" field.
|
||||
func UsedAtLT(v time.Time) predicate.PromoCodeUsage {
|
||||
return predicate.PromoCodeUsage(sql.FieldLT(FieldUsedAt, v))
|
||||
}
|
||||
|
||||
// UsedAtLTE applies the LTE predicate on the "used_at" field.
|
||||
func UsedAtLTE(v time.Time) predicate.PromoCodeUsage {
|
||||
return predicate.PromoCodeUsage(sql.FieldLTE(FieldUsedAt, v))
|
||||
}
|
||||
|
||||
// HasPromoCode applies the HasEdge predicate on the "promo_code" edge.
|
||||
func HasPromoCode() predicate.PromoCodeUsage {
|
||||
return predicate.PromoCodeUsage(func(s *sql.Selector) {
|
||||
step := sqlgraph.NewStep(
|
||||
sqlgraph.From(Table, FieldID),
|
||||
sqlgraph.Edge(sqlgraph.M2O, true, PromoCodeTable, PromoCodeColumn),
|
||||
)
|
||||
sqlgraph.HasNeighbors(s, step)
|
||||
})
|
||||
}
|
||||
|
||||
// HasPromoCodeWith applies the HasEdge predicate on the "promo_code" edge with a given conditions (other predicates).
|
||||
func HasPromoCodeWith(preds ...predicate.PromoCode) predicate.PromoCodeUsage {
|
||||
return predicate.PromoCodeUsage(func(s *sql.Selector) {
|
||||
step := newPromoCodeStep()
|
||||
sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
|
||||
for _, p := range preds {
|
||||
p(s)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// HasUser applies the HasEdge predicate on the "user" edge.
|
||||
func HasUser() predicate.PromoCodeUsage {
|
||||
return predicate.PromoCodeUsage(func(s *sql.Selector) {
|
||||
step := sqlgraph.NewStep(
|
||||
sqlgraph.From(Table, FieldID),
|
||||
sqlgraph.Edge(sqlgraph.M2O, true, UserTable, UserColumn),
|
||||
)
|
||||
sqlgraph.HasNeighbors(s, step)
|
||||
})
|
||||
}
|
||||
|
||||
// HasUserWith applies the HasEdge predicate on the "user" edge with a given conditions (other predicates).
|
||||
func HasUserWith(preds ...predicate.User) predicate.PromoCodeUsage {
|
||||
return predicate.PromoCodeUsage(func(s *sql.Selector) {
|
||||
step := newUserStep()
|
||||
sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
|
||||
for _, p := range preds {
|
||||
p(s)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// And groups predicates with the AND operator between them.
|
||||
func And(predicates ...predicate.PromoCodeUsage) predicate.PromoCodeUsage {
|
||||
return predicate.PromoCodeUsage(sql.AndPredicates(predicates...))
|
||||
}
|
||||
|
||||
// Or groups predicates with the OR operator between them.
|
||||
func Or(predicates ...predicate.PromoCodeUsage) predicate.PromoCodeUsage {
|
||||
return predicate.PromoCodeUsage(sql.OrPredicates(predicates...))
|
||||
}
|
||||
|
||||
// Not applies the not operator on the given predicate.
|
||||
func Not(p predicate.PromoCodeUsage) predicate.PromoCodeUsage {
|
||||
return predicate.PromoCodeUsage(sql.NotPredicates(p))
|
||||
}
|
||||
696
backend/ent/promocodeusage_create.go
Normal file
696
backend/ent/promocodeusage_create.go
Normal file
@@ -0,0 +1,696 @@
|
||||
// Code generated by ent, DO NOT EDIT.
|
||||
|
||||
package ent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"entgo.io/ent/dialect/sql"
|
||||
"entgo.io/ent/dialect/sql/sqlgraph"
|
||||
"entgo.io/ent/schema/field"
|
||||
"github.com/Wei-Shaw/sub2api/ent/promocode"
|
||||
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
|
||||
"github.com/Wei-Shaw/sub2api/ent/user"
|
||||
)
|
||||
|
||||
// PromoCodeUsageCreate is the builder for creating a PromoCodeUsage entity.
|
||||
type PromoCodeUsageCreate struct {
|
||||
config
|
||||
mutation *PromoCodeUsageMutation
|
||||
hooks []Hook
|
||||
conflict []sql.ConflictOption
|
||||
}
|
||||
|
||||
// SetPromoCodeID sets the "promo_code_id" field.
|
||||
func (_c *PromoCodeUsageCreate) SetPromoCodeID(v int64) *PromoCodeUsageCreate {
|
||||
_c.mutation.SetPromoCodeID(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetUserID sets the "user_id" field.
|
||||
func (_c *PromoCodeUsageCreate) SetUserID(v int64) *PromoCodeUsageCreate {
|
||||
_c.mutation.SetUserID(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetBonusAmount sets the "bonus_amount" field.
|
||||
func (_c *PromoCodeUsageCreate) SetBonusAmount(v float64) *PromoCodeUsageCreate {
|
||||
_c.mutation.SetBonusAmount(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetUsedAt sets the "used_at" field.
|
||||
func (_c *PromoCodeUsageCreate) SetUsedAt(v time.Time) *PromoCodeUsageCreate {
|
||||
_c.mutation.SetUsedAt(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableUsedAt sets the "used_at" field if the given value is not nil.
|
||||
func (_c *PromoCodeUsageCreate) SetNillableUsedAt(v *time.Time) *PromoCodeUsageCreate {
|
||||
if v != nil {
|
||||
_c.SetUsedAt(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetPromoCode sets the "promo_code" edge to the PromoCode entity.
|
||||
func (_c *PromoCodeUsageCreate) SetPromoCode(v *PromoCode) *PromoCodeUsageCreate {
|
||||
return _c.SetPromoCodeID(v.ID)
|
||||
}
|
||||
|
||||
// SetUser sets the "user" edge to the User entity.
|
||||
func (_c *PromoCodeUsageCreate) SetUser(v *User) *PromoCodeUsageCreate {
|
||||
return _c.SetUserID(v.ID)
|
||||
}
|
||||
|
||||
// Mutation returns the PromoCodeUsageMutation object of the builder.
|
||||
func (_c *PromoCodeUsageCreate) Mutation() *PromoCodeUsageMutation {
|
||||
return _c.mutation
|
||||
}
|
||||
|
||||
// Save creates the PromoCodeUsage in the database.
|
||||
func (_c *PromoCodeUsageCreate) Save(ctx context.Context) (*PromoCodeUsage, error) {
|
||||
_c.defaults()
|
||||
return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks)
|
||||
}
|
||||
|
||||
// SaveX calls Save and panics if Save returns an error.
|
||||
func (_c *PromoCodeUsageCreate) SaveX(ctx context.Context) *PromoCodeUsage {
|
||||
v, err := _c.Save(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// Exec executes the query.
|
||||
func (_c *PromoCodeUsageCreate) Exec(ctx context.Context) error {
|
||||
_, err := _c.Save(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
// ExecX is like Exec, but panics if an error occurs.
|
||||
func (_c *PromoCodeUsageCreate) ExecX(ctx context.Context) {
|
||||
if err := _c.Exec(ctx); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
// defaults sets the default values of the builder before save.
|
||||
func (_c *PromoCodeUsageCreate) defaults() {
|
||||
if _, ok := _c.mutation.UsedAt(); !ok {
|
||||
v := promocodeusage.DefaultUsedAt()
|
||||
_c.mutation.SetUsedAt(v)
|
||||
}
|
||||
}
|
||||
|
||||
// check runs all checks and user-defined validators on the builder.
|
||||
func (_c *PromoCodeUsageCreate) check() error {
|
||||
if _, ok := _c.mutation.PromoCodeID(); !ok {
|
||||
return &ValidationError{Name: "promo_code_id", err: errors.New(`ent: missing required field "PromoCodeUsage.promo_code_id"`)}
|
||||
}
|
||||
if _, ok := _c.mutation.UserID(); !ok {
|
||||
return &ValidationError{Name: "user_id", err: errors.New(`ent: missing required field "PromoCodeUsage.user_id"`)}
|
||||
}
|
||||
if _, ok := _c.mutation.BonusAmount(); !ok {
|
||||
return &ValidationError{Name: "bonus_amount", err: errors.New(`ent: missing required field "PromoCodeUsage.bonus_amount"`)}
|
||||
}
|
||||
if _, ok := _c.mutation.UsedAt(); !ok {
|
||||
return &ValidationError{Name: "used_at", err: errors.New(`ent: missing required field "PromoCodeUsage.used_at"`)}
|
||||
}
|
||||
if len(_c.mutation.PromoCodeIDs()) == 0 {
|
||||
return &ValidationError{Name: "promo_code", err: errors.New(`ent: missing required edge "PromoCodeUsage.promo_code"`)}
|
||||
}
|
||||
if len(_c.mutation.UserIDs()) == 0 {
|
||||
return &ValidationError{Name: "user", err: errors.New(`ent: missing required edge "PromoCodeUsage.user"`)}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (_c *PromoCodeUsageCreate) sqlSave(ctx context.Context) (*PromoCodeUsage, error) {
|
||||
if err := _c.check(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_node, _spec := _c.createSpec()
|
||||
if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil {
|
||||
if sqlgraph.IsConstraintError(err) {
|
||||
err = &ConstraintError{msg: err.Error(), wrap: err}
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
id := _spec.ID.Value.(int64)
|
||||
_node.ID = int64(id)
|
||||
_c.mutation.id = &_node.ID
|
||||
_c.mutation.done = true
|
||||
return _node, nil
|
||||
}
|
||||
|
||||
func (_c *PromoCodeUsageCreate) createSpec() (*PromoCodeUsage, *sqlgraph.CreateSpec) {
|
||||
var (
|
||||
_node = &PromoCodeUsage{config: _c.config}
|
||||
_spec = sqlgraph.NewCreateSpec(promocodeusage.Table, sqlgraph.NewFieldSpec(promocodeusage.FieldID, field.TypeInt64))
|
||||
)
|
||||
_spec.OnConflict = _c.conflict
|
||||
if value, ok := _c.mutation.BonusAmount(); ok {
|
||||
_spec.SetField(promocodeusage.FieldBonusAmount, field.TypeFloat64, value)
|
||||
_node.BonusAmount = value
|
||||
}
|
||||
if value, ok := _c.mutation.UsedAt(); ok {
|
||||
_spec.SetField(promocodeusage.FieldUsedAt, field.TypeTime, value)
|
||||
_node.UsedAt = value
|
||||
}
|
||||
if nodes := _c.mutation.PromoCodeIDs(); len(nodes) > 0 {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.M2O,
|
||||
Inverse: true,
|
||||
Table: promocodeusage.PromoCodeTable,
|
||||
Columns: []string{promocodeusage.PromoCodeColumn},
|
||||
Bidi: false,
|
||||
Target: &sqlgraph.EdgeTarget{
|
||||
IDSpec: sqlgraph.NewFieldSpec(promocode.FieldID, field.TypeInt64),
|
||||
},
|
||||
}
|
||||
for _, k := range nodes {
|
||||
edge.Target.Nodes = append(edge.Target.Nodes, k)
|
||||
}
|
||||
_node.PromoCodeID = nodes[0]
|
||||
_spec.Edges = append(_spec.Edges, edge)
|
||||
}
|
||||
if nodes := _c.mutation.UserIDs(); len(nodes) > 0 {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.M2O,
|
||||
Inverse: true,
|
||||
Table: promocodeusage.UserTable,
|
||||
Columns: []string{promocodeusage.UserColumn},
|
||||
Bidi: false,
|
||||
Target: &sqlgraph.EdgeTarget{
|
||||
IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
|
||||
},
|
||||
}
|
||||
for _, k := range nodes {
|
||||
edge.Target.Nodes = append(edge.Target.Nodes, k)
|
||||
}
|
||||
_node.UserID = nodes[0]
|
||||
_spec.Edges = append(_spec.Edges, edge)
|
||||
}
|
||||
return _node, _spec
|
||||
}
|
||||
|
||||
// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
|
||||
// of the `INSERT` statement. For example:
|
||||
//
|
||||
// client.PromoCodeUsage.Create().
|
||||
// SetPromoCodeID(v).
|
||||
// OnConflict(
|
||||
// // Update the row with the new values
|
||||
// // the was proposed for insertion.
|
||||
// sql.ResolveWithNewValues(),
|
||||
// ).
|
||||
// // Override some of the fields with custom
|
||||
// // update values.
|
||||
// Update(func(u *ent.PromoCodeUsageUpsert) {
|
||||
// SetPromoCodeID(v+v).
|
||||
// }).
|
||||
// Exec(ctx)
|
||||
func (_c *PromoCodeUsageCreate) OnConflict(opts ...sql.ConflictOption) *PromoCodeUsageUpsertOne {
|
||||
_c.conflict = opts
|
||||
return &PromoCodeUsageUpsertOne{
|
||||
create: _c,
|
||||
}
|
||||
}
|
||||
|
||||
// OnConflictColumns calls `OnConflict` and configures the columns
|
||||
// as conflict target. Using this option is equivalent to using:
|
||||
//
|
||||
// client.PromoCodeUsage.Create().
|
||||
// OnConflict(sql.ConflictColumns(columns...)).
|
||||
// Exec(ctx)
|
||||
func (_c *PromoCodeUsageCreate) OnConflictColumns(columns ...string) *PromoCodeUsageUpsertOne {
|
||||
_c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
|
||||
return &PromoCodeUsageUpsertOne{
|
||||
create: _c,
|
||||
}
|
||||
}
|
||||
|
||||
type (
|
||||
// PromoCodeUsageUpsertOne is the builder for "upsert"-ing
|
||||
// one PromoCodeUsage node.
|
||||
PromoCodeUsageUpsertOne struct {
|
||||
create *PromoCodeUsageCreate
|
||||
}
|
||||
|
||||
// PromoCodeUsageUpsert is the "OnConflict" setter.
|
||||
PromoCodeUsageUpsert struct {
|
||||
*sql.UpdateSet
|
||||
}
|
||||
)
|
||||
|
||||
// SetPromoCodeID sets the "promo_code_id" field.
|
||||
func (u *PromoCodeUsageUpsert) SetPromoCodeID(v int64) *PromoCodeUsageUpsert {
|
||||
u.Set(promocodeusage.FieldPromoCodeID, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdatePromoCodeID sets the "promo_code_id" field to the value that was provided on create.
|
||||
func (u *PromoCodeUsageUpsert) UpdatePromoCodeID() *PromoCodeUsageUpsert {
|
||||
u.SetExcluded(promocodeusage.FieldPromoCodeID)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetUserID sets the "user_id" field.
|
||||
func (u *PromoCodeUsageUpsert) SetUserID(v int64) *PromoCodeUsageUpsert {
|
||||
u.Set(promocodeusage.FieldUserID, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateUserID sets the "user_id" field to the value that was provided on create.
|
||||
func (u *PromoCodeUsageUpsert) UpdateUserID() *PromoCodeUsageUpsert {
|
||||
u.SetExcluded(promocodeusage.FieldUserID)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetBonusAmount sets the "bonus_amount" field.
|
||||
func (u *PromoCodeUsageUpsert) SetBonusAmount(v float64) *PromoCodeUsageUpsert {
|
||||
u.Set(promocodeusage.FieldBonusAmount, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateBonusAmount sets the "bonus_amount" field to the value that was provided on create.
|
||||
func (u *PromoCodeUsageUpsert) UpdateBonusAmount() *PromoCodeUsageUpsert {
|
||||
u.SetExcluded(promocodeusage.FieldBonusAmount)
|
||||
return u
|
||||
}
|
||||
|
||||
// AddBonusAmount adds v to the "bonus_amount" field.
|
||||
func (u *PromoCodeUsageUpsert) AddBonusAmount(v float64) *PromoCodeUsageUpsert {
|
||||
u.Add(promocodeusage.FieldBonusAmount, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetUsedAt sets the "used_at" field.
|
||||
func (u *PromoCodeUsageUpsert) SetUsedAt(v time.Time) *PromoCodeUsageUpsert {
|
||||
u.Set(promocodeusage.FieldUsedAt, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateUsedAt sets the "used_at" field to the value that was provided on create.
|
||||
func (u *PromoCodeUsageUpsert) UpdateUsedAt() *PromoCodeUsageUpsert {
|
||||
u.SetExcluded(promocodeusage.FieldUsedAt)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateNewValues updates the mutable fields using the new values that were set on create.
|
||||
// Using this option is equivalent to using:
|
||||
//
|
||||
// client.PromoCodeUsage.Create().
|
||||
// OnConflict(
|
||||
// sql.ResolveWithNewValues(),
|
||||
// ).
|
||||
// Exec(ctx)
|
||||
func (u *PromoCodeUsageUpsertOne) UpdateNewValues() *PromoCodeUsageUpsertOne {
|
||||
u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
|
||||
return u
|
||||
}
|
||||
|
||||
// Ignore sets each column to itself in case of conflict.
|
||||
// Using this option is equivalent to using:
|
||||
//
|
||||
// client.PromoCodeUsage.Create().
|
||||
// OnConflict(sql.ResolveWithIgnore()).
|
||||
// Exec(ctx)
|
||||
func (u *PromoCodeUsageUpsertOne) Ignore() *PromoCodeUsageUpsertOne {
|
||||
u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
|
||||
return u
|
||||
}
|
||||
|
||||
// DoNothing configures the conflict_action to `DO NOTHING`.
|
||||
// Supported only by SQLite and PostgreSQL.
|
||||
func (u *PromoCodeUsageUpsertOne) DoNothing() *PromoCodeUsageUpsertOne {
|
||||
u.create.conflict = append(u.create.conflict, sql.DoNothing())
|
||||
return u
|
||||
}
|
||||
|
||||
// Update allows overriding fields `UPDATE` values. See the PromoCodeUsageCreate.OnConflict
|
||||
// documentation for more info.
|
||||
func (u *PromoCodeUsageUpsertOne) Update(set func(*PromoCodeUsageUpsert)) *PromoCodeUsageUpsertOne {
|
||||
u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
|
||||
set(&PromoCodeUsageUpsert{UpdateSet: update})
|
||||
}))
|
||||
return u
|
||||
}
|
||||
|
||||
// SetPromoCodeID sets the "promo_code_id" field.
|
||||
func (u *PromoCodeUsageUpsertOne) SetPromoCodeID(v int64) *PromoCodeUsageUpsertOne {
|
||||
return u.Update(func(s *PromoCodeUsageUpsert) {
|
||||
s.SetPromoCodeID(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdatePromoCodeID sets the "promo_code_id" field to the value that was provided on create.
|
||||
func (u *PromoCodeUsageUpsertOne) UpdatePromoCodeID() *PromoCodeUsageUpsertOne {
|
||||
return u.Update(func(s *PromoCodeUsageUpsert) {
|
||||
s.UpdatePromoCodeID()
|
||||
})
|
||||
}
|
||||
|
||||
// SetUserID sets the "user_id" field.
|
||||
func (u *PromoCodeUsageUpsertOne) SetUserID(v int64) *PromoCodeUsageUpsertOne {
|
||||
return u.Update(func(s *PromoCodeUsageUpsert) {
|
||||
s.SetUserID(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateUserID sets the "user_id" field to the value that was provided on create.
|
||||
func (u *PromoCodeUsageUpsertOne) UpdateUserID() *PromoCodeUsageUpsertOne {
|
||||
return u.Update(func(s *PromoCodeUsageUpsert) {
|
||||
s.UpdateUserID()
|
||||
})
|
||||
}
|
||||
|
||||
// SetBonusAmount sets the "bonus_amount" field.
|
||||
func (u *PromoCodeUsageUpsertOne) SetBonusAmount(v float64) *PromoCodeUsageUpsertOne {
|
||||
return u.Update(func(s *PromoCodeUsageUpsert) {
|
||||
s.SetBonusAmount(v)
|
||||
})
|
||||
}
|
||||
|
||||
// AddBonusAmount adds v to the "bonus_amount" field.
|
||||
func (u *PromoCodeUsageUpsertOne) AddBonusAmount(v float64) *PromoCodeUsageUpsertOne {
|
||||
return u.Update(func(s *PromoCodeUsageUpsert) {
|
||||
s.AddBonusAmount(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateBonusAmount sets the "bonus_amount" field to the value that was provided on create.
|
||||
func (u *PromoCodeUsageUpsertOne) UpdateBonusAmount() *PromoCodeUsageUpsertOne {
|
||||
return u.Update(func(s *PromoCodeUsageUpsert) {
|
||||
s.UpdateBonusAmount()
|
||||
})
|
||||
}
|
||||
|
||||
// SetUsedAt sets the "used_at" field.
|
||||
func (u *PromoCodeUsageUpsertOne) SetUsedAt(v time.Time) *PromoCodeUsageUpsertOne {
|
||||
return u.Update(func(s *PromoCodeUsageUpsert) {
|
||||
s.SetUsedAt(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateUsedAt sets the "used_at" field to the value that was provided on create.
|
||||
func (u *PromoCodeUsageUpsertOne) UpdateUsedAt() *PromoCodeUsageUpsertOne {
|
||||
return u.Update(func(s *PromoCodeUsageUpsert) {
|
||||
s.UpdateUsedAt()
|
||||
})
|
||||
}
|
||||
|
||||
// Exec executes the query.
|
||||
func (u *PromoCodeUsageUpsertOne) Exec(ctx context.Context) error {
|
||||
if len(u.create.conflict) == 0 {
|
||||
return errors.New("ent: missing options for PromoCodeUsageCreate.OnConflict")
|
||||
}
|
||||
return u.create.Exec(ctx)
|
||||
}
|
||||
|
||||
// ExecX is like Exec, but panics if an error occurs.
|
||||
func (u *PromoCodeUsageUpsertOne) ExecX(ctx context.Context) {
|
||||
if err := u.create.Exec(ctx); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Exec executes the UPSERT query and returns the inserted/updated ID.
|
||||
func (u *PromoCodeUsageUpsertOne) ID(ctx context.Context) (id int64, err error) {
|
||||
node, err := u.create.Save(ctx)
|
||||
if err != nil {
|
||||
return id, err
|
||||
}
|
||||
return node.ID, nil
|
||||
}
|
||||
|
||||
// IDX is like ID, but panics if an error occurs.
|
||||
func (u *PromoCodeUsageUpsertOne) IDX(ctx context.Context) int64 {
|
||||
id, err := u.ID(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
// PromoCodeUsageCreateBulk is the builder for creating many PromoCodeUsage entities in bulk.
|
||||
type PromoCodeUsageCreateBulk struct {
|
||||
config
|
||||
err error
|
||||
builders []*PromoCodeUsageCreate
|
||||
conflict []sql.ConflictOption
|
||||
}
|
||||
|
||||
// Save creates the PromoCodeUsage entities in the database.
|
||||
func (_c *PromoCodeUsageCreateBulk) Save(ctx context.Context) ([]*PromoCodeUsage, error) {
|
||||
if _c.err != nil {
|
||||
return nil, _c.err
|
||||
}
|
||||
specs := make([]*sqlgraph.CreateSpec, len(_c.builders))
|
||||
nodes := make([]*PromoCodeUsage, len(_c.builders))
|
||||
mutators := make([]Mutator, len(_c.builders))
|
||||
for i := range _c.builders {
|
||||
func(i int, root context.Context) {
|
||||
builder := _c.builders[i]
|
||||
builder.defaults()
|
||||
var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) {
|
||||
mutation, ok := m.(*PromoCodeUsageMutation)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unexpected mutation type %T", m)
|
||||
}
|
||||
if err := builder.check(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
builder.mutation = mutation
|
||||
var err error
|
||||
nodes[i], specs[i] = builder.createSpec()
|
||||
if i < len(mutators)-1 {
|
||||
_, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation)
|
||||
} else {
|
||||
spec := &sqlgraph.BatchCreateSpec{Nodes: specs}
|
||||
spec.OnConflict = _c.conflict
|
||||
// Invoke the actual operation on the latest mutation in the chain.
|
||||
if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil {
|
||||
if sqlgraph.IsConstraintError(err) {
|
||||
err = &ConstraintError{msg: err.Error(), wrap: err}
|
||||
}
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mutation.id = &nodes[i].ID
|
||||
if specs[i].ID.Value != nil {
|
||||
id := specs[i].ID.Value.(int64)
|
||||
nodes[i].ID = int64(id)
|
||||
}
|
||||
mutation.done = true
|
||||
return nodes[i], nil
|
||||
})
|
||||
for i := len(builder.hooks) - 1; i >= 0; i-- {
|
||||
mut = builder.hooks[i](mut)
|
||||
}
|
||||
mutators[i] = mut
|
||||
}(i, ctx)
|
||||
}
|
||||
if len(mutators) > 0 {
|
||||
if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return nodes, nil
|
||||
}
|
||||
|
||||
// SaveX is like Save, but panics if an error occurs.
|
||||
func (_c *PromoCodeUsageCreateBulk) SaveX(ctx context.Context) []*PromoCodeUsage {
|
||||
v, err := _c.Save(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// Exec executes the query.
|
||||
func (_c *PromoCodeUsageCreateBulk) Exec(ctx context.Context) error {
|
||||
_, err := _c.Save(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
// ExecX is like Exec, but panics if an error occurs.
|
||||
func (_c *PromoCodeUsageCreateBulk) ExecX(ctx context.Context) {
|
||||
if err := _c.Exec(ctx); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
|
||||
// of the `INSERT` statement. For example:
|
||||
//
|
||||
// client.PromoCodeUsage.CreateBulk(builders...).
|
||||
// OnConflict(
|
||||
// // Update the row with the new values
|
||||
// // the was proposed for insertion.
|
||||
// sql.ResolveWithNewValues(),
|
||||
// ).
|
||||
// // Override some of the fields with custom
|
||||
// // update values.
|
||||
// Update(func(u *ent.PromoCodeUsageUpsert) {
|
||||
// SetPromoCodeID(v+v).
|
||||
// }).
|
||||
// Exec(ctx)
|
||||
func (_c *PromoCodeUsageCreateBulk) OnConflict(opts ...sql.ConflictOption) *PromoCodeUsageUpsertBulk {
|
||||
_c.conflict = opts
|
||||
return &PromoCodeUsageUpsertBulk{
|
||||
create: _c,
|
||||
}
|
||||
}
|
||||
|
||||
// OnConflictColumns calls `OnConflict` and configures the columns
|
||||
// as conflict target. Using this option is equivalent to using:
|
||||
//
|
||||
// client.PromoCodeUsage.Create().
|
||||
// OnConflict(sql.ConflictColumns(columns...)).
|
||||
// Exec(ctx)
|
||||
func (_c *PromoCodeUsageCreateBulk) OnConflictColumns(columns ...string) *PromoCodeUsageUpsertBulk {
|
||||
_c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
|
||||
return &PromoCodeUsageUpsertBulk{
|
||||
create: _c,
|
||||
}
|
||||
}
|
||||
|
||||
// PromoCodeUsageUpsertBulk is the builder for "upsert"-ing
|
||||
// a bulk of PromoCodeUsage nodes.
|
||||
type PromoCodeUsageUpsertBulk struct {
|
||||
create *PromoCodeUsageCreateBulk
|
||||
}
|
||||
|
||||
// UpdateNewValues updates the mutable fields using the new values that
|
||||
// were set on create. Using this option is equivalent to using:
|
||||
//
|
||||
// client.PromoCodeUsage.Create().
|
||||
// OnConflict(
|
||||
// sql.ResolveWithNewValues(),
|
||||
// ).
|
||||
// Exec(ctx)
|
||||
func (u *PromoCodeUsageUpsertBulk) UpdateNewValues() *PromoCodeUsageUpsertBulk {
|
||||
u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
|
||||
return u
|
||||
}
|
||||
|
||||
// Ignore sets each column to itself in case of conflict.
|
||||
// Using this option is equivalent to using:
|
||||
//
|
||||
// client.PromoCodeUsage.Create().
|
||||
// OnConflict(sql.ResolveWithIgnore()).
|
||||
// Exec(ctx)
|
||||
func (u *PromoCodeUsageUpsertBulk) Ignore() *PromoCodeUsageUpsertBulk {
|
||||
u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
|
||||
return u
|
||||
}
|
||||
|
||||
// DoNothing configures the conflict_action to `DO NOTHING`.
|
||||
// Supported only by SQLite and PostgreSQL.
|
||||
func (u *PromoCodeUsageUpsertBulk) DoNothing() *PromoCodeUsageUpsertBulk {
|
||||
u.create.conflict = append(u.create.conflict, sql.DoNothing())
|
||||
return u
|
||||
}
|
||||
|
||||
// Update allows overriding fields `UPDATE` values. See the PromoCodeUsageCreateBulk.OnConflict
|
||||
// documentation for more info.
|
||||
func (u *PromoCodeUsageUpsertBulk) Update(set func(*PromoCodeUsageUpsert)) *PromoCodeUsageUpsertBulk {
|
||||
u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
|
||||
set(&PromoCodeUsageUpsert{UpdateSet: update})
|
||||
}))
|
||||
return u
|
||||
}
|
||||
|
||||
// SetPromoCodeID sets the "promo_code_id" field.
|
||||
func (u *PromoCodeUsageUpsertBulk) SetPromoCodeID(v int64) *PromoCodeUsageUpsertBulk {
|
||||
return u.Update(func(s *PromoCodeUsageUpsert) {
|
||||
s.SetPromoCodeID(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdatePromoCodeID sets the "promo_code_id" field to the value that was provided on create.
|
||||
func (u *PromoCodeUsageUpsertBulk) UpdatePromoCodeID() *PromoCodeUsageUpsertBulk {
|
||||
return u.Update(func(s *PromoCodeUsageUpsert) {
|
||||
s.UpdatePromoCodeID()
|
||||
})
|
||||
}
|
||||
|
||||
// SetUserID sets the "user_id" field.
|
||||
func (u *PromoCodeUsageUpsertBulk) SetUserID(v int64) *PromoCodeUsageUpsertBulk {
|
||||
return u.Update(func(s *PromoCodeUsageUpsert) {
|
||||
s.SetUserID(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateUserID sets the "user_id" field to the value that was provided on create.
|
||||
func (u *PromoCodeUsageUpsertBulk) UpdateUserID() *PromoCodeUsageUpsertBulk {
|
||||
return u.Update(func(s *PromoCodeUsageUpsert) {
|
||||
s.UpdateUserID()
|
||||
})
|
||||
}
|
||||
|
||||
// SetBonusAmount sets the "bonus_amount" field.
|
||||
func (u *PromoCodeUsageUpsertBulk) SetBonusAmount(v float64) *PromoCodeUsageUpsertBulk {
|
||||
return u.Update(func(s *PromoCodeUsageUpsert) {
|
||||
s.SetBonusAmount(v)
|
||||
})
|
||||
}
|
||||
|
||||
// AddBonusAmount adds v to the "bonus_amount" field.
|
||||
func (u *PromoCodeUsageUpsertBulk) AddBonusAmount(v float64) *PromoCodeUsageUpsertBulk {
|
||||
return u.Update(func(s *PromoCodeUsageUpsert) {
|
||||
s.AddBonusAmount(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateBonusAmount sets the "bonus_amount" field to the value that was provided on create.
|
||||
func (u *PromoCodeUsageUpsertBulk) UpdateBonusAmount() *PromoCodeUsageUpsertBulk {
|
||||
return u.Update(func(s *PromoCodeUsageUpsert) {
|
||||
s.UpdateBonusAmount()
|
||||
})
|
||||
}
|
||||
|
||||
// SetUsedAt sets the "used_at" field.
|
||||
func (u *PromoCodeUsageUpsertBulk) SetUsedAt(v time.Time) *PromoCodeUsageUpsertBulk {
|
||||
return u.Update(func(s *PromoCodeUsageUpsert) {
|
||||
s.SetUsedAt(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateUsedAt sets the "used_at" field to the value that was provided on create.
|
||||
func (u *PromoCodeUsageUpsertBulk) UpdateUsedAt() *PromoCodeUsageUpsertBulk {
|
||||
return u.Update(func(s *PromoCodeUsageUpsert) {
|
||||
s.UpdateUsedAt()
|
||||
})
|
||||
}
|
||||
|
||||
// Exec executes the query.
|
||||
func (u *PromoCodeUsageUpsertBulk) Exec(ctx context.Context) error {
|
||||
if u.create.err != nil {
|
||||
return u.create.err
|
||||
}
|
||||
for i, b := range u.create.builders {
|
||||
if len(b.conflict) != 0 {
|
||||
return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the PromoCodeUsageCreateBulk instead", i)
|
||||
}
|
||||
}
|
||||
if len(u.create.conflict) == 0 {
|
||||
return errors.New("ent: missing options for PromoCodeUsageCreateBulk.OnConflict")
|
||||
}
|
||||
return u.create.Exec(ctx)
|
||||
}
|
||||
|
||||
// ExecX is like Exec, but panics if an error occurs.
|
||||
func (u *PromoCodeUsageUpsertBulk) ExecX(ctx context.Context) {
|
||||
if err := u.create.Exec(ctx); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
88
backend/ent/promocodeusage_delete.go
Normal file
88
backend/ent/promocodeusage_delete.go
Normal file
@@ -0,0 +1,88 @@
|
||||
// Code generated by ent, DO NOT EDIT.
|
||||
|
||||
package ent
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"entgo.io/ent/dialect/sql"
|
||||
"entgo.io/ent/dialect/sql/sqlgraph"
|
||||
"entgo.io/ent/schema/field"
|
||||
"github.com/Wei-Shaw/sub2api/ent/predicate"
|
||||
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
|
||||
)
|
||||
|
||||
// PromoCodeUsageDelete is the builder for deleting a PromoCodeUsage entity.
|
||||
type PromoCodeUsageDelete struct {
|
||||
config
|
||||
hooks []Hook
|
||||
mutation *PromoCodeUsageMutation
|
||||
}
|
||||
|
||||
// Where appends a list predicates to the PromoCodeUsageDelete builder.
|
||||
func (_d *PromoCodeUsageDelete) Where(ps ...predicate.PromoCodeUsage) *PromoCodeUsageDelete {
|
||||
_d.mutation.Where(ps...)
|
||||
return _d
|
||||
}
|
||||
|
||||
// Exec executes the deletion query and returns how many vertices were deleted.
|
||||
func (_d *PromoCodeUsageDelete) Exec(ctx context.Context) (int, error) {
|
||||
return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks)
|
||||
}
|
||||
|
||||
// ExecX is like Exec, but panics if an error occurs.
|
||||
func (_d *PromoCodeUsageDelete) ExecX(ctx context.Context) int {
|
||||
n, err := _d.Exec(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func (_d *PromoCodeUsageDelete) sqlExec(ctx context.Context) (int, error) {
|
||||
_spec := sqlgraph.NewDeleteSpec(promocodeusage.Table, sqlgraph.NewFieldSpec(promocodeusage.FieldID, field.TypeInt64))
|
||||
if ps := _d.mutation.predicates; len(ps) > 0 {
|
||||
_spec.Predicate = func(selector *sql.Selector) {
|
||||
for i := range ps {
|
||||
ps[i](selector)
|
||||
}
|
||||
}
|
||||
}
|
||||
affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec)
|
||||
if err != nil && sqlgraph.IsConstraintError(err) {
|
||||
err = &ConstraintError{msg: err.Error(), wrap: err}
|
||||
}
|
||||
_d.mutation.done = true
|
||||
return affected, err
|
||||
}
|
||||
|
||||
// PromoCodeUsageDeleteOne is the builder for deleting a single PromoCodeUsage entity.
|
||||
type PromoCodeUsageDeleteOne struct {
|
||||
_d *PromoCodeUsageDelete
|
||||
}
|
||||
|
||||
// Where appends a list predicates to the PromoCodeUsageDelete builder.
|
||||
func (_d *PromoCodeUsageDeleteOne) Where(ps ...predicate.PromoCodeUsage) *PromoCodeUsageDeleteOne {
|
||||
_d._d.mutation.Where(ps...)
|
||||
return _d
|
||||
}
|
||||
|
||||
// Exec executes the deletion query.
|
||||
func (_d *PromoCodeUsageDeleteOne) Exec(ctx context.Context) error {
|
||||
n, err := _d._d.Exec(ctx)
|
||||
switch {
|
||||
case err != nil:
|
||||
return err
|
||||
case n == 0:
|
||||
return &NotFoundError{promocodeusage.Label}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// ExecX is like Exec, but panics if an error occurs.
|
||||
func (_d *PromoCodeUsageDeleteOne) ExecX(ctx context.Context) {
|
||||
if err := _d.Exec(ctx); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
718
backend/ent/promocodeusage_query.go
Normal file
718
backend/ent/promocodeusage_query.go
Normal file
@@ -0,0 +1,718 @@
|
||||
// Code generated by ent, DO NOT EDIT.
|
||||
|
||||
package ent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"entgo.io/ent"
|
||||
"entgo.io/ent/dialect"
|
||||
"entgo.io/ent/dialect/sql"
|
||||
"entgo.io/ent/dialect/sql/sqlgraph"
|
||||
"entgo.io/ent/schema/field"
|
||||
"github.com/Wei-Shaw/sub2api/ent/predicate"
|
||||
"github.com/Wei-Shaw/sub2api/ent/promocode"
|
||||
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
|
||||
"github.com/Wei-Shaw/sub2api/ent/user"
|
||||
)
|
||||
|
||||
// PromoCodeUsageQuery is the builder for querying PromoCodeUsage entities.
|
||||
type PromoCodeUsageQuery struct {
|
||||
config
|
||||
ctx *QueryContext
|
||||
order []promocodeusage.OrderOption
|
||||
inters []Interceptor
|
||||
predicates []predicate.PromoCodeUsage
|
||||
withPromoCode *PromoCodeQuery
|
||||
withUser *UserQuery
|
||||
modifiers []func(*sql.Selector)
|
||||
// intermediate query (i.e. traversal path).
|
||||
sql *sql.Selector
|
||||
path func(context.Context) (*sql.Selector, error)
|
||||
}
|
||||
|
||||
// Where adds a new predicate for the PromoCodeUsageQuery builder.
|
||||
func (_q *PromoCodeUsageQuery) Where(ps ...predicate.PromoCodeUsage) *PromoCodeUsageQuery {
|
||||
_q.predicates = append(_q.predicates, ps...)
|
||||
return _q
|
||||
}
|
||||
|
||||
// Limit the number of records to be returned by this query.
|
||||
func (_q *PromoCodeUsageQuery) Limit(limit int) *PromoCodeUsageQuery {
|
||||
_q.ctx.Limit = &limit
|
||||
return _q
|
||||
}
|
||||
|
||||
// Offset to start from.
|
||||
func (_q *PromoCodeUsageQuery) Offset(offset int) *PromoCodeUsageQuery {
|
||||
_q.ctx.Offset = &offset
|
||||
return _q
|
||||
}
|
||||
|
||||
// Unique configures the query builder to filter duplicate records on query.
|
||||
// By default, unique is set to true, and can be disabled using this method.
|
||||
func (_q *PromoCodeUsageQuery) Unique(unique bool) *PromoCodeUsageQuery {
|
||||
_q.ctx.Unique = &unique
|
||||
return _q
|
||||
}
|
||||
|
||||
// Order specifies how the records should be ordered.
|
||||
func (_q *PromoCodeUsageQuery) Order(o ...promocodeusage.OrderOption) *PromoCodeUsageQuery {
|
||||
_q.order = append(_q.order, o...)
|
||||
return _q
|
||||
}
|
||||
|
||||
// QueryPromoCode chains the current query on the "promo_code" edge.
|
||||
func (_q *PromoCodeUsageQuery) QueryPromoCode() *PromoCodeQuery {
|
||||
query := (&PromoCodeClient{config: _q.config}).Query()
|
||||
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
|
||||
if err := _q.prepareQuery(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
selector := _q.sqlQuery(ctx)
|
||||
if err := selector.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
step := sqlgraph.NewStep(
|
||||
sqlgraph.From(promocodeusage.Table, promocodeusage.FieldID, selector),
|
||||
sqlgraph.To(promocode.Table, promocode.FieldID),
|
||||
sqlgraph.Edge(sqlgraph.M2O, true, promocodeusage.PromoCodeTable, promocodeusage.PromoCodeColumn),
|
||||
)
|
||||
fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
|
||||
return fromU, nil
|
||||
}
|
||||
return query
|
||||
}
|
||||
|
||||
// QueryUser chains the current query on the "user" edge.
|
||||
func (_q *PromoCodeUsageQuery) QueryUser() *UserQuery {
|
||||
query := (&UserClient{config: _q.config}).Query()
|
||||
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
|
||||
if err := _q.prepareQuery(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
selector := _q.sqlQuery(ctx)
|
||||
if err := selector.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
step := sqlgraph.NewStep(
|
||||
sqlgraph.From(promocodeusage.Table, promocodeusage.FieldID, selector),
|
||||
sqlgraph.To(user.Table, user.FieldID),
|
||||
sqlgraph.Edge(sqlgraph.M2O, true, promocodeusage.UserTable, promocodeusage.UserColumn),
|
||||
)
|
||||
fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
|
||||
return fromU, nil
|
||||
}
|
||||
return query
|
||||
}
|
||||
|
||||
// First returns the first PromoCodeUsage entity from the query.
|
||||
// Returns a *NotFoundError when no PromoCodeUsage was found.
|
||||
func (_q *PromoCodeUsageQuery) First(ctx context.Context) (*PromoCodeUsage, error) {
|
||||
nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(nodes) == 0 {
|
||||
return nil, &NotFoundError{promocodeusage.Label}
|
||||
}
|
||||
return nodes[0], nil
|
||||
}
|
||||
|
||||
// FirstX is like First, but panics if an error occurs.
|
||||
func (_q *PromoCodeUsageQuery) FirstX(ctx context.Context) *PromoCodeUsage {
|
||||
node, err := _q.First(ctx)
|
||||
if err != nil && !IsNotFound(err) {
|
||||
panic(err)
|
||||
}
|
||||
return node
|
||||
}
|
||||
|
||||
// FirstID returns the first PromoCodeUsage ID from the query.
|
||||
// Returns a *NotFoundError when no PromoCodeUsage ID was found.
|
||||
func (_q *PromoCodeUsageQuery) FirstID(ctx context.Context) (id int64, err error) {
|
||||
var ids []int64
|
||||
if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil {
|
||||
return
|
||||
}
|
||||
if len(ids) == 0 {
|
||||
err = &NotFoundError{promocodeusage.Label}
|
||||
return
|
||||
}
|
||||
return ids[0], nil
|
||||
}
|
||||
|
||||
// FirstIDX is like FirstID, but panics if an error occurs.
|
||||
func (_q *PromoCodeUsageQuery) FirstIDX(ctx context.Context) int64 {
|
||||
id, err := _q.FirstID(ctx)
|
||||
if err != nil && !IsNotFound(err) {
|
||||
panic(err)
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
// Only returns a single PromoCodeUsage entity found by the query, ensuring it only returns one.
|
||||
// Returns a *NotSingularError when more than one PromoCodeUsage entity is found.
|
||||
// Returns a *NotFoundError when no PromoCodeUsage entities are found.
|
||||
func (_q *PromoCodeUsageQuery) Only(ctx context.Context) (*PromoCodeUsage, error) {
|
||||
nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch len(nodes) {
|
||||
case 1:
|
||||
return nodes[0], nil
|
||||
case 0:
|
||||
return nil, &NotFoundError{promocodeusage.Label}
|
||||
default:
|
||||
return nil, &NotSingularError{promocodeusage.Label}
|
||||
}
|
||||
}
|
||||
|
||||
// OnlyX is like Only, but panics if an error occurs.
|
||||
func (_q *PromoCodeUsageQuery) OnlyX(ctx context.Context) *PromoCodeUsage {
|
||||
node, err := _q.Only(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return node
|
||||
}
|
||||
|
||||
// OnlyID is like Only, but returns the only PromoCodeUsage ID in the query.
|
||||
// Returns a *NotSingularError when more than one PromoCodeUsage ID is found.
|
||||
// Returns a *NotFoundError when no entities are found.
|
||||
func (_q *PromoCodeUsageQuery) OnlyID(ctx context.Context) (id int64, err error) {
|
||||
var ids []int64
|
||||
if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil {
|
||||
return
|
||||
}
|
||||
switch len(ids) {
|
||||
case 1:
|
||||
id = ids[0]
|
||||
case 0:
|
||||
err = &NotFoundError{promocodeusage.Label}
|
||||
default:
|
||||
err = &NotSingularError{promocodeusage.Label}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// OnlyIDX is like OnlyID, but panics if an error occurs.
|
||||
func (_q *PromoCodeUsageQuery) OnlyIDX(ctx context.Context) int64 {
|
||||
id, err := _q.OnlyID(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
// All executes the query and returns a list of PromoCodeUsages.
|
||||
func (_q *PromoCodeUsageQuery) All(ctx context.Context) ([]*PromoCodeUsage, error) {
|
||||
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll)
|
||||
if err := _q.prepareQuery(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
qr := querierAll[[]*PromoCodeUsage, *PromoCodeUsageQuery]()
|
||||
return withInterceptors[[]*PromoCodeUsage](ctx, _q, qr, _q.inters)
|
||||
}
|
||||
|
||||
// AllX is like All, but panics if an error occurs.
|
||||
func (_q *PromoCodeUsageQuery) AllX(ctx context.Context) []*PromoCodeUsage {
|
||||
nodes, err := _q.All(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return nodes
|
||||
}
|
||||
|
||||
// IDs executes the query and returns a list of PromoCodeUsage IDs.
|
||||
func (_q *PromoCodeUsageQuery) IDs(ctx context.Context) (ids []int64, err error) {
|
||||
if _q.ctx.Unique == nil && _q.path != nil {
|
||||
_q.Unique(true)
|
||||
}
|
||||
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs)
|
||||
if err = _q.Select(promocodeusage.FieldID).Scan(ctx, &ids); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
// IDsX is like IDs, but panics if an error occurs.
|
||||
func (_q *PromoCodeUsageQuery) IDsX(ctx context.Context) []int64 {
|
||||
ids, err := _q.IDs(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
// Count returns the count of the given query.
|
||||
func (_q *PromoCodeUsageQuery) Count(ctx context.Context) (int, error) {
|
||||
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount)
|
||||
if err := _q.prepareQuery(ctx); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return withInterceptors[int](ctx, _q, querierCount[*PromoCodeUsageQuery](), _q.inters)
|
||||
}
|
||||
|
||||
// CountX is like Count, but panics if an error occurs.
|
||||
func (_q *PromoCodeUsageQuery) CountX(ctx context.Context) int {
|
||||
count, err := _q.Count(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
// Exist returns true if the query has elements in the graph.
|
||||
func (_q *PromoCodeUsageQuery) Exist(ctx context.Context) (bool, error) {
|
||||
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist)
|
||||
switch _, err := _q.FirstID(ctx); {
|
||||
case IsNotFound(err):
|
||||
return false, nil
|
||||
case err != nil:
|
||||
return false, fmt.Errorf("ent: check existence: %w", err)
|
||||
default:
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
// ExistX is like Exist, but panics if an error occurs.
|
||||
func (_q *PromoCodeUsageQuery) ExistX(ctx context.Context) bool {
|
||||
exist, err := _q.Exist(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return exist
|
||||
}
|
||||
|
||||
// Clone returns a duplicate of the PromoCodeUsageQuery builder, including all associated steps. It can be
|
||||
// used to prepare common query builders and use them differently after the clone is made.
|
||||
func (_q *PromoCodeUsageQuery) Clone() *PromoCodeUsageQuery {
|
||||
if _q == nil {
|
||||
return nil
|
||||
}
|
||||
return &PromoCodeUsageQuery{
|
||||
config: _q.config,
|
||||
ctx: _q.ctx.Clone(),
|
||||
order: append([]promocodeusage.OrderOption{}, _q.order...),
|
||||
inters: append([]Interceptor{}, _q.inters...),
|
||||
predicates: append([]predicate.PromoCodeUsage{}, _q.predicates...),
|
||||
withPromoCode: _q.withPromoCode.Clone(),
|
||||
withUser: _q.withUser.Clone(),
|
||||
// clone intermediate query.
|
||||
sql: _q.sql.Clone(),
|
||||
path: _q.path,
|
||||
}
|
||||
}
|
||||
|
||||
// WithPromoCode tells the query-builder to eager-load the nodes that are connected to
|
||||
// the "promo_code" edge. The optional arguments are used to configure the query builder of the edge.
|
||||
func (_q *PromoCodeUsageQuery) WithPromoCode(opts ...func(*PromoCodeQuery)) *PromoCodeUsageQuery {
|
||||
query := (&PromoCodeClient{config: _q.config}).Query()
|
||||
for _, opt := range opts {
|
||||
opt(query)
|
||||
}
|
||||
_q.withPromoCode = query
|
||||
return _q
|
||||
}
|
||||
|
||||
// WithUser tells the query-builder to eager-load the nodes that are connected to
|
||||
// the "user" edge. The optional arguments are used to configure the query builder of the edge.
|
||||
func (_q *PromoCodeUsageQuery) WithUser(opts ...func(*UserQuery)) *PromoCodeUsageQuery {
|
||||
query := (&UserClient{config: _q.config}).Query()
|
||||
for _, opt := range opts {
|
||||
opt(query)
|
||||
}
|
||||
_q.withUser = query
|
||||
return _q
|
||||
}
|
||||
|
||||
// GroupBy is used to group vertices by one or more fields/columns.
|
||||
// It is often used with aggregate functions, like: count, max, mean, min, sum.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// var v []struct {
|
||||
// PromoCodeID int64 `json:"promo_code_id,omitempty"`
|
||||
// Count int `json:"count,omitempty"`
|
||||
// }
|
||||
//
|
||||
// client.PromoCodeUsage.Query().
|
||||
// GroupBy(promocodeusage.FieldPromoCodeID).
|
||||
// Aggregate(ent.Count()).
|
||||
// Scan(ctx, &v)
|
||||
func (_q *PromoCodeUsageQuery) GroupBy(field string, fields ...string) *PromoCodeUsageGroupBy {
|
||||
_q.ctx.Fields = append([]string{field}, fields...)
|
||||
grbuild := &PromoCodeUsageGroupBy{build: _q}
|
||||
grbuild.flds = &_q.ctx.Fields
|
||||
grbuild.label = promocodeusage.Label
|
||||
grbuild.scan = grbuild.Scan
|
||||
return grbuild
|
||||
}
|
||||
|
||||
// Select allows the selection one or more fields/columns for the given query,
|
||||
// instead of selecting all fields in the entity.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// var v []struct {
|
||||
// PromoCodeID int64 `json:"promo_code_id,omitempty"`
|
||||
// }
|
||||
//
|
||||
// client.PromoCodeUsage.Query().
|
||||
// Select(promocodeusage.FieldPromoCodeID).
|
||||
// Scan(ctx, &v)
|
||||
func (_q *PromoCodeUsageQuery) Select(fields ...string) *PromoCodeUsageSelect {
|
||||
_q.ctx.Fields = append(_q.ctx.Fields, fields...)
|
||||
sbuild := &PromoCodeUsageSelect{PromoCodeUsageQuery: _q}
|
||||
sbuild.label = promocodeusage.Label
|
||||
sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan
|
||||
return sbuild
|
||||
}
|
||||
|
||||
// Aggregate returns a PromoCodeUsageSelect configured with the given aggregations.
|
||||
func (_q *PromoCodeUsageQuery) Aggregate(fns ...AggregateFunc) *PromoCodeUsageSelect {
|
||||
return _q.Select().Aggregate(fns...)
|
||||
}
|
||||
|
||||
func (_q *PromoCodeUsageQuery) prepareQuery(ctx context.Context) error {
|
||||
for _, inter := range _q.inters {
|
||||
if inter == nil {
|
||||
return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
|
||||
}
|
||||
if trv, ok := inter.(Traverser); ok {
|
||||
if err := trv.Traverse(ctx, _q); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, f := range _q.ctx.Fields {
|
||||
if !promocodeusage.ValidColumn(f) {
|
||||
return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
|
||||
}
|
||||
}
|
||||
if _q.path != nil {
|
||||
prev, err := _q.path(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_q.sql = prev
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (_q *PromoCodeUsageQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*PromoCodeUsage, error) {
|
||||
var (
|
||||
nodes = []*PromoCodeUsage{}
|
||||
_spec = _q.querySpec()
|
||||
loadedTypes = [2]bool{
|
||||
_q.withPromoCode != nil,
|
||||
_q.withUser != nil,
|
||||
}
|
||||
)
|
||||
_spec.ScanValues = func(columns []string) ([]any, error) {
|
||||
return (*PromoCodeUsage).scanValues(nil, columns)
|
||||
}
|
||||
_spec.Assign = func(columns []string, values []any) error {
|
||||
node := &PromoCodeUsage{config: _q.config}
|
||||
nodes = append(nodes, node)
|
||||
node.Edges.loadedTypes = loadedTypes
|
||||
return node.assignValues(columns, values)
|
||||
}
|
||||
if len(_q.modifiers) > 0 {
|
||||
_spec.Modifiers = _q.modifiers
|
||||
}
|
||||
for i := range hooks {
|
||||
hooks[i](ctx, _spec)
|
||||
}
|
||||
if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(nodes) == 0 {
|
||||
return nodes, nil
|
||||
}
|
||||
if query := _q.withPromoCode; query != nil {
|
||||
if err := _q.loadPromoCode(ctx, query, nodes, nil,
|
||||
func(n *PromoCodeUsage, e *PromoCode) { n.Edges.PromoCode = e }); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if query := _q.withUser; query != nil {
|
||||
if err := _q.loadUser(ctx, query, nodes, nil,
|
||||
func(n *PromoCodeUsage, e *User) { n.Edges.User = e }); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return nodes, nil
|
||||
}
|
||||
|
||||
func (_q *PromoCodeUsageQuery) loadPromoCode(ctx context.Context, query *PromoCodeQuery, nodes []*PromoCodeUsage, init func(*PromoCodeUsage), assign func(*PromoCodeUsage, *PromoCode)) error {
|
||||
ids := make([]int64, 0, len(nodes))
|
||||
nodeids := make(map[int64][]*PromoCodeUsage)
|
||||
for i := range nodes {
|
||||
fk := nodes[i].PromoCodeID
|
||||
if _, ok := nodeids[fk]; !ok {
|
||||
ids = append(ids, fk)
|
||||
}
|
||||
nodeids[fk] = append(nodeids[fk], nodes[i])
|
||||
}
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
query.Where(promocode.IDIn(ids...))
|
||||
neighbors, err := query.All(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, n := range neighbors {
|
||||
nodes, ok := nodeids[n.ID]
|
||||
if !ok {
|
||||
return fmt.Errorf(`unexpected foreign-key "promo_code_id" returned %v`, n.ID)
|
||||
}
|
||||
for i := range nodes {
|
||||
assign(nodes[i], n)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (_q *PromoCodeUsageQuery) loadUser(ctx context.Context, query *UserQuery, nodes []*PromoCodeUsage, init func(*PromoCodeUsage), assign func(*PromoCodeUsage, *User)) error {
|
||||
ids := make([]int64, 0, len(nodes))
|
||||
nodeids := make(map[int64][]*PromoCodeUsage)
|
||||
for i := range nodes {
|
||||
fk := nodes[i].UserID
|
||||
if _, ok := nodeids[fk]; !ok {
|
||||
ids = append(ids, fk)
|
||||
}
|
||||
nodeids[fk] = append(nodeids[fk], nodes[i])
|
||||
}
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
query.Where(user.IDIn(ids...))
|
||||
neighbors, err := query.All(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, n := range neighbors {
|
||||
nodes, ok := nodeids[n.ID]
|
||||
if !ok {
|
||||
return fmt.Errorf(`unexpected foreign-key "user_id" returned %v`, n.ID)
|
||||
}
|
||||
for i := range nodes {
|
||||
assign(nodes[i], n)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (_q *PromoCodeUsageQuery) sqlCount(ctx context.Context) (int, error) {
|
||||
_spec := _q.querySpec()
|
||||
if len(_q.modifiers) > 0 {
|
||||
_spec.Modifiers = _q.modifiers
|
||||
}
|
||||
_spec.Node.Columns = _q.ctx.Fields
|
||||
if len(_q.ctx.Fields) > 0 {
|
||||
_spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
|
||||
}
|
||||
return sqlgraph.CountNodes(ctx, _q.driver, _spec)
|
||||
}
|
||||
|
||||
func (_q *PromoCodeUsageQuery) querySpec() *sqlgraph.QuerySpec {
|
||||
_spec := sqlgraph.NewQuerySpec(promocodeusage.Table, promocodeusage.Columns, sqlgraph.NewFieldSpec(promocodeusage.FieldID, field.TypeInt64))
|
||||
_spec.From = _q.sql
|
||||
if unique := _q.ctx.Unique; unique != nil {
|
||||
_spec.Unique = *unique
|
||||
} else if _q.path != nil {
|
||||
_spec.Unique = true
|
||||
}
|
||||
if fields := _q.ctx.Fields; len(fields) > 0 {
|
||||
_spec.Node.Columns = make([]string, 0, len(fields))
|
||||
_spec.Node.Columns = append(_spec.Node.Columns, promocodeusage.FieldID)
|
||||
for i := range fields {
|
||||
if fields[i] != promocodeusage.FieldID {
|
||||
_spec.Node.Columns = append(_spec.Node.Columns, fields[i])
|
||||
}
|
||||
}
|
||||
if _q.withPromoCode != nil {
|
||||
_spec.Node.AddColumnOnce(promocodeusage.FieldPromoCodeID)
|
||||
}
|
||||
if _q.withUser != nil {
|
||||
_spec.Node.AddColumnOnce(promocodeusage.FieldUserID)
|
||||
}
|
||||
}
|
||||
if ps := _q.predicates; len(ps) > 0 {
|
||||
_spec.Predicate = func(selector *sql.Selector) {
|
||||
for i := range ps {
|
||||
ps[i](selector)
|
||||
}
|
||||
}
|
||||
}
|
||||
if limit := _q.ctx.Limit; limit != nil {
|
||||
_spec.Limit = *limit
|
||||
}
|
||||
if offset := _q.ctx.Offset; offset != nil {
|
||||
_spec.Offset = *offset
|
||||
}
|
||||
if ps := _q.order; len(ps) > 0 {
|
||||
_spec.Order = func(selector *sql.Selector) {
|
||||
for i := range ps {
|
||||
ps[i](selector)
|
||||
}
|
||||
}
|
||||
}
|
||||
return _spec
|
||||
}
|
||||
|
||||
func (_q *PromoCodeUsageQuery) sqlQuery(ctx context.Context) *sql.Selector {
|
||||
builder := sql.Dialect(_q.driver.Dialect())
|
||||
t1 := builder.Table(promocodeusage.Table)
|
||||
columns := _q.ctx.Fields
|
||||
if len(columns) == 0 {
|
||||
columns = promocodeusage.Columns
|
||||
}
|
||||
selector := builder.Select(t1.Columns(columns...)...).From(t1)
|
||||
if _q.sql != nil {
|
||||
selector = _q.sql
|
||||
selector.Select(selector.Columns(columns...)...)
|
||||
}
|
||||
if _q.ctx.Unique != nil && *_q.ctx.Unique {
|
||||
selector.Distinct()
|
||||
}
|
||||
for _, m := range _q.modifiers {
|
||||
m(selector)
|
||||
}
|
||||
for _, p := range _q.predicates {
|
||||
p(selector)
|
||||
}
|
||||
for _, p := range _q.order {
|
||||
p(selector)
|
||||
}
|
||||
if offset := _q.ctx.Offset; offset != nil {
|
||||
// limit is mandatory for offset clause. We start
|
||||
// with default value, and override it below if needed.
|
||||
selector.Offset(*offset).Limit(math.MaxInt32)
|
||||
}
|
||||
if limit := _q.ctx.Limit; limit != nil {
|
||||
selector.Limit(*limit)
|
||||
}
|
||||
return selector
|
||||
}
|
||||
|
||||
// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
|
||||
// updated, deleted or "selected ... for update" by other sessions, until the transaction is
|
||||
// either committed or rolled-back.
|
||||
func (_q *PromoCodeUsageQuery) ForUpdate(opts ...sql.LockOption) *PromoCodeUsageQuery {
|
||||
if _q.driver.Dialect() == dialect.Postgres {
|
||||
_q.Unique(false)
|
||||
}
|
||||
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
|
||||
s.ForUpdate(opts...)
|
||||
})
|
||||
return _q
|
||||
}
|
||||
|
||||
// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
|
||||
// on any rows that are read. Other sessions can read the rows, but cannot modify them
|
||||
// until your transaction commits.
|
||||
func (_q *PromoCodeUsageQuery) ForShare(opts ...sql.LockOption) *PromoCodeUsageQuery {
|
||||
if _q.driver.Dialect() == dialect.Postgres {
|
||||
_q.Unique(false)
|
||||
}
|
||||
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
|
||||
s.ForShare(opts...)
|
||||
})
|
||||
return _q
|
||||
}
|
||||
|
||||
// PromoCodeUsageGroupBy is the group-by builder for PromoCodeUsage entities.
|
||||
type PromoCodeUsageGroupBy struct {
|
||||
selector
|
||||
build *PromoCodeUsageQuery
|
||||
}
|
||||
|
||||
// Aggregate adds the given aggregation functions to the group-by query.
|
||||
func (_g *PromoCodeUsageGroupBy) Aggregate(fns ...AggregateFunc) *PromoCodeUsageGroupBy {
|
||||
_g.fns = append(_g.fns, fns...)
|
||||
return _g
|
||||
}
|
||||
|
||||
// Scan applies the selector query and scans the result into the given value.
|
||||
func (_g *PromoCodeUsageGroupBy) Scan(ctx context.Context, v any) error {
|
||||
ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy)
|
||||
if err := _g.build.prepareQuery(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
return scanWithInterceptors[*PromoCodeUsageQuery, *PromoCodeUsageGroupBy](ctx, _g.build, _g, _g.build.inters, v)
|
||||
}
|
||||
|
||||
func (_g *PromoCodeUsageGroupBy) sqlScan(ctx context.Context, root *PromoCodeUsageQuery, v any) error {
|
||||
selector := root.sqlQuery(ctx).Select()
|
||||
aggregation := make([]string, 0, len(_g.fns))
|
||||
for _, fn := range _g.fns {
|
||||
aggregation = append(aggregation, fn(selector))
|
||||
}
|
||||
if len(selector.SelectedColumns()) == 0 {
|
||||
columns := make([]string, 0, len(*_g.flds)+len(_g.fns))
|
||||
for _, f := range *_g.flds {
|
||||
columns = append(columns, selector.C(f))
|
||||
}
|
||||
columns = append(columns, aggregation...)
|
||||
selector.Select(columns...)
|
||||
}
|
||||
selector.GroupBy(selector.Columns(*_g.flds...)...)
|
||||
if err := selector.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
rows := &sql.Rows{}
|
||||
query, args := selector.Query()
|
||||
if err := _g.build.driver.Query(ctx, query, args, rows); err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
return sql.ScanSlice(rows, v)
|
||||
}
|
||||
|
||||
// PromoCodeUsageSelect is the builder for selecting fields of PromoCodeUsage entities.
|
||||
type PromoCodeUsageSelect struct {
|
||||
*PromoCodeUsageQuery
|
||||
selector
|
||||
}
|
||||
|
||||
// Aggregate adds the given aggregation functions to the selector query.
|
||||
func (_s *PromoCodeUsageSelect) Aggregate(fns ...AggregateFunc) *PromoCodeUsageSelect {
|
||||
_s.fns = append(_s.fns, fns...)
|
||||
return _s
|
||||
}
|
||||
|
||||
// Scan applies the selector query and scans the result into the given value.
|
||||
func (_s *PromoCodeUsageSelect) Scan(ctx context.Context, v any) error {
|
||||
ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect)
|
||||
if err := _s.prepareQuery(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
return scanWithInterceptors[*PromoCodeUsageQuery, *PromoCodeUsageSelect](ctx, _s.PromoCodeUsageQuery, _s, _s.inters, v)
|
||||
}
|
||||
|
||||
func (_s *PromoCodeUsageSelect) sqlScan(ctx context.Context, root *PromoCodeUsageQuery, v any) error {
|
||||
selector := root.sqlQuery(ctx)
|
||||
aggregation := make([]string, 0, len(_s.fns))
|
||||
for _, fn := range _s.fns {
|
||||
aggregation = append(aggregation, fn(selector))
|
||||
}
|
||||
switch n := len(*_s.selector.flds); {
|
||||
case n == 0 && len(aggregation) > 0:
|
||||
selector.Select(aggregation...)
|
||||
case n != 0 && len(aggregation) > 0:
|
||||
selector.AppendSelect(aggregation...)
|
||||
}
|
||||
rows := &sql.Rows{}
|
||||
query, args := selector.Query()
|
||||
if err := _s.driver.Query(ctx, query, args, rows); err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
return sql.ScanSlice(rows, v)
|
||||
}
|
||||
510
backend/ent/promocodeusage_update.go
Normal file
510
backend/ent/promocodeusage_update.go
Normal file
@@ -0,0 +1,510 @@
|
||||
// Code generated by ent, DO NOT EDIT.
|
||||
|
||||
package ent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"entgo.io/ent/dialect/sql"
|
||||
"entgo.io/ent/dialect/sql/sqlgraph"
|
||||
"entgo.io/ent/schema/field"
|
||||
"github.com/Wei-Shaw/sub2api/ent/predicate"
|
||||
"github.com/Wei-Shaw/sub2api/ent/promocode"
|
||||
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
|
||||
"github.com/Wei-Shaw/sub2api/ent/user"
|
||||
)
|
||||
|
||||
// PromoCodeUsageUpdate is the builder for updating PromoCodeUsage entities.
|
||||
type PromoCodeUsageUpdate struct {
|
||||
config
|
||||
hooks []Hook
|
||||
mutation *PromoCodeUsageMutation
|
||||
}
|
||||
|
||||
// Where appends a list predicates to the PromoCodeUsageUpdate builder.
|
||||
func (_u *PromoCodeUsageUpdate) Where(ps ...predicate.PromoCodeUsage) *PromoCodeUsageUpdate {
|
||||
_u.mutation.Where(ps...)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetPromoCodeID sets the "promo_code_id" field.
|
||||
func (_u *PromoCodeUsageUpdate) SetPromoCodeID(v int64) *PromoCodeUsageUpdate {
|
||||
_u.mutation.SetPromoCodeID(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillablePromoCodeID sets the "promo_code_id" field if the given value is not nil.
|
||||
func (_u *PromoCodeUsageUpdate) SetNillablePromoCodeID(v *int64) *PromoCodeUsageUpdate {
|
||||
if v != nil {
|
||||
_u.SetPromoCodeID(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetUserID sets the "user_id" field.
|
||||
func (_u *PromoCodeUsageUpdate) SetUserID(v int64) *PromoCodeUsageUpdate {
|
||||
_u.mutation.SetUserID(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableUserID sets the "user_id" field if the given value is not nil.
|
||||
func (_u *PromoCodeUsageUpdate) SetNillableUserID(v *int64) *PromoCodeUsageUpdate {
|
||||
if v != nil {
|
||||
_u.SetUserID(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetBonusAmount sets the "bonus_amount" field.
|
||||
func (_u *PromoCodeUsageUpdate) SetBonusAmount(v float64) *PromoCodeUsageUpdate {
|
||||
_u.mutation.ResetBonusAmount()
|
||||
_u.mutation.SetBonusAmount(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableBonusAmount sets the "bonus_amount" field if the given value is not nil.
|
||||
func (_u *PromoCodeUsageUpdate) SetNillableBonusAmount(v *float64) *PromoCodeUsageUpdate {
|
||||
if v != nil {
|
||||
_u.SetBonusAmount(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddBonusAmount adds value to the "bonus_amount" field.
|
||||
func (_u *PromoCodeUsageUpdate) AddBonusAmount(v float64) *PromoCodeUsageUpdate {
|
||||
_u.mutation.AddBonusAmount(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetUsedAt sets the "used_at" field.
|
||||
func (_u *PromoCodeUsageUpdate) SetUsedAt(v time.Time) *PromoCodeUsageUpdate {
|
||||
_u.mutation.SetUsedAt(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableUsedAt sets the "used_at" field if the given value is not nil.
|
||||
func (_u *PromoCodeUsageUpdate) SetNillableUsedAt(v *time.Time) *PromoCodeUsageUpdate {
|
||||
if v != nil {
|
||||
_u.SetUsedAt(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetPromoCode sets the "promo_code" edge to the PromoCode entity.
|
||||
func (_u *PromoCodeUsageUpdate) SetPromoCode(v *PromoCode) *PromoCodeUsageUpdate {
|
||||
return _u.SetPromoCodeID(v.ID)
|
||||
}
|
||||
|
||||
// SetUser sets the "user" edge to the User entity.
|
||||
func (_u *PromoCodeUsageUpdate) SetUser(v *User) *PromoCodeUsageUpdate {
|
||||
return _u.SetUserID(v.ID)
|
||||
}
|
||||
|
||||
// Mutation returns the PromoCodeUsageMutation object of the builder.
|
||||
func (_u *PromoCodeUsageUpdate) Mutation() *PromoCodeUsageMutation {
|
||||
return _u.mutation
|
||||
}
|
||||
|
||||
// ClearPromoCode clears the "promo_code" edge to the PromoCode entity.
|
||||
func (_u *PromoCodeUsageUpdate) ClearPromoCode() *PromoCodeUsageUpdate {
|
||||
_u.mutation.ClearPromoCode()
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearUser clears the "user" edge to the User entity.
|
||||
func (_u *PromoCodeUsageUpdate) ClearUser() *PromoCodeUsageUpdate {
|
||||
_u.mutation.ClearUser()
|
||||
return _u
|
||||
}
|
||||
|
||||
// Save executes the query and returns the number of nodes affected by the update operation.
|
||||
func (_u *PromoCodeUsageUpdate) Save(ctx context.Context) (int, error) {
|
||||
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
|
||||
}
|
||||
|
||||
// SaveX is like Save, but panics if an error occurs.
|
||||
func (_u *PromoCodeUsageUpdate) SaveX(ctx context.Context) int {
|
||||
affected, err := _u.Save(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return affected
|
||||
}
|
||||
|
||||
// Exec executes the query.
|
||||
func (_u *PromoCodeUsageUpdate) Exec(ctx context.Context) error {
|
||||
_, err := _u.Save(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
// ExecX is like Exec, but panics if an error occurs.
|
||||
func (_u *PromoCodeUsageUpdate) ExecX(ctx context.Context) {
|
||||
if err := _u.Exec(ctx); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
// check runs all checks and user-defined validators on the builder.
|
||||
func (_u *PromoCodeUsageUpdate) check() error {
|
||||
if _u.mutation.PromoCodeCleared() && len(_u.mutation.PromoCodeIDs()) > 0 {
|
||||
return errors.New(`ent: clearing a required unique edge "PromoCodeUsage.promo_code"`)
|
||||
}
|
||||
if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 {
|
||||
return errors.New(`ent: clearing a required unique edge "PromoCodeUsage.user"`)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (_u *PromoCodeUsageUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
||||
if err := _u.check(); err != nil {
|
||||
return _node, err
|
||||
}
|
||||
_spec := sqlgraph.NewUpdateSpec(promocodeusage.Table, promocodeusage.Columns, sqlgraph.NewFieldSpec(promocodeusage.FieldID, field.TypeInt64))
|
||||
if ps := _u.mutation.predicates; len(ps) > 0 {
|
||||
_spec.Predicate = func(selector *sql.Selector) {
|
||||
for i := range ps {
|
||||
ps[i](selector)
|
||||
}
|
||||
}
|
||||
}
|
||||
if value, ok := _u.mutation.BonusAmount(); ok {
|
||||
_spec.SetField(promocodeusage.FieldBonusAmount, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedBonusAmount(); ok {
|
||||
_spec.AddField(promocodeusage.FieldBonusAmount, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.UsedAt(); ok {
|
||||
_spec.SetField(promocodeusage.FieldUsedAt, field.TypeTime, value)
|
||||
}
|
||||
if _u.mutation.PromoCodeCleared() {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.M2O,
|
||||
Inverse: true,
|
||||
Table: promocodeusage.PromoCodeTable,
|
||||
Columns: []string{promocodeusage.PromoCodeColumn},
|
||||
Bidi: false,
|
||||
Target: &sqlgraph.EdgeTarget{
|
||||
IDSpec: sqlgraph.NewFieldSpec(promocode.FieldID, field.TypeInt64),
|
||||
},
|
||||
}
|
||||
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
|
||||
}
|
||||
if nodes := _u.mutation.PromoCodeIDs(); len(nodes) > 0 {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.M2O,
|
||||
Inverse: true,
|
||||
Table: promocodeusage.PromoCodeTable,
|
||||
Columns: []string{promocodeusage.PromoCodeColumn},
|
||||
Bidi: false,
|
||||
Target: &sqlgraph.EdgeTarget{
|
||||
IDSpec: sqlgraph.NewFieldSpec(promocode.FieldID, field.TypeInt64),
|
||||
},
|
||||
}
|
||||
for _, k := range nodes {
|
||||
edge.Target.Nodes = append(edge.Target.Nodes, k)
|
||||
}
|
||||
_spec.Edges.Add = append(_spec.Edges.Add, edge)
|
||||
}
|
||||
if _u.mutation.UserCleared() {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.M2O,
|
||||
Inverse: true,
|
||||
Table: promocodeusage.UserTable,
|
||||
Columns: []string{promocodeusage.UserColumn},
|
||||
Bidi: false,
|
||||
Target: &sqlgraph.EdgeTarget{
|
||||
IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
|
||||
},
|
||||
}
|
||||
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
|
||||
}
|
||||
if nodes := _u.mutation.UserIDs(); len(nodes) > 0 {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.M2O,
|
||||
Inverse: true,
|
||||
Table: promocodeusage.UserTable,
|
||||
Columns: []string{promocodeusage.UserColumn},
|
||||
Bidi: false,
|
||||
Target: &sqlgraph.EdgeTarget{
|
||||
IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
|
||||
},
|
||||
}
|
||||
for _, k := range nodes {
|
||||
edge.Target.Nodes = append(edge.Target.Nodes, k)
|
||||
}
|
||||
_spec.Edges.Add = append(_spec.Edges.Add, edge)
|
||||
}
|
||||
if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
|
||||
if _, ok := err.(*sqlgraph.NotFoundError); ok {
|
||||
err = &NotFoundError{promocodeusage.Label}
|
||||
} else if sqlgraph.IsConstraintError(err) {
|
||||
err = &ConstraintError{msg: err.Error(), wrap: err}
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
_u.mutation.done = true
|
||||
return _node, nil
|
||||
}
|
||||
|
||||
// PromoCodeUsageUpdateOne is the builder for updating a single PromoCodeUsage entity.
|
||||
type PromoCodeUsageUpdateOne struct {
|
||||
config
|
||||
fields []string
|
||||
hooks []Hook
|
||||
mutation *PromoCodeUsageMutation
|
||||
}
|
||||
|
||||
// SetPromoCodeID sets the "promo_code_id" field.
|
||||
func (_u *PromoCodeUsageUpdateOne) SetPromoCodeID(v int64) *PromoCodeUsageUpdateOne {
|
||||
_u.mutation.SetPromoCodeID(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillablePromoCodeID sets the "promo_code_id" field if the given value is not nil.
|
||||
func (_u *PromoCodeUsageUpdateOne) SetNillablePromoCodeID(v *int64) *PromoCodeUsageUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetPromoCodeID(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetUserID sets the "user_id" field.
|
||||
func (_u *PromoCodeUsageUpdateOne) SetUserID(v int64) *PromoCodeUsageUpdateOne {
|
||||
_u.mutation.SetUserID(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableUserID sets the "user_id" field if the given value is not nil.
|
||||
func (_u *PromoCodeUsageUpdateOne) SetNillableUserID(v *int64) *PromoCodeUsageUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetUserID(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetBonusAmount sets the "bonus_amount" field.
|
||||
func (_u *PromoCodeUsageUpdateOne) SetBonusAmount(v float64) *PromoCodeUsageUpdateOne {
|
||||
_u.mutation.ResetBonusAmount()
|
||||
_u.mutation.SetBonusAmount(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableBonusAmount sets the "bonus_amount" field if the given value is not nil.
|
||||
func (_u *PromoCodeUsageUpdateOne) SetNillableBonusAmount(v *float64) *PromoCodeUsageUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetBonusAmount(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddBonusAmount adds value to the "bonus_amount" field.
|
||||
func (_u *PromoCodeUsageUpdateOne) AddBonusAmount(v float64) *PromoCodeUsageUpdateOne {
|
||||
_u.mutation.AddBonusAmount(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetUsedAt sets the "used_at" field.
|
||||
func (_u *PromoCodeUsageUpdateOne) SetUsedAt(v time.Time) *PromoCodeUsageUpdateOne {
|
||||
_u.mutation.SetUsedAt(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableUsedAt sets the "used_at" field if the given value is not nil.
|
||||
func (_u *PromoCodeUsageUpdateOne) SetNillableUsedAt(v *time.Time) *PromoCodeUsageUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetUsedAt(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetPromoCode sets the "promo_code" edge to the PromoCode entity.
|
||||
func (_u *PromoCodeUsageUpdateOne) SetPromoCode(v *PromoCode) *PromoCodeUsageUpdateOne {
|
||||
return _u.SetPromoCodeID(v.ID)
|
||||
}
|
||||
|
||||
// SetUser sets the "user" edge to the User entity.
|
||||
func (_u *PromoCodeUsageUpdateOne) SetUser(v *User) *PromoCodeUsageUpdateOne {
|
||||
return _u.SetUserID(v.ID)
|
||||
}
|
||||
|
||||
// Mutation returns the PromoCodeUsageMutation object of the builder.
|
||||
func (_u *PromoCodeUsageUpdateOne) Mutation() *PromoCodeUsageMutation {
|
||||
return _u.mutation
|
||||
}
|
||||
|
||||
// ClearPromoCode clears the "promo_code" edge to the PromoCode entity.
|
||||
func (_u *PromoCodeUsageUpdateOne) ClearPromoCode() *PromoCodeUsageUpdateOne {
|
||||
_u.mutation.ClearPromoCode()
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearUser clears the "user" edge to the User entity.
|
||||
func (_u *PromoCodeUsageUpdateOne) ClearUser() *PromoCodeUsageUpdateOne {
|
||||
_u.mutation.ClearUser()
|
||||
return _u
|
||||
}
|
||||
|
||||
// Where appends a list predicates to the PromoCodeUsageUpdate builder.
|
||||
func (_u *PromoCodeUsageUpdateOne) Where(ps ...predicate.PromoCodeUsage) *PromoCodeUsageUpdateOne {
|
||||
_u.mutation.Where(ps...)
|
||||
return _u
|
||||
}
|
||||
|
||||
// Select allows selecting one or more fields (columns) of the returned entity.
|
||||
// The default is selecting all fields defined in the entity schema.
|
||||
func (_u *PromoCodeUsageUpdateOne) Select(field string, fields ...string) *PromoCodeUsageUpdateOne {
|
||||
_u.fields = append([]string{field}, fields...)
|
||||
return _u
|
||||
}
|
||||
|
||||
// Save executes the query and returns the updated PromoCodeUsage entity.
|
||||
func (_u *PromoCodeUsageUpdateOne) Save(ctx context.Context) (*PromoCodeUsage, error) {
|
||||
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
|
||||
}
|
||||
|
||||
// SaveX is like Save, but panics if an error occurs.
|
||||
func (_u *PromoCodeUsageUpdateOne) SaveX(ctx context.Context) *PromoCodeUsage {
|
||||
node, err := _u.Save(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return node
|
||||
}
|
||||
|
||||
// Exec executes the query on the entity.
|
||||
func (_u *PromoCodeUsageUpdateOne) Exec(ctx context.Context) error {
|
||||
_, err := _u.Save(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
// ExecX is like Exec, but panics if an error occurs.
|
||||
func (_u *PromoCodeUsageUpdateOne) ExecX(ctx context.Context) {
|
||||
if err := _u.Exec(ctx); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
// check runs all checks and user-defined validators on the builder.
|
||||
func (_u *PromoCodeUsageUpdateOne) check() error {
|
||||
if _u.mutation.PromoCodeCleared() && len(_u.mutation.PromoCodeIDs()) > 0 {
|
||||
return errors.New(`ent: clearing a required unique edge "PromoCodeUsage.promo_code"`)
|
||||
}
|
||||
if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 {
|
||||
return errors.New(`ent: clearing a required unique edge "PromoCodeUsage.user"`)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (_u *PromoCodeUsageUpdateOne) sqlSave(ctx context.Context) (_node *PromoCodeUsage, err error) {
|
||||
if err := _u.check(); err != nil {
|
||||
return _node, err
|
||||
}
|
||||
_spec := sqlgraph.NewUpdateSpec(promocodeusage.Table, promocodeusage.Columns, sqlgraph.NewFieldSpec(promocodeusage.FieldID, field.TypeInt64))
|
||||
id, ok := _u.mutation.ID()
|
||||
if !ok {
|
||||
return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "PromoCodeUsage.id" for update`)}
|
||||
}
|
||||
_spec.Node.ID.Value = id
|
||||
if fields := _u.fields; len(fields) > 0 {
|
||||
_spec.Node.Columns = make([]string, 0, len(fields))
|
||||
_spec.Node.Columns = append(_spec.Node.Columns, promocodeusage.FieldID)
|
||||
for _, f := range fields {
|
||||
if !promocodeusage.ValidColumn(f) {
|
||||
return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
|
||||
}
|
||||
if f != promocodeusage.FieldID {
|
||||
_spec.Node.Columns = append(_spec.Node.Columns, f)
|
||||
}
|
||||
}
|
||||
}
|
||||
if ps := _u.mutation.predicates; len(ps) > 0 {
|
||||
_spec.Predicate = func(selector *sql.Selector) {
|
||||
for i := range ps {
|
||||
ps[i](selector)
|
||||
}
|
||||
}
|
||||
}
|
||||
if value, ok := _u.mutation.BonusAmount(); ok {
|
||||
_spec.SetField(promocodeusage.FieldBonusAmount, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedBonusAmount(); ok {
|
||||
_spec.AddField(promocodeusage.FieldBonusAmount, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.UsedAt(); ok {
|
||||
_spec.SetField(promocodeusage.FieldUsedAt, field.TypeTime, value)
|
||||
}
|
||||
if _u.mutation.PromoCodeCleared() {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.M2O,
|
||||
Inverse: true,
|
||||
Table: promocodeusage.PromoCodeTable,
|
||||
Columns: []string{promocodeusage.PromoCodeColumn},
|
||||
Bidi: false,
|
||||
Target: &sqlgraph.EdgeTarget{
|
||||
IDSpec: sqlgraph.NewFieldSpec(promocode.FieldID, field.TypeInt64),
|
||||
},
|
||||
}
|
||||
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
|
||||
}
|
||||
if nodes := _u.mutation.PromoCodeIDs(); len(nodes) > 0 {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.M2O,
|
||||
Inverse: true,
|
||||
Table: promocodeusage.PromoCodeTable,
|
||||
Columns: []string{promocodeusage.PromoCodeColumn},
|
||||
Bidi: false,
|
||||
Target: &sqlgraph.EdgeTarget{
|
||||
IDSpec: sqlgraph.NewFieldSpec(promocode.FieldID, field.TypeInt64),
|
||||
},
|
||||
}
|
||||
for _, k := range nodes {
|
||||
edge.Target.Nodes = append(edge.Target.Nodes, k)
|
||||
}
|
||||
_spec.Edges.Add = append(_spec.Edges.Add, edge)
|
||||
}
|
||||
if _u.mutation.UserCleared() {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.M2O,
|
||||
Inverse: true,
|
||||
Table: promocodeusage.UserTable,
|
||||
Columns: []string{promocodeusage.UserColumn},
|
||||
Bidi: false,
|
||||
Target: &sqlgraph.EdgeTarget{
|
||||
IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
|
||||
},
|
||||
}
|
||||
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
|
||||
}
|
||||
if nodes := _u.mutation.UserIDs(); len(nodes) > 0 {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.M2O,
|
||||
Inverse: true,
|
||||
Table: promocodeusage.UserTable,
|
||||
Columns: []string{promocodeusage.UserColumn},
|
||||
Bidi: false,
|
||||
Target: &sqlgraph.EdgeTarget{
|
||||
IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
|
||||
},
|
||||
}
|
||||
for _, k := range nodes {
|
||||
edge.Target.Nodes = append(edge.Target.Nodes, k)
|
||||
}
|
||||
_spec.Edges.Add = append(_spec.Edges.Add, edge)
|
||||
}
|
||||
_node = &PromoCodeUsage{config: _u.config}
|
||||
_spec.Assign = _node.assignValues
|
||||
_spec.ScanValues = _node.scanValues
|
||||
if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil {
|
||||
if _, ok := err.(*sqlgraph.NotFoundError); ok {
|
||||
err = &NotFoundError{promocodeusage.Label}
|
||||
} else if sqlgraph.IsConstraintError(err) {
|
||||
err = &ConstraintError{msg: err.Error(), wrap: err}
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
_u.mutation.done = true
|
||||
return _node, nil
|
||||
}
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"math"
|
||||
|
||||
"entgo.io/ent"
|
||||
"entgo.io/ent/dialect"
|
||||
"entgo.io/ent/dialect/sql"
|
||||
"entgo.io/ent/dialect/sql/sqlgraph"
|
||||
"entgo.io/ent/schema/field"
|
||||
@@ -25,6 +26,7 @@ type ProxyQuery struct {
|
||||
inters []Interceptor
|
||||
predicates []predicate.Proxy
|
||||
withAccounts *AccountQuery
|
||||
modifiers []func(*sql.Selector)
|
||||
// intermediate query (i.e. traversal path).
|
||||
sql *sql.Selector
|
||||
path func(context.Context) (*sql.Selector, error)
|
||||
@@ -384,6 +386,9 @@ func (_q *ProxyQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Proxy,
|
||||
node.Edges.loadedTypes = loadedTypes
|
||||
return node.assignValues(columns, values)
|
||||
}
|
||||
if len(_q.modifiers) > 0 {
|
||||
_spec.Modifiers = _q.modifiers
|
||||
}
|
||||
for i := range hooks {
|
||||
hooks[i](ctx, _spec)
|
||||
}
|
||||
@@ -439,6 +444,9 @@ func (_q *ProxyQuery) loadAccounts(ctx context.Context, query *AccountQuery, nod
|
||||
|
||||
func (_q *ProxyQuery) sqlCount(ctx context.Context) (int, error) {
|
||||
_spec := _q.querySpec()
|
||||
if len(_q.modifiers) > 0 {
|
||||
_spec.Modifiers = _q.modifiers
|
||||
}
|
||||
_spec.Node.Columns = _q.ctx.Fields
|
||||
if len(_q.ctx.Fields) > 0 {
|
||||
_spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
|
||||
@@ -501,6 +509,9 @@ func (_q *ProxyQuery) sqlQuery(ctx context.Context) *sql.Selector {
|
||||
if _q.ctx.Unique != nil && *_q.ctx.Unique {
|
||||
selector.Distinct()
|
||||
}
|
||||
for _, m := range _q.modifiers {
|
||||
m(selector)
|
||||
}
|
||||
for _, p := range _q.predicates {
|
||||
p(selector)
|
||||
}
|
||||
@@ -518,6 +529,32 @@ func (_q *ProxyQuery) sqlQuery(ctx context.Context) *sql.Selector {
|
||||
return selector
|
||||
}
|
||||
|
||||
// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
|
||||
// updated, deleted or "selected ... for update" by other sessions, until the transaction is
|
||||
// either committed or rolled-back.
|
||||
func (_q *ProxyQuery) ForUpdate(opts ...sql.LockOption) *ProxyQuery {
|
||||
if _q.driver.Dialect() == dialect.Postgres {
|
||||
_q.Unique(false)
|
||||
}
|
||||
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
|
||||
s.ForUpdate(opts...)
|
||||
})
|
||||
return _q
|
||||
}
|
||||
|
||||
// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
|
||||
// on any rows that are read. Other sessions can read the rows, but cannot modify them
|
||||
// until your transaction commits.
|
||||
func (_q *ProxyQuery) ForShare(opts ...sql.LockOption) *ProxyQuery {
|
||||
if _q.driver.Dialect() == dialect.Postgres {
|
||||
_q.Unique(false)
|
||||
}
|
||||
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
|
||||
s.ForShare(opts...)
|
||||
})
|
||||
return _q
|
||||
}
|
||||
|
||||
// ProxyGroupBy is the group-by builder for Proxy entities.
|
||||
type ProxyGroupBy struct {
|
||||
selector
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"math"
|
||||
|
||||
"entgo.io/ent"
|
||||
"entgo.io/ent/dialect"
|
||||
"entgo.io/ent/dialect/sql"
|
||||
"entgo.io/ent/dialect/sql/sqlgraph"
|
||||
"entgo.io/ent/schema/field"
|
||||
@@ -26,6 +27,7 @@ type RedeemCodeQuery struct {
|
||||
predicates []predicate.RedeemCode
|
||||
withUser *UserQuery
|
||||
withGroup *GroupQuery
|
||||
modifiers []func(*sql.Selector)
|
||||
// intermediate query (i.e. traversal path).
|
||||
sql *sql.Selector
|
||||
path func(context.Context) (*sql.Selector, error)
|
||||
@@ -420,6 +422,9 @@ func (_q *RedeemCodeQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*R
|
||||
node.Edges.loadedTypes = loadedTypes
|
||||
return node.assignValues(columns, values)
|
||||
}
|
||||
if len(_q.modifiers) > 0 {
|
||||
_spec.Modifiers = _q.modifiers
|
||||
}
|
||||
for i := range hooks {
|
||||
hooks[i](ctx, _spec)
|
||||
}
|
||||
@@ -511,6 +516,9 @@ func (_q *RedeemCodeQuery) loadGroup(ctx context.Context, query *GroupQuery, nod
|
||||
|
||||
func (_q *RedeemCodeQuery) sqlCount(ctx context.Context) (int, error) {
|
||||
_spec := _q.querySpec()
|
||||
if len(_q.modifiers) > 0 {
|
||||
_spec.Modifiers = _q.modifiers
|
||||
}
|
||||
_spec.Node.Columns = _q.ctx.Fields
|
||||
if len(_q.ctx.Fields) > 0 {
|
||||
_spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
|
||||
@@ -579,6 +587,9 @@ func (_q *RedeemCodeQuery) sqlQuery(ctx context.Context) *sql.Selector {
|
||||
if _q.ctx.Unique != nil && *_q.ctx.Unique {
|
||||
selector.Distinct()
|
||||
}
|
||||
for _, m := range _q.modifiers {
|
||||
m(selector)
|
||||
}
|
||||
for _, p := range _q.predicates {
|
||||
p(selector)
|
||||
}
|
||||
@@ -596,6 +607,32 @@ func (_q *RedeemCodeQuery) sqlQuery(ctx context.Context) *sql.Selector {
|
||||
return selector
|
||||
}
|
||||
|
||||
// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
|
||||
// updated, deleted or "selected ... for update" by other sessions, until the transaction is
|
||||
// either committed or rolled-back.
|
||||
func (_q *RedeemCodeQuery) ForUpdate(opts ...sql.LockOption) *RedeemCodeQuery {
|
||||
if _q.driver.Dialect() == dialect.Postgres {
|
||||
_q.Unique(false)
|
||||
}
|
||||
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
|
||||
s.ForUpdate(opts...)
|
||||
})
|
||||
return _q
|
||||
}
|
||||
|
||||
// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
|
||||
// on any rows that are read. Other sessions can read the rows, but cannot modify them
|
||||
// until your transaction commits.
|
||||
func (_q *RedeemCodeQuery) ForShare(opts ...sql.LockOption) *RedeemCodeQuery {
|
||||
if _q.driver.Dialect() == dialect.Postgres {
|
||||
_q.Unique(false)
|
||||
}
|
||||
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
|
||||
s.ForShare(opts...)
|
||||
})
|
||||
return _q
|
||||
}
|
||||
|
||||
// RedeemCodeGroupBy is the group-by builder for RedeemCode entities.
|
||||
type RedeemCodeGroupBy struct {
|
||||
selector
|
||||
|
||||
@@ -9,6 +9,8 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/ent/accountgroup"
|
||||
"github.com/Wei-Shaw/sub2api/ent/apikey"
|
||||
"github.com/Wei-Shaw/sub2api/ent/group"
|
||||
"github.com/Wei-Shaw/sub2api/ent/promocode"
|
||||
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
|
||||
"github.com/Wei-Shaw/sub2api/ent/proxy"
|
||||
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
|
||||
"github.com/Wei-Shaw/sub2api/ent/schema"
|
||||
@@ -274,6 +276,60 @@ func init() {
|
||||
groupDescClaudeCodeOnly := groupFields[14].Descriptor()
|
||||
// group.DefaultClaudeCodeOnly holds the default value on creation for the claude_code_only field.
|
||||
group.DefaultClaudeCodeOnly = groupDescClaudeCodeOnly.Default.(bool)
|
||||
promocodeFields := schema.PromoCode{}.Fields()
|
||||
_ = promocodeFields
|
||||
// promocodeDescCode is the schema descriptor for code field.
|
||||
promocodeDescCode := promocodeFields[0].Descriptor()
|
||||
// promocode.CodeValidator is a validator for the "code" field. It is called by the builders before save.
|
||||
promocode.CodeValidator = func() func(string) error {
|
||||
validators := promocodeDescCode.Validators
|
||||
fns := [...]func(string) error{
|
||||
validators[0].(func(string) error),
|
||||
validators[1].(func(string) error),
|
||||
}
|
||||
return func(code string) error {
|
||||
for _, fn := range fns {
|
||||
if err := fn(code); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}()
|
||||
// promocodeDescBonusAmount is the schema descriptor for bonus_amount field.
|
||||
promocodeDescBonusAmount := promocodeFields[1].Descriptor()
|
||||
// promocode.DefaultBonusAmount holds the default value on creation for the bonus_amount field.
|
||||
promocode.DefaultBonusAmount = promocodeDescBonusAmount.Default.(float64)
|
||||
// promocodeDescMaxUses is the schema descriptor for max_uses field.
|
||||
promocodeDescMaxUses := promocodeFields[2].Descriptor()
|
||||
// promocode.DefaultMaxUses holds the default value on creation for the max_uses field.
|
||||
promocode.DefaultMaxUses = promocodeDescMaxUses.Default.(int)
|
||||
// promocodeDescUsedCount is the schema descriptor for used_count field.
|
||||
promocodeDescUsedCount := promocodeFields[3].Descriptor()
|
||||
// promocode.DefaultUsedCount holds the default value on creation for the used_count field.
|
||||
promocode.DefaultUsedCount = promocodeDescUsedCount.Default.(int)
|
||||
// promocodeDescStatus is the schema descriptor for status field.
|
||||
promocodeDescStatus := promocodeFields[4].Descriptor()
|
||||
// promocode.DefaultStatus holds the default value on creation for the status field.
|
||||
promocode.DefaultStatus = promocodeDescStatus.Default.(string)
|
||||
// promocode.StatusValidator is a validator for the "status" field. It is called by the builders before save.
|
||||
promocode.StatusValidator = promocodeDescStatus.Validators[0].(func(string) error)
|
||||
// promocodeDescCreatedAt is the schema descriptor for created_at field.
|
||||
promocodeDescCreatedAt := promocodeFields[7].Descriptor()
|
||||
// promocode.DefaultCreatedAt holds the default value on creation for the created_at field.
|
||||
promocode.DefaultCreatedAt = promocodeDescCreatedAt.Default.(func() time.Time)
|
||||
// promocodeDescUpdatedAt is the schema descriptor for updated_at field.
|
||||
promocodeDescUpdatedAt := promocodeFields[8].Descriptor()
|
||||
// promocode.DefaultUpdatedAt holds the default value on creation for the updated_at field.
|
||||
promocode.DefaultUpdatedAt = promocodeDescUpdatedAt.Default.(func() time.Time)
|
||||
// promocode.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
|
||||
promocode.UpdateDefaultUpdatedAt = promocodeDescUpdatedAt.UpdateDefault.(func() time.Time)
|
||||
promocodeusageFields := schema.PromoCodeUsage{}.Fields()
|
||||
_ = promocodeusageFields
|
||||
// promocodeusageDescUsedAt is the schema descriptor for used_at field.
|
||||
promocodeusageDescUsedAt := promocodeusageFields[3].Descriptor()
|
||||
// promocodeusage.DefaultUsedAt holds the default value on creation for the used_at field.
|
||||
promocodeusage.DefaultUsedAt = promocodeusageDescUsedAt.Default.(func() time.Time)
|
||||
proxyMixin := schema.Proxy{}.Mixin()
|
||||
proxyMixinHooks1 := proxyMixin[1].Hooks()
|
||||
proxy.Hooks[0] = proxyMixinHooks1[0]
|
||||
@@ -533,16 +589,20 @@ func init() {
|
||||
usagelogDescUserAgent := usagelogFields[24].Descriptor()
|
||||
// usagelog.UserAgentValidator is a validator for the "user_agent" field. It is called by the builders before save.
|
||||
usagelog.UserAgentValidator = usagelogDescUserAgent.Validators[0].(func(string) error)
|
||||
// usagelogDescIPAddress is the schema descriptor for ip_address field.
|
||||
usagelogDescIPAddress := usagelogFields[25].Descriptor()
|
||||
// usagelog.IPAddressValidator is a validator for the "ip_address" field. It is called by the builders before save.
|
||||
usagelog.IPAddressValidator = usagelogDescIPAddress.Validators[0].(func(string) error)
|
||||
// usagelogDescImageCount is the schema descriptor for image_count field.
|
||||
usagelogDescImageCount := usagelogFields[25].Descriptor()
|
||||
usagelogDescImageCount := usagelogFields[26].Descriptor()
|
||||
// usagelog.DefaultImageCount holds the default value on creation for the image_count field.
|
||||
usagelog.DefaultImageCount = usagelogDescImageCount.Default.(int)
|
||||
// usagelogDescImageSize is the schema descriptor for image_size field.
|
||||
usagelogDescImageSize := usagelogFields[26].Descriptor()
|
||||
usagelogDescImageSize := usagelogFields[27].Descriptor()
|
||||
// usagelog.ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save.
|
||||
usagelog.ImageSizeValidator = usagelogDescImageSize.Validators[0].(func(string) error)
|
||||
// usagelogDescCreatedAt is the schema descriptor for created_at field.
|
||||
usagelogDescCreatedAt := usagelogFields[27].Descriptor()
|
||||
usagelogDescCreatedAt := usagelogFields[28].Descriptor()
|
||||
// usagelog.DefaultCreatedAt holds the default value on creation for the created_at field.
|
||||
usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time)
|
||||
userMixin := schema.User{}.Mixin()
|
||||
|
||||
@@ -46,6 +46,12 @@ func (APIKey) Fields() []ent.Field {
|
||||
field.String("status").
|
||||
MaxLen(20).
|
||||
Default(service.StatusActive),
|
||||
field.JSON("ip_whitelist", []string{}).
|
||||
Optional().
|
||||
Comment("Allowed IPs/CIDRs, e.g. [\"192.168.1.100\", \"10.0.0.0/8\"]"),
|
||||
field.JSON("ip_blacklist", []string{}).
|
||||
Optional().
|
||||
Comment("Blocked IPs/CIDRs"),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
87
backend/ent/schema/promo_code.go
Normal file
87
backend/ent/schema/promo_code.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"entgo.io/ent"
|
||||
"entgo.io/ent/dialect"
|
||||
"entgo.io/ent/dialect/entsql"
|
||||
"entgo.io/ent/schema"
|
||||
"entgo.io/ent/schema/edge"
|
||||
"entgo.io/ent/schema/field"
|
||||
"entgo.io/ent/schema/index"
|
||||
)
|
||||
|
||||
// PromoCode holds the schema definition for the PromoCode entity.
|
||||
//
|
||||
// 注册优惠码:用户注册时使用,可获得赠送余额
|
||||
// 与 RedeemCode 不同,PromoCode 支持多次使用(有使用次数限制)
|
||||
//
|
||||
// 删除策略:硬删除
|
||||
type PromoCode struct {
|
||||
ent.Schema
|
||||
}
|
||||
|
||||
func (PromoCode) Annotations() []schema.Annotation {
|
||||
return []schema.Annotation{
|
||||
entsql.Annotation{Table: "promo_codes"},
|
||||
}
|
||||
}
|
||||
|
||||
func (PromoCode) Fields() []ent.Field {
|
||||
return []ent.Field{
|
||||
field.String("code").
|
||||
MaxLen(32).
|
||||
NotEmpty().
|
||||
Unique().
|
||||
Comment("优惠码"),
|
||||
field.Float("bonus_amount").
|
||||
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}).
|
||||
Default(0).
|
||||
Comment("赠送余额金额"),
|
||||
field.Int("max_uses").
|
||||
Default(0).
|
||||
Comment("最大使用次数,0表示无限制"),
|
||||
field.Int("used_count").
|
||||
Default(0).
|
||||
Comment("已使用次数"),
|
||||
field.String("status").
|
||||
MaxLen(20).
|
||||
Default(service.PromoCodeStatusActive).
|
||||
Comment("状态: active, disabled"),
|
||||
field.Time("expires_at").
|
||||
Optional().
|
||||
Nillable().
|
||||
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}).
|
||||
Comment("过期时间,null表示永不过期"),
|
||||
field.String("notes").
|
||||
Optional().
|
||||
Nillable().
|
||||
SchemaType(map[string]string{dialect.Postgres: "text"}).
|
||||
Comment("备注"),
|
||||
field.Time("created_at").
|
||||
Immutable().
|
||||
Default(time.Now).
|
||||
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
|
||||
field.Time("updated_at").
|
||||
Default(time.Now).
|
||||
UpdateDefault(time.Now).
|
||||
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
|
||||
}
|
||||
}
|
||||
|
||||
func (PromoCode) Edges() []ent.Edge {
|
||||
return []ent.Edge{
|
||||
edge.To("usage_records", PromoCodeUsage.Type),
|
||||
}
|
||||
}
|
||||
|
||||
func (PromoCode) Indexes() []ent.Index {
|
||||
return []ent.Index{
|
||||
// code 字段已在 Fields() 中声明 Unique(),无需重复索引
|
||||
index.Fields("status"),
|
||||
index.Fields("expires_at"),
|
||||
}
|
||||
}
|
||||
66
backend/ent/schema/promo_code_usage.go
Normal file
66
backend/ent/schema/promo_code_usage.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"entgo.io/ent"
|
||||
"entgo.io/ent/dialect"
|
||||
"entgo.io/ent/dialect/entsql"
|
||||
"entgo.io/ent/schema"
|
||||
"entgo.io/ent/schema/edge"
|
||||
"entgo.io/ent/schema/field"
|
||||
"entgo.io/ent/schema/index"
|
||||
)
|
||||
|
||||
// PromoCodeUsage holds the schema definition for the PromoCodeUsage entity.
|
||||
//
|
||||
// 优惠码使用记录:记录每个用户使用优惠码的情况
|
||||
type PromoCodeUsage struct {
|
||||
ent.Schema
|
||||
}
|
||||
|
||||
func (PromoCodeUsage) Annotations() []schema.Annotation {
|
||||
return []schema.Annotation{
|
||||
entsql.Annotation{Table: "promo_code_usages"},
|
||||
}
|
||||
}
|
||||
|
||||
func (PromoCodeUsage) Fields() []ent.Field {
|
||||
return []ent.Field{
|
||||
field.Int64("promo_code_id").
|
||||
Comment("优惠码ID"),
|
||||
field.Int64("user_id").
|
||||
Comment("使用用户ID"),
|
||||
field.Float("bonus_amount").
|
||||
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}).
|
||||
Comment("实际赠送金额"),
|
||||
field.Time("used_at").
|
||||
Default(time.Now).
|
||||
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}).
|
||||
Comment("使用时间"),
|
||||
}
|
||||
}
|
||||
|
||||
func (PromoCodeUsage) Edges() []ent.Edge {
|
||||
return []ent.Edge{
|
||||
edge.From("promo_code", PromoCode.Type).
|
||||
Ref("usage_records").
|
||||
Field("promo_code_id").
|
||||
Required().
|
||||
Unique(),
|
||||
edge.From("user", User.Type).
|
||||
Ref("promo_code_usages").
|
||||
Field("user_id").
|
||||
Required().
|
||||
Unique(),
|
||||
}
|
||||
}
|
||||
|
||||
func (PromoCodeUsage) Indexes() []ent.Index {
|
||||
return []ent.Index{
|
||||
index.Fields("promo_code_id"),
|
||||
index.Fields("user_id"),
|
||||
// 每个用户每个优惠码只能使用一次
|
||||
index.Fields("promo_code_id", "user_id").Unique(),
|
||||
}
|
||||
}
|
||||
@@ -100,6 +100,10 @@ func (UsageLog) Fields() []ent.Field {
|
||||
MaxLen(512).
|
||||
Optional().
|
||||
Nillable(),
|
||||
field.String("ip_address").
|
||||
MaxLen(45). // 支持 IPv6
|
||||
Optional().
|
||||
Nillable(),
|
||||
|
||||
// 图片生成字段(仅 gemini-3-pro-image 等图片模型使用)
|
||||
field.Int("image_count").
|
||||
|
||||
@@ -74,6 +74,7 @@ func (User) Edges() []ent.Edge {
|
||||
Through("user_allowed_groups", UserAllowedGroup.Type),
|
||||
edge.To("usage_logs", UsageLog.Type),
|
||||
edge.To("attribute_values", UserAttributeValue.Type),
|
||||
edge.To("promo_code_usages", PromoCodeUsage.Type),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"math"
|
||||
|
||||
"entgo.io/ent"
|
||||
"entgo.io/ent/dialect"
|
||||
"entgo.io/ent/dialect/sql"
|
||||
"entgo.io/ent/dialect/sql/sqlgraph"
|
||||
"entgo.io/ent/schema/field"
|
||||
@@ -22,6 +23,7 @@ type SettingQuery struct {
|
||||
order []setting.OrderOption
|
||||
inters []Interceptor
|
||||
predicates []predicate.Setting
|
||||
modifiers []func(*sql.Selector)
|
||||
// intermediate query (i.e. traversal path).
|
||||
sql *sql.Selector
|
||||
path func(context.Context) (*sql.Selector, error)
|
||||
@@ -343,6 +345,9 @@ func (_q *SettingQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Sett
|
||||
nodes = append(nodes, node)
|
||||
return node.assignValues(columns, values)
|
||||
}
|
||||
if len(_q.modifiers) > 0 {
|
||||
_spec.Modifiers = _q.modifiers
|
||||
}
|
||||
for i := range hooks {
|
||||
hooks[i](ctx, _spec)
|
||||
}
|
||||
@@ -357,6 +362,9 @@ func (_q *SettingQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Sett
|
||||
|
||||
func (_q *SettingQuery) sqlCount(ctx context.Context) (int, error) {
|
||||
_spec := _q.querySpec()
|
||||
if len(_q.modifiers) > 0 {
|
||||
_spec.Modifiers = _q.modifiers
|
||||
}
|
||||
_spec.Node.Columns = _q.ctx.Fields
|
||||
if len(_q.ctx.Fields) > 0 {
|
||||
_spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
|
||||
@@ -419,6 +427,9 @@ func (_q *SettingQuery) sqlQuery(ctx context.Context) *sql.Selector {
|
||||
if _q.ctx.Unique != nil && *_q.ctx.Unique {
|
||||
selector.Distinct()
|
||||
}
|
||||
for _, m := range _q.modifiers {
|
||||
m(selector)
|
||||
}
|
||||
for _, p := range _q.predicates {
|
||||
p(selector)
|
||||
}
|
||||
@@ -436,6 +447,32 @@ func (_q *SettingQuery) sqlQuery(ctx context.Context) *sql.Selector {
|
||||
return selector
|
||||
}
|
||||
|
||||
// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
|
||||
// updated, deleted or "selected ... for update" by other sessions, until the transaction is
|
||||
// either committed or rolled-back.
|
||||
func (_q *SettingQuery) ForUpdate(opts ...sql.LockOption) *SettingQuery {
|
||||
if _q.driver.Dialect() == dialect.Postgres {
|
||||
_q.Unique(false)
|
||||
}
|
||||
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
|
||||
s.ForUpdate(opts...)
|
||||
})
|
||||
return _q
|
||||
}
|
||||
|
||||
// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
|
||||
// on any rows that are read. Other sessions can read the rows, but cannot modify them
|
||||
// until your transaction commits.
|
||||
func (_q *SettingQuery) ForShare(opts ...sql.LockOption) *SettingQuery {
|
||||
if _q.driver.Dialect() == dialect.Postgres {
|
||||
_q.Unique(false)
|
||||
}
|
||||
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
|
||||
s.ForShare(opts...)
|
||||
})
|
||||
return _q
|
||||
}
|
||||
|
||||
// SettingGroupBy is the group-by builder for Setting entities.
|
||||
type SettingGroupBy struct {
|
||||
selector
|
||||
|
||||
@@ -22,6 +22,10 @@ type Tx struct {
|
||||
AccountGroup *AccountGroupClient
|
||||
// Group is the client for interacting with the Group builders.
|
||||
Group *GroupClient
|
||||
// PromoCode is the client for interacting with the PromoCode builders.
|
||||
PromoCode *PromoCodeClient
|
||||
// PromoCodeUsage is the client for interacting with the PromoCodeUsage builders.
|
||||
PromoCodeUsage *PromoCodeUsageClient
|
||||
// Proxy is the client for interacting with the Proxy builders.
|
||||
Proxy *ProxyClient
|
||||
// RedeemCode is the client for interacting with the RedeemCode builders.
|
||||
@@ -175,6 +179,8 @@ func (tx *Tx) init() {
|
||||
tx.Account = NewAccountClient(tx.config)
|
||||
tx.AccountGroup = NewAccountGroupClient(tx.config)
|
||||
tx.Group = NewGroupClient(tx.config)
|
||||
tx.PromoCode = NewPromoCodeClient(tx.config)
|
||||
tx.PromoCodeUsage = NewPromoCodeUsageClient(tx.config)
|
||||
tx.Proxy = NewProxyClient(tx.config)
|
||||
tx.RedeemCode = NewRedeemCodeClient(tx.config)
|
||||
tx.Setting = NewSettingClient(tx.config)
|
||||
|
||||
@@ -72,6 +72,8 @@ type UsageLog struct {
|
||||
FirstTokenMs *int `json:"first_token_ms,omitempty"`
|
||||
// UserAgent holds the value of the "user_agent" field.
|
||||
UserAgent *string `json:"user_agent,omitempty"`
|
||||
// IPAddress holds the value of the "ip_address" field.
|
||||
IPAddress *string `json:"ip_address,omitempty"`
|
||||
// ImageCount holds the value of the "image_count" field.
|
||||
ImageCount int `json:"image_count,omitempty"`
|
||||
// ImageSize holds the value of the "image_size" field.
|
||||
@@ -167,7 +169,7 @@ func (*UsageLog) scanValues(columns []string) ([]any, error) {
|
||||
values[i] = new(sql.NullFloat64)
|
||||
case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount:
|
||||
values[i] = new(sql.NullInt64)
|
||||
case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldUserAgent, usagelog.FieldImageSize:
|
||||
case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize:
|
||||
values[i] = new(sql.NullString)
|
||||
case usagelog.FieldCreatedAt:
|
||||
values[i] = new(sql.NullTime)
|
||||
@@ -347,6 +349,13 @@ func (_m *UsageLog) assignValues(columns []string, values []any) error {
|
||||
_m.UserAgent = new(string)
|
||||
*_m.UserAgent = value.String
|
||||
}
|
||||
case usagelog.FieldIPAddress:
|
||||
if value, ok := values[i].(*sql.NullString); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field ip_address", values[i])
|
||||
} else if value.Valid {
|
||||
_m.IPAddress = new(string)
|
||||
*_m.IPAddress = value.String
|
||||
}
|
||||
case usagelog.FieldImageCount:
|
||||
if value, ok := values[i].(*sql.NullInt64); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field image_count", values[i])
|
||||
@@ -512,6 +521,11 @@ func (_m *UsageLog) String() string {
|
||||
builder.WriteString(*v)
|
||||
}
|
||||
builder.WriteString(", ")
|
||||
if v := _m.IPAddress; v != nil {
|
||||
builder.WriteString("ip_address=")
|
||||
builder.WriteString(*v)
|
||||
}
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("image_count=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.ImageCount))
|
||||
builder.WriteString(", ")
|
||||
|
||||
@@ -64,6 +64,8 @@ const (
|
||||
FieldFirstTokenMs = "first_token_ms"
|
||||
// FieldUserAgent holds the string denoting the user_agent field in the database.
|
||||
FieldUserAgent = "user_agent"
|
||||
// FieldIPAddress holds the string denoting the ip_address field in the database.
|
||||
FieldIPAddress = "ip_address"
|
||||
// FieldImageCount holds the string denoting the image_count field in the database.
|
||||
FieldImageCount = "image_count"
|
||||
// FieldImageSize holds the string denoting the image_size field in the database.
|
||||
@@ -147,6 +149,7 @@ var Columns = []string{
|
||||
FieldDurationMs,
|
||||
FieldFirstTokenMs,
|
||||
FieldUserAgent,
|
||||
FieldIPAddress,
|
||||
FieldImageCount,
|
||||
FieldImageSize,
|
||||
FieldCreatedAt,
|
||||
@@ -199,6 +202,8 @@ var (
|
||||
DefaultStream bool
|
||||
// UserAgentValidator is a validator for the "user_agent" field. It is called by the builders before save.
|
||||
UserAgentValidator func(string) error
|
||||
// IPAddressValidator is a validator for the "ip_address" field. It is called by the builders before save.
|
||||
IPAddressValidator func(string) error
|
||||
// DefaultImageCount holds the default value on creation for the "image_count" field.
|
||||
DefaultImageCount int
|
||||
// ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save.
|
||||
@@ -340,6 +345,11 @@ func ByUserAgent(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldUserAgent, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByIPAddress orders the results by the ip_address field.
|
||||
func ByIPAddress(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldIPAddress, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByImageCount orders the results by the image_count field.
|
||||
func ByImageCount(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldImageCount, opts...).ToFunc()
|
||||
|
||||
@@ -180,6 +180,11 @@ func UserAgent(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldUserAgent, v))
|
||||
}
|
||||
|
||||
// IPAddress applies equality check predicate on the "ip_address" field. It's identical to IPAddressEQ.
|
||||
func IPAddress(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldIPAddress, v))
|
||||
}
|
||||
|
||||
// ImageCount applies equality check predicate on the "image_count" field. It's identical to ImageCountEQ.
|
||||
func ImageCount(v int) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldImageCount, v))
|
||||
@@ -1190,6 +1195,81 @@ func UserAgentContainsFold(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldContainsFold(FieldUserAgent, v))
|
||||
}
|
||||
|
||||
// IPAddressEQ applies the EQ predicate on the "ip_address" field.
|
||||
func IPAddressEQ(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldIPAddress, v))
|
||||
}
|
||||
|
||||
// IPAddressNEQ applies the NEQ predicate on the "ip_address" field.
|
||||
func IPAddressNEQ(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldNEQ(FieldIPAddress, v))
|
||||
}
|
||||
|
||||
// IPAddressIn applies the In predicate on the "ip_address" field.
|
||||
func IPAddressIn(vs ...string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldIn(FieldIPAddress, vs...))
|
||||
}
|
||||
|
||||
// IPAddressNotIn applies the NotIn predicate on the "ip_address" field.
|
||||
func IPAddressNotIn(vs ...string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldNotIn(FieldIPAddress, vs...))
|
||||
}
|
||||
|
||||
// IPAddressGT applies the GT predicate on the "ip_address" field.
|
||||
func IPAddressGT(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldGT(FieldIPAddress, v))
|
||||
}
|
||||
|
||||
// IPAddressGTE applies the GTE predicate on the "ip_address" field.
|
||||
func IPAddressGTE(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldGTE(FieldIPAddress, v))
|
||||
}
|
||||
|
||||
// IPAddressLT applies the LT predicate on the "ip_address" field.
|
||||
func IPAddressLT(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldLT(FieldIPAddress, v))
|
||||
}
|
||||
|
||||
// IPAddressLTE applies the LTE predicate on the "ip_address" field.
|
||||
func IPAddressLTE(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldLTE(FieldIPAddress, v))
|
||||
}
|
||||
|
||||
// IPAddressContains applies the Contains predicate on the "ip_address" field.
|
||||
func IPAddressContains(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldContains(FieldIPAddress, v))
|
||||
}
|
||||
|
||||
// IPAddressHasPrefix applies the HasPrefix predicate on the "ip_address" field.
|
||||
func IPAddressHasPrefix(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldHasPrefix(FieldIPAddress, v))
|
||||
}
|
||||
|
||||
// IPAddressHasSuffix applies the HasSuffix predicate on the "ip_address" field.
|
||||
func IPAddressHasSuffix(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldHasSuffix(FieldIPAddress, v))
|
||||
}
|
||||
|
||||
// IPAddressIsNil applies the IsNil predicate on the "ip_address" field.
|
||||
func IPAddressIsNil() predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldIsNull(FieldIPAddress))
|
||||
}
|
||||
|
||||
// IPAddressNotNil applies the NotNil predicate on the "ip_address" field.
|
||||
func IPAddressNotNil() predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldNotNull(FieldIPAddress))
|
||||
}
|
||||
|
||||
// IPAddressEqualFold applies the EqualFold predicate on the "ip_address" field.
|
||||
func IPAddressEqualFold(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEqualFold(FieldIPAddress, v))
|
||||
}
|
||||
|
||||
// IPAddressContainsFold applies the ContainsFold predicate on the "ip_address" field.
|
||||
func IPAddressContainsFold(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldContainsFold(FieldIPAddress, v))
|
||||
}
|
||||
|
||||
// ImageCountEQ applies the EQ predicate on the "image_count" field.
|
||||
func ImageCountEQ(v int) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldImageCount, v))
|
||||
|
||||
@@ -337,6 +337,20 @@ func (_c *UsageLogCreate) SetNillableUserAgent(v *string) *UsageLogCreate {
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetIPAddress sets the "ip_address" field.
|
||||
func (_c *UsageLogCreate) SetIPAddress(v string) *UsageLogCreate {
|
||||
_c.mutation.SetIPAddress(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableIPAddress sets the "ip_address" field if the given value is not nil.
|
||||
func (_c *UsageLogCreate) SetNillableIPAddress(v *string) *UsageLogCreate {
|
||||
if v != nil {
|
||||
_c.SetIPAddress(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetImageCount sets the "image_count" field.
|
||||
func (_c *UsageLogCreate) SetImageCount(v int) *UsageLogCreate {
|
||||
_c.mutation.SetImageCount(v)
|
||||
@@ -586,6 +600,11 @@ func (_c *UsageLogCreate) check() error {
|
||||
return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _c.mutation.IPAddress(); ok {
|
||||
if err := usagelog.IPAddressValidator(v); err != nil {
|
||||
return &ValidationError{Name: "ip_address", err: fmt.Errorf(`ent: validator failed for field "UsageLog.ip_address": %w`, err)}
|
||||
}
|
||||
}
|
||||
if _, ok := _c.mutation.ImageCount(); !ok {
|
||||
return &ValidationError{Name: "image_count", err: errors.New(`ent: missing required field "UsageLog.image_count"`)}
|
||||
}
|
||||
@@ -713,6 +732,10 @@ func (_c *UsageLogCreate) createSpec() (*UsageLog, *sqlgraph.CreateSpec) {
|
||||
_spec.SetField(usagelog.FieldUserAgent, field.TypeString, value)
|
||||
_node.UserAgent = &value
|
||||
}
|
||||
if value, ok := _c.mutation.IPAddress(); ok {
|
||||
_spec.SetField(usagelog.FieldIPAddress, field.TypeString, value)
|
||||
_node.IPAddress = &value
|
||||
}
|
||||
if value, ok := _c.mutation.ImageCount(); ok {
|
||||
_spec.SetField(usagelog.FieldImageCount, field.TypeInt, value)
|
||||
_node.ImageCount = value
|
||||
@@ -1288,6 +1311,24 @@ func (u *UsageLogUpsert) ClearUserAgent() *UsageLogUpsert {
|
||||
return u
|
||||
}
|
||||
|
||||
// SetIPAddress sets the "ip_address" field.
|
||||
func (u *UsageLogUpsert) SetIPAddress(v string) *UsageLogUpsert {
|
||||
u.Set(usagelog.FieldIPAddress, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateIPAddress sets the "ip_address" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsert) UpdateIPAddress() *UsageLogUpsert {
|
||||
u.SetExcluded(usagelog.FieldIPAddress)
|
||||
return u
|
||||
}
|
||||
|
||||
// ClearIPAddress clears the value of the "ip_address" field.
|
||||
func (u *UsageLogUpsert) ClearIPAddress() *UsageLogUpsert {
|
||||
u.SetNull(usagelog.FieldIPAddress)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetImageCount sets the "image_count" field.
|
||||
func (u *UsageLogUpsert) SetImageCount(v int) *UsageLogUpsert {
|
||||
u.Set(usagelog.FieldImageCount, v)
|
||||
@@ -1866,6 +1907,27 @@ func (u *UsageLogUpsertOne) ClearUserAgent() *UsageLogUpsertOne {
|
||||
})
|
||||
}
|
||||
|
||||
// SetIPAddress sets the "ip_address" field.
|
||||
func (u *UsageLogUpsertOne) SetIPAddress(v string) *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.SetIPAddress(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateIPAddress sets the "ip_address" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsertOne) UpdateIPAddress() *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.UpdateIPAddress()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearIPAddress clears the value of the "ip_address" field.
|
||||
func (u *UsageLogUpsertOne) ClearIPAddress() *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.ClearIPAddress()
|
||||
})
|
||||
}
|
||||
|
||||
// SetImageCount sets the "image_count" field.
|
||||
func (u *UsageLogUpsertOne) SetImageCount(v int) *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
@@ -2616,6 +2678,27 @@ func (u *UsageLogUpsertBulk) ClearUserAgent() *UsageLogUpsertBulk {
|
||||
})
|
||||
}
|
||||
|
||||
// SetIPAddress sets the "ip_address" field.
|
||||
func (u *UsageLogUpsertBulk) SetIPAddress(v string) *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.SetIPAddress(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateIPAddress sets the "ip_address" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsertBulk) UpdateIPAddress() *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.UpdateIPAddress()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearIPAddress clears the value of the "ip_address" field.
|
||||
func (u *UsageLogUpsertBulk) ClearIPAddress() *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.ClearIPAddress()
|
||||
})
|
||||
}
|
||||
|
||||
// SetImageCount sets the "image_count" field.
|
||||
func (u *UsageLogUpsertBulk) SetImageCount(v int) *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"math"
|
||||
|
||||
"entgo.io/ent"
|
||||
"entgo.io/ent/dialect"
|
||||
"entgo.io/ent/dialect/sql"
|
||||
"entgo.io/ent/dialect/sql/sqlgraph"
|
||||
"entgo.io/ent/schema/field"
|
||||
@@ -32,6 +33,7 @@ type UsageLogQuery struct {
|
||||
withAccount *AccountQuery
|
||||
withGroup *GroupQuery
|
||||
withSubscription *UserSubscriptionQuery
|
||||
modifiers []func(*sql.Selector)
|
||||
// intermediate query (i.e. traversal path).
|
||||
sql *sql.Selector
|
||||
path func(context.Context) (*sql.Selector, error)
|
||||
@@ -531,6 +533,9 @@ func (_q *UsageLogQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Usa
|
||||
node.Edges.loadedTypes = loadedTypes
|
||||
return node.assignValues(columns, values)
|
||||
}
|
||||
if len(_q.modifiers) > 0 {
|
||||
_spec.Modifiers = _q.modifiers
|
||||
}
|
||||
for i := range hooks {
|
||||
hooks[i](ctx, _spec)
|
||||
}
|
||||
@@ -727,6 +732,9 @@ func (_q *UsageLogQuery) loadSubscription(ctx context.Context, query *UserSubscr
|
||||
|
||||
func (_q *UsageLogQuery) sqlCount(ctx context.Context) (int, error) {
|
||||
_spec := _q.querySpec()
|
||||
if len(_q.modifiers) > 0 {
|
||||
_spec.Modifiers = _q.modifiers
|
||||
}
|
||||
_spec.Node.Columns = _q.ctx.Fields
|
||||
if len(_q.ctx.Fields) > 0 {
|
||||
_spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
|
||||
@@ -804,6 +812,9 @@ func (_q *UsageLogQuery) sqlQuery(ctx context.Context) *sql.Selector {
|
||||
if _q.ctx.Unique != nil && *_q.ctx.Unique {
|
||||
selector.Distinct()
|
||||
}
|
||||
for _, m := range _q.modifiers {
|
||||
m(selector)
|
||||
}
|
||||
for _, p := range _q.predicates {
|
||||
p(selector)
|
||||
}
|
||||
@@ -821,6 +832,32 @@ func (_q *UsageLogQuery) sqlQuery(ctx context.Context) *sql.Selector {
|
||||
return selector
|
||||
}
|
||||
|
||||
// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
|
||||
// updated, deleted or "selected ... for update" by other sessions, until the transaction is
|
||||
// either committed or rolled-back.
|
||||
func (_q *UsageLogQuery) ForUpdate(opts ...sql.LockOption) *UsageLogQuery {
|
||||
if _q.driver.Dialect() == dialect.Postgres {
|
||||
_q.Unique(false)
|
||||
}
|
||||
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
|
||||
s.ForUpdate(opts...)
|
||||
})
|
||||
return _q
|
||||
}
|
||||
|
||||
// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
|
||||
// on any rows that are read. Other sessions can read the rows, but cannot modify them
|
||||
// until your transaction commits.
|
||||
func (_q *UsageLogQuery) ForShare(opts ...sql.LockOption) *UsageLogQuery {
|
||||
if _q.driver.Dialect() == dialect.Postgres {
|
||||
_q.Unique(false)
|
||||
}
|
||||
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
|
||||
s.ForShare(opts...)
|
||||
})
|
||||
return _q
|
||||
}
|
||||
|
||||
// UsageLogGroupBy is the group-by builder for UsageLog entities.
|
||||
type UsageLogGroupBy struct {
|
||||
selector
|
||||
|
||||
@@ -524,6 +524,26 @@ func (_u *UsageLogUpdate) ClearUserAgent() *UsageLogUpdate {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetIPAddress sets the "ip_address" field.
|
||||
func (_u *UsageLogUpdate) SetIPAddress(v string) *UsageLogUpdate {
|
||||
_u.mutation.SetIPAddress(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableIPAddress sets the "ip_address" field if the given value is not nil.
|
||||
func (_u *UsageLogUpdate) SetNillableIPAddress(v *string) *UsageLogUpdate {
|
||||
if v != nil {
|
||||
_u.SetIPAddress(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearIPAddress clears the value of the "ip_address" field.
|
||||
func (_u *UsageLogUpdate) ClearIPAddress() *UsageLogUpdate {
|
||||
_u.mutation.ClearIPAddress()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetImageCount sets the "image_count" field.
|
||||
func (_u *UsageLogUpdate) SetImageCount(v int) *UsageLogUpdate {
|
||||
_u.mutation.ResetImageCount()
|
||||
@@ -669,6 +689,11 @@ func (_u *UsageLogUpdate) check() error {
|
||||
return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.IPAddress(); ok {
|
||||
if err := usagelog.IPAddressValidator(v); err != nil {
|
||||
return &ValidationError{Name: "ip_address", err: fmt.Errorf(`ent: validator failed for field "UsageLog.ip_address": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.ImageSize(); ok {
|
||||
if err := usagelog.ImageSizeValidator(v); err != nil {
|
||||
return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)}
|
||||
@@ -815,6 +840,12 @@ func (_u *UsageLogUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
||||
if _u.mutation.UserAgentCleared() {
|
||||
_spec.ClearField(usagelog.FieldUserAgent, field.TypeString)
|
||||
}
|
||||
if value, ok := _u.mutation.IPAddress(); ok {
|
||||
_spec.SetField(usagelog.FieldIPAddress, field.TypeString, value)
|
||||
}
|
||||
if _u.mutation.IPAddressCleared() {
|
||||
_spec.ClearField(usagelog.FieldIPAddress, field.TypeString)
|
||||
}
|
||||
if value, ok := _u.mutation.ImageCount(); ok {
|
||||
_spec.SetField(usagelog.FieldImageCount, field.TypeInt, value)
|
||||
}
|
||||
@@ -1484,6 +1515,26 @@ func (_u *UsageLogUpdateOne) ClearUserAgent() *UsageLogUpdateOne {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetIPAddress sets the "ip_address" field.
|
||||
func (_u *UsageLogUpdateOne) SetIPAddress(v string) *UsageLogUpdateOne {
|
||||
_u.mutation.SetIPAddress(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableIPAddress sets the "ip_address" field if the given value is not nil.
|
||||
func (_u *UsageLogUpdateOne) SetNillableIPAddress(v *string) *UsageLogUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetIPAddress(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearIPAddress clears the value of the "ip_address" field.
|
||||
func (_u *UsageLogUpdateOne) ClearIPAddress() *UsageLogUpdateOne {
|
||||
_u.mutation.ClearIPAddress()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetImageCount sets the "image_count" field.
|
||||
func (_u *UsageLogUpdateOne) SetImageCount(v int) *UsageLogUpdateOne {
|
||||
_u.mutation.ResetImageCount()
|
||||
@@ -1642,6 +1693,11 @@ func (_u *UsageLogUpdateOne) check() error {
|
||||
return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.IPAddress(); ok {
|
||||
if err := usagelog.IPAddressValidator(v); err != nil {
|
||||
return &ValidationError{Name: "ip_address", err: fmt.Errorf(`ent: validator failed for field "UsageLog.ip_address": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.ImageSize(); ok {
|
||||
if err := usagelog.ImageSizeValidator(v); err != nil {
|
||||
return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)}
|
||||
@@ -1805,6 +1861,12 @@ func (_u *UsageLogUpdateOne) sqlSave(ctx context.Context) (_node *UsageLog, err
|
||||
if _u.mutation.UserAgentCleared() {
|
||||
_spec.ClearField(usagelog.FieldUserAgent, field.TypeString)
|
||||
}
|
||||
if value, ok := _u.mutation.IPAddress(); ok {
|
||||
_spec.SetField(usagelog.FieldIPAddress, field.TypeString, value)
|
||||
}
|
||||
if _u.mutation.IPAddressCleared() {
|
||||
_spec.ClearField(usagelog.FieldIPAddress, field.TypeString)
|
||||
}
|
||||
if value, ok := _u.mutation.ImageCount(); ok {
|
||||
_spec.SetField(usagelog.FieldImageCount, field.TypeInt, value)
|
||||
}
|
||||
|
||||
@@ -61,11 +61,13 @@ type UserEdges struct {
|
||||
UsageLogs []*UsageLog `json:"usage_logs,omitempty"`
|
||||
// AttributeValues holds the value of the attribute_values edge.
|
||||
AttributeValues []*UserAttributeValue `json:"attribute_values,omitempty"`
|
||||
// PromoCodeUsages holds the value of the promo_code_usages edge.
|
||||
PromoCodeUsages []*PromoCodeUsage `json:"promo_code_usages,omitempty"`
|
||||
// UserAllowedGroups holds the value of the user_allowed_groups edge.
|
||||
UserAllowedGroups []*UserAllowedGroup `json:"user_allowed_groups,omitempty"`
|
||||
// loadedTypes holds the information for reporting if a
|
||||
// type was loaded (or requested) in eager-loading or not.
|
||||
loadedTypes [8]bool
|
||||
loadedTypes [9]bool
|
||||
}
|
||||
|
||||
// APIKeysOrErr returns the APIKeys value or an error if the edge
|
||||
@@ -131,10 +133,19 @@ func (e UserEdges) AttributeValuesOrErr() ([]*UserAttributeValue, error) {
|
||||
return nil, &NotLoadedError{edge: "attribute_values"}
|
||||
}
|
||||
|
||||
// PromoCodeUsagesOrErr returns the PromoCodeUsages value or an error if the edge
|
||||
// was not loaded in eager-loading.
|
||||
func (e UserEdges) PromoCodeUsagesOrErr() ([]*PromoCodeUsage, error) {
|
||||
if e.loadedTypes[7] {
|
||||
return e.PromoCodeUsages, nil
|
||||
}
|
||||
return nil, &NotLoadedError{edge: "promo_code_usages"}
|
||||
}
|
||||
|
||||
// UserAllowedGroupsOrErr returns the UserAllowedGroups value or an error if the edge
|
||||
// was not loaded in eager-loading.
|
||||
func (e UserEdges) UserAllowedGroupsOrErr() ([]*UserAllowedGroup, error) {
|
||||
if e.loadedTypes[7] {
|
||||
if e.loadedTypes[8] {
|
||||
return e.UserAllowedGroups, nil
|
||||
}
|
||||
return nil, &NotLoadedError{edge: "user_allowed_groups"}
|
||||
@@ -289,6 +300,11 @@ func (_m *User) QueryAttributeValues() *UserAttributeValueQuery {
|
||||
return NewUserClient(_m.config).QueryAttributeValues(_m)
|
||||
}
|
||||
|
||||
// QueryPromoCodeUsages queries the "promo_code_usages" edge of the User entity.
|
||||
func (_m *User) QueryPromoCodeUsages() *PromoCodeUsageQuery {
|
||||
return NewUserClient(_m.config).QueryPromoCodeUsages(_m)
|
||||
}
|
||||
|
||||
// QueryUserAllowedGroups queries the "user_allowed_groups" edge of the User entity.
|
||||
func (_m *User) QueryUserAllowedGroups() *UserAllowedGroupQuery {
|
||||
return NewUserClient(_m.config).QueryUserAllowedGroups(_m)
|
||||
|
||||
@@ -51,6 +51,8 @@ const (
|
||||
EdgeUsageLogs = "usage_logs"
|
||||
// EdgeAttributeValues holds the string denoting the attribute_values edge name in mutations.
|
||||
EdgeAttributeValues = "attribute_values"
|
||||
// EdgePromoCodeUsages holds the string denoting the promo_code_usages edge name in mutations.
|
||||
EdgePromoCodeUsages = "promo_code_usages"
|
||||
// EdgeUserAllowedGroups holds the string denoting the user_allowed_groups edge name in mutations.
|
||||
EdgeUserAllowedGroups = "user_allowed_groups"
|
||||
// Table holds the table name of the user in the database.
|
||||
@@ -102,6 +104,13 @@ const (
|
||||
AttributeValuesInverseTable = "user_attribute_values"
|
||||
// AttributeValuesColumn is the table column denoting the attribute_values relation/edge.
|
||||
AttributeValuesColumn = "user_id"
|
||||
// PromoCodeUsagesTable is the table that holds the promo_code_usages relation/edge.
|
||||
PromoCodeUsagesTable = "promo_code_usages"
|
||||
// PromoCodeUsagesInverseTable is the table name for the PromoCodeUsage entity.
|
||||
// It exists in this package in order to avoid circular dependency with the "promocodeusage" package.
|
||||
PromoCodeUsagesInverseTable = "promo_code_usages"
|
||||
// PromoCodeUsagesColumn is the table column denoting the promo_code_usages relation/edge.
|
||||
PromoCodeUsagesColumn = "user_id"
|
||||
// UserAllowedGroupsTable is the table that holds the user_allowed_groups relation/edge.
|
||||
UserAllowedGroupsTable = "user_allowed_groups"
|
||||
// UserAllowedGroupsInverseTable is the table name for the UserAllowedGroup entity.
|
||||
@@ -342,6 +351,20 @@ func ByAttributeValues(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
|
||||
}
|
||||
}
|
||||
|
||||
// ByPromoCodeUsagesCount orders the results by promo_code_usages count.
|
||||
func ByPromoCodeUsagesCount(opts ...sql.OrderTermOption) OrderOption {
|
||||
return func(s *sql.Selector) {
|
||||
sqlgraph.OrderByNeighborsCount(s, newPromoCodeUsagesStep(), opts...)
|
||||
}
|
||||
}
|
||||
|
||||
// ByPromoCodeUsages orders the results by promo_code_usages terms.
|
||||
func ByPromoCodeUsages(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
|
||||
return func(s *sql.Selector) {
|
||||
sqlgraph.OrderByNeighborTerms(s, newPromoCodeUsagesStep(), append([]sql.OrderTerm{term}, terms...)...)
|
||||
}
|
||||
}
|
||||
|
||||
// ByUserAllowedGroupsCount orders the results by user_allowed_groups count.
|
||||
func ByUserAllowedGroupsCount(opts ...sql.OrderTermOption) OrderOption {
|
||||
return func(s *sql.Selector) {
|
||||
@@ -404,6 +427,13 @@ func newAttributeValuesStep() *sqlgraph.Step {
|
||||
sqlgraph.Edge(sqlgraph.O2M, false, AttributeValuesTable, AttributeValuesColumn),
|
||||
)
|
||||
}
|
||||
func newPromoCodeUsagesStep() *sqlgraph.Step {
|
||||
return sqlgraph.NewStep(
|
||||
sqlgraph.From(Table, FieldID),
|
||||
sqlgraph.To(PromoCodeUsagesInverseTable, FieldID),
|
||||
sqlgraph.Edge(sqlgraph.O2M, false, PromoCodeUsagesTable, PromoCodeUsagesColumn),
|
||||
)
|
||||
}
|
||||
func newUserAllowedGroupsStep() *sqlgraph.Step {
|
||||
return sqlgraph.NewStep(
|
||||
sqlgraph.From(Table, FieldID),
|
||||
|
||||
@@ -871,6 +871,29 @@ func HasAttributeValuesWith(preds ...predicate.UserAttributeValue) predicate.Use
|
||||
})
|
||||
}
|
||||
|
||||
// HasPromoCodeUsages applies the HasEdge predicate on the "promo_code_usages" edge.
|
||||
func HasPromoCodeUsages() predicate.User {
|
||||
return predicate.User(func(s *sql.Selector) {
|
||||
step := sqlgraph.NewStep(
|
||||
sqlgraph.From(Table, FieldID),
|
||||
sqlgraph.Edge(sqlgraph.O2M, false, PromoCodeUsagesTable, PromoCodeUsagesColumn),
|
||||
)
|
||||
sqlgraph.HasNeighbors(s, step)
|
||||
})
|
||||
}
|
||||
|
||||
// HasPromoCodeUsagesWith applies the HasEdge predicate on the "promo_code_usages" edge with a given conditions (other predicates).
|
||||
func HasPromoCodeUsagesWith(preds ...predicate.PromoCodeUsage) predicate.User {
|
||||
return predicate.User(func(s *sql.Selector) {
|
||||
step := newPromoCodeUsagesStep()
|
||||
sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
|
||||
for _, p := range preds {
|
||||
p(s)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// HasUserAllowedGroups applies the HasEdge predicate on the "user_allowed_groups" edge.
|
||||
func HasUserAllowedGroups() predicate.User {
|
||||
return predicate.User(func(s *sql.Selector) {
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"entgo.io/ent/schema/field"
|
||||
"github.com/Wei-Shaw/sub2api/ent/apikey"
|
||||
"github.com/Wei-Shaw/sub2api/ent/group"
|
||||
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
|
||||
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
|
||||
"github.com/Wei-Shaw/sub2api/ent/usagelog"
|
||||
"github.com/Wei-Shaw/sub2api/ent/user"
|
||||
@@ -271,6 +272,21 @@ func (_c *UserCreate) AddAttributeValues(v ...*UserAttributeValue) *UserCreate {
|
||||
return _c.AddAttributeValueIDs(ids...)
|
||||
}
|
||||
|
||||
// AddPromoCodeUsageIDs adds the "promo_code_usages" edge to the PromoCodeUsage entity by IDs.
|
||||
func (_c *UserCreate) AddPromoCodeUsageIDs(ids ...int64) *UserCreate {
|
||||
_c.mutation.AddPromoCodeUsageIDs(ids...)
|
||||
return _c
|
||||
}
|
||||
|
||||
// AddPromoCodeUsages adds the "promo_code_usages" edges to the PromoCodeUsage entity.
|
||||
func (_c *UserCreate) AddPromoCodeUsages(v ...*PromoCodeUsage) *UserCreate {
|
||||
ids := make([]int64, len(v))
|
||||
for i := range v {
|
||||
ids[i] = v[i].ID
|
||||
}
|
||||
return _c.AddPromoCodeUsageIDs(ids...)
|
||||
}
|
||||
|
||||
// Mutation returns the UserMutation object of the builder.
|
||||
func (_c *UserCreate) Mutation() *UserMutation {
|
||||
return _c.mutation
|
||||
@@ -593,6 +609,22 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) {
|
||||
}
|
||||
_spec.Edges = append(_spec.Edges, edge)
|
||||
}
|
||||
if nodes := _c.mutation.PromoCodeUsagesIDs(); len(nodes) > 0 {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
Inverse: false,
|
||||
Table: user.PromoCodeUsagesTable,
|
||||
Columns: []string{user.PromoCodeUsagesColumn},
|
||||
Bidi: false,
|
||||
Target: &sqlgraph.EdgeTarget{
|
||||
IDSpec: sqlgraph.NewFieldSpec(promocodeusage.FieldID, field.TypeInt64),
|
||||
},
|
||||
}
|
||||
for _, k := range nodes {
|
||||
edge.Target.Nodes = append(edge.Target.Nodes, k)
|
||||
}
|
||||
_spec.Edges = append(_spec.Edges, edge)
|
||||
}
|
||||
return _node, _spec
|
||||
}
|
||||
|
||||
|
||||
@@ -9,12 +9,14 @@ import (
|
||||
"math"
|
||||
|
||||
"entgo.io/ent"
|
||||
"entgo.io/ent/dialect"
|
||||
"entgo.io/ent/dialect/sql"
|
||||
"entgo.io/ent/dialect/sql/sqlgraph"
|
||||
"entgo.io/ent/schema/field"
|
||||
"github.com/Wei-Shaw/sub2api/ent/apikey"
|
||||
"github.com/Wei-Shaw/sub2api/ent/group"
|
||||
"github.com/Wei-Shaw/sub2api/ent/predicate"
|
||||
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
|
||||
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
|
||||
"github.com/Wei-Shaw/sub2api/ent/usagelog"
|
||||
"github.com/Wei-Shaw/sub2api/ent/user"
|
||||
@@ -37,7 +39,9 @@ type UserQuery struct {
|
||||
withAllowedGroups *GroupQuery
|
||||
withUsageLogs *UsageLogQuery
|
||||
withAttributeValues *UserAttributeValueQuery
|
||||
withPromoCodeUsages *PromoCodeUsageQuery
|
||||
withUserAllowedGroups *UserAllowedGroupQuery
|
||||
modifiers []func(*sql.Selector)
|
||||
// intermediate query (i.e. traversal path).
|
||||
sql *sql.Selector
|
||||
path func(context.Context) (*sql.Selector, error)
|
||||
@@ -228,6 +232,28 @@ func (_q *UserQuery) QueryAttributeValues() *UserAttributeValueQuery {
|
||||
return query
|
||||
}
|
||||
|
||||
// QueryPromoCodeUsages chains the current query on the "promo_code_usages" edge.
|
||||
func (_q *UserQuery) QueryPromoCodeUsages() *PromoCodeUsageQuery {
|
||||
query := (&PromoCodeUsageClient{config: _q.config}).Query()
|
||||
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
|
||||
if err := _q.prepareQuery(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
selector := _q.sqlQuery(ctx)
|
||||
if err := selector.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
step := sqlgraph.NewStep(
|
||||
sqlgraph.From(user.Table, user.FieldID, selector),
|
||||
sqlgraph.To(promocodeusage.Table, promocodeusage.FieldID),
|
||||
sqlgraph.Edge(sqlgraph.O2M, false, user.PromoCodeUsagesTable, user.PromoCodeUsagesColumn),
|
||||
)
|
||||
fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
|
||||
return fromU, nil
|
||||
}
|
||||
return query
|
||||
}
|
||||
|
||||
// QueryUserAllowedGroups chains the current query on the "user_allowed_groups" edge.
|
||||
func (_q *UserQuery) QueryUserAllowedGroups() *UserAllowedGroupQuery {
|
||||
query := (&UserAllowedGroupClient{config: _q.config}).Query()
|
||||
@@ -449,6 +475,7 @@ func (_q *UserQuery) Clone() *UserQuery {
|
||||
withAllowedGroups: _q.withAllowedGroups.Clone(),
|
||||
withUsageLogs: _q.withUsageLogs.Clone(),
|
||||
withAttributeValues: _q.withAttributeValues.Clone(),
|
||||
withPromoCodeUsages: _q.withPromoCodeUsages.Clone(),
|
||||
withUserAllowedGroups: _q.withUserAllowedGroups.Clone(),
|
||||
// clone intermediate query.
|
||||
sql: _q.sql.Clone(),
|
||||
@@ -533,6 +560,17 @@ func (_q *UserQuery) WithAttributeValues(opts ...func(*UserAttributeValueQuery))
|
||||
return _q
|
||||
}
|
||||
|
||||
// WithPromoCodeUsages tells the query-builder to eager-load the nodes that are connected to
|
||||
// the "promo_code_usages" edge. The optional arguments are used to configure the query builder of the edge.
|
||||
func (_q *UserQuery) WithPromoCodeUsages(opts ...func(*PromoCodeUsageQuery)) *UserQuery {
|
||||
query := (&PromoCodeUsageClient{config: _q.config}).Query()
|
||||
for _, opt := range opts {
|
||||
opt(query)
|
||||
}
|
||||
_q.withPromoCodeUsages = query
|
||||
return _q
|
||||
}
|
||||
|
||||
// WithUserAllowedGroups tells the query-builder to eager-load the nodes that are connected to
|
||||
// the "user_allowed_groups" edge. The optional arguments are used to configure the query builder of the edge.
|
||||
func (_q *UserQuery) WithUserAllowedGroups(opts ...func(*UserAllowedGroupQuery)) *UserQuery {
|
||||
@@ -622,7 +660,7 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
|
||||
var (
|
||||
nodes = []*User{}
|
||||
_spec = _q.querySpec()
|
||||
loadedTypes = [8]bool{
|
||||
loadedTypes = [9]bool{
|
||||
_q.withAPIKeys != nil,
|
||||
_q.withRedeemCodes != nil,
|
||||
_q.withSubscriptions != nil,
|
||||
@@ -630,6 +668,7 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
|
||||
_q.withAllowedGroups != nil,
|
||||
_q.withUsageLogs != nil,
|
||||
_q.withAttributeValues != nil,
|
||||
_q.withPromoCodeUsages != nil,
|
||||
_q.withUserAllowedGroups != nil,
|
||||
}
|
||||
)
|
||||
@@ -642,6 +681,9 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
|
||||
node.Edges.loadedTypes = loadedTypes
|
||||
return node.assignValues(columns, values)
|
||||
}
|
||||
if len(_q.modifiers) > 0 {
|
||||
_spec.Modifiers = _q.modifiers
|
||||
}
|
||||
for i := range hooks {
|
||||
hooks[i](ctx, _spec)
|
||||
}
|
||||
@@ -702,6 +744,13 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if query := _q.withPromoCodeUsages; query != nil {
|
||||
if err := _q.loadPromoCodeUsages(ctx, query, nodes,
|
||||
func(n *User) { n.Edges.PromoCodeUsages = []*PromoCodeUsage{} },
|
||||
func(n *User, e *PromoCodeUsage) { n.Edges.PromoCodeUsages = append(n.Edges.PromoCodeUsages, e) }); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if query := _q.withUserAllowedGroups; query != nil {
|
||||
if err := _q.loadUserAllowedGroups(ctx, query, nodes,
|
||||
func(n *User) { n.Edges.UserAllowedGroups = []*UserAllowedGroup{} },
|
||||
@@ -959,6 +1008,36 @@ func (_q *UserQuery) loadAttributeValues(ctx context.Context, query *UserAttribu
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (_q *UserQuery) loadPromoCodeUsages(ctx context.Context, query *PromoCodeUsageQuery, nodes []*User, init func(*User), assign func(*User, *PromoCodeUsage)) error {
|
||||
fks := make([]driver.Value, 0, len(nodes))
|
||||
nodeids := make(map[int64]*User)
|
||||
for i := range nodes {
|
||||
fks = append(fks, nodes[i].ID)
|
||||
nodeids[nodes[i].ID] = nodes[i]
|
||||
if init != nil {
|
||||
init(nodes[i])
|
||||
}
|
||||
}
|
||||
if len(query.ctx.Fields) > 0 {
|
||||
query.ctx.AppendFieldOnce(promocodeusage.FieldUserID)
|
||||
}
|
||||
query.Where(predicate.PromoCodeUsage(func(s *sql.Selector) {
|
||||
s.Where(sql.InValues(s.C(user.PromoCodeUsagesColumn), fks...))
|
||||
}))
|
||||
neighbors, err := query.All(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, n := range neighbors {
|
||||
fk := n.UserID
|
||||
node, ok := nodeids[fk]
|
||||
if !ok {
|
||||
return fmt.Errorf(`unexpected referenced foreign-key "user_id" returned %v for node %v`, fk, n.ID)
|
||||
}
|
||||
assign(node, n)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (_q *UserQuery) loadUserAllowedGroups(ctx context.Context, query *UserAllowedGroupQuery, nodes []*User, init func(*User), assign func(*User, *UserAllowedGroup)) error {
|
||||
fks := make([]driver.Value, 0, len(nodes))
|
||||
nodeids := make(map[int64]*User)
|
||||
@@ -992,6 +1071,9 @@ func (_q *UserQuery) loadUserAllowedGroups(ctx context.Context, query *UserAllow
|
||||
|
||||
func (_q *UserQuery) sqlCount(ctx context.Context) (int, error) {
|
||||
_spec := _q.querySpec()
|
||||
if len(_q.modifiers) > 0 {
|
||||
_spec.Modifiers = _q.modifiers
|
||||
}
|
||||
_spec.Node.Columns = _q.ctx.Fields
|
||||
if len(_q.ctx.Fields) > 0 {
|
||||
_spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
|
||||
@@ -1054,6 +1136,9 @@ func (_q *UserQuery) sqlQuery(ctx context.Context) *sql.Selector {
|
||||
if _q.ctx.Unique != nil && *_q.ctx.Unique {
|
||||
selector.Distinct()
|
||||
}
|
||||
for _, m := range _q.modifiers {
|
||||
m(selector)
|
||||
}
|
||||
for _, p := range _q.predicates {
|
||||
p(selector)
|
||||
}
|
||||
@@ -1071,6 +1156,32 @@ func (_q *UserQuery) sqlQuery(ctx context.Context) *sql.Selector {
|
||||
return selector
|
||||
}
|
||||
|
||||
// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
|
||||
// updated, deleted or "selected ... for update" by other sessions, until the transaction is
|
||||
// either committed or rolled-back.
|
||||
func (_q *UserQuery) ForUpdate(opts ...sql.LockOption) *UserQuery {
|
||||
if _q.driver.Dialect() == dialect.Postgres {
|
||||
_q.Unique(false)
|
||||
}
|
||||
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
|
||||
s.ForUpdate(opts...)
|
||||
})
|
||||
return _q
|
||||
}
|
||||
|
||||
// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
|
||||
// on any rows that are read. Other sessions can read the rows, but cannot modify them
|
||||
// until your transaction commits.
|
||||
func (_q *UserQuery) ForShare(opts ...sql.LockOption) *UserQuery {
|
||||
if _q.driver.Dialect() == dialect.Postgres {
|
||||
_q.Unique(false)
|
||||
}
|
||||
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
|
||||
s.ForShare(opts...)
|
||||
})
|
||||
return _q
|
||||
}
|
||||
|
||||
// UserGroupBy is the group-by builder for User entities.
|
||||
type UserGroupBy struct {
|
||||
selector
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/ent/apikey"
|
||||
"github.com/Wei-Shaw/sub2api/ent/group"
|
||||
"github.com/Wei-Shaw/sub2api/ent/predicate"
|
||||
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
|
||||
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
|
||||
"github.com/Wei-Shaw/sub2api/ent/usagelog"
|
||||
"github.com/Wei-Shaw/sub2api/ent/user"
|
||||
@@ -291,6 +292,21 @@ func (_u *UserUpdate) AddAttributeValues(v ...*UserAttributeValue) *UserUpdate {
|
||||
return _u.AddAttributeValueIDs(ids...)
|
||||
}
|
||||
|
||||
// AddPromoCodeUsageIDs adds the "promo_code_usages" edge to the PromoCodeUsage entity by IDs.
|
||||
func (_u *UserUpdate) AddPromoCodeUsageIDs(ids ...int64) *UserUpdate {
|
||||
_u.mutation.AddPromoCodeUsageIDs(ids...)
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddPromoCodeUsages adds the "promo_code_usages" edges to the PromoCodeUsage entity.
|
||||
func (_u *UserUpdate) AddPromoCodeUsages(v ...*PromoCodeUsage) *UserUpdate {
|
||||
ids := make([]int64, len(v))
|
||||
for i := range v {
|
||||
ids[i] = v[i].ID
|
||||
}
|
||||
return _u.AddPromoCodeUsageIDs(ids...)
|
||||
}
|
||||
|
||||
// Mutation returns the UserMutation object of the builder.
|
||||
func (_u *UserUpdate) Mutation() *UserMutation {
|
||||
return _u.mutation
|
||||
@@ -443,6 +459,27 @@ func (_u *UserUpdate) RemoveAttributeValues(v ...*UserAttributeValue) *UserUpdat
|
||||
return _u.RemoveAttributeValueIDs(ids...)
|
||||
}
|
||||
|
||||
// ClearPromoCodeUsages clears all "promo_code_usages" edges to the PromoCodeUsage entity.
|
||||
func (_u *UserUpdate) ClearPromoCodeUsages() *UserUpdate {
|
||||
_u.mutation.ClearPromoCodeUsages()
|
||||
return _u
|
||||
}
|
||||
|
||||
// RemovePromoCodeUsageIDs removes the "promo_code_usages" edge to PromoCodeUsage entities by IDs.
|
||||
func (_u *UserUpdate) RemovePromoCodeUsageIDs(ids ...int64) *UserUpdate {
|
||||
_u.mutation.RemovePromoCodeUsageIDs(ids...)
|
||||
return _u
|
||||
}
|
||||
|
||||
// RemovePromoCodeUsages removes "promo_code_usages" edges to PromoCodeUsage entities.
|
||||
func (_u *UserUpdate) RemovePromoCodeUsages(v ...*PromoCodeUsage) *UserUpdate {
|
||||
ids := make([]int64, len(v))
|
||||
for i := range v {
|
||||
ids[i] = v[i].ID
|
||||
}
|
||||
return _u.RemovePromoCodeUsageIDs(ids...)
|
||||
}
|
||||
|
||||
// Save executes the query and returns the number of nodes affected by the update operation.
|
||||
func (_u *UserUpdate) Save(ctx context.Context) (int, error) {
|
||||
if err := _u.defaults(); err != nil {
|
||||
@@ -893,6 +930,51 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
||||
}
|
||||
_spec.Edges.Add = append(_spec.Edges.Add, edge)
|
||||
}
|
||||
if _u.mutation.PromoCodeUsagesCleared() {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
Inverse: false,
|
||||
Table: user.PromoCodeUsagesTable,
|
||||
Columns: []string{user.PromoCodeUsagesColumn},
|
||||
Bidi: false,
|
||||
Target: &sqlgraph.EdgeTarget{
|
||||
IDSpec: sqlgraph.NewFieldSpec(promocodeusage.FieldID, field.TypeInt64),
|
||||
},
|
||||
}
|
||||
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
|
||||
}
|
||||
if nodes := _u.mutation.RemovedPromoCodeUsagesIDs(); len(nodes) > 0 && !_u.mutation.PromoCodeUsagesCleared() {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
Inverse: false,
|
||||
Table: user.PromoCodeUsagesTable,
|
||||
Columns: []string{user.PromoCodeUsagesColumn},
|
||||
Bidi: false,
|
||||
Target: &sqlgraph.EdgeTarget{
|
||||
IDSpec: sqlgraph.NewFieldSpec(promocodeusage.FieldID, field.TypeInt64),
|
||||
},
|
||||
}
|
||||
for _, k := range nodes {
|
||||
edge.Target.Nodes = append(edge.Target.Nodes, k)
|
||||
}
|
||||
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
|
||||
}
|
||||
if nodes := _u.mutation.PromoCodeUsagesIDs(); len(nodes) > 0 {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
Inverse: false,
|
||||
Table: user.PromoCodeUsagesTable,
|
||||
Columns: []string{user.PromoCodeUsagesColumn},
|
||||
Bidi: false,
|
||||
Target: &sqlgraph.EdgeTarget{
|
||||
IDSpec: sqlgraph.NewFieldSpec(promocodeusage.FieldID, field.TypeInt64),
|
||||
},
|
||||
}
|
||||
for _, k := range nodes {
|
||||
edge.Target.Nodes = append(edge.Target.Nodes, k)
|
||||
}
|
||||
_spec.Edges.Add = append(_spec.Edges.Add, edge)
|
||||
}
|
||||
if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
|
||||
if _, ok := err.(*sqlgraph.NotFoundError); ok {
|
||||
err = &NotFoundError{user.Label}
|
||||
@@ -1170,6 +1252,21 @@ func (_u *UserUpdateOne) AddAttributeValues(v ...*UserAttributeValue) *UserUpdat
|
||||
return _u.AddAttributeValueIDs(ids...)
|
||||
}
|
||||
|
||||
// AddPromoCodeUsageIDs adds the "promo_code_usages" edge to the PromoCodeUsage entity by IDs.
|
||||
func (_u *UserUpdateOne) AddPromoCodeUsageIDs(ids ...int64) *UserUpdateOne {
|
||||
_u.mutation.AddPromoCodeUsageIDs(ids...)
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddPromoCodeUsages adds the "promo_code_usages" edges to the PromoCodeUsage entity.
|
||||
func (_u *UserUpdateOne) AddPromoCodeUsages(v ...*PromoCodeUsage) *UserUpdateOne {
|
||||
ids := make([]int64, len(v))
|
||||
for i := range v {
|
||||
ids[i] = v[i].ID
|
||||
}
|
||||
return _u.AddPromoCodeUsageIDs(ids...)
|
||||
}
|
||||
|
||||
// Mutation returns the UserMutation object of the builder.
|
||||
func (_u *UserUpdateOne) Mutation() *UserMutation {
|
||||
return _u.mutation
|
||||
@@ -1322,6 +1419,27 @@ func (_u *UserUpdateOne) RemoveAttributeValues(v ...*UserAttributeValue) *UserUp
|
||||
return _u.RemoveAttributeValueIDs(ids...)
|
||||
}
|
||||
|
||||
// ClearPromoCodeUsages clears all "promo_code_usages" edges to the PromoCodeUsage entity.
|
||||
func (_u *UserUpdateOne) ClearPromoCodeUsages() *UserUpdateOne {
|
||||
_u.mutation.ClearPromoCodeUsages()
|
||||
return _u
|
||||
}
|
||||
|
||||
// RemovePromoCodeUsageIDs removes the "promo_code_usages" edge to PromoCodeUsage entities by IDs.
|
||||
func (_u *UserUpdateOne) RemovePromoCodeUsageIDs(ids ...int64) *UserUpdateOne {
|
||||
_u.mutation.RemovePromoCodeUsageIDs(ids...)
|
||||
return _u
|
||||
}
|
||||
|
||||
// RemovePromoCodeUsages removes "promo_code_usages" edges to PromoCodeUsage entities.
|
||||
func (_u *UserUpdateOne) RemovePromoCodeUsages(v ...*PromoCodeUsage) *UserUpdateOne {
|
||||
ids := make([]int64, len(v))
|
||||
for i := range v {
|
||||
ids[i] = v[i].ID
|
||||
}
|
||||
return _u.RemovePromoCodeUsageIDs(ids...)
|
||||
}
|
||||
|
||||
// Where appends a list predicates to the UserUpdate builder.
|
||||
func (_u *UserUpdateOne) Where(ps ...predicate.User) *UserUpdateOne {
|
||||
_u.mutation.Where(ps...)
|
||||
@@ -1802,6 +1920,51 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) {
|
||||
}
|
||||
_spec.Edges.Add = append(_spec.Edges.Add, edge)
|
||||
}
|
||||
if _u.mutation.PromoCodeUsagesCleared() {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
Inverse: false,
|
||||
Table: user.PromoCodeUsagesTable,
|
||||
Columns: []string{user.PromoCodeUsagesColumn},
|
||||
Bidi: false,
|
||||
Target: &sqlgraph.EdgeTarget{
|
||||
IDSpec: sqlgraph.NewFieldSpec(promocodeusage.FieldID, field.TypeInt64),
|
||||
},
|
||||
}
|
||||
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
|
||||
}
|
||||
if nodes := _u.mutation.RemovedPromoCodeUsagesIDs(); len(nodes) > 0 && !_u.mutation.PromoCodeUsagesCleared() {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
Inverse: false,
|
||||
Table: user.PromoCodeUsagesTable,
|
||||
Columns: []string{user.PromoCodeUsagesColumn},
|
||||
Bidi: false,
|
||||
Target: &sqlgraph.EdgeTarget{
|
||||
IDSpec: sqlgraph.NewFieldSpec(promocodeusage.FieldID, field.TypeInt64),
|
||||
},
|
||||
}
|
||||
for _, k := range nodes {
|
||||
edge.Target.Nodes = append(edge.Target.Nodes, k)
|
||||
}
|
||||
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
|
||||
}
|
||||
if nodes := _u.mutation.PromoCodeUsagesIDs(); len(nodes) > 0 {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
Inverse: false,
|
||||
Table: user.PromoCodeUsagesTable,
|
||||
Columns: []string{user.PromoCodeUsagesColumn},
|
||||
Bidi: false,
|
||||
Target: &sqlgraph.EdgeTarget{
|
||||
IDSpec: sqlgraph.NewFieldSpec(promocodeusage.FieldID, field.TypeInt64),
|
||||
},
|
||||
}
|
||||
for _, k := range nodes {
|
||||
edge.Target.Nodes = append(edge.Target.Nodes, k)
|
||||
}
|
||||
_spec.Edges.Add = append(_spec.Edges.Add, edge)
|
||||
}
|
||||
_node = &User{config: _u.config}
|
||||
_spec.Assign = _node.assignValues
|
||||
_spec.ScanValues = _node.scanValues
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"math"
|
||||
|
||||
"entgo.io/ent"
|
||||
"entgo.io/ent/dialect"
|
||||
"entgo.io/ent/dialect/sql"
|
||||
"entgo.io/ent/dialect/sql/sqlgraph"
|
||||
"github.com/Wei-Shaw/sub2api/ent/group"
|
||||
@@ -25,6 +26,7 @@ type UserAllowedGroupQuery struct {
|
||||
predicates []predicate.UserAllowedGroup
|
||||
withUser *UserQuery
|
||||
withGroup *GroupQuery
|
||||
modifiers []func(*sql.Selector)
|
||||
// intermediate query (i.e. traversal path).
|
||||
sql *sql.Selector
|
||||
path func(context.Context) (*sql.Selector, error)
|
||||
@@ -347,6 +349,9 @@ func (_q *UserAllowedGroupQuery) sqlAll(ctx context.Context, hooks ...queryHook)
|
||||
node.Edges.loadedTypes = loadedTypes
|
||||
return node.assignValues(columns, values)
|
||||
}
|
||||
if len(_q.modifiers) > 0 {
|
||||
_spec.Modifiers = _q.modifiers
|
||||
}
|
||||
for i := range hooks {
|
||||
hooks[i](ctx, _spec)
|
||||
}
|
||||
@@ -432,6 +437,9 @@ func (_q *UserAllowedGroupQuery) loadGroup(ctx context.Context, query *GroupQuer
|
||||
|
||||
func (_q *UserAllowedGroupQuery) sqlCount(ctx context.Context) (int, error) {
|
||||
_spec := _q.querySpec()
|
||||
if len(_q.modifiers) > 0 {
|
||||
_spec.Modifiers = _q.modifiers
|
||||
}
|
||||
_spec.Unique = false
|
||||
_spec.Node.Columns = nil
|
||||
return sqlgraph.CountNodes(ctx, _q.driver, _spec)
|
||||
@@ -495,6 +503,9 @@ func (_q *UserAllowedGroupQuery) sqlQuery(ctx context.Context) *sql.Selector {
|
||||
if _q.ctx.Unique != nil && *_q.ctx.Unique {
|
||||
selector.Distinct()
|
||||
}
|
||||
for _, m := range _q.modifiers {
|
||||
m(selector)
|
||||
}
|
||||
for _, p := range _q.predicates {
|
||||
p(selector)
|
||||
}
|
||||
@@ -512,6 +523,32 @@ func (_q *UserAllowedGroupQuery) sqlQuery(ctx context.Context) *sql.Selector {
|
||||
return selector
|
||||
}
|
||||
|
||||
// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
|
||||
// updated, deleted or "selected ... for update" by other sessions, until the transaction is
|
||||
// either committed or rolled-back.
|
||||
func (_q *UserAllowedGroupQuery) ForUpdate(opts ...sql.LockOption) *UserAllowedGroupQuery {
|
||||
if _q.driver.Dialect() == dialect.Postgres {
|
||||
_q.Unique(false)
|
||||
}
|
||||
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
|
||||
s.ForUpdate(opts...)
|
||||
})
|
||||
return _q
|
||||
}
|
||||
|
||||
// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
|
||||
// on any rows that are read. Other sessions can read the rows, but cannot modify them
|
||||
// until your transaction commits.
|
||||
func (_q *UserAllowedGroupQuery) ForShare(opts ...sql.LockOption) *UserAllowedGroupQuery {
|
||||
if _q.driver.Dialect() == dialect.Postgres {
|
||||
_q.Unique(false)
|
||||
}
|
||||
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
|
||||
s.ForShare(opts...)
|
||||
})
|
||||
return _q
|
||||
}
|
||||
|
||||
// UserAllowedGroupGroupBy is the group-by builder for UserAllowedGroup entities.
|
||||
type UserAllowedGroupGroupBy struct {
|
||||
selector
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"math"
|
||||
|
||||
"entgo.io/ent"
|
||||
"entgo.io/ent/dialect"
|
||||
"entgo.io/ent/dialect/sql"
|
||||
"entgo.io/ent/dialect/sql/sqlgraph"
|
||||
"entgo.io/ent/schema/field"
|
||||
@@ -25,6 +26,7 @@ type UserAttributeDefinitionQuery struct {
|
||||
inters []Interceptor
|
||||
predicates []predicate.UserAttributeDefinition
|
||||
withValues *UserAttributeValueQuery
|
||||
modifiers []func(*sql.Selector)
|
||||
// intermediate query (i.e. traversal path).
|
||||
sql *sql.Selector
|
||||
path func(context.Context) (*sql.Selector, error)
|
||||
@@ -384,6 +386,9 @@ func (_q *UserAttributeDefinitionQuery) sqlAll(ctx context.Context, hooks ...que
|
||||
node.Edges.loadedTypes = loadedTypes
|
||||
return node.assignValues(columns, values)
|
||||
}
|
||||
if len(_q.modifiers) > 0 {
|
||||
_spec.Modifiers = _q.modifiers
|
||||
}
|
||||
for i := range hooks {
|
||||
hooks[i](ctx, _spec)
|
||||
}
|
||||
@@ -436,6 +441,9 @@ func (_q *UserAttributeDefinitionQuery) loadValues(ctx context.Context, query *U
|
||||
|
||||
func (_q *UserAttributeDefinitionQuery) sqlCount(ctx context.Context) (int, error) {
|
||||
_spec := _q.querySpec()
|
||||
if len(_q.modifiers) > 0 {
|
||||
_spec.Modifiers = _q.modifiers
|
||||
}
|
||||
_spec.Node.Columns = _q.ctx.Fields
|
||||
if len(_q.ctx.Fields) > 0 {
|
||||
_spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
|
||||
@@ -498,6 +506,9 @@ func (_q *UserAttributeDefinitionQuery) sqlQuery(ctx context.Context) *sql.Selec
|
||||
if _q.ctx.Unique != nil && *_q.ctx.Unique {
|
||||
selector.Distinct()
|
||||
}
|
||||
for _, m := range _q.modifiers {
|
||||
m(selector)
|
||||
}
|
||||
for _, p := range _q.predicates {
|
||||
p(selector)
|
||||
}
|
||||
@@ -515,6 +526,32 @@ func (_q *UserAttributeDefinitionQuery) sqlQuery(ctx context.Context) *sql.Selec
|
||||
return selector
|
||||
}
|
||||
|
||||
// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
|
||||
// updated, deleted or "selected ... for update" by other sessions, until the transaction is
|
||||
// either committed or rolled-back.
|
||||
func (_q *UserAttributeDefinitionQuery) ForUpdate(opts ...sql.LockOption) *UserAttributeDefinitionQuery {
|
||||
if _q.driver.Dialect() == dialect.Postgres {
|
||||
_q.Unique(false)
|
||||
}
|
||||
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
|
||||
s.ForUpdate(opts...)
|
||||
})
|
||||
return _q
|
||||
}
|
||||
|
||||
// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
|
||||
// on any rows that are read. Other sessions can read the rows, but cannot modify them
|
||||
// until your transaction commits.
|
||||
func (_q *UserAttributeDefinitionQuery) ForShare(opts ...sql.LockOption) *UserAttributeDefinitionQuery {
|
||||
if _q.driver.Dialect() == dialect.Postgres {
|
||||
_q.Unique(false)
|
||||
}
|
||||
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
|
||||
s.ForShare(opts...)
|
||||
})
|
||||
return _q
|
||||
}
|
||||
|
||||
// UserAttributeDefinitionGroupBy is the group-by builder for UserAttributeDefinition entities.
|
||||
type UserAttributeDefinitionGroupBy struct {
|
||||
selector
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"math"
|
||||
|
||||
"entgo.io/ent"
|
||||
"entgo.io/ent/dialect"
|
||||
"entgo.io/ent/dialect/sql"
|
||||
"entgo.io/ent/dialect/sql/sqlgraph"
|
||||
"entgo.io/ent/schema/field"
|
||||
@@ -26,6 +27,7 @@ type UserAttributeValueQuery struct {
|
||||
predicates []predicate.UserAttributeValue
|
||||
withUser *UserQuery
|
||||
withDefinition *UserAttributeDefinitionQuery
|
||||
modifiers []func(*sql.Selector)
|
||||
// intermediate query (i.e. traversal path).
|
||||
sql *sql.Selector
|
||||
path func(context.Context) (*sql.Selector, error)
|
||||
@@ -420,6 +422,9 @@ func (_q *UserAttributeValueQuery) sqlAll(ctx context.Context, hooks ...queryHoo
|
||||
node.Edges.loadedTypes = loadedTypes
|
||||
return node.assignValues(columns, values)
|
||||
}
|
||||
if len(_q.modifiers) > 0 {
|
||||
_spec.Modifiers = _q.modifiers
|
||||
}
|
||||
for i := range hooks {
|
||||
hooks[i](ctx, _spec)
|
||||
}
|
||||
@@ -505,6 +510,9 @@ func (_q *UserAttributeValueQuery) loadDefinition(ctx context.Context, query *Us
|
||||
|
||||
func (_q *UserAttributeValueQuery) sqlCount(ctx context.Context) (int, error) {
|
||||
_spec := _q.querySpec()
|
||||
if len(_q.modifiers) > 0 {
|
||||
_spec.Modifiers = _q.modifiers
|
||||
}
|
||||
_spec.Node.Columns = _q.ctx.Fields
|
||||
if len(_q.ctx.Fields) > 0 {
|
||||
_spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
|
||||
@@ -573,6 +581,9 @@ func (_q *UserAttributeValueQuery) sqlQuery(ctx context.Context) *sql.Selector {
|
||||
if _q.ctx.Unique != nil && *_q.ctx.Unique {
|
||||
selector.Distinct()
|
||||
}
|
||||
for _, m := range _q.modifiers {
|
||||
m(selector)
|
||||
}
|
||||
for _, p := range _q.predicates {
|
||||
p(selector)
|
||||
}
|
||||
@@ -590,6 +601,32 @@ func (_q *UserAttributeValueQuery) sqlQuery(ctx context.Context) *sql.Selector {
|
||||
return selector
|
||||
}
|
||||
|
||||
// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
|
||||
// updated, deleted or "selected ... for update" by other sessions, until the transaction is
|
||||
// either committed or rolled-back.
|
||||
func (_q *UserAttributeValueQuery) ForUpdate(opts ...sql.LockOption) *UserAttributeValueQuery {
|
||||
if _q.driver.Dialect() == dialect.Postgres {
|
||||
_q.Unique(false)
|
||||
}
|
||||
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
|
||||
s.ForUpdate(opts...)
|
||||
})
|
||||
return _q
|
||||
}
|
||||
|
||||
// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
|
||||
// on any rows that are read. Other sessions can read the rows, but cannot modify them
|
||||
// until your transaction commits.
|
||||
func (_q *UserAttributeValueQuery) ForShare(opts ...sql.LockOption) *UserAttributeValueQuery {
|
||||
if _q.driver.Dialect() == dialect.Postgres {
|
||||
_q.Unique(false)
|
||||
}
|
||||
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
|
||||
s.ForShare(opts...)
|
||||
})
|
||||
return _q
|
||||
}
|
||||
|
||||
// UserAttributeValueGroupBy is the group-by builder for UserAttributeValue entities.
|
||||
type UserAttributeValueGroupBy struct {
|
||||
selector
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"math"
|
||||
|
||||
"entgo.io/ent"
|
||||
"entgo.io/ent/dialect"
|
||||
"entgo.io/ent/dialect/sql"
|
||||
"entgo.io/ent/dialect/sql/sqlgraph"
|
||||
"entgo.io/ent/schema/field"
|
||||
@@ -30,6 +31,7 @@ type UserSubscriptionQuery struct {
|
||||
withGroup *GroupQuery
|
||||
withAssignedByUser *UserQuery
|
||||
withUsageLogs *UsageLogQuery
|
||||
modifiers []func(*sql.Selector)
|
||||
// intermediate query (i.e. traversal path).
|
||||
sql *sql.Selector
|
||||
path func(context.Context) (*sql.Selector, error)
|
||||
@@ -494,6 +496,9 @@ func (_q *UserSubscriptionQuery) sqlAll(ctx context.Context, hooks ...queryHook)
|
||||
node.Edges.loadedTypes = loadedTypes
|
||||
return node.assignValues(columns, values)
|
||||
}
|
||||
if len(_q.modifiers) > 0 {
|
||||
_spec.Modifiers = _q.modifiers
|
||||
}
|
||||
for i := range hooks {
|
||||
hooks[i](ctx, _spec)
|
||||
}
|
||||
@@ -657,6 +662,9 @@ func (_q *UserSubscriptionQuery) loadUsageLogs(ctx context.Context, query *Usage
|
||||
|
||||
func (_q *UserSubscriptionQuery) sqlCount(ctx context.Context) (int, error) {
|
||||
_spec := _q.querySpec()
|
||||
if len(_q.modifiers) > 0 {
|
||||
_spec.Modifiers = _q.modifiers
|
||||
}
|
||||
_spec.Node.Columns = _q.ctx.Fields
|
||||
if len(_q.ctx.Fields) > 0 {
|
||||
_spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
|
||||
@@ -728,6 +736,9 @@ func (_q *UserSubscriptionQuery) sqlQuery(ctx context.Context) *sql.Selector {
|
||||
if _q.ctx.Unique != nil && *_q.ctx.Unique {
|
||||
selector.Distinct()
|
||||
}
|
||||
for _, m := range _q.modifiers {
|
||||
m(selector)
|
||||
}
|
||||
for _, p := range _q.predicates {
|
||||
p(selector)
|
||||
}
|
||||
@@ -745,6 +756,32 @@ func (_q *UserSubscriptionQuery) sqlQuery(ctx context.Context) *sql.Selector {
|
||||
return selector
|
||||
}
|
||||
|
||||
// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
|
||||
// updated, deleted or "selected ... for update" by other sessions, until the transaction is
|
||||
// either committed or rolled-back.
|
||||
func (_q *UserSubscriptionQuery) ForUpdate(opts ...sql.LockOption) *UserSubscriptionQuery {
|
||||
if _q.driver.Dialect() == dialect.Postgres {
|
||||
_q.Unique(false)
|
||||
}
|
||||
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
|
||||
s.ForUpdate(opts...)
|
||||
})
|
||||
return _q
|
||||
}
|
||||
|
||||
// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
|
||||
// on any rows that are read. Other sessions can read the rows, but cannot modify them
|
||||
// until your transaction commits.
|
||||
func (_q *UserSubscriptionQuery) ForShare(opts ...sql.LockOption) *UserSubscriptionQuery {
|
||||
if _q.driver.Dialect() == dialect.Postgres {
|
||||
_q.Unique(false)
|
||||
}
|
||||
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
|
||||
s.ForShare(opts...)
|
||||
})
|
||||
return _q
|
||||
}
|
||||
|
||||
// UserSubscriptionGroupBy is the group-by builder for UserSubscription entities.
|
||||
type UserSubscriptionGroupBy struct {
|
||||
selector
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -35,24 +36,25 @@ const (
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Server ServerConfig `mapstructure:"server"`
|
||||
CORS CORSConfig `mapstructure:"cors"`
|
||||
Security SecurityConfig `mapstructure:"security"`
|
||||
Billing BillingConfig `mapstructure:"billing"`
|
||||
Turnstile TurnstileConfig `mapstructure:"turnstile"`
|
||||
Database DatabaseConfig `mapstructure:"database"`
|
||||
Redis RedisConfig `mapstructure:"redis"`
|
||||
JWT JWTConfig `mapstructure:"jwt"`
|
||||
Default DefaultConfig `mapstructure:"default"`
|
||||
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
|
||||
Pricing PricingConfig `mapstructure:"pricing"`
|
||||
Gateway GatewayConfig `mapstructure:"gateway"`
|
||||
Concurrency ConcurrencyConfig `mapstructure:"concurrency"`
|
||||
TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"`
|
||||
RunMode string `mapstructure:"run_mode" yaml:"run_mode"`
|
||||
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
|
||||
Gemini GeminiConfig `mapstructure:"gemini"`
|
||||
Update UpdateConfig `mapstructure:"update"`
|
||||
Server ServerConfig `mapstructure:"server"`
|
||||
CORS CORSConfig `mapstructure:"cors"`
|
||||
Security SecurityConfig `mapstructure:"security"`
|
||||
Billing BillingConfig `mapstructure:"billing"`
|
||||
Turnstile TurnstileConfig `mapstructure:"turnstile"`
|
||||
Database DatabaseConfig `mapstructure:"database"`
|
||||
Redis RedisConfig `mapstructure:"redis"`
|
||||
JWT JWTConfig `mapstructure:"jwt"`
|
||||
LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"`
|
||||
Default DefaultConfig `mapstructure:"default"`
|
||||
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
|
||||
Pricing PricingConfig `mapstructure:"pricing"`
|
||||
Gateway GatewayConfig `mapstructure:"gateway"`
|
||||
Concurrency ConcurrencyConfig `mapstructure:"concurrency"`
|
||||
TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"`
|
||||
RunMode string `mapstructure:"run_mode" yaml:"run_mode"`
|
||||
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
|
||||
Gemini GeminiConfig `mapstructure:"gemini"`
|
||||
Update UpdateConfig `mapstructure:"update"`
|
||||
}
|
||||
|
||||
// UpdateConfig 在线更新相关配置
|
||||
@@ -322,6 +324,30 @@ type TurnstileConfig struct {
|
||||
Required bool `mapstructure:"required"`
|
||||
}
|
||||
|
||||
// LinuxDoConnectConfig 用于 LinuxDo Connect OAuth 登录(终端用户 SSO)。
|
||||
//
|
||||
// 注意:这与上游账号的 OAuth(例如 OpenAI/Gemini 账号接入)不是一回事。
|
||||
// 这里是用于登录 Sub2API 本身的用户体系。
|
||||
type LinuxDoConnectConfig struct {
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
ClientID string `mapstructure:"client_id"`
|
||||
ClientSecret string `mapstructure:"client_secret"`
|
||||
AuthorizeURL string `mapstructure:"authorize_url"`
|
||||
TokenURL string `mapstructure:"token_url"`
|
||||
UserInfoURL string `mapstructure:"userinfo_url"`
|
||||
Scopes string `mapstructure:"scopes"`
|
||||
RedirectURL string `mapstructure:"redirect_url"` // 后端回调地址(需在提供方后台登记)
|
||||
FrontendRedirectURL string `mapstructure:"frontend_redirect_url"` // 前端接收 token 的路由(默认:/auth/linuxdo/callback)
|
||||
TokenAuthMethod string `mapstructure:"token_auth_method"` // client_secret_post / client_secret_basic / none
|
||||
UsePKCE bool `mapstructure:"use_pkce"`
|
||||
|
||||
// 可选:用于从 userinfo JSON 中提取字段的 gjson 路径。
|
||||
// 为空时,服务端会尝试一组常见字段名。
|
||||
UserInfoEmailPath string `mapstructure:"userinfo_email_path"`
|
||||
UserInfoIDPath string `mapstructure:"userinfo_id_path"`
|
||||
UserInfoUsernamePath string `mapstructure:"userinfo_username_path"`
|
||||
}
|
||||
|
||||
type DefaultConfig struct {
|
||||
AdminEmail string `mapstructure:"admin_email"`
|
||||
AdminPassword string `mapstructure:"admin_password"`
|
||||
@@ -388,6 +414,18 @@ func Load() (*Config, error) {
|
||||
cfg.Server.Mode = "debug"
|
||||
}
|
||||
cfg.JWT.Secret = strings.TrimSpace(cfg.JWT.Secret)
|
||||
cfg.LinuxDo.ClientID = strings.TrimSpace(cfg.LinuxDo.ClientID)
|
||||
cfg.LinuxDo.ClientSecret = strings.TrimSpace(cfg.LinuxDo.ClientSecret)
|
||||
cfg.LinuxDo.AuthorizeURL = strings.TrimSpace(cfg.LinuxDo.AuthorizeURL)
|
||||
cfg.LinuxDo.TokenURL = strings.TrimSpace(cfg.LinuxDo.TokenURL)
|
||||
cfg.LinuxDo.UserInfoURL = strings.TrimSpace(cfg.LinuxDo.UserInfoURL)
|
||||
cfg.LinuxDo.Scopes = strings.TrimSpace(cfg.LinuxDo.Scopes)
|
||||
cfg.LinuxDo.RedirectURL = strings.TrimSpace(cfg.LinuxDo.RedirectURL)
|
||||
cfg.LinuxDo.FrontendRedirectURL = strings.TrimSpace(cfg.LinuxDo.FrontendRedirectURL)
|
||||
cfg.LinuxDo.TokenAuthMethod = strings.ToLower(strings.TrimSpace(cfg.LinuxDo.TokenAuthMethod))
|
||||
cfg.LinuxDo.UserInfoEmailPath = strings.TrimSpace(cfg.LinuxDo.UserInfoEmailPath)
|
||||
cfg.LinuxDo.UserInfoIDPath = strings.TrimSpace(cfg.LinuxDo.UserInfoIDPath)
|
||||
cfg.LinuxDo.UserInfoUsernamePath = strings.TrimSpace(cfg.LinuxDo.UserInfoUsernamePath)
|
||||
cfg.CORS.AllowedOrigins = normalizeStringSlice(cfg.CORS.AllowedOrigins)
|
||||
cfg.Security.ResponseHeaders.AdditionalAllowed = normalizeStringSlice(cfg.Security.ResponseHeaders.AdditionalAllowed)
|
||||
cfg.Security.ResponseHeaders.ForceRemove = normalizeStringSlice(cfg.Security.ResponseHeaders.ForceRemove)
|
||||
@@ -426,6 +464,81 @@ func Load() (*Config, error) {
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
// ValidateAbsoluteHTTPURL 校验一个绝对 http(s) URL(禁止 fragment)。
|
||||
func ValidateAbsoluteHTTPURL(raw string) error {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return fmt.Errorf("empty url")
|
||||
}
|
||||
u, err := url.Parse(raw)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !u.IsAbs() {
|
||||
return fmt.Errorf("must be absolute")
|
||||
}
|
||||
if !isHTTPScheme(u.Scheme) {
|
||||
return fmt.Errorf("unsupported scheme: %s", u.Scheme)
|
||||
}
|
||||
if strings.TrimSpace(u.Host) == "" {
|
||||
return fmt.Errorf("missing host")
|
||||
}
|
||||
if u.Fragment != "" {
|
||||
return fmt.Errorf("must not include fragment")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateFrontendRedirectURL 校验前端回调地址:
|
||||
// - 允许同源相对路径(以 / 开头)
|
||||
// - 或绝对 http(s) URL(禁止 fragment)
|
||||
func ValidateFrontendRedirectURL(raw string) error {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return fmt.Errorf("empty url")
|
||||
}
|
||||
if strings.ContainsAny(raw, "\r\n") {
|
||||
return fmt.Errorf("contains invalid characters")
|
||||
}
|
||||
if strings.HasPrefix(raw, "/") {
|
||||
if strings.HasPrefix(raw, "//") {
|
||||
return fmt.Errorf("must not start with //")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
u, err := url.Parse(raw)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !u.IsAbs() {
|
||||
return fmt.Errorf("must be absolute http(s) url or relative path")
|
||||
}
|
||||
if !isHTTPScheme(u.Scheme) {
|
||||
return fmt.Errorf("unsupported scheme: %s", u.Scheme)
|
||||
}
|
||||
if strings.TrimSpace(u.Host) == "" {
|
||||
return fmt.Errorf("missing host")
|
||||
}
|
||||
if u.Fragment != "" {
|
||||
return fmt.Errorf("must not include fragment")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func isHTTPScheme(scheme string) bool {
|
||||
return strings.EqualFold(scheme, "http") || strings.EqualFold(scheme, "https")
|
||||
}
|
||||
|
||||
func warnIfInsecureURL(field, raw string) {
|
||||
u, err := url.Parse(strings.TrimSpace(raw))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if strings.EqualFold(u.Scheme, "http") {
|
||||
log.Printf("Warning: %s uses http scheme; use https in production to avoid token leakage.", field)
|
||||
}
|
||||
}
|
||||
|
||||
func setDefaults() {
|
||||
viper.SetDefault("run_mode", RunModeStandard)
|
||||
|
||||
@@ -475,6 +588,22 @@ func setDefaults() {
|
||||
// Turnstile
|
||||
viper.SetDefault("turnstile.required", false)
|
||||
|
||||
// LinuxDo Connect OAuth 登录(终端用户 SSO)
|
||||
viper.SetDefault("linuxdo_connect.enabled", false)
|
||||
viper.SetDefault("linuxdo_connect.client_id", "")
|
||||
viper.SetDefault("linuxdo_connect.client_secret", "")
|
||||
viper.SetDefault("linuxdo_connect.authorize_url", "https://connect.linux.do/oauth2/authorize")
|
||||
viper.SetDefault("linuxdo_connect.token_url", "https://connect.linux.do/oauth2/token")
|
||||
viper.SetDefault("linuxdo_connect.userinfo_url", "https://connect.linux.do/api/user")
|
||||
viper.SetDefault("linuxdo_connect.scopes", "user")
|
||||
viper.SetDefault("linuxdo_connect.redirect_url", "")
|
||||
viper.SetDefault("linuxdo_connect.frontend_redirect_url", "/auth/linuxdo/callback")
|
||||
viper.SetDefault("linuxdo_connect.token_auth_method", "client_secret_post")
|
||||
viper.SetDefault("linuxdo_connect.use_pkce", false)
|
||||
viper.SetDefault("linuxdo_connect.userinfo_email_path", "")
|
||||
viper.SetDefault("linuxdo_connect.userinfo_id_path", "")
|
||||
viper.SetDefault("linuxdo_connect.userinfo_username_path", "")
|
||||
|
||||
// Database
|
||||
viper.SetDefault("database.host", "localhost")
|
||||
viper.SetDefault("database.port", 5432)
|
||||
@@ -544,7 +673,7 @@ func setDefaults() {
|
||||
viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 30) // 并发槽位过期时间(支持超长请求)
|
||||
viper.SetDefault("gateway.stream_data_interval_timeout", 180)
|
||||
viper.SetDefault("gateway.stream_keepalive_interval", 10)
|
||||
viper.SetDefault("gateway.max_line_size", 10*1024*1024)
|
||||
viper.SetDefault("gateway.max_line_size", 40*1024*1024)
|
||||
viper.SetDefault("gateway.scheduling.sticky_session_max_waiting", 3)
|
||||
viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 45*time.Second)
|
||||
viper.SetDefault("gateway.scheduling.fallback_wait_timeout", 30*time.Second)
|
||||
@@ -586,6 +715,60 @@ func (c *Config) Validate() error {
|
||||
if c.Security.CSP.Enabled && strings.TrimSpace(c.Security.CSP.Policy) == "" {
|
||||
return fmt.Errorf("security.csp.policy is required when CSP is enabled")
|
||||
}
|
||||
if c.LinuxDo.Enabled {
|
||||
if strings.TrimSpace(c.LinuxDo.ClientID) == "" {
|
||||
return fmt.Errorf("linuxdo_connect.client_id is required when linuxdo_connect.enabled=true")
|
||||
}
|
||||
if strings.TrimSpace(c.LinuxDo.AuthorizeURL) == "" {
|
||||
return fmt.Errorf("linuxdo_connect.authorize_url is required when linuxdo_connect.enabled=true")
|
||||
}
|
||||
if strings.TrimSpace(c.LinuxDo.TokenURL) == "" {
|
||||
return fmt.Errorf("linuxdo_connect.token_url is required when linuxdo_connect.enabled=true")
|
||||
}
|
||||
if strings.TrimSpace(c.LinuxDo.UserInfoURL) == "" {
|
||||
return fmt.Errorf("linuxdo_connect.userinfo_url is required when linuxdo_connect.enabled=true")
|
||||
}
|
||||
if strings.TrimSpace(c.LinuxDo.RedirectURL) == "" {
|
||||
return fmt.Errorf("linuxdo_connect.redirect_url is required when linuxdo_connect.enabled=true")
|
||||
}
|
||||
method := strings.ToLower(strings.TrimSpace(c.LinuxDo.TokenAuthMethod))
|
||||
switch method {
|
||||
case "", "client_secret_post", "client_secret_basic", "none":
|
||||
default:
|
||||
return fmt.Errorf("linuxdo_connect.token_auth_method must be one of: client_secret_post/client_secret_basic/none")
|
||||
}
|
||||
if method == "none" && !c.LinuxDo.UsePKCE {
|
||||
return fmt.Errorf("linuxdo_connect.use_pkce must be true when linuxdo_connect.token_auth_method=none")
|
||||
}
|
||||
if (method == "" || method == "client_secret_post" || method == "client_secret_basic") && strings.TrimSpace(c.LinuxDo.ClientSecret) == "" {
|
||||
return fmt.Errorf("linuxdo_connect.client_secret is required when linuxdo_connect.enabled=true and token_auth_method is client_secret_post/client_secret_basic")
|
||||
}
|
||||
if strings.TrimSpace(c.LinuxDo.FrontendRedirectURL) == "" {
|
||||
return fmt.Errorf("linuxdo_connect.frontend_redirect_url is required when linuxdo_connect.enabled=true")
|
||||
}
|
||||
|
||||
if err := ValidateAbsoluteHTTPURL(c.LinuxDo.AuthorizeURL); err != nil {
|
||||
return fmt.Errorf("linuxdo_connect.authorize_url invalid: %w", err)
|
||||
}
|
||||
if err := ValidateAbsoluteHTTPURL(c.LinuxDo.TokenURL); err != nil {
|
||||
return fmt.Errorf("linuxdo_connect.token_url invalid: %w", err)
|
||||
}
|
||||
if err := ValidateAbsoluteHTTPURL(c.LinuxDo.UserInfoURL); err != nil {
|
||||
return fmt.Errorf("linuxdo_connect.userinfo_url invalid: %w", err)
|
||||
}
|
||||
if err := ValidateAbsoluteHTTPURL(c.LinuxDo.RedirectURL); err != nil {
|
||||
return fmt.Errorf("linuxdo_connect.redirect_url invalid: %w", err)
|
||||
}
|
||||
if err := ValidateFrontendRedirectURL(c.LinuxDo.FrontendRedirectURL); err != nil {
|
||||
return fmt.Errorf("linuxdo_connect.frontend_redirect_url invalid: %w", err)
|
||||
}
|
||||
|
||||
warnIfInsecureURL("linuxdo_connect.authorize_url", c.LinuxDo.AuthorizeURL)
|
||||
warnIfInsecureURL("linuxdo_connect.token_url", c.LinuxDo.TokenURL)
|
||||
warnIfInsecureURL("linuxdo_connect.userinfo_url", c.LinuxDo.UserInfoURL)
|
||||
warnIfInsecureURL("linuxdo_connect.redirect_url", c.LinuxDo.RedirectURL)
|
||||
warnIfInsecureURL("linuxdo_connect.frontend_redirect_url", c.LinuxDo.FrontendRedirectURL)
|
||||
}
|
||||
if c.Billing.CircuitBreaker.Enabled {
|
||||
if c.Billing.CircuitBreaker.FailureThreshold <= 0 {
|
||||
return fmt.Errorf("billing.circuit_breaker.failure_threshold must be positive")
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -90,3 +91,53 @@ func TestLoadDefaultSecurityToggles(t *testing.T) {
|
||||
t.Fatalf("ResponseHeaders.Enabled = true, want false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateLinuxDoFrontendRedirectURL(t *testing.T) {
|
||||
viper.Reset()
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error: %v", err)
|
||||
}
|
||||
|
||||
cfg.LinuxDo.Enabled = true
|
||||
cfg.LinuxDo.ClientID = "test-client"
|
||||
cfg.LinuxDo.ClientSecret = "test-secret"
|
||||
cfg.LinuxDo.RedirectURL = "https://example.com/api/v1/auth/oauth/linuxdo/callback"
|
||||
cfg.LinuxDo.TokenAuthMethod = "client_secret_post"
|
||||
cfg.LinuxDo.UsePKCE = false
|
||||
|
||||
cfg.LinuxDo.FrontendRedirectURL = "javascript:alert(1)"
|
||||
err = cfg.Validate()
|
||||
if err == nil {
|
||||
t.Fatalf("Validate() expected error for javascript scheme, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "linuxdo_connect.frontend_redirect_url") {
|
||||
t.Fatalf("Validate() expected frontend_redirect_url error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateLinuxDoPKCERequiredForPublicClient(t *testing.T) {
|
||||
viper.Reset()
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error: %v", err)
|
||||
}
|
||||
|
||||
cfg.LinuxDo.Enabled = true
|
||||
cfg.LinuxDo.ClientID = "test-client"
|
||||
cfg.LinuxDo.ClientSecret = ""
|
||||
cfg.LinuxDo.RedirectURL = "https://example.com/api/v1/auth/oauth/linuxdo/callback"
|
||||
cfg.LinuxDo.FrontendRedirectURL = "/auth/linuxdo/callback"
|
||||
cfg.LinuxDo.TokenAuthMethod = "none"
|
||||
cfg.LinuxDo.UsePKCE = false
|
||||
|
||||
err = cfg.Validate()
|
||||
if err == nil {
|
||||
t.Fatalf("Validate() expected error when token_auth_method=none and use_pkce=false, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "linuxdo_connect.use_pkce") {
|
||||
t.Fatalf("Validate() expected use_pkce error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -116,6 +116,7 @@ type BulkUpdateAccountsRequest struct {
|
||||
Concurrency *int `json:"concurrency"`
|
||||
Priority *int `json:"priority"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive error"`
|
||||
Schedulable *bool `json:"schedulable"`
|
||||
GroupIDs *[]int64 `json:"group_ids"`
|
||||
Credentials map[string]any `json:"credentials"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
@@ -136,6 +137,11 @@ func (h *AccountHandler) List(c *gin.Context) {
|
||||
accountType := c.Query("type")
|
||||
status := c.Query("status")
|
||||
search := c.Query("search")
|
||||
// 标准化和验证 search 参数
|
||||
search = strings.TrimSpace(search)
|
||||
if len(search) > 100 {
|
||||
search = search[:100]
|
||||
}
|
||||
|
||||
accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search)
|
||||
if err != nil {
|
||||
@@ -655,6 +661,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
|
||||
req.Concurrency != nil ||
|
||||
req.Priority != nil ||
|
||||
req.Status != "" ||
|
||||
req.Schedulable != nil ||
|
||||
req.GroupIDs != nil ||
|
||||
len(req.Credentials) > 0 ||
|
||||
len(req.Extra) > 0
|
||||
@@ -671,6 +678,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
|
||||
Concurrency: req.Concurrency,
|
||||
Priority: req.Priority,
|
||||
Status: req.Status,
|
||||
Schedulable: req.Schedulable,
|
||||
GroupIDs: req.GroupIDs,
|
||||
Credentials: req.Credentials,
|
||||
Extra: req.Extra,
|
||||
|
||||
@@ -2,6 +2,7 @@ package admin
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
@@ -67,6 +68,12 @@ func (h *GroupHandler) List(c *gin.Context) {
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
platform := c.Query("platform")
|
||||
status := c.Query("status")
|
||||
search := c.Query("search")
|
||||
// 标准化和验证 search 参数
|
||||
search = strings.TrimSpace(search)
|
||||
if len(search) > 100 {
|
||||
search = search[:100]
|
||||
}
|
||||
isExclusiveStr := c.Query("is_exclusive")
|
||||
|
||||
var isExclusive *bool
|
||||
@@ -75,7 +82,7 @@ func (h *GroupHandler) List(c *gin.Context) {
|
||||
isExclusive = &val
|
||||
}
|
||||
|
||||
groups, total, err := h.adminService.ListGroups(c.Request.Context(), page, pageSize, platform, status, isExclusive)
|
||||
groups, total, err := h.adminService.ListGroups(c.Request.Context(), page, pageSize, platform, status, search, isExclusive)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
|
||||
209
backend/internal/handler/admin/promo_handler.go
Normal file
209
backend/internal/handler/admin/promo_handler.go
Normal file
@@ -0,0 +1,209 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// PromoHandler handles admin promo code management
|
||||
type PromoHandler struct {
|
||||
promoService *service.PromoService
|
||||
}
|
||||
|
||||
// NewPromoHandler creates a new admin promo handler
|
||||
func NewPromoHandler(promoService *service.PromoService) *PromoHandler {
|
||||
return &PromoHandler{
|
||||
promoService: promoService,
|
||||
}
|
||||
}
|
||||
|
||||
// CreatePromoCodeRequest represents create promo code request
|
||||
type CreatePromoCodeRequest struct {
|
||||
Code string `json:"code"` // 可选,为空则自动生成
|
||||
BonusAmount float64 `json:"bonus_amount" binding:"required,min=0"` // 赠送余额
|
||||
MaxUses int `json:"max_uses" binding:"min=0"` // 最大使用次数,0=无限
|
||||
ExpiresAt *int64 `json:"expires_at"` // 过期时间戳(秒)
|
||||
Notes string `json:"notes"` // 备注
|
||||
}
|
||||
|
||||
// UpdatePromoCodeRequest represents update promo code request
|
||||
type UpdatePromoCodeRequest struct {
|
||||
Code *string `json:"code"`
|
||||
BonusAmount *float64 `json:"bonus_amount" binding:"omitempty,min=0"`
|
||||
MaxUses *int `json:"max_uses" binding:"omitempty,min=0"`
|
||||
Status *string `json:"status" binding:"omitempty,oneof=active disabled"`
|
||||
ExpiresAt *int64 `json:"expires_at"`
|
||||
Notes *string `json:"notes"`
|
||||
}
|
||||
|
||||
// List handles listing all promo codes with pagination
|
||||
// GET /api/v1/admin/promo-codes
|
||||
func (h *PromoHandler) List(c *gin.Context) {
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
status := c.Query("status")
|
||||
search := strings.TrimSpace(c.Query("search"))
|
||||
if len(search) > 100 {
|
||||
search = search[:100]
|
||||
}
|
||||
|
||||
params := pagination.PaginationParams{
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
}
|
||||
|
||||
codes, paginationResult, err := h.promoService.List(c.Request.Context(), params, status, search)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.PromoCode, 0, len(codes))
|
||||
for i := range codes {
|
||||
out = append(out, *dto.PromoCodeFromService(&codes[i]))
|
||||
}
|
||||
response.Paginated(c, out, paginationResult.Total, page, pageSize)
|
||||
}
|
||||
|
||||
// GetByID handles getting a promo code by ID
|
||||
// GET /api/v1/admin/promo-codes/:id
|
||||
func (h *PromoHandler) GetByID(c *gin.Context) {
|
||||
codeID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid promo code ID")
|
||||
return
|
||||
}
|
||||
|
||||
code, err := h.promoService.GetByID(c.Request.Context(), codeID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.PromoCodeFromService(code))
|
||||
}
|
||||
|
||||
// Create handles creating a new promo code
|
||||
// POST /api/v1/admin/promo-codes
|
||||
func (h *PromoHandler) Create(c *gin.Context) {
|
||||
var req CreatePromoCodeRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
input := &service.CreatePromoCodeInput{
|
||||
Code: req.Code,
|
||||
BonusAmount: req.BonusAmount,
|
||||
MaxUses: req.MaxUses,
|
||||
Notes: req.Notes,
|
||||
}
|
||||
|
||||
if req.ExpiresAt != nil {
|
||||
t := time.Unix(*req.ExpiresAt, 0)
|
||||
input.ExpiresAt = &t
|
||||
}
|
||||
|
||||
code, err := h.promoService.Create(c.Request.Context(), input)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.PromoCodeFromService(code))
|
||||
}
|
||||
|
||||
// Update handles updating a promo code
|
||||
// PUT /api/v1/admin/promo-codes/:id
|
||||
func (h *PromoHandler) Update(c *gin.Context) {
|
||||
codeID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid promo code ID")
|
||||
return
|
||||
}
|
||||
|
||||
var req UpdatePromoCodeRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
input := &service.UpdatePromoCodeInput{
|
||||
Code: req.Code,
|
||||
BonusAmount: req.BonusAmount,
|
||||
MaxUses: req.MaxUses,
|
||||
Status: req.Status,
|
||||
Notes: req.Notes,
|
||||
}
|
||||
|
||||
if req.ExpiresAt != nil {
|
||||
if *req.ExpiresAt == 0 {
|
||||
// 0 表示清除过期时间
|
||||
input.ExpiresAt = nil
|
||||
} else {
|
||||
t := time.Unix(*req.ExpiresAt, 0)
|
||||
input.ExpiresAt = &t
|
||||
}
|
||||
}
|
||||
|
||||
code, err := h.promoService.Update(c.Request.Context(), codeID, input)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.PromoCodeFromService(code))
|
||||
}
|
||||
|
||||
// Delete handles deleting a promo code
|
||||
// DELETE /api/v1/admin/promo-codes/:id
|
||||
func (h *PromoHandler) Delete(c *gin.Context) {
|
||||
codeID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid promo code ID")
|
||||
return
|
||||
}
|
||||
|
||||
err = h.promoService.Delete(c.Request.Context(), codeID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "Promo code deleted successfully"})
|
||||
}
|
||||
|
||||
// GetUsages handles getting usage records for a promo code
|
||||
// GET /api/v1/admin/promo-codes/:id/usages
|
||||
func (h *PromoHandler) GetUsages(c *gin.Context) {
|
||||
codeID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid promo code ID")
|
||||
return
|
||||
}
|
||||
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
params := pagination.PaginationParams{
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
}
|
||||
|
||||
usages, paginationResult, err := h.promoService.ListUsages(c.Request.Context(), codeID, params)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.PromoCodeUsage, 0, len(usages))
|
||||
for i := range usages {
|
||||
out = append(out, *dto.PromoCodeUsageFromService(&usages[i]))
|
||||
}
|
||||
response.Paginated(c, out, paginationResult.Total, page, pageSize)
|
||||
}
|
||||
@@ -51,6 +51,11 @@ func (h *ProxyHandler) List(c *gin.Context) {
|
||||
protocol := c.Query("protocol")
|
||||
status := c.Query("status")
|
||||
search := c.Query("search")
|
||||
// 标准化和验证 search 参数
|
||||
search = strings.TrimSpace(search)
|
||||
if len(search) > 100 {
|
||||
search = search[:100]
|
||||
}
|
||||
|
||||
proxies, total, err := h.adminService.ListProxiesWithAccountCount(c.Request.Context(), page, pageSize, protocol, status, search)
|
||||
if err != nil {
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/csv"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
@@ -41,6 +42,11 @@ func (h *RedeemHandler) List(c *gin.Context) {
|
||||
codeType := c.Query("type")
|
||||
status := c.Query("status")
|
||||
search := c.Query("search")
|
||||
// 标准化和验证 search 参数
|
||||
search = strings.TrimSpace(search)
|
||||
if len(search) > 100 {
|
||||
search = search[:100]
|
||||
}
|
||||
|
||||
codes, total, err := h.adminService.ListRedeemCodes(c.Request.Context(), page, pageSize, codeType, status, search)
|
||||
if err != nil {
|
||||
|
||||
@@ -2,8 +2,10 @@ package admin
|
||||
|
||||
import (
|
||||
"log"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
@@ -38,33 +40,37 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
}
|
||||
|
||||
response.Success(c, dto.SystemSettings{
|
||||
RegistrationEnabled: settings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: settings.EmailVerifyEnabled,
|
||||
SMTPHost: settings.SMTPHost,
|
||||
SMTPPort: settings.SMTPPort,
|
||||
SMTPUsername: settings.SMTPUsername,
|
||||
SMTPPasswordConfigured: settings.SMTPPasswordConfigured,
|
||||
SMTPFrom: settings.SMTPFrom,
|
||||
SMTPFromName: settings.SMTPFromName,
|
||||
SMTPUseTLS: settings.SMTPUseTLS,
|
||||
TurnstileEnabled: settings.TurnstileEnabled,
|
||||
TurnstileSiteKey: settings.TurnstileSiteKey,
|
||||
TurnstileSecretKeyConfigured: settings.TurnstileSecretKeyConfigured,
|
||||
SiteName: settings.SiteName,
|
||||
SiteLogo: settings.SiteLogo,
|
||||
SiteSubtitle: settings.SiteSubtitle,
|
||||
APIBaseURL: settings.APIBaseURL,
|
||||
ContactInfo: settings.ContactInfo,
|
||||
DocURL: settings.DocURL,
|
||||
DefaultConcurrency: settings.DefaultConcurrency,
|
||||
DefaultBalance: settings.DefaultBalance,
|
||||
EnableModelFallback: settings.EnableModelFallback,
|
||||
FallbackModelAnthropic: settings.FallbackModelAnthropic,
|
||||
FallbackModelOpenAI: settings.FallbackModelOpenAI,
|
||||
FallbackModelGemini: settings.FallbackModelGemini,
|
||||
FallbackModelAntigravity: settings.FallbackModelAntigravity,
|
||||
EnableIdentityPatch: settings.EnableIdentityPatch,
|
||||
IdentityPatchPrompt: settings.IdentityPatchPrompt,
|
||||
RegistrationEnabled: settings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: settings.EmailVerifyEnabled,
|
||||
SMTPHost: settings.SMTPHost,
|
||||
SMTPPort: settings.SMTPPort,
|
||||
SMTPUsername: settings.SMTPUsername,
|
||||
SMTPPasswordConfigured: settings.SMTPPasswordConfigured,
|
||||
SMTPFrom: settings.SMTPFrom,
|
||||
SMTPFromName: settings.SMTPFromName,
|
||||
SMTPUseTLS: settings.SMTPUseTLS,
|
||||
TurnstileEnabled: settings.TurnstileEnabled,
|
||||
TurnstileSiteKey: settings.TurnstileSiteKey,
|
||||
TurnstileSecretKeyConfigured: settings.TurnstileSecretKeyConfigured,
|
||||
LinuxDoConnectEnabled: settings.LinuxDoConnectEnabled,
|
||||
LinuxDoConnectClientID: settings.LinuxDoConnectClientID,
|
||||
LinuxDoConnectClientSecretConfigured: settings.LinuxDoConnectClientSecretConfigured,
|
||||
LinuxDoConnectRedirectURL: settings.LinuxDoConnectRedirectURL,
|
||||
SiteName: settings.SiteName,
|
||||
SiteLogo: settings.SiteLogo,
|
||||
SiteSubtitle: settings.SiteSubtitle,
|
||||
APIBaseURL: settings.APIBaseURL,
|
||||
ContactInfo: settings.ContactInfo,
|
||||
DocURL: settings.DocURL,
|
||||
DefaultConcurrency: settings.DefaultConcurrency,
|
||||
DefaultBalance: settings.DefaultBalance,
|
||||
EnableModelFallback: settings.EnableModelFallback,
|
||||
FallbackModelAnthropic: settings.FallbackModelAnthropic,
|
||||
FallbackModelOpenAI: settings.FallbackModelOpenAI,
|
||||
FallbackModelGemini: settings.FallbackModelGemini,
|
||||
FallbackModelAntigravity: settings.FallbackModelAntigravity,
|
||||
EnableIdentityPatch: settings.EnableIdentityPatch,
|
||||
IdentityPatchPrompt: settings.IdentityPatchPrompt,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -88,6 +94,12 @@ type UpdateSettingsRequest struct {
|
||||
TurnstileSiteKey string `json:"turnstile_site_key"`
|
||||
TurnstileSecretKey string `json:"turnstile_secret_key"`
|
||||
|
||||
// LinuxDo Connect OAuth 登录(终端用户 SSO)
|
||||
LinuxDoConnectEnabled bool `json:"linuxdo_connect_enabled"`
|
||||
LinuxDoConnectClientID string `json:"linuxdo_connect_client_id"`
|
||||
LinuxDoConnectClientSecret string `json:"linuxdo_connect_client_secret"`
|
||||
LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"`
|
||||
|
||||
// OEM设置
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
@@ -165,34 +177,67 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// LinuxDo Connect 参数验证
|
||||
if req.LinuxDoConnectEnabled {
|
||||
req.LinuxDoConnectClientID = strings.TrimSpace(req.LinuxDoConnectClientID)
|
||||
req.LinuxDoConnectClientSecret = strings.TrimSpace(req.LinuxDoConnectClientSecret)
|
||||
req.LinuxDoConnectRedirectURL = strings.TrimSpace(req.LinuxDoConnectRedirectURL)
|
||||
|
||||
if req.LinuxDoConnectClientID == "" {
|
||||
response.BadRequest(c, "LinuxDo Client ID is required when enabled")
|
||||
return
|
||||
}
|
||||
if req.LinuxDoConnectRedirectURL == "" {
|
||||
response.BadRequest(c, "LinuxDo Redirect URL is required when enabled")
|
||||
return
|
||||
}
|
||||
if err := config.ValidateAbsoluteHTTPURL(req.LinuxDoConnectRedirectURL); err != nil {
|
||||
response.BadRequest(c, "LinuxDo Redirect URL must be an absolute http(s) URL")
|
||||
return
|
||||
}
|
||||
|
||||
// 如果未提供 client_secret,则保留现有值(如有)。
|
||||
if req.LinuxDoConnectClientSecret == "" {
|
||||
if previousSettings.LinuxDoConnectClientSecret == "" {
|
||||
response.BadRequest(c, "LinuxDo Client Secret is required when enabled")
|
||||
return
|
||||
}
|
||||
req.LinuxDoConnectClientSecret = previousSettings.LinuxDoConnectClientSecret
|
||||
}
|
||||
}
|
||||
|
||||
settings := &service.SystemSettings{
|
||||
RegistrationEnabled: req.RegistrationEnabled,
|
||||
EmailVerifyEnabled: req.EmailVerifyEnabled,
|
||||
SMTPHost: req.SMTPHost,
|
||||
SMTPPort: req.SMTPPort,
|
||||
SMTPUsername: req.SMTPUsername,
|
||||
SMTPPassword: req.SMTPPassword,
|
||||
SMTPFrom: req.SMTPFrom,
|
||||
SMTPFromName: req.SMTPFromName,
|
||||
SMTPUseTLS: req.SMTPUseTLS,
|
||||
TurnstileEnabled: req.TurnstileEnabled,
|
||||
TurnstileSiteKey: req.TurnstileSiteKey,
|
||||
TurnstileSecretKey: req.TurnstileSecretKey,
|
||||
SiteName: req.SiteName,
|
||||
SiteLogo: req.SiteLogo,
|
||||
SiteSubtitle: req.SiteSubtitle,
|
||||
APIBaseURL: req.APIBaseURL,
|
||||
ContactInfo: req.ContactInfo,
|
||||
DocURL: req.DocURL,
|
||||
DefaultConcurrency: req.DefaultConcurrency,
|
||||
DefaultBalance: req.DefaultBalance,
|
||||
EnableModelFallback: req.EnableModelFallback,
|
||||
FallbackModelAnthropic: req.FallbackModelAnthropic,
|
||||
FallbackModelOpenAI: req.FallbackModelOpenAI,
|
||||
FallbackModelGemini: req.FallbackModelGemini,
|
||||
FallbackModelAntigravity: req.FallbackModelAntigravity,
|
||||
EnableIdentityPatch: req.EnableIdentityPatch,
|
||||
IdentityPatchPrompt: req.IdentityPatchPrompt,
|
||||
RegistrationEnabled: req.RegistrationEnabled,
|
||||
EmailVerifyEnabled: req.EmailVerifyEnabled,
|
||||
SMTPHost: req.SMTPHost,
|
||||
SMTPPort: req.SMTPPort,
|
||||
SMTPUsername: req.SMTPUsername,
|
||||
SMTPPassword: req.SMTPPassword,
|
||||
SMTPFrom: req.SMTPFrom,
|
||||
SMTPFromName: req.SMTPFromName,
|
||||
SMTPUseTLS: req.SMTPUseTLS,
|
||||
TurnstileEnabled: req.TurnstileEnabled,
|
||||
TurnstileSiteKey: req.TurnstileSiteKey,
|
||||
TurnstileSecretKey: req.TurnstileSecretKey,
|
||||
LinuxDoConnectEnabled: req.LinuxDoConnectEnabled,
|
||||
LinuxDoConnectClientID: req.LinuxDoConnectClientID,
|
||||
LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret,
|
||||
LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL,
|
||||
SiteName: req.SiteName,
|
||||
SiteLogo: req.SiteLogo,
|
||||
SiteSubtitle: req.SiteSubtitle,
|
||||
APIBaseURL: req.APIBaseURL,
|
||||
ContactInfo: req.ContactInfo,
|
||||
DocURL: req.DocURL,
|
||||
DefaultConcurrency: req.DefaultConcurrency,
|
||||
DefaultBalance: req.DefaultBalance,
|
||||
EnableModelFallback: req.EnableModelFallback,
|
||||
FallbackModelAnthropic: req.FallbackModelAnthropic,
|
||||
FallbackModelOpenAI: req.FallbackModelOpenAI,
|
||||
FallbackModelGemini: req.FallbackModelGemini,
|
||||
FallbackModelAntigravity: req.FallbackModelAntigravity,
|
||||
EnableIdentityPatch: req.EnableIdentityPatch,
|
||||
IdentityPatchPrompt: req.IdentityPatchPrompt,
|
||||
}
|
||||
|
||||
if err := h.settingService.UpdateSettings(c.Request.Context(), settings); err != nil {
|
||||
@@ -210,33 +255,37 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
}
|
||||
|
||||
response.Success(c, dto.SystemSettings{
|
||||
RegistrationEnabled: updatedSettings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
|
||||
SMTPHost: updatedSettings.SMTPHost,
|
||||
SMTPPort: updatedSettings.SMTPPort,
|
||||
SMTPUsername: updatedSettings.SMTPUsername,
|
||||
SMTPPasswordConfigured: updatedSettings.SMTPPasswordConfigured,
|
||||
SMTPFrom: updatedSettings.SMTPFrom,
|
||||
SMTPFromName: updatedSettings.SMTPFromName,
|
||||
SMTPUseTLS: updatedSettings.SMTPUseTLS,
|
||||
TurnstileEnabled: updatedSettings.TurnstileEnabled,
|
||||
TurnstileSiteKey: updatedSettings.TurnstileSiteKey,
|
||||
TurnstileSecretKeyConfigured: updatedSettings.TurnstileSecretKeyConfigured,
|
||||
SiteName: updatedSettings.SiteName,
|
||||
SiteLogo: updatedSettings.SiteLogo,
|
||||
SiteSubtitle: updatedSettings.SiteSubtitle,
|
||||
APIBaseURL: updatedSettings.APIBaseURL,
|
||||
ContactInfo: updatedSettings.ContactInfo,
|
||||
DocURL: updatedSettings.DocURL,
|
||||
DefaultConcurrency: updatedSettings.DefaultConcurrency,
|
||||
DefaultBalance: updatedSettings.DefaultBalance,
|
||||
EnableModelFallback: updatedSettings.EnableModelFallback,
|
||||
FallbackModelAnthropic: updatedSettings.FallbackModelAnthropic,
|
||||
FallbackModelOpenAI: updatedSettings.FallbackModelOpenAI,
|
||||
FallbackModelGemini: updatedSettings.FallbackModelGemini,
|
||||
FallbackModelAntigravity: updatedSettings.FallbackModelAntigravity,
|
||||
EnableIdentityPatch: updatedSettings.EnableIdentityPatch,
|
||||
IdentityPatchPrompt: updatedSettings.IdentityPatchPrompt,
|
||||
RegistrationEnabled: updatedSettings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
|
||||
SMTPHost: updatedSettings.SMTPHost,
|
||||
SMTPPort: updatedSettings.SMTPPort,
|
||||
SMTPUsername: updatedSettings.SMTPUsername,
|
||||
SMTPPasswordConfigured: updatedSettings.SMTPPasswordConfigured,
|
||||
SMTPFrom: updatedSettings.SMTPFrom,
|
||||
SMTPFromName: updatedSettings.SMTPFromName,
|
||||
SMTPUseTLS: updatedSettings.SMTPUseTLS,
|
||||
TurnstileEnabled: updatedSettings.TurnstileEnabled,
|
||||
TurnstileSiteKey: updatedSettings.TurnstileSiteKey,
|
||||
TurnstileSecretKeyConfigured: updatedSettings.TurnstileSecretKeyConfigured,
|
||||
LinuxDoConnectEnabled: updatedSettings.LinuxDoConnectEnabled,
|
||||
LinuxDoConnectClientID: updatedSettings.LinuxDoConnectClientID,
|
||||
LinuxDoConnectClientSecretConfigured: updatedSettings.LinuxDoConnectClientSecretConfigured,
|
||||
LinuxDoConnectRedirectURL: updatedSettings.LinuxDoConnectRedirectURL,
|
||||
SiteName: updatedSettings.SiteName,
|
||||
SiteLogo: updatedSettings.SiteLogo,
|
||||
SiteSubtitle: updatedSettings.SiteSubtitle,
|
||||
APIBaseURL: updatedSettings.APIBaseURL,
|
||||
ContactInfo: updatedSettings.ContactInfo,
|
||||
DocURL: updatedSettings.DocURL,
|
||||
DefaultConcurrency: updatedSettings.DefaultConcurrency,
|
||||
DefaultBalance: updatedSettings.DefaultBalance,
|
||||
EnableModelFallback: updatedSettings.EnableModelFallback,
|
||||
FallbackModelAnthropic: updatedSettings.FallbackModelAnthropic,
|
||||
FallbackModelOpenAI: updatedSettings.FallbackModelOpenAI,
|
||||
FallbackModelGemini: updatedSettings.FallbackModelGemini,
|
||||
FallbackModelAntigravity: updatedSettings.FallbackModelAntigravity,
|
||||
EnableIdentityPatch: updatedSettings.EnableIdentityPatch,
|
||||
IdentityPatchPrompt: updatedSettings.IdentityPatchPrompt,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -298,6 +347,18 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
||||
if req.TurnstileSecretKey != "" {
|
||||
changed = append(changed, "turnstile_secret_key")
|
||||
}
|
||||
if before.LinuxDoConnectEnabled != after.LinuxDoConnectEnabled {
|
||||
changed = append(changed, "linuxdo_connect_enabled")
|
||||
}
|
||||
if before.LinuxDoConnectClientID != after.LinuxDoConnectClientID {
|
||||
changed = append(changed, "linuxdo_connect_client_id")
|
||||
}
|
||||
if req.LinuxDoConnectClientSecret != "" {
|
||||
changed = append(changed, "linuxdo_connect_client_secret")
|
||||
}
|
||||
if before.LinuxDoConnectRedirectURL != after.LinuxDoConnectRedirectURL {
|
||||
changed = append(changed, "linuxdo_connect_redirect_url")
|
||||
}
|
||||
if before.SiteName != after.SiteName {
|
||||
changed = append(changed, "site_name")
|
||||
}
|
||||
@@ -337,6 +398,12 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
||||
if before.FallbackModelAntigravity != after.FallbackModelAntigravity {
|
||||
changed = append(changed, "fallback_model_antigravity")
|
||||
}
|
||||
if before.EnableIdentityPatch != after.EnableIdentityPatch {
|
||||
changed = append(changed, "enable_identity_patch")
|
||||
}
|
||||
if before.IdentityPatchPrompt != after.IdentityPatchPrompt {
|
||||
changed = append(changed, "identity_patch_prompt")
|
||||
}
|
||||
return changed
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ package admin
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
@@ -63,10 +64,17 @@ type UpdateBalanceRequest struct {
|
||||
func (h *UserHandler) List(c *gin.Context) {
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
|
||||
search := c.Query("search")
|
||||
// 标准化和验证 search 参数
|
||||
search = strings.TrimSpace(search)
|
||||
if len(search) > 100 {
|
||||
search = search[:100]
|
||||
}
|
||||
|
||||
filters := service.UserListFilters{
|
||||
Status: c.Query("status"),
|
||||
Role: c.Query("role"),
|
||||
Search: c.Query("search"),
|
||||
Search: search,
|
||||
Attributes: parseAttributeFilters(c),
|
||||
}
|
||||
|
||||
|
||||
@@ -27,16 +27,20 @@ func NewAPIKeyHandler(apiKeyService *service.APIKeyService) *APIKeyHandler {
|
||||
|
||||
// CreateAPIKeyRequest represents the create API key request payload
|
||||
type CreateAPIKeyRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
GroupID *int64 `json:"group_id"` // nullable
|
||||
CustomKey *string `json:"custom_key"` // 可选的自定义key
|
||||
Name string `json:"name" binding:"required"`
|
||||
GroupID *int64 `json:"group_id"` // nullable
|
||||
CustomKey *string `json:"custom_key"` // 可选的自定义key
|
||||
IPWhitelist []string `json:"ip_whitelist"` // IP 白名单
|
||||
IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单
|
||||
}
|
||||
|
||||
// UpdateAPIKeyRequest represents the update API key request payload
|
||||
type UpdateAPIKeyRequest struct {
|
||||
Name string `json:"name"`
|
||||
GroupID *int64 `json:"group_id"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
|
||||
Name string `json:"name"`
|
||||
GroupID *int64 `json:"group_id"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
|
||||
IPWhitelist []string `json:"ip_whitelist"` // IP 白名单
|
||||
IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单
|
||||
}
|
||||
|
||||
// List handles listing user's API keys with pagination
|
||||
@@ -110,9 +114,11 @@ func (h *APIKeyHandler) Create(c *gin.Context) {
|
||||
}
|
||||
|
||||
svcReq := service.CreateAPIKeyRequest{
|
||||
Name: req.Name,
|
||||
GroupID: req.GroupID,
|
||||
CustomKey: req.CustomKey,
|
||||
Name: req.Name,
|
||||
GroupID: req.GroupID,
|
||||
CustomKey: req.CustomKey,
|
||||
IPWhitelist: req.IPWhitelist,
|
||||
IPBlacklist: req.IPBlacklist,
|
||||
}
|
||||
key, err := h.apiKeyService.Create(c.Request.Context(), subject.UserID, svcReq)
|
||||
if err != nil {
|
||||
@@ -144,7 +150,10 @@ func (h *APIKeyHandler) Update(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
svcReq := service.UpdateAPIKeyRequest{}
|
||||
svcReq := service.UpdateAPIKeyRequest{
|
||||
IPWhitelist: req.IPWhitelist,
|
||||
IPBlacklist: req.IPBlacklist,
|
||||
}
|
||||
if req.Name != "" {
|
||||
svcReq.Name = &req.Name
|
||||
}
|
||||
|
||||
@@ -12,17 +12,21 @@ import (
|
||||
|
||||
// AuthHandler handles authentication-related requests
|
||||
type AuthHandler struct {
|
||||
cfg *config.Config
|
||||
authService *service.AuthService
|
||||
userService *service.UserService
|
||||
cfg *config.Config
|
||||
authService *service.AuthService
|
||||
userService *service.UserService
|
||||
settingSvc *service.SettingService
|
||||
promoService *service.PromoService
|
||||
}
|
||||
|
||||
// NewAuthHandler creates a new AuthHandler
|
||||
func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService) *AuthHandler {
|
||||
func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService, settingService *service.SettingService, promoService *service.PromoService) *AuthHandler {
|
||||
return &AuthHandler{
|
||||
cfg: cfg,
|
||||
authService: authService,
|
||||
userService: userService,
|
||||
cfg: cfg,
|
||||
authService: authService,
|
||||
userService: userService,
|
||||
settingSvc: settingService,
|
||||
promoService: promoService,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -32,6 +36,7 @@ type RegisterRequest struct {
|
||||
Password string `json:"password" binding:"required,min=6"`
|
||||
VerifyCode string `json:"verify_code"`
|
||||
TurnstileToken string `json:"turnstile_token"`
|
||||
PromoCode string `json:"promo_code"` // 注册优惠码
|
||||
}
|
||||
|
||||
// SendVerifyCodeRequest 发送验证码请求
|
||||
@@ -77,7 +82,7 @@ func (h *AuthHandler) Register(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
token, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode)
|
||||
token, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode, req.PromoCode)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
@@ -172,3 +177,63 @@ func (h *AuthHandler) GetCurrentUser(c *gin.Context) {
|
||||
|
||||
response.Success(c, UserResponse{User: dto.UserFromService(user), RunMode: runMode})
|
||||
}
|
||||
|
||||
// ValidatePromoCodeRequest 验证优惠码请求
|
||||
type ValidatePromoCodeRequest struct {
|
||||
Code string `json:"code" binding:"required"`
|
||||
}
|
||||
|
||||
// ValidatePromoCodeResponse 验证优惠码响应
|
||||
type ValidatePromoCodeResponse struct {
|
||||
Valid bool `json:"valid"`
|
||||
BonusAmount float64 `json:"bonus_amount,omitempty"`
|
||||
ErrorCode string `json:"error_code,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
}
|
||||
|
||||
// ValidatePromoCode 验证优惠码(公开接口,注册前调用)
|
||||
// POST /api/v1/auth/validate-promo-code
|
||||
func (h *AuthHandler) ValidatePromoCode(c *gin.Context) {
|
||||
var req ValidatePromoCodeRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
promoCode, err := h.promoService.ValidatePromoCode(c.Request.Context(), req.Code)
|
||||
if err != nil {
|
||||
// 根据错误类型返回对应的错误码
|
||||
errorCode := "PROMO_CODE_INVALID"
|
||||
switch err {
|
||||
case service.ErrPromoCodeNotFound:
|
||||
errorCode = "PROMO_CODE_NOT_FOUND"
|
||||
case service.ErrPromoCodeExpired:
|
||||
errorCode = "PROMO_CODE_EXPIRED"
|
||||
case service.ErrPromoCodeDisabled:
|
||||
errorCode = "PROMO_CODE_DISABLED"
|
||||
case service.ErrPromoCodeMaxUsed:
|
||||
errorCode = "PROMO_CODE_MAX_USED"
|
||||
case service.ErrPromoCodeAlreadyUsed:
|
||||
errorCode = "PROMO_CODE_ALREADY_USED"
|
||||
}
|
||||
|
||||
response.Success(c, ValidatePromoCodeResponse{
|
||||
Valid: false,
|
||||
ErrorCode: errorCode,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if promoCode == nil {
|
||||
response.Success(c, ValidatePromoCodeResponse{
|
||||
Valid: false,
|
||||
ErrorCode: "PROMO_CODE_INVALID",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, ValidatePromoCodeResponse{
|
||||
Valid: true,
|
||||
BonusAmount: promoCode.BonusAmount,
|
||||
})
|
||||
}
|
||||
|
||||
679
backend/internal/handler/auth_linuxdo_oauth.go
Normal file
679
backend/internal/handler/auth_linuxdo_oauth.go
Normal file
@@ -0,0 +1,679 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/imroc/req/v3"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
const (
|
||||
linuxDoOAuthCookiePath = "/api/v1/auth/oauth/linuxdo"
|
||||
linuxDoOAuthStateCookieName = "linuxdo_oauth_state"
|
||||
linuxDoOAuthVerifierCookie = "linuxdo_oauth_verifier"
|
||||
linuxDoOAuthRedirectCookie = "linuxdo_oauth_redirect"
|
||||
linuxDoOAuthCookieMaxAgeSec = 10 * 60 // 10 minutes
|
||||
linuxDoOAuthDefaultRedirectTo = "/dashboard"
|
||||
linuxDoOAuthDefaultFrontendCB = "/auth/linuxdo/callback"
|
||||
|
||||
linuxDoOAuthMaxRedirectLen = 2048
|
||||
linuxDoOAuthMaxFragmentValueLen = 512
|
||||
linuxDoOAuthMaxSubjectLen = 64 - len("linuxdo-")
|
||||
)
|
||||
|
||||
type linuxDoTokenResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
}
|
||||
|
||||
type linuxDoTokenExchangeError struct {
|
||||
StatusCode int
|
||||
ProviderError string
|
||||
ProviderDescription string
|
||||
Body string
|
||||
}
|
||||
|
||||
func (e *linuxDoTokenExchangeError) Error() string {
|
||||
if e == nil {
|
||||
return ""
|
||||
}
|
||||
parts := []string{fmt.Sprintf("token exchange status=%d", e.StatusCode)}
|
||||
if strings.TrimSpace(e.ProviderError) != "" {
|
||||
parts = append(parts, "error="+strings.TrimSpace(e.ProviderError))
|
||||
}
|
||||
if strings.TrimSpace(e.ProviderDescription) != "" {
|
||||
parts = append(parts, "error_description="+strings.TrimSpace(e.ProviderDescription))
|
||||
}
|
||||
return strings.Join(parts, " ")
|
||||
}
|
||||
|
||||
// LinuxDoOAuthStart 启动 LinuxDo Connect OAuth 登录流程。
|
||||
// GET /api/v1/auth/oauth/linuxdo/start?redirect=/dashboard
|
||||
func (h *AuthHandler) LinuxDoOAuthStart(c *gin.Context) {
|
||||
cfg, err := h.getLinuxDoOAuthConfig(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
state, err := oauth.GenerateState()
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_STATE_GEN_FAILED", "failed to generate oauth state").WithCause(err))
|
||||
return
|
||||
}
|
||||
|
||||
redirectTo := sanitizeFrontendRedirectPath(c.Query("redirect"))
|
||||
if redirectTo == "" {
|
||||
redirectTo = linuxDoOAuthDefaultRedirectTo
|
||||
}
|
||||
|
||||
secureCookie := isRequestHTTPS(c)
|
||||
setCookie(c, linuxDoOAuthStateCookieName, encodeCookieValue(state), linuxDoOAuthCookieMaxAgeSec, secureCookie)
|
||||
setCookie(c, linuxDoOAuthRedirectCookie, encodeCookieValue(redirectTo), linuxDoOAuthCookieMaxAgeSec, secureCookie)
|
||||
|
||||
codeChallenge := ""
|
||||
if cfg.UsePKCE {
|
||||
verifier, err := oauth.GenerateCodeVerifier()
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_PKCE_GEN_FAILED", "failed to generate pkce verifier").WithCause(err))
|
||||
return
|
||||
}
|
||||
codeChallenge = oauth.GenerateCodeChallenge(verifier)
|
||||
setCookie(c, linuxDoOAuthVerifierCookie, encodeCookieValue(verifier), linuxDoOAuthCookieMaxAgeSec, secureCookie)
|
||||
}
|
||||
|
||||
redirectURI := strings.TrimSpace(cfg.RedirectURL)
|
||||
if redirectURI == "" {
|
||||
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth redirect url not configured"))
|
||||
return
|
||||
}
|
||||
|
||||
authURL, err := buildLinuxDoAuthorizeURL(cfg, state, codeChallenge, redirectURI)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BUILD_URL_FAILED", "failed to build oauth authorization url").WithCause(err))
|
||||
return
|
||||
}
|
||||
|
||||
c.Redirect(http.StatusFound, authURL)
|
||||
}
|
||||
|
||||
// LinuxDoOAuthCallback 处理 OAuth 回调:创建/登录用户,然后重定向到前端。
|
||||
// GET /api/v1/auth/oauth/linuxdo/callback?code=...&state=...
|
||||
func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
|
||||
cfg, cfgErr := h.getLinuxDoOAuthConfig(c.Request.Context())
|
||||
if cfgErr != nil {
|
||||
response.ErrorFrom(c, cfgErr)
|
||||
return
|
||||
}
|
||||
|
||||
frontendCallback := strings.TrimSpace(cfg.FrontendRedirectURL)
|
||||
if frontendCallback == "" {
|
||||
frontendCallback = linuxDoOAuthDefaultFrontendCB
|
||||
}
|
||||
|
||||
if providerErr := strings.TrimSpace(c.Query("error")); providerErr != "" {
|
||||
redirectOAuthError(c, frontendCallback, "provider_error", providerErr, c.Query("error_description"))
|
||||
return
|
||||
}
|
||||
|
||||
code := strings.TrimSpace(c.Query("code"))
|
||||
state := strings.TrimSpace(c.Query("state"))
|
||||
if code == "" || state == "" {
|
||||
redirectOAuthError(c, frontendCallback, "missing_params", "missing code/state", "")
|
||||
return
|
||||
}
|
||||
|
||||
secureCookie := isRequestHTTPS(c)
|
||||
defer func() {
|
||||
clearCookie(c, linuxDoOAuthStateCookieName, secureCookie)
|
||||
clearCookie(c, linuxDoOAuthVerifierCookie, secureCookie)
|
||||
clearCookie(c, linuxDoOAuthRedirectCookie, secureCookie)
|
||||
}()
|
||||
|
||||
expectedState, err := readCookieDecoded(c, linuxDoOAuthStateCookieName)
|
||||
if err != nil || expectedState == "" || state != expectedState {
|
||||
redirectOAuthError(c, frontendCallback, "invalid_state", "invalid oauth state", "")
|
||||
return
|
||||
}
|
||||
|
||||
redirectTo, _ := readCookieDecoded(c, linuxDoOAuthRedirectCookie)
|
||||
redirectTo = sanitizeFrontendRedirectPath(redirectTo)
|
||||
if redirectTo == "" {
|
||||
redirectTo = linuxDoOAuthDefaultRedirectTo
|
||||
}
|
||||
|
||||
codeVerifier := ""
|
||||
if cfg.UsePKCE {
|
||||
codeVerifier, _ = readCookieDecoded(c, linuxDoOAuthVerifierCookie)
|
||||
if codeVerifier == "" {
|
||||
redirectOAuthError(c, frontendCallback, "missing_verifier", "missing pkce verifier", "")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
redirectURI := strings.TrimSpace(cfg.RedirectURL)
|
||||
if redirectURI == "" {
|
||||
redirectOAuthError(c, frontendCallback, "config_error", "oauth redirect url not configured", "")
|
||||
return
|
||||
}
|
||||
|
||||
tokenResp, err := linuxDoExchangeCode(c.Request.Context(), cfg, code, redirectURI, codeVerifier)
|
||||
if err != nil {
|
||||
description := ""
|
||||
var exchangeErr *linuxDoTokenExchangeError
|
||||
if errors.As(err, &exchangeErr) && exchangeErr != nil {
|
||||
log.Printf(
|
||||
"[LinuxDo OAuth] token exchange failed: status=%d provider_error=%q provider_description=%q body=%s",
|
||||
exchangeErr.StatusCode,
|
||||
exchangeErr.ProviderError,
|
||||
exchangeErr.ProviderDescription,
|
||||
truncateLogValue(exchangeErr.Body, 2048),
|
||||
)
|
||||
description = exchangeErr.Error()
|
||||
} else {
|
||||
log.Printf("[LinuxDo OAuth] token exchange failed: %v", err)
|
||||
description = err.Error()
|
||||
}
|
||||
redirectOAuthError(c, frontendCallback, "token_exchange_failed", "failed to exchange oauth code", singleLine(description))
|
||||
return
|
||||
}
|
||||
|
||||
email, username, subject, err := linuxDoFetchUserInfo(c.Request.Context(), cfg, tokenResp)
|
||||
if err != nil {
|
||||
log.Printf("[LinuxDo OAuth] userinfo fetch failed: %v", err)
|
||||
redirectOAuthError(c, frontendCallback, "userinfo_failed", "failed to fetch user info", "")
|
||||
return
|
||||
}
|
||||
|
||||
// 安全考虑:不要把第三方返回的 email 直接映射到本地账号(可能与本地邮箱用户冲突导致账号被接管)。
|
||||
// 统一使用基于 subject 的稳定合成邮箱来做账号绑定。
|
||||
if subject != "" {
|
||||
email = linuxDoSyntheticEmail(subject)
|
||||
}
|
||||
|
||||
jwtToken, _, err := h.authService.LoginOrRegisterOAuth(c.Request.Context(), email, username)
|
||||
if err != nil {
|
||||
// 避免把内部细节泄露给客户端;给前端保留结构化原因与提示信息即可。
|
||||
redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err))
|
||||
return
|
||||
}
|
||||
|
||||
fragment := url.Values{}
|
||||
fragment.Set("access_token", jwtToken)
|
||||
fragment.Set("token_type", "Bearer")
|
||||
fragment.Set("redirect", redirectTo)
|
||||
redirectWithFragment(c, frontendCallback, fragment)
|
||||
}
|
||||
|
||||
func (h *AuthHandler) getLinuxDoOAuthConfig(ctx context.Context) (config.LinuxDoConnectConfig, error) {
|
||||
if h != nil && h.settingSvc != nil {
|
||||
return h.settingSvc.GetLinuxDoConnectOAuthConfig(ctx)
|
||||
}
|
||||
if h == nil || h.cfg == nil {
|
||||
return config.LinuxDoConnectConfig{}, infraerrors.ServiceUnavailable("CONFIG_NOT_READY", "config not loaded")
|
||||
}
|
||||
if !h.cfg.LinuxDo.Enabled {
|
||||
return config.LinuxDoConnectConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "oauth login is disabled")
|
||||
}
|
||||
return h.cfg.LinuxDo, nil
|
||||
}
|
||||
|
||||
func linuxDoExchangeCode(
|
||||
ctx context.Context,
|
||||
cfg config.LinuxDoConnectConfig,
|
||||
code string,
|
||||
redirectURI string,
|
||||
codeVerifier string,
|
||||
) (*linuxDoTokenResponse, error) {
|
||||
client := req.C().SetTimeout(30 * time.Second)
|
||||
|
||||
form := url.Values{}
|
||||
form.Set("grant_type", "authorization_code")
|
||||
form.Set("client_id", cfg.ClientID)
|
||||
form.Set("code", code)
|
||||
form.Set("redirect_uri", redirectURI)
|
||||
if cfg.UsePKCE {
|
||||
form.Set("code_verifier", codeVerifier)
|
||||
}
|
||||
|
||||
r := client.R().
|
||||
SetContext(ctx).
|
||||
SetHeader("Accept", "application/json")
|
||||
|
||||
switch strings.ToLower(strings.TrimSpace(cfg.TokenAuthMethod)) {
|
||||
case "", "client_secret_post":
|
||||
form.Set("client_secret", cfg.ClientSecret)
|
||||
case "client_secret_basic":
|
||||
r.SetBasicAuth(cfg.ClientID, cfg.ClientSecret)
|
||||
case "none":
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported token_auth_method: %s", cfg.TokenAuthMethod)
|
||||
}
|
||||
|
||||
resp, err := r.SetFormDataFromValues(form).Post(cfg.TokenURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request token: %w", err)
|
||||
}
|
||||
body := strings.TrimSpace(resp.String())
|
||||
if !resp.IsSuccessState() {
|
||||
providerErr, providerDesc := parseOAuthProviderError(body)
|
||||
return nil, &linuxDoTokenExchangeError{
|
||||
StatusCode: resp.StatusCode,
|
||||
ProviderError: providerErr,
|
||||
ProviderDescription: providerDesc,
|
||||
Body: body,
|
||||
}
|
||||
}
|
||||
|
||||
tokenResp, ok := parseLinuxDoTokenResponse(body)
|
||||
if !ok || strings.TrimSpace(tokenResp.AccessToken) == "" {
|
||||
return nil, &linuxDoTokenExchangeError{
|
||||
StatusCode: resp.StatusCode,
|
||||
Body: body,
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(tokenResp.TokenType) == "" {
|
||||
tokenResp.TokenType = "Bearer"
|
||||
}
|
||||
return tokenResp, nil
|
||||
}
|
||||
|
||||
func linuxDoFetchUserInfo(
|
||||
ctx context.Context,
|
||||
cfg config.LinuxDoConnectConfig,
|
||||
token *linuxDoTokenResponse,
|
||||
) (email string, username string, subject string, err error) {
|
||||
client := req.C().SetTimeout(30 * time.Second)
|
||||
authorization, err := buildBearerAuthorization(token.TokenType, token.AccessToken)
|
||||
if err != nil {
|
||||
return "", "", "", fmt.Errorf("invalid token for userinfo request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := client.R().
|
||||
SetContext(ctx).
|
||||
SetHeader("Accept", "application/json").
|
||||
SetHeader("Authorization", authorization).
|
||||
Get(cfg.UserInfoURL)
|
||||
if err != nil {
|
||||
return "", "", "", fmt.Errorf("request userinfo: %w", err)
|
||||
}
|
||||
if !resp.IsSuccessState() {
|
||||
return "", "", "", fmt.Errorf("userinfo status=%d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return linuxDoParseUserInfo(resp.String(), cfg)
|
||||
}
|
||||
|
||||
func linuxDoParseUserInfo(body string, cfg config.LinuxDoConnectConfig) (email string, username string, subject string, err error) {
|
||||
email = firstNonEmpty(
|
||||
getGJSON(body, cfg.UserInfoEmailPath),
|
||||
getGJSON(body, "email"),
|
||||
getGJSON(body, "user.email"),
|
||||
getGJSON(body, "data.email"),
|
||||
getGJSON(body, "attributes.email"),
|
||||
)
|
||||
username = firstNonEmpty(
|
||||
getGJSON(body, cfg.UserInfoUsernamePath),
|
||||
getGJSON(body, "username"),
|
||||
getGJSON(body, "preferred_username"),
|
||||
getGJSON(body, "name"),
|
||||
getGJSON(body, "user.username"),
|
||||
getGJSON(body, "user.name"),
|
||||
)
|
||||
subject = firstNonEmpty(
|
||||
getGJSON(body, cfg.UserInfoIDPath),
|
||||
getGJSON(body, "sub"),
|
||||
getGJSON(body, "id"),
|
||||
getGJSON(body, "user_id"),
|
||||
getGJSON(body, "uid"),
|
||||
getGJSON(body, "user.id"),
|
||||
)
|
||||
|
||||
subject = strings.TrimSpace(subject)
|
||||
if subject == "" {
|
||||
return "", "", "", errors.New("userinfo missing id field")
|
||||
}
|
||||
if !isSafeLinuxDoSubject(subject) {
|
||||
return "", "", "", errors.New("userinfo returned invalid id field")
|
||||
}
|
||||
|
||||
email = strings.TrimSpace(email)
|
||||
if email == "" {
|
||||
// LinuxDo Connect 的 userinfo 可能不提供 email。为兼容现有用户模型(email 必填且唯一),使用稳定的合成邮箱。
|
||||
email = linuxDoSyntheticEmail(subject)
|
||||
}
|
||||
|
||||
username = strings.TrimSpace(username)
|
||||
if username == "" {
|
||||
username = "linuxdo_" + subject
|
||||
}
|
||||
|
||||
return email, username, subject, nil
|
||||
}
|
||||
|
||||
func buildLinuxDoAuthorizeURL(cfg config.LinuxDoConnectConfig, state string, codeChallenge string, redirectURI string) (string, error) {
|
||||
u, err := url.Parse(cfg.AuthorizeURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("parse authorize_url: %w", err)
|
||||
}
|
||||
|
||||
q := u.Query()
|
||||
q.Set("response_type", "code")
|
||||
q.Set("client_id", cfg.ClientID)
|
||||
q.Set("redirect_uri", redirectURI)
|
||||
if strings.TrimSpace(cfg.Scopes) != "" {
|
||||
q.Set("scope", cfg.Scopes)
|
||||
}
|
||||
q.Set("state", state)
|
||||
if cfg.UsePKCE {
|
||||
q.Set("code_challenge", codeChallenge)
|
||||
q.Set("code_challenge_method", "S256")
|
||||
}
|
||||
|
||||
u.RawQuery = q.Encode()
|
||||
return u.String(), nil
|
||||
}
|
||||
|
||||
func redirectOAuthError(c *gin.Context, frontendCallback string, code string, message string, description string) {
|
||||
fragment := url.Values{}
|
||||
fragment.Set("error", truncateFragmentValue(code))
|
||||
if strings.TrimSpace(message) != "" {
|
||||
fragment.Set("error_message", truncateFragmentValue(message))
|
||||
}
|
||||
if strings.TrimSpace(description) != "" {
|
||||
fragment.Set("error_description", truncateFragmentValue(description))
|
||||
}
|
||||
redirectWithFragment(c, frontendCallback, fragment)
|
||||
}
|
||||
|
||||
func redirectWithFragment(c *gin.Context, frontendCallback string, fragment url.Values) {
|
||||
u, err := url.Parse(frontendCallback)
|
||||
if err != nil {
|
||||
// 兜底:尽力跳转到默认页面,避免卡死在回调页。
|
||||
c.Redirect(http.StatusFound, linuxDoOAuthDefaultRedirectTo)
|
||||
return
|
||||
}
|
||||
if u.Scheme != "" && !strings.EqualFold(u.Scheme, "http") && !strings.EqualFold(u.Scheme, "https") {
|
||||
c.Redirect(http.StatusFound, linuxDoOAuthDefaultRedirectTo)
|
||||
return
|
||||
}
|
||||
u.Fragment = fragment.Encode()
|
||||
c.Header("Cache-Control", "no-store")
|
||||
c.Header("Pragma", "no-cache")
|
||||
c.Redirect(http.StatusFound, u.String())
|
||||
}
|
||||
|
||||
func firstNonEmpty(values ...string) string {
|
||||
for _, v := range values {
|
||||
v = strings.TrimSpace(v)
|
||||
if v != "" {
|
||||
return v
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func parseOAuthProviderError(body string) (providerErr string, providerDesc string) {
|
||||
body = strings.TrimSpace(body)
|
||||
if body == "" {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
providerErr = firstNonEmpty(
|
||||
getGJSON(body, "error"),
|
||||
getGJSON(body, "code"),
|
||||
getGJSON(body, "error.code"),
|
||||
)
|
||||
providerDesc = firstNonEmpty(
|
||||
getGJSON(body, "error_description"),
|
||||
getGJSON(body, "error.message"),
|
||||
getGJSON(body, "message"),
|
||||
getGJSON(body, "detail"),
|
||||
)
|
||||
|
||||
if providerErr != "" || providerDesc != "" {
|
||||
return providerErr, providerDesc
|
||||
}
|
||||
|
||||
values, err := url.ParseQuery(body)
|
||||
if err != nil {
|
||||
return "", ""
|
||||
}
|
||||
providerErr = firstNonEmpty(values.Get("error"), values.Get("code"))
|
||||
providerDesc = firstNonEmpty(values.Get("error_description"), values.Get("error_message"), values.Get("message"))
|
||||
return providerErr, providerDesc
|
||||
}
|
||||
|
||||
func parseLinuxDoTokenResponse(body string) (*linuxDoTokenResponse, bool) {
|
||||
body = strings.TrimSpace(body)
|
||||
if body == "" {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
accessToken := strings.TrimSpace(getGJSON(body, "access_token"))
|
||||
if accessToken != "" {
|
||||
tokenType := strings.TrimSpace(getGJSON(body, "token_type"))
|
||||
refreshToken := strings.TrimSpace(getGJSON(body, "refresh_token"))
|
||||
scope := strings.TrimSpace(getGJSON(body, "scope"))
|
||||
expiresIn := gjson.Get(body, "expires_in").Int()
|
||||
return &linuxDoTokenResponse{
|
||||
AccessToken: accessToken,
|
||||
TokenType: tokenType,
|
||||
ExpiresIn: expiresIn,
|
||||
RefreshToken: refreshToken,
|
||||
Scope: scope,
|
||||
}, true
|
||||
}
|
||||
|
||||
values, err := url.ParseQuery(body)
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
accessToken = strings.TrimSpace(values.Get("access_token"))
|
||||
if accessToken == "" {
|
||||
return nil, false
|
||||
}
|
||||
expiresIn := int64(0)
|
||||
if raw := strings.TrimSpace(values.Get("expires_in")); raw != "" {
|
||||
if v, err := strconv.ParseInt(raw, 10, 64); err == nil {
|
||||
expiresIn = v
|
||||
}
|
||||
}
|
||||
return &linuxDoTokenResponse{
|
||||
AccessToken: accessToken,
|
||||
TokenType: strings.TrimSpace(values.Get("token_type")),
|
||||
ExpiresIn: expiresIn,
|
||||
RefreshToken: strings.TrimSpace(values.Get("refresh_token")),
|
||||
Scope: strings.TrimSpace(values.Get("scope")),
|
||||
}, true
|
||||
}
|
||||
|
||||
func getGJSON(body string, path string) string {
|
||||
path = strings.TrimSpace(path)
|
||||
if path == "" {
|
||||
return ""
|
||||
}
|
||||
res := gjson.Get(body, path)
|
||||
if !res.Exists() {
|
||||
return ""
|
||||
}
|
||||
return res.String()
|
||||
}
|
||||
|
||||
func truncateLogValue(value string, maxLen int) string {
|
||||
value = strings.TrimSpace(value)
|
||||
if value == "" || maxLen <= 0 {
|
||||
return ""
|
||||
}
|
||||
if len(value) <= maxLen {
|
||||
return value
|
||||
}
|
||||
value = value[:maxLen]
|
||||
for !utf8.ValidString(value) {
|
||||
value = value[:len(value)-1]
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func singleLine(value string) string {
|
||||
value = strings.TrimSpace(value)
|
||||
if value == "" {
|
||||
return ""
|
||||
}
|
||||
return strings.Join(strings.Fields(value), " ")
|
||||
}
|
||||
|
||||
func sanitizeFrontendRedirectPath(path string) string {
|
||||
path = strings.TrimSpace(path)
|
||||
if path == "" {
|
||||
return ""
|
||||
}
|
||||
if len(path) > linuxDoOAuthMaxRedirectLen {
|
||||
return ""
|
||||
}
|
||||
// 只允许同源相对路径(避免开放重定向)。
|
||||
if !strings.HasPrefix(path, "/") {
|
||||
return ""
|
||||
}
|
||||
if strings.HasPrefix(path, "//") {
|
||||
return ""
|
||||
}
|
||||
if strings.Contains(path, "://") {
|
||||
return ""
|
||||
}
|
||||
if strings.ContainsAny(path, "\r\n") {
|
||||
return ""
|
||||
}
|
||||
return path
|
||||
}
|
||||
|
||||
func isRequestHTTPS(c *gin.Context) bool {
|
||||
if c.Request.TLS != nil {
|
||||
return true
|
||||
}
|
||||
proto := strings.ToLower(strings.TrimSpace(c.GetHeader("X-Forwarded-Proto")))
|
||||
return proto == "https"
|
||||
}
|
||||
|
||||
func encodeCookieValue(value string) string {
|
||||
return base64.RawURLEncoding.EncodeToString([]byte(value))
|
||||
}
|
||||
|
||||
func decodeCookieValue(value string) (string, error) {
|
||||
raw, err := base64.RawURLEncoding.DecodeString(value)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(raw), nil
|
||||
}
|
||||
|
||||
func readCookieDecoded(c *gin.Context, name string) (string, error) {
|
||||
ck, err := c.Request.Cookie(name)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return decodeCookieValue(ck.Value)
|
||||
}
|
||||
|
||||
func setCookie(c *gin.Context, name string, value string, maxAgeSec int, secure bool) {
|
||||
http.SetCookie(c.Writer, &http.Cookie{
|
||||
Name: name,
|
||||
Value: value,
|
||||
Path: linuxDoOAuthCookiePath,
|
||||
MaxAge: maxAgeSec,
|
||||
HttpOnly: true,
|
||||
Secure: secure,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
})
|
||||
}
|
||||
|
||||
func clearCookie(c *gin.Context, name string, secure bool) {
|
||||
http.SetCookie(c.Writer, &http.Cookie{
|
||||
Name: name,
|
||||
Value: "",
|
||||
Path: linuxDoOAuthCookiePath,
|
||||
MaxAge: -1,
|
||||
HttpOnly: true,
|
||||
Secure: secure,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
})
|
||||
}
|
||||
|
||||
func truncateFragmentValue(value string) string {
|
||||
value = strings.TrimSpace(value)
|
||||
if value == "" {
|
||||
return ""
|
||||
}
|
||||
if len(value) > linuxDoOAuthMaxFragmentValueLen {
|
||||
value = value[:linuxDoOAuthMaxFragmentValueLen]
|
||||
for !utf8.ValidString(value) {
|
||||
value = value[:len(value)-1]
|
||||
}
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func buildBearerAuthorization(tokenType, accessToken string) (string, error) {
|
||||
tokenType = strings.TrimSpace(tokenType)
|
||||
if tokenType == "" {
|
||||
tokenType = "Bearer"
|
||||
}
|
||||
if !strings.EqualFold(tokenType, "Bearer") {
|
||||
return "", fmt.Errorf("unsupported token_type: %s", tokenType)
|
||||
}
|
||||
|
||||
accessToken = strings.TrimSpace(accessToken)
|
||||
if accessToken == "" {
|
||||
return "", errors.New("missing access_token")
|
||||
}
|
||||
if strings.ContainsAny(accessToken, " \t\r\n") {
|
||||
return "", errors.New("access_token contains whitespace")
|
||||
}
|
||||
return "Bearer " + accessToken, nil
|
||||
}
|
||||
|
||||
func isSafeLinuxDoSubject(subject string) bool {
|
||||
subject = strings.TrimSpace(subject)
|
||||
if subject == "" || len(subject) > linuxDoOAuthMaxSubjectLen {
|
||||
return false
|
||||
}
|
||||
for _, r := range subject {
|
||||
switch {
|
||||
case r >= '0' && r <= '9':
|
||||
case r >= 'a' && r <= 'z':
|
||||
case r >= 'A' && r <= 'Z':
|
||||
case r == '_' || r == '-':
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func linuxDoSyntheticEmail(subject string) string {
|
||||
subject = strings.TrimSpace(subject)
|
||||
if subject == "" {
|
||||
return ""
|
||||
}
|
||||
return "linuxdo-" + subject + service.LinuxDoConnectSyntheticEmailDomain
|
||||
}
|
||||
108
backend/internal/handler/auth_linuxdo_oauth_test.go
Normal file
108
backend/internal/handler/auth_linuxdo_oauth_test.go
Normal file
@@ -0,0 +1,108 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSanitizeFrontendRedirectPath(t *testing.T) {
|
||||
require.Equal(t, "/dashboard", sanitizeFrontendRedirectPath("/dashboard"))
|
||||
require.Equal(t, "/dashboard", sanitizeFrontendRedirectPath(" /dashboard "))
|
||||
require.Equal(t, "", sanitizeFrontendRedirectPath("dashboard"))
|
||||
require.Equal(t, "", sanitizeFrontendRedirectPath("//evil.com"))
|
||||
require.Equal(t, "", sanitizeFrontendRedirectPath("https://evil.com"))
|
||||
require.Equal(t, "", sanitizeFrontendRedirectPath("/\nfoo"))
|
||||
|
||||
long := "/" + strings.Repeat("a", linuxDoOAuthMaxRedirectLen)
|
||||
require.Equal(t, "", sanitizeFrontendRedirectPath(long))
|
||||
}
|
||||
|
||||
func TestBuildBearerAuthorization(t *testing.T) {
|
||||
auth, err := buildBearerAuthorization("", "token123")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "Bearer token123", auth)
|
||||
|
||||
auth, err = buildBearerAuthorization("bearer", "token123")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "Bearer token123", auth)
|
||||
|
||||
_, err = buildBearerAuthorization("MAC", "token123")
|
||||
require.Error(t, err)
|
||||
|
||||
_, err = buildBearerAuthorization("Bearer", "token 123")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestLinuxDoParseUserInfoParsesIDAndUsername(t *testing.T) {
|
||||
cfg := config.LinuxDoConnectConfig{
|
||||
UserInfoURL: "https://connect.linux.do/api/user",
|
||||
}
|
||||
|
||||
email, username, subject, err := linuxDoParseUserInfo(`{"id":123,"username":"alice"}`, cfg)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "123", subject)
|
||||
require.Equal(t, "alice", username)
|
||||
require.Equal(t, "linuxdo-123@linuxdo-connect.invalid", email)
|
||||
}
|
||||
|
||||
func TestLinuxDoParseUserInfoDefaultsUsername(t *testing.T) {
|
||||
cfg := config.LinuxDoConnectConfig{
|
||||
UserInfoURL: "https://connect.linux.do/api/user",
|
||||
}
|
||||
|
||||
email, username, subject, err := linuxDoParseUserInfo(`{"id":"123"}`, cfg)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "123", subject)
|
||||
require.Equal(t, "linuxdo_123", username)
|
||||
require.Equal(t, "linuxdo-123@linuxdo-connect.invalid", email)
|
||||
}
|
||||
|
||||
func TestLinuxDoParseUserInfoRejectsUnsafeSubject(t *testing.T) {
|
||||
cfg := config.LinuxDoConnectConfig{
|
||||
UserInfoURL: "https://connect.linux.do/api/user",
|
||||
}
|
||||
|
||||
_, _, _, err := linuxDoParseUserInfo(`{"id":"123@456"}`, cfg)
|
||||
require.Error(t, err)
|
||||
|
||||
tooLong := strings.Repeat("a", linuxDoOAuthMaxSubjectLen+1)
|
||||
_, _, _, err = linuxDoParseUserInfo(`{"id":"`+tooLong+`"}`, cfg)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestParseOAuthProviderErrorJSON(t *testing.T) {
|
||||
code, desc := parseOAuthProviderError(`{"error":"invalid_client","error_description":"bad secret"}`)
|
||||
require.Equal(t, "invalid_client", code)
|
||||
require.Equal(t, "bad secret", desc)
|
||||
}
|
||||
|
||||
func TestParseOAuthProviderErrorForm(t *testing.T) {
|
||||
code, desc := parseOAuthProviderError("error=invalid_request&error_description=Missing+code_verifier")
|
||||
require.Equal(t, "invalid_request", code)
|
||||
require.Equal(t, "Missing code_verifier", desc)
|
||||
}
|
||||
|
||||
func TestParseLinuxDoTokenResponseJSON(t *testing.T) {
|
||||
token, ok := parseLinuxDoTokenResponse(`{"access_token":"t1","token_type":"Bearer","expires_in":3600,"scope":"user"}`)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "t1", token.AccessToken)
|
||||
require.Equal(t, "Bearer", token.TokenType)
|
||||
require.Equal(t, int64(3600), token.ExpiresIn)
|
||||
require.Equal(t, "user", token.Scope)
|
||||
}
|
||||
|
||||
func TestParseLinuxDoTokenResponseForm(t *testing.T) {
|
||||
token, ok := parseLinuxDoTokenResponse("access_token=t2&token_type=bearer&expires_in=60")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "t2", token.AccessToken)
|
||||
require.Equal(t, "bearer", token.TokenType)
|
||||
require.Equal(t, int64(60), token.ExpiresIn)
|
||||
}
|
||||
|
||||
func TestSingleLineStripsWhitespace(t *testing.T) {
|
||||
require.Equal(t, "hello world", singleLine("hello\r\nworld"))
|
||||
require.Equal(t, "", singleLine("\n\t\r"))
|
||||
}
|
||||
@@ -53,16 +53,18 @@ func APIKeyFromService(k *service.APIKey) *APIKey {
|
||||
return nil
|
||||
}
|
||||
return &APIKey{
|
||||
ID: k.ID,
|
||||
UserID: k.UserID,
|
||||
Key: k.Key,
|
||||
Name: k.Name,
|
||||
GroupID: k.GroupID,
|
||||
Status: k.Status,
|
||||
CreatedAt: k.CreatedAt,
|
||||
UpdatedAt: k.UpdatedAt,
|
||||
User: UserFromServiceShallow(k.User),
|
||||
Group: GroupFromServiceShallow(k.Group),
|
||||
ID: k.ID,
|
||||
UserID: k.UserID,
|
||||
Key: k.Key,
|
||||
Name: k.Name,
|
||||
GroupID: k.GroupID,
|
||||
Status: k.Status,
|
||||
IPWhitelist: k.IPWhitelist,
|
||||
IPBlacklist: k.IPBlacklist,
|
||||
CreatedAt: k.CreatedAt,
|
||||
UpdatedAt: k.UpdatedAt,
|
||||
User: UserFromServiceShallow(k.User),
|
||||
Group: GroupFromServiceShallow(k.Group),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -250,11 +252,12 @@ func AccountSummaryFromService(a *service.Account) *AccountSummary {
|
||||
|
||||
// usageLogFromServiceBase is a helper that converts service UsageLog to DTO.
|
||||
// The account parameter allows caller to control what Account info is included.
|
||||
func usageLogFromServiceBase(l *service.UsageLog, account *AccountSummary) *UsageLog {
|
||||
// The includeIPAddress parameter controls whether to include the IP address (admin-only).
|
||||
func usageLogFromServiceBase(l *service.UsageLog, account *AccountSummary, includeIPAddress bool) *UsageLog {
|
||||
if l == nil {
|
||||
return nil
|
||||
}
|
||||
return &UsageLog{
|
||||
result := &UsageLog{
|
||||
ID: l.ID,
|
||||
UserID: l.UserID,
|
||||
APIKeyID: l.APIKeyID,
|
||||
@@ -290,21 +293,26 @@ func usageLogFromServiceBase(l *service.UsageLog, account *AccountSummary) *Usag
|
||||
Group: GroupFromServiceShallow(l.Group),
|
||||
Subscription: UserSubscriptionFromService(l.Subscription),
|
||||
}
|
||||
// IP 地址仅对管理员可见
|
||||
if includeIPAddress {
|
||||
result.IPAddress = l.IPAddress
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// UsageLogFromService converts a service UsageLog to DTO for regular users.
|
||||
// It excludes Account details - users should not see account information.
|
||||
// It excludes Account details and IP address - users should not see these.
|
||||
func UsageLogFromService(l *service.UsageLog) *UsageLog {
|
||||
return usageLogFromServiceBase(l, nil)
|
||||
return usageLogFromServiceBase(l, nil, false)
|
||||
}
|
||||
|
||||
// UsageLogFromServiceAdmin converts a service UsageLog to DTO for admin users.
|
||||
// It includes minimal Account info (ID, Name only).
|
||||
// It includes minimal Account info (ID, Name only) and IP address.
|
||||
func UsageLogFromServiceAdmin(l *service.UsageLog) *UsageLog {
|
||||
if l == nil {
|
||||
return nil
|
||||
}
|
||||
return usageLogFromServiceBase(l, AccountSummaryFromService(l.Account))
|
||||
return usageLogFromServiceBase(l, AccountSummaryFromService(l.Account), true)
|
||||
}
|
||||
|
||||
func SettingFromService(s *service.Setting) *Setting {
|
||||
@@ -362,3 +370,35 @@ func BulkAssignResultFromService(r *service.BulkAssignResult) *BulkAssignResult
|
||||
Errors: r.Errors,
|
||||
}
|
||||
}
|
||||
|
||||
func PromoCodeFromService(pc *service.PromoCode) *PromoCode {
|
||||
if pc == nil {
|
||||
return nil
|
||||
}
|
||||
return &PromoCode{
|
||||
ID: pc.ID,
|
||||
Code: pc.Code,
|
||||
BonusAmount: pc.BonusAmount,
|
||||
MaxUses: pc.MaxUses,
|
||||
UsedCount: pc.UsedCount,
|
||||
Status: pc.Status,
|
||||
ExpiresAt: pc.ExpiresAt,
|
||||
Notes: pc.Notes,
|
||||
CreatedAt: pc.CreatedAt,
|
||||
UpdatedAt: pc.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func PromoCodeUsageFromService(u *service.PromoCodeUsage) *PromoCodeUsage {
|
||||
if u == nil {
|
||||
return nil
|
||||
}
|
||||
return &PromoCodeUsage{
|
||||
ID: u.ID,
|
||||
PromoCodeID: u.PromoCodeID,
|
||||
UserID: u.UserID,
|
||||
BonusAmount: u.BonusAmount,
|
||||
UsedAt: u.UsedAt,
|
||||
User: UserFromServiceShallow(u.User),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,6 +17,11 @@ type SystemSettings struct {
|
||||
TurnstileSiteKey string `json:"turnstile_site_key"`
|
||||
TurnstileSecretKeyConfigured bool `json:"turnstile_secret_key_configured"`
|
||||
|
||||
LinuxDoConnectEnabled bool `json:"linuxdo_connect_enabled"`
|
||||
LinuxDoConnectClientID string `json:"linuxdo_connect_client_id"`
|
||||
LinuxDoConnectClientSecretConfigured bool `json:"linuxdo_connect_client_secret_configured"`
|
||||
LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"`
|
||||
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
@@ -50,5 +55,6 @@ type PublicSettings struct {
|
||||
APIBaseURL string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocURL string `json:"doc_url"`
|
||||
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
@@ -20,14 +20,16 @@ type User struct {
|
||||
}
|
||||
|
||||
type APIKey struct {
|
||||
ID int64 `json:"id"`
|
||||
UserID int64 `json:"user_id"`
|
||||
Key string `json:"key"`
|
||||
Name string `json:"name"`
|
||||
GroupID *int64 `json:"group_id"`
|
||||
Status string `json:"status"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
ID int64 `json:"id"`
|
||||
UserID int64 `json:"user_id"`
|
||||
Key string `json:"key"`
|
||||
Name string `json:"name"`
|
||||
GroupID *int64 `json:"group_id"`
|
||||
Status string `json:"status"`
|
||||
IPWhitelist []string `json:"ip_whitelist"`
|
||||
IPBlacklist []string `json:"ip_blacklist"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
|
||||
User *User `json:"user,omitempty"`
|
||||
Group *Group `json:"group,omitempty"`
|
||||
@@ -187,6 +189,9 @@ type UsageLog struct {
|
||||
// User-Agent
|
||||
UserAgent *string `json:"user_agent"`
|
||||
|
||||
// IP 地址(仅管理员可见)
|
||||
IPAddress *string `json:"ip_address,omitempty"`
|
||||
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
|
||||
User *User `json:"user,omitempty"`
|
||||
@@ -245,3 +250,28 @@ type BulkAssignResult struct {
|
||||
Subscriptions []UserSubscription `json:"subscriptions"`
|
||||
Errors []string `json:"errors"`
|
||||
}
|
||||
|
||||
// PromoCode 注册优惠码
|
||||
type PromoCode struct {
|
||||
ID int64 `json:"id"`
|
||||
Code string `json:"code"`
|
||||
BonusAmount float64 `json:"bonus_amount"`
|
||||
MaxUses int `json:"max_uses"`
|
||||
UsedCount int `json:"used_count"`
|
||||
Status string `json:"status"`
|
||||
ExpiresAt *time.Time `json:"expires_at"`
|
||||
Notes string `json:"notes"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// PromoCodeUsage 优惠码使用记录
|
||||
type PromoCodeUsage struct {
|
||||
ID int64 `json:"id"`
|
||||
PromoCodeID int64 `json:"promo_code_id"`
|
||||
UserID int64 `json:"user_id"`
|
||||
BonusAmount float64 `json:"bonus_amount"`
|
||||
UsedAt time.Time `json:"used_at"`
|
||||
|
||||
User *User `json:"user,omitempty"`
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
pkgerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
@@ -114,6 +115,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
// 获取 User-Agent
|
||||
userAgent := c.Request.UserAgent()
|
||||
|
||||
// 获取客户端 IP
|
||||
clientIP := ip.GetClientIP(c)
|
||||
|
||||
// 0. 检查wait队列是否已满
|
||||
maxWait := service.CalculateMaxWait(subject.Concurrency)
|
||||
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
|
||||
@@ -273,7 +277,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 异步记录使用量(subscription已在函数开头获取)
|
||||
go func(result *service.ForwardResult, usedAccount *service.Account, ua string) {
|
||||
go func(result *service.ForwardResult, usedAccount *service.Account, ua string, cip string) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
@@ -283,10 +287,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
Account: usedAccount,
|
||||
Subscription: subscription,
|
||||
UserAgent: ua,
|
||||
IPAddress: cip,
|
||||
}); err != nil {
|
||||
log.Printf("Record usage failed: %v", err)
|
||||
}
|
||||
}(result, account, userAgent)
|
||||
}(result, account, userAgent, clientIP)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -401,7 +406,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 异步记录使用量(subscription已在函数开头获取)
|
||||
go func(result *service.ForwardResult, usedAccount *service.Account, ua string) {
|
||||
go func(result *service.ForwardResult, usedAccount *service.Account, ua string, cip string) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
@@ -411,10 +416,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
Account: usedAccount,
|
||||
Subscription: subscription,
|
||||
UserAgent: ua,
|
||||
IPAddress: cip,
|
||||
}); err != nil {
|
||||
log.Printf("Record usage failed: %v", err)
|
||||
}
|
||||
}(result, account, userAgent)
|
||||
}(result, account, userAgent, clientIP)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/gemini"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
@@ -167,6 +168,9 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
// 获取 User-Agent
|
||||
userAgent := c.Request.UserAgent()
|
||||
|
||||
// 获取客户端 IP
|
||||
clientIP := ip.GetClientIP(c)
|
||||
|
||||
// For Gemini native API, do not send Claude-style ping frames.
|
||||
geminiConcurrency := NewConcurrencyHelper(h.concurrencyHelper.concurrencyService, SSEPingFormatNone, 0)
|
||||
|
||||
@@ -307,7 +311,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 6) record usage async
|
||||
go func(result *service.ForwardResult, usedAccount *service.Account, ua string) {
|
||||
go func(result *service.ForwardResult, usedAccount *service.Account, ua string, cip string) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
@@ -317,10 +321,11 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
Account: usedAccount,
|
||||
Subscription: subscription,
|
||||
UserAgent: ua,
|
||||
IPAddress: cip,
|
||||
}); err != nil {
|
||||
log.Printf("Record usage failed: %v", err)
|
||||
}
|
||||
}(result, account, userAgent)
|
||||
}(result, account, userAgent, clientIP)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ type AdminHandlers struct {
|
||||
AntigravityOAuth *admin.AntigravityOAuthHandler
|
||||
Proxy *admin.ProxyHandler
|
||||
Redeem *admin.RedeemHandler
|
||||
Promo *admin.PromoHandler
|
||||
Setting *admin.SettingHandler
|
||||
System *admin.SystemHandler
|
||||
Subscription *admin.SubscriptionHandler
|
||||
|
||||
@@ -8,9 +8,12 @@ import (
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
@@ -93,6 +96,24 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
|
||||
// 获取客户端 IP
|
||||
clientIP := ip.GetClientIP(c)
|
||||
|
||||
if !openai.IsCodexCLIRequest(userAgent) {
|
||||
existingInstructions, _ := reqBody["instructions"].(string)
|
||||
if strings.TrimSpace(existingInstructions) == "" {
|
||||
if instructions := strings.TrimSpace(service.GetOpenCodeInstructions()); instructions != "" {
|
||||
reqBody["instructions"] = instructions
|
||||
// Re-serialize body
|
||||
body, err = json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to process request")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Track if we've started streaming (for error handling)
|
||||
streamStarted := false
|
||||
|
||||
@@ -231,7 +252,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Async record usage
|
||||
go func(result *service.OpenAIForwardResult, usedAccount *service.Account, ua string) {
|
||||
go func(result *service.OpenAIForwardResult, usedAccount *service.Account, ua string, cip string) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||
@@ -241,10 +262,11 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
Account: usedAccount,
|
||||
Subscription: subscription,
|
||||
UserAgent: ua,
|
||||
IPAddress: cip,
|
||||
}); err != nil {
|
||||
log.Printf("Record usage failed: %v", err)
|
||||
}
|
||||
}(result, account, userAgent)
|
||||
}(result, account, userAgent, clientIP)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -42,6 +42,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
|
||||
APIBaseURL: settings.APIBaseURL,
|
||||
ContactInfo: settings.ContactInfo,
|
||||
DocURL: settings.DocURL,
|
||||
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
|
||||
Version: h.version,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -19,6 +19,7 @@ func ProvideAdminHandlers(
|
||||
antigravityOAuthHandler *admin.AntigravityOAuthHandler,
|
||||
proxyHandler *admin.ProxyHandler,
|
||||
redeemHandler *admin.RedeemHandler,
|
||||
promoHandler *admin.PromoHandler,
|
||||
settingHandler *admin.SettingHandler,
|
||||
systemHandler *admin.SystemHandler,
|
||||
subscriptionHandler *admin.SubscriptionHandler,
|
||||
@@ -36,6 +37,7 @@ func ProvideAdminHandlers(
|
||||
AntigravityOAuth: antigravityOAuthHandler,
|
||||
Proxy: proxyHandler,
|
||||
Redeem: redeemHandler,
|
||||
Promo: promoHandler,
|
||||
Setting: settingHandler,
|
||||
System: systemHandler,
|
||||
Subscription: subscriptionHandler,
|
||||
@@ -105,6 +107,7 @@ var ProviderSet = wire.NewSet(
|
||||
admin.NewAntigravityOAuthHandler,
|
||||
admin.NewProxyHandler,
|
||||
admin.NewRedeemHandler,
|
||||
admin.NewPromoHandler,
|
||||
admin.NewSettingHandler,
|
||||
ProvideSystemHandler,
|
||||
admin.NewSubscriptionHandler,
|
||||
|
||||
60
backend/internal/middleware/rate_limiter.go
Normal file
60
backend/internal/middleware/rate_limiter.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// RateLimiter Redis 速率限制器
|
||||
type RateLimiter struct {
|
||||
redis *redis.Client
|
||||
prefix string
|
||||
}
|
||||
|
||||
// NewRateLimiter 创建速率限制器实例
|
||||
func NewRateLimiter(redisClient *redis.Client) *RateLimiter {
|
||||
return &RateLimiter{
|
||||
redis: redisClient,
|
||||
prefix: "rate_limit:",
|
||||
}
|
||||
}
|
||||
|
||||
// Limit 返回速率限制中间件
|
||||
// key: 限制类型标识
|
||||
// limit: 时间窗口内最大请求数
|
||||
// window: 时间窗口
|
||||
func (r *RateLimiter) Limit(key string, limit int, window time.Duration) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
ip := c.ClientIP()
|
||||
redisKey := r.prefix + key + ":" + ip
|
||||
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// 使用 INCR 原子操作增加计数
|
||||
count, err := r.redis.Incr(ctx, redisKey).Result()
|
||||
if err != nil {
|
||||
// Redis 错误时放行,避免影响正常服务
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// 首次访问时设置过期时间
|
||||
if count == 1 {
|
||||
r.redis.Expire(ctx, redisKey, window)
|
||||
}
|
||||
|
||||
// 超过限制
|
||||
if count > int64(limit) {
|
||||
c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
|
||||
"error": "rate limit exceeded",
|
||||
"message": "Too many requests, please try again later",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
@@ -9,4 +9,6 @@ const (
|
||||
ForcePlatform Key = "ctx_force_platform"
|
||||
// IsClaudeCodeClient 是否为 Claude Code 客户端,由中间件设置
|
||||
IsClaudeCodeClient Key = "ctx_is_claude_code_client"
|
||||
// Group 认证后的分组信息,由 API Key 认证中间件设置
|
||||
Group Key = "ctx_group"
|
||||
)
|
||||
|
||||
168
backend/internal/pkg/ip/ip.go
Normal file
168
backend/internal/pkg/ip/ip.go
Normal file
@@ -0,0 +1,168 @@
|
||||
// Package ip 提供客户端 IP 地址提取工具。
|
||||
package ip
|
||||
|
||||
import (
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// GetClientIP 从 Gin Context 中提取客户端真实 IP 地址。
|
||||
// 按以下优先级检查 Header:
|
||||
// 1. CF-Connecting-IP (Cloudflare)
|
||||
// 2. X-Real-IP (Nginx)
|
||||
// 3. X-Forwarded-For (取第一个非私有 IP)
|
||||
// 4. c.ClientIP() (Gin 内置方法)
|
||||
func GetClientIP(c *gin.Context) string {
|
||||
// 1. Cloudflare
|
||||
if ip := c.GetHeader("CF-Connecting-IP"); ip != "" {
|
||||
return normalizeIP(ip)
|
||||
}
|
||||
|
||||
// 2. Nginx X-Real-IP
|
||||
if ip := c.GetHeader("X-Real-IP"); ip != "" {
|
||||
return normalizeIP(ip)
|
||||
}
|
||||
|
||||
// 3. X-Forwarded-For (多个 IP 时取第一个公网 IP)
|
||||
if xff := c.GetHeader("X-Forwarded-For"); xff != "" {
|
||||
ips := strings.Split(xff, ",")
|
||||
for _, ip := range ips {
|
||||
ip = strings.TrimSpace(ip)
|
||||
if ip != "" && !isPrivateIP(ip) {
|
||||
return normalizeIP(ip)
|
||||
}
|
||||
}
|
||||
// 如果都是私有 IP,返回第一个
|
||||
if len(ips) > 0 {
|
||||
return normalizeIP(strings.TrimSpace(ips[0]))
|
||||
}
|
||||
}
|
||||
|
||||
// 4. Gin 内置方法
|
||||
return normalizeIP(c.ClientIP())
|
||||
}
|
||||
|
||||
// normalizeIP 规范化 IP 地址,去除端口号和空格。
|
||||
func normalizeIP(ip string) string {
|
||||
ip = strings.TrimSpace(ip)
|
||||
// 移除端口号(如 "192.168.1.1:8080" -> "192.168.1.1")
|
||||
if host, _, err := net.SplitHostPort(ip); err == nil {
|
||||
return host
|
||||
}
|
||||
return ip
|
||||
}
|
||||
|
||||
// isPrivateIP 检查 IP 是否为私有地址。
|
||||
func isPrivateIP(ipStr string) bool {
|
||||
ip := net.ParseIP(ipStr)
|
||||
if ip == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// 私有 IP 范围
|
||||
privateBlocks := []string{
|
||||
"10.0.0.0/8",
|
||||
"172.16.0.0/12",
|
||||
"192.168.0.0/16",
|
||||
"127.0.0.0/8",
|
||||
"::1/128",
|
||||
"fc00::/7",
|
||||
}
|
||||
|
||||
for _, block := range privateBlocks {
|
||||
_, cidr, err := net.ParseCIDR(block)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if cidr.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// MatchesPattern 检查 IP 是否匹配指定的模式(支持单个 IP 或 CIDR)。
|
||||
// pattern 可以是:
|
||||
// - 单个 IP: "192.168.1.100"
|
||||
// - CIDR 范围: "192.168.1.0/24"
|
||||
func MatchesPattern(clientIP, pattern string) bool {
|
||||
ip := net.ParseIP(clientIP)
|
||||
if ip == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// 尝试解析为 CIDR
|
||||
if strings.Contains(pattern, "/") {
|
||||
_, cidr, err := net.ParseCIDR(pattern)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return cidr.Contains(ip)
|
||||
}
|
||||
|
||||
// 作为单个 IP 处理
|
||||
patternIP := net.ParseIP(pattern)
|
||||
if patternIP == nil {
|
||||
return false
|
||||
}
|
||||
return ip.Equal(patternIP)
|
||||
}
|
||||
|
||||
// MatchesAnyPattern 检查 IP 是否匹配任意一个模式。
|
||||
func MatchesAnyPattern(clientIP string, patterns []string) bool {
|
||||
for _, pattern := range patterns {
|
||||
if MatchesPattern(clientIP, pattern) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// CheckIPRestriction 检查 IP 是否被 API Key 的 IP 限制允许。
|
||||
// 返回值:(是否允许, 拒绝原因)
|
||||
// 逻辑:
|
||||
// 1. 先检查黑名单,如果在黑名单中则直接拒绝
|
||||
// 2. 如果白名单不为空,IP 必须在白名单中
|
||||
// 3. 如果白名单为空,允许访问(除非被黑名单拒绝)
|
||||
func CheckIPRestriction(clientIP string, whitelist, blacklist []string) (bool, string) {
|
||||
// 规范化 IP
|
||||
clientIP = normalizeIP(clientIP)
|
||||
if clientIP == "" {
|
||||
return false, "access denied"
|
||||
}
|
||||
|
||||
// 1. 检查黑名单
|
||||
if len(blacklist) > 0 && MatchesAnyPattern(clientIP, blacklist) {
|
||||
return false, "access denied"
|
||||
}
|
||||
|
||||
// 2. 检查白名单(如果设置了白名单,IP 必须在其中)
|
||||
if len(whitelist) > 0 && !MatchesAnyPattern(clientIP, whitelist) {
|
||||
return false, "access denied"
|
||||
}
|
||||
|
||||
return true, ""
|
||||
}
|
||||
|
||||
// ValidateIPPattern 验证 IP 或 CIDR 格式是否有效。
|
||||
func ValidateIPPattern(pattern string) bool {
|
||||
if strings.Contains(pattern, "/") {
|
||||
_, _, err := net.ParseCIDR(pattern)
|
||||
return err == nil
|
||||
}
|
||||
return net.ParseIP(pattern) != nil
|
||||
}
|
||||
|
||||
// ValidateIPPatterns 验证多个 IP 或 CIDR 格式。
|
||||
// 返回无效的模式列表。
|
||||
func ValidateIPPatterns(patterns []string) []string {
|
||||
var invalid []string
|
||||
for _, p := range patterns {
|
||||
if !ValidateIPPattern(p) {
|
||||
invalid = append(invalid, p)
|
||||
}
|
||||
}
|
||||
return invalid
|
||||
}
|
||||
@@ -675,6 +675,40 @@ func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetA
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *accountRepository) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope service.AntigravityQuotaScope, resetAt time.Time) error {
|
||||
now := time.Now().UTC()
|
||||
payload := map[string]string{
|
||||
"rate_limited_at": now.Format(time.RFC3339),
|
||||
"rate_limit_reset_at": resetAt.UTC().Format(time.RFC3339),
|
||||
}
|
||||
raw, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
path := "{antigravity_quota_scopes," + string(scope) + "}"
|
||||
client := clientFromContext(ctx, r.client)
|
||||
result, err := client.ExecContext(
|
||||
ctx,
|
||||
"UPDATE accounts SET extra = jsonb_set(COALESCE(extra, '{}'::jsonb), $1::text[], $2::jsonb, true), updated_at = NOW() WHERE id = $3 AND deleted_at IS NULL",
|
||||
path,
|
||||
raw,
|
||||
id,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
affected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if affected == 0 {
|
||||
return service.ErrAccountNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *accountRepository) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
|
||||
_, err := r.client.Account.Update().
|
||||
Where(dbaccount.IDEQ(id)).
|
||||
@@ -718,6 +752,27 @@ func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *accountRepository) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
result, err := client.ExecContext(
|
||||
ctx,
|
||||
"UPDATE accounts SET extra = COALESCE(extra, '{}'::jsonb) - 'antigravity_quota_scopes', updated_at = NOW() WHERE id = $1 AND deleted_at IS NULL",
|
||||
id,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
affected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if affected == 0 {
|
||||
return service.ErrAccountNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *accountRepository) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
|
||||
builder := r.client.Account.Update().
|
||||
Where(dbaccount.IDEQ(id)).
|
||||
@@ -831,6 +886,11 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
|
||||
args = append(args, *updates.Status)
|
||||
idx++
|
||||
}
|
||||
if updates.Schedulable != nil {
|
||||
setClauses = append(setClauses, "schedulable = $"+itoa(idx))
|
||||
args = append(args, *updates.Schedulable)
|
||||
idx++
|
||||
}
|
||||
// JSONB 需要合并而非覆盖,使用 raw SQL 保持旧行为。
|
||||
if len(updates.Credentials) > 0 {
|
||||
payload, err := json.Marshal(updates.Credentials)
|
||||
|
||||
@@ -26,13 +26,21 @@ func (r *apiKeyRepository) activeQuery() *dbent.APIKeyQuery {
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) Create(ctx context.Context, key *service.APIKey) error {
|
||||
created, err := r.client.APIKey.Create().
|
||||
builder := r.client.APIKey.Create().
|
||||
SetUserID(key.UserID).
|
||||
SetKey(key.Key).
|
||||
SetName(key.Name).
|
||||
SetStatus(key.Status).
|
||||
SetNillableGroupID(key.GroupID).
|
||||
Save(ctx)
|
||||
SetNillableGroupID(key.GroupID)
|
||||
|
||||
if len(key.IPWhitelist) > 0 {
|
||||
builder.SetIPWhitelist(key.IPWhitelist)
|
||||
}
|
||||
if len(key.IPBlacklist) > 0 {
|
||||
builder.SetIPBlacklist(key.IPBlacklist)
|
||||
}
|
||||
|
||||
created, err := builder.Save(ctx)
|
||||
if err == nil {
|
||||
key.ID = created.ID
|
||||
key.CreatedAt = created.CreatedAt
|
||||
@@ -108,6 +116,18 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) erro
|
||||
builder.ClearGroupID()
|
||||
}
|
||||
|
||||
// IP 限制字段
|
||||
if len(key.IPWhitelist) > 0 {
|
||||
builder.SetIPWhitelist(key.IPWhitelist)
|
||||
} else {
|
||||
builder.ClearIPWhitelist()
|
||||
}
|
||||
if len(key.IPBlacklist) > 0 {
|
||||
builder.SetIPBlacklist(key.IPBlacklist)
|
||||
} else {
|
||||
builder.ClearIPBlacklist()
|
||||
}
|
||||
|
||||
affected, err := builder.Save(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -268,14 +288,16 @@ func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey {
|
||||
return nil
|
||||
}
|
||||
out := &service.APIKey{
|
||||
ID: m.ID,
|
||||
UserID: m.UserID,
|
||||
Key: m.Key,
|
||||
Name: m.Name,
|
||||
Status: m.Status,
|
||||
CreatedAt: m.CreatedAt,
|
||||
UpdatedAt: m.UpdatedAt,
|
||||
GroupID: m.GroupID,
|
||||
ID: m.ID,
|
||||
UserID: m.UserID,
|
||||
Key: m.Key,
|
||||
Name: m.Name,
|
||||
Status: m.Status,
|
||||
IPWhitelist: m.IPWhitelist,
|
||||
IPBlacklist: m.IPBlacklist,
|
||||
CreatedAt: m.CreatedAt,
|
||||
UpdatedAt: m.UpdatedAt,
|
||||
GroupID: m.GroupID,
|
||||
}
|
||||
if m.Edges.User != nil {
|
||||
out.User = userEntityToService(m.Edges.User)
|
||||
@@ -317,6 +339,7 @@ func groupEntityToService(g *dbent.Group) *service.Group {
|
||||
RateMultiplier: g.RateMultiplier,
|
||||
IsExclusive: g.IsExclusive,
|
||||
Status: g.Status,
|
||||
Hydrated: true,
|
||||
SubscriptionType: g.SubscriptionType,
|
||||
DailyLimitUSD: g.DailyLimitUsd,
|
||||
WeeklyLimitUSD: g.WeeklyLimitUsd,
|
||||
|
||||
@@ -60,6 +60,17 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
|
||||
}
|
||||
|
||||
func (r *groupRepository) GetByID(ctx context.Context, id int64) (*service.Group, error) {
|
||||
out, err := r.GetByIDLite(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
count, _ := r.GetAccountCount(ctx, out.ID)
|
||||
out.AccountCount = count
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *groupRepository) GetByIDLite(ctx context.Context, id int64) (*service.Group, error) {
|
||||
// AccountCount is intentionally not loaded here; use GetByID when needed.
|
||||
m, err := r.client.Group.Query().
|
||||
Where(group.IDEQ(id)).
|
||||
Only(ctx)
|
||||
@@ -67,10 +78,7 @@ func (r *groupRepository) GetByID(ctx context.Context, id int64) (*service.Group
|
||||
return nil, translatePersistenceError(err, service.ErrGroupNotFound, nil)
|
||||
}
|
||||
|
||||
out := groupEntityToService(m)
|
||||
count, _ := r.GetAccountCount(ctx, out.ID)
|
||||
out.AccountCount = count
|
||||
return out, nil
|
||||
return groupEntityToService(m), nil
|
||||
}
|
||||
|
||||
func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) error {
|
||||
@@ -112,10 +120,10 @@ func (r *groupRepository) Delete(ctx context.Context, id int64) error {
|
||||
}
|
||||
|
||||
func (r *groupRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Group, *pagination.PaginationResult, error) {
|
||||
return r.ListWithFilters(ctx, params, "", "", nil)
|
||||
return r.ListWithFilters(ctx, params, "", "", "", nil)
|
||||
}
|
||||
|
||||
func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]service.Group, *pagination.PaginationResult, error) {
|
||||
func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]service.Group, *pagination.PaginationResult, error) {
|
||||
q := r.client.Group.Query()
|
||||
|
||||
if platform != "" {
|
||||
@@ -124,6 +132,12 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination
|
||||
if status != "" {
|
||||
q = q.Where(group.StatusEQ(status))
|
||||
}
|
||||
if search != "" {
|
||||
q = q.Where(group.Or(
|
||||
group.NameContainsFold(search),
|
||||
group.DescriptionContainsFold(search),
|
||||
))
|
||||
}
|
||||
if isExclusive != nil {
|
||||
q = q.Where(group.IsExclusiveEQ(*isExclusive))
|
||||
}
|
||||
|
||||
@@ -4,6 +4,8 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
@@ -19,6 +21,20 @@ type GroupRepoSuite struct {
|
||||
repo *groupRepository
|
||||
}
|
||||
|
||||
type forbidSQLExecutor struct {
|
||||
called bool
|
||||
}
|
||||
|
||||
func (s *forbidSQLExecutor) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
|
||||
s.called = true
|
||||
return nil, errors.New("unexpected sql exec")
|
||||
}
|
||||
|
||||
func (s *forbidSQLExecutor) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
|
||||
s.called = true
|
||||
return nil, errors.New("unexpected sql query")
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
tx := testEntTx(s.T())
|
||||
@@ -57,6 +73,26 @@ func (s *GroupRepoSuite) TestGetByID_NotFound() {
|
||||
s.Require().ErrorIs(err, service.ErrGroupNotFound)
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestGetByIDLite_DoesNotUseAccountCount() {
|
||||
group := &service.Group{
|
||||
Name: "lite-group",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, group))
|
||||
|
||||
spy := &forbidSQLExecutor{}
|
||||
repo := newGroupRepositoryWithSQL(s.tx.Client(), spy)
|
||||
|
||||
got, err := repo.GetByIDLite(s.ctx, group.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(group.ID, got.ID)
|
||||
s.Require().False(spy.called, "expected no direct sql executor usage")
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestUpdate() {
|
||||
group := &service.Group{
|
||||
Name: "original",
|
||||
@@ -131,6 +167,7 @@ func (s *GroupRepoSuite) TestListWithFilters_Platform() {
|
||||
pagination.PaginationParams{Page: 1, PageSize: 10},
|
||||
service.PlatformOpenAI,
|
||||
"",
|
||||
"",
|
||||
nil,
|
||||
)
|
||||
s.Require().NoError(err, "ListWithFilters base")
|
||||
@@ -152,7 +189,7 @@ func (s *GroupRepoSuite) TestListWithFilters_Platform() {
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}))
|
||||
|
||||
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformOpenAI, "", nil)
|
||||
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformOpenAI, "", "", nil)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(groups, len(baseGroups)+1)
|
||||
// Verify all groups are OpenAI platform
|
||||
@@ -179,7 +216,7 @@ func (s *GroupRepoSuite) TestListWithFilters_Status() {
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}))
|
||||
|
||||
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.StatusDisabled, nil)
|
||||
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.StatusDisabled, "", nil)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(groups, 1)
|
||||
s.Require().Equal(service.StatusDisabled, groups[0].Status)
|
||||
@@ -204,12 +241,117 @@ func (s *GroupRepoSuite) TestListWithFilters_IsExclusive() {
|
||||
}))
|
||||
|
||||
isExclusive := true
|
||||
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", &isExclusive)
|
||||
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "", &isExclusive)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(groups, 1)
|
||||
s.Require().True(groups[0].IsExclusive)
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestListWithFilters_Search() {
|
||||
newRepo := func() (*groupRepository, context.Context) {
|
||||
tx := testEntTx(s.T())
|
||||
return newGroupRepositoryWithSQL(tx.Client(), tx), context.Background()
|
||||
}
|
||||
|
||||
containsID := func(groups []service.Group, id int64) bool {
|
||||
for i := range groups {
|
||||
if groups[i].ID == id {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
mustCreate := func(repo *groupRepository, ctx context.Context, g *service.Group) *service.Group {
|
||||
s.Require().NoError(repo.Create(ctx, g))
|
||||
s.Require().NotZero(g.ID)
|
||||
return g
|
||||
}
|
||||
|
||||
newGroup := func(name string) *service.Group {
|
||||
return &service.Group{
|
||||
Name: name,
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}
|
||||
}
|
||||
|
||||
s.Run("search_name_should_match", func() {
|
||||
repo, ctx := newRepo()
|
||||
|
||||
target := mustCreate(repo, ctx, newGroup("it-group-search-name-target"))
|
||||
other := mustCreate(repo, ctx, newGroup("it-group-search-name-other"))
|
||||
|
||||
groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "name-target", nil)
|
||||
s.Require().NoError(err)
|
||||
s.Require().True(containsID(groups, target.ID), "expected target group to match by name")
|
||||
s.Require().False(containsID(groups, other.ID), "expected other group to be filtered out")
|
||||
})
|
||||
|
||||
s.Run("search_description_should_match", func() {
|
||||
repo, ctx := newRepo()
|
||||
|
||||
target := newGroup("it-group-search-desc-target")
|
||||
target.Description = "something about desc-needle in here"
|
||||
target = mustCreate(repo, ctx, target)
|
||||
|
||||
other := newGroup("it-group-search-desc-other")
|
||||
other.Description = "nothing to see here"
|
||||
other = mustCreate(repo, ctx, other)
|
||||
|
||||
groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "desc-needle", nil)
|
||||
s.Require().NoError(err)
|
||||
s.Require().True(containsID(groups, target.ID), "expected target group to match by description")
|
||||
s.Require().False(containsID(groups, other.ID), "expected other group to be filtered out")
|
||||
})
|
||||
|
||||
s.Run("search_nonexistent_should_return_empty", func() {
|
||||
repo, ctx := newRepo()
|
||||
|
||||
_ = mustCreate(repo, ctx, newGroup("it-group-search-nonexistent-baseline"))
|
||||
|
||||
search := s.T().Name() + "__no_such_group__"
|
||||
groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", search, nil)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Empty(groups)
|
||||
})
|
||||
|
||||
s.Run("search_should_be_case_insensitive", func() {
|
||||
repo, ctx := newRepo()
|
||||
|
||||
target := mustCreate(repo, ctx, newGroup("MiXeDCaSe-Needle"))
|
||||
other := mustCreate(repo, ctx, newGroup("it-group-search-case-other"))
|
||||
|
||||
groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "mixedcase-needle", nil)
|
||||
s.Require().NoError(err)
|
||||
s.Require().True(containsID(groups, target.ID), "expected case-insensitive match")
|
||||
s.Require().False(containsID(groups, other.ID), "expected other group to be filtered out")
|
||||
})
|
||||
|
||||
s.Run("search_should_escape_like_wildcards", func() {
|
||||
repo, ctx := newRepo()
|
||||
|
||||
percentTarget := mustCreate(repo, ctx, newGroup("it-group-search-100%-target"))
|
||||
percentOther := mustCreate(repo, ctx, newGroup("it-group-search-100X-other"))
|
||||
|
||||
groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "100%", nil)
|
||||
s.Require().NoError(err)
|
||||
s.Require().True(containsID(groups, percentTarget.ID), "expected literal %% match")
|
||||
s.Require().False(containsID(groups, percentOther.ID), "expected %% not to act as wildcard")
|
||||
|
||||
underscoreTarget := mustCreate(repo, ctx, newGroup("it-group-search-ab_cd-target"))
|
||||
underscoreOther := mustCreate(repo, ctx, newGroup("it-group-search-abXcd-other"))
|
||||
|
||||
groups, _, err = repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "ab_cd", nil)
|
||||
s.Require().NoError(err)
|
||||
s.Require().True(containsID(groups, underscoreTarget.ID), "expected literal _ match")
|
||||
s.Require().False(containsID(groups, underscoreOther.ID), "expected _ not to act as wildcard")
|
||||
})
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestListWithFilters_AccountCount() {
|
||||
g1 := &service.Group{
|
||||
Name: "g1",
|
||||
@@ -244,7 +386,7 @@ func (s *GroupRepoSuite) TestListWithFilters_AccountCount() {
|
||||
s.Require().NoError(err)
|
||||
|
||||
isExclusive := true
|
||||
groups, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformAnthropic, service.StatusActive, &isExclusive)
|
||||
groups, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformAnthropic, service.StatusActive, "", &isExclusive)
|
||||
s.Require().NoError(err, "ListWithFilters")
|
||||
s.Require().Equal(int64(1), page.Total)
|
||||
s.Require().Len(groups, 1)
|
||||
|
||||
273
backend/internal/repository/promo_code_repo.go
Normal file
273
backend/internal/repository/promo_code_repo.go
Normal file
@@ -0,0 +1,273 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/promocode"
|
||||
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
type promoCodeRepository struct {
|
||||
client *dbent.Client
|
||||
}
|
||||
|
||||
func NewPromoCodeRepository(client *dbent.Client) service.PromoCodeRepository {
|
||||
return &promoCodeRepository{client: client}
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) Create(ctx context.Context, code *service.PromoCode) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
builder := client.PromoCode.Create().
|
||||
SetCode(code.Code).
|
||||
SetBonusAmount(code.BonusAmount).
|
||||
SetMaxUses(code.MaxUses).
|
||||
SetUsedCount(code.UsedCount).
|
||||
SetStatus(code.Status).
|
||||
SetNotes(code.Notes)
|
||||
|
||||
if code.ExpiresAt != nil {
|
||||
builder.SetExpiresAt(*code.ExpiresAt)
|
||||
}
|
||||
|
||||
created, err := builder.Save(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
code.ID = created.ID
|
||||
code.CreatedAt = created.CreatedAt
|
||||
code.UpdatedAt = created.UpdatedAt
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) GetByID(ctx context.Context, id int64) (*service.PromoCode, error) {
|
||||
m, err := r.client.PromoCode.Query().
|
||||
Where(promocode.IDEQ(id)).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil, service.ErrPromoCodeNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return promoCodeEntityToService(m), nil
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) GetByCode(ctx context.Context, code string) (*service.PromoCode, error) {
|
||||
m, err := r.client.PromoCode.Query().
|
||||
Where(promocode.CodeEqualFold(code)).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil, service.ErrPromoCodeNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return promoCodeEntityToService(m), nil
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) GetByCodeForUpdate(ctx context.Context, code string) (*service.PromoCode, error) {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
m, err := client.PromoCode.Query().
|
||||
Where(promocode.CodeEqualFold(code)).
|
||||
ForUpdate().
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil, service.ErrPromoCodeNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return promoCodeEntityToService(m), nil
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) Update(ctx context.Context, code *service.PromoCode) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
builder := client.PromoCode.UpdateOneID(code.ID).
|
||||
SetCode(code.Code).
|
||||
SetBonusAmount(code.BonusAmount).
|
||||
SetMaxUses(code.MaxUses).
|
||||
SetUsedCount(code.UsedCount).
|
||||
SetStatus(code.Status).
|
||||
SetNotes(code.Notes)
|
||||
|
||||
if code.ExpiresAt != nil {
|
||||
builder.SetExpiresAt(*code.ExpiresAt)
|
||||
} else {
|
||||
builder.ClearExpiresAt()
|
||||
}
|
||||
|
||||
updated, err := builder.Save(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return service.ErrPromoCodeNotFound
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
code.UpdatedAt = updated.UpdatedAt
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) Delete(ctx context.Context, id int64) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
_, err := client.PromoCode.Delete().Where(promocode.IDEQ(id)).Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.PromoCode, *pagination.PaginationResult, error) {
|
||||
return r.ListWithFilters(ctx, params, "", "")
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, search string) ([]service.PromoCode, *pagination.PaginationResult, error) {
|
||||
q := r.client.PromoCode.Query()
|
||||
|
||||
if status != "" {
|
||||
q = q.Where(promocode.StatusEQ(status))
|
||||
}
|
||||
if search != "" {
|
||||
q = q.Where(promocode.CodeContainsFold(search))
|
||||
}
|
||||
|
||||
total, err := q.Count(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
codes, err := q.
|
||||
Offset(params.Offset()).
|
||||
Limit(params.Limit()).
|
||||
Order(dbent.Desc(promocode.FieldID)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
outCodes := promoCodeEntitiesToService(codes)
|
||||
|
||||
return outCodes, paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) CreateUsage(ctx context.Context, usage *service.PromoCodeUsage) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
created, err := client.PromoCodeUsage.Create().
|
||||
SetPromoCodeID(usage.PromoCodeID).
|
||||
SetUserID(usage.UserID).
|
||||
SetBonusAmount(usage.BonusAmount).
|
||||
SetUsedAt(usage.UsedAt).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
usage.ID = created.ID
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) GetUsageByPromoCodeAndUser(ctx context.Context, promoCodeID, userID int64) (*service.PromoCodeUsage, error) {
|
||||
m, err := r.client.PromoCodeUsage.Query().
|
||||
Where(
|
||||
promocodeusage.PromoCodeIDEQ(promoCodeID),
|
||||
promocodeusage.UserIDEQ(userID),
|
||||
).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return promoCodeUsageEntityToService(m), nil
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) ListUsagesByPromoCode(ctx context.Context, promoCodeID int64, params pagination.PaginationParams) ([]service.PromoCodeUsage, *pagination.PaginationResult, error) {
|
||||
q := r.client.PromoCodeUsage.Query().
|
||||
Where(promocodeusage.PromoCodeIDEQ(promoCodeID))
|
||||
|
||||
total, err := q.Count(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
usages, err := q.
|
||||
WithUser().
|
||||
Offset(params.Offset()).
|
||||
Limit(params.Limit()).
|
||||
Order(dbent.Desc(promocodeusage.FieldID)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
outUsages := promoCodeUsageEntitiesToService(usages)
|
||||
|
||||
return outUsages, paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) IncrementUsedCount(ctx context.Context, id int64) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
_, err := client.PromoCode.UpdateOneID(id).
|
||||
AddUsedCount(1).
|
||||
Save(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
// Entity to Service conversions
|
||||
|
||||
func promoCodeEntityToService(m *dbent.PromoCode) *service.PromoCode {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
return &service.PromoCode{
|
||||
ID: m.ID,
|
||||
Code: m.Code,
|
||||
BonusAmount: m.BonusAmount,
|
||||
MaxUses: m.MaxUses,
|
||||
UsedCount: m.UsedCount,
|
||||
Status: m.Status,
|
||||
ExpiresAt: m.ExpiresAt,
|
||||
Notes: derefString(m.Notes),
|
||||
CreatedAt: m.CreatedAt,
|
||||
UpdatedAt: m.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func promoCodeEntitiesToService(models []*dbent.PromoCode) []service.PromoCode {
|
||||
out := make([]service.PromoCode, 0, len(models))
|
||||
for i := range models {
|
||||
if s := promoCodeEntityToService(models[i]); s != nil {
|
||||
out = append(out, *s)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func promoCodeUsageEntityToService(m *dbent.PromoCodeUsage) *service.PromoCodeUsage {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
out := &service.PromoCodeUsage{
|
||||
ID: m.ID,
|
||||
PromoCodeID: m.PromoCodeID,
|
||||
UserID: m.UserID,
|
||||
BonusAmount: m.BonusAmount,
|
||||
UsedAt: m.UsedAt,
|
||||
}
|
||||
if m.Edges.User != nil {
|
||||
out.User = userEntityToService(m.Edges.User)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func promoCodeUsageEntitiesToService(models []*dbent.PromoCodeUsage) []service.PromoCodeUsage {
|
||||
out := make([]service.PromoCodeUsage, 0, len(models))
|
||||
for i := range models {
|
||||
if s := promoCodeUsageEntityToService(models[i]); s != nil {
|
||||
out = append(out, *s)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -22,7 +22,7 @@ import (
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, image_count, image_size, created_at"
|
||||
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, created_at"
|
||||
|
||||
type usageLogRepository struct {
|
||||
client *dbent.Client
|
||||
@@ -110,6 +110,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
duration_ms,
|
||||
first_token_ms,
|
||||
user_agent,
|
||||
ip_address,
|
||||
image_count,
|
||||
image_size,
|
||||
created_at
|
||||
@@ -119,7 +120,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
$8, $9, $10, $11,
|
||||
$12, $13,
|
||||
$14, $15, $16, $17, $18, $19,
|
||||
$20, $21, $22, $23, $24, $25, $26, $27, $28
|
||||
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29
|
||||
)
|
||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||
RETURNING id, created_at
|
||||
@@ -130,6 +131,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
duration := nullInt(log.DurationMs)
|
||||
firstToken := nullInt(log.FirstTokenMs)
|
||||
userAgent := nullString(log.UserAgent)
|
||||
ipAddress := nullString(log.IPAddress)
|
||||
imageSize := nullString(log.ImageSize)
|
||||
|
||||
var requestIDArg any
|
||||
@@ -163,6 +165,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
duration,
|
||||
firstToken,
|
||||
userAgent,
|
||||
ipAddress,
|
||||
log.ImageCount,
|
||||
imageSize,
|
||||
createdAt,
|
||||
@@ -1873,6 +1876,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
durationMs sql.NullInt64
|
||||
firstTokenMs sql.NullInt64
|
||||
userAgent sql.NullString
|
||||
ipAddress sql.NullString
|
||||
imageCount int
|
||||
imageSize sql.NullString
|
||||
createdAt time.Time
|
||||
@@ -1905,6 +1909,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
&durationMs,
|
||||
&firstTokenMs,
|
||||
&userAgent,
|
||||
&ipAddress,
|
||||
&imageCount,
|
||||
&imageSize,
|
||||
&createdAt,
|
||||
@@ -1959,6 +1964,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
if userAgent.Valid {
|
||||
log.UserAgent = &userAgent.String
|
||||
}
|
||||
if ipAddress.Valid {
|
||||
log.IPAddress = &ipAddress.String
|
||||
}
|
||||
if imageSize.Valid {
|
||||
log.ImageSize = &imageSize.String
|
||||
}
|
||||
|
||||
@@ -45,6 +45,7 @@ var ProviderSet = wire.NewSet(
|
||||
NewAccountRepository,
|
||||
NewProxyRepository,
|
||||
NewRedeemCodeRepository,
|
||||
NewPromoCodeRepository,
|
||||
NewUsageLogRepository,
|
||||
NewSettingRepository,
|
||||
NewUserSubscriptionRepository,
|
||||
|
||||
@@ -82,6 +82,8 @@ func TestAPIContracts(t *testing.T) {
|
||||
"name": "Key One",
|
||||
"group_id": null,
|
||||
"status": "active",
|
||||
"ip_whitelist": null,
|
||||
"ip_blacklist": null,
|
||||
"created_at": "2025-01-02T03:04:05Z",
|
||||
"updated_at": "2025-01-02T03:04:05Z"
|
||||
}
|
||||
@@ -116,6 +118,8 @@ func TestAPIContracts(t *testing.T) {
|
||||
"name": "Key One",
|
||||
"group_id": null,
|
||||
"status": "active",
|
||||
"ip_whitelist": null,
|
||||
"ip_blacklist": null,
|
||||
"created_at": "2025-01-02T03:04:05Z",
|
||||
"updated_at": "2025-01-02T03:04:05Z"
|
||||
}
|
||||
@@ -304,6 +308,10 @@ func TestAPIContracts(t *testing.T) {
|
||||
"turnstile_enabled": true,
|
||||
"turnstile_site_key": "site-key",
|
||||
"turnstile_secret_key_configured": true,
|
||||
"linuxdo_connect_enabled": false,
|
||||
"linuxdo_connect_client_id": "",
|
||||
"linuxdo_connect_client_secret_configured": false,
|
||||
"linuxdo_connect_redirect_url": "",
|
||||
"site_name": "Sub2API",
|
||||
"site_logo": "",
|
||||
"site_subtitle": "Subtitle",
|
||||
@@ -390,7 +398,7 @@ func newContractDeps(t *testing.T) *contractDeps {
|
||||
settingRepo := newStubSettingRepo()
|
||||
settingService := service.NewSettingService(settingRepo, cfg)
|
||||
|
||||
authHandler := handler.NewAuthHandler(cfg, nil, userService)
|
||||
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil)
|
||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
||||
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil)
|
||||
@@ -567,6 +575,10 @@ func (stubGroupRepo) GetByID(ctx context.Context, id int64) (*service.Group, err
|
||||
return nil, service.ErrGroupNotFound
|
||||
}
|
||||
|
||||
func (stubGroupRepo) GetByIDLite(ctx context.Context, id int64) (*service.Group, error) {
|
||||
return nil, service.ErrGroupNotFound
|
||||
}
|
||||
|
||||
func (stubGroupRepo) Update(ctx context.Context, group *service.Group) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
@@ -583,7 +595,7 @@ func (stubGroupRepo) List(ctx context.Context, params pagination.PaginationParam
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (stubGroupRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]service.Group, *pagination.PaginationResult, error) {
|
||||
func (stubGroupRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]service.Group, *pagination.PaginationResult, error) {
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/wire"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// ProviderSet 提供服务器层的依赖
|
||||
@@ -30,6 +31,7 @@ func ProvideRouter(
|
||||
apiKeyAuth middleware2.APIKeyAuthMiddleware,
|
||||
apiKeyService *service.APIKeyService,
|
||||
subscriptionService *service.SubscriptionService,
|
||||
redisClient *redis.Client,
|
||||
) *gin.Engine {
|
||||
if cfg.Server.Mode == "release" {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
@@ -47,7 +49,7 @@ func ProvideRouter(
|
||||
}
|
||||
}
|
||||
|
||||
return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, cfg)
|
||||
return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, cfg, redisClient)
|
||||
}
|
||||
|
||||
// ProvideHTTPServer 提供 HTTP 服务器
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -71,6 +74,17 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
|
||||
return
|
||||
}
|
||||
|
||||
// 检查 IP 限制(白名单/黑名单)
|
||||
// 注意:错误信息故意模糊,避免暴露具体的 IP 限制机制
|
||||
if len(apiKey.IPWhitelist) > 0 || len(apiKey.IPBlacklist) > 0 {
|
||||
clientIP := ip.GetClientIP(c)
|
||||
allowed, _ := ip.CheckIPRestriction(clientIP, apiKey.IPWhitelist, apiKey.IPBlacklist)
|
||||
if !allowed {
|
||||
AbortWithError(c, 403, "ACCESS_DENIED", "Access denied")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 检查关联的用户
|
||||
if apiKey.User == nil {
|
||||
AbortWithError(c, 401, "USER_NOT_FOUND", "User associated with API key not found")
|
||||
@@ -91,6 +105,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
|
||||
Concurrency: apiKey.User.Concurrency,
|
||||
})
|
||||
c.Set(string(ContextKeyUserRole), apiKey.User.Role)
|
||||
setGroupContext(c, apiKey.Group)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
@@ -149,6 +164,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
|
||||
Concurrency: apiKey.User.Concurrency,
|
||||
})
|
||||
c.Set(string(ContextKeyUserRole), apiKey.User.Role)
|
||||
setGroupContext(c, apiKey.Group)
|
||||
|
||||
c.Next()
|
||||
}
|
||||
@@ -173,3 +189,14 @@ func GetSubscriptionFromContext(c *gin.Context) (*service.UserSubscription, bool
|
||||
subscription, ok := value.(*service.UserSubscription)
|
||||
return subscription, ok
|
||||
}
|
||||
|
||||
func setGroupContext(c *gin.Context, group *service.Group) {
|
||||
if !service.IsGroupContextValid(group) {
|
||||
return
|
||||
}
|
||||
if existing, ok := c.Request.Context().Value(ctxkey.Group).(*service.Group); ok && existing != nil && existing.ID == group.ID && service.IsGroupContextValid(existing) {
|
||||
return
|
||||
}
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.Group, group)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
}
|
||||
|
||||
@@ -63,6 +63,7 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs
|
||||
Concurrency: apiKey.User.Concurrency,
|
||||
})
|
||||
c.Set(string(ContextKeyUserRole), apiKey.User.Role)
|
||||
setGroupContext(c, apiKey.Group)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
@@ -102,6 +103,7 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs
|
||||
Concurrency: apiKey.User.Concurrency,
|
||||
})
|
||||
c.Set(string(ContextKeyUserRole), apiKey.User.Role)
|
||||
setGroupContext(c, apiKey.Group)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
@@ -133,6 +134,70 @@ func TestApiKeyAuthWithSubscriptionGoogle_QueryApiKeyRejected(t *testing.T) {
|
||||
require.Equal(t, "INVALID_ARGUMENT", resp.Error.Status)
|
||||
}
|
||||
|
||||
func TestApiKeyAuthWithSubscriptionGoogleSetsGroupContext(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
group := &service.Group{
|
||||
ID: 99,
|
||||
Name: "g1",
|
||||
Status: service.StatusActive,
|
||||
Platform: service.PlatformGemini,
|
||||
Hydrated: true,
|
||||
}
|
||||
user := &service.User{
|
||||
ID: 7,
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Balance: 10,
|
||||
Concurrency: 3,
|
||||
}
|
||||
apiKey := &service.APIKey{
|
||||
ID: 100,
|
||||
UserID: user.ID,
|
||||
Key: "test-key",
|
||||
Status: service.StatusActive,
|
||||
User: user,
|
||||
Group: group,
|
||||
}
|
||||
apiKey.GroupID = &group.ID
|
||||
|
||||
apiKeyService := service.NewAPIKeyService(
|
||||
fakeAPIKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
if key != apiKey.Key {
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
clone := *apiKey
|
||||
return &clone, nil
|
||||
},
|
||||
},
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
&config.Config{RunMode: config.RunModeSimple},
|
||||
)
|
||||
|
||||
cfg := &config.Config{RunMode: config.RunModeSimple}
|
||||
r := gin.New()
|
||||
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, cfg))
|
||||
r.GET("/v1beta/test", func(c *gin.Context) {
|
||||
groupFromCtx, ok := c.Request.Context().Value(ctxkey.Group).(*service.Group)
|
||||
if !ok || groupFromCtx == nil || groupFromCtx.ID != group.ID {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"ok": false})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
|
||||
req.Header.Set("x-api-key", apiKey.Key)
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
func TestApiKeyAuthWithSubscriptionGoogle_QueryKeyAllowedOnV1Beta(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -25,6 +26,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
|
||||
ID: 42,
|
||||
Name: "sub",
|
||||
Status: service.StatusActive,
|
||||
Hydrated: true,
|
||||
SubscriptionType: service.SubscriptionTypeSubscription,
|
||||
DailyLimitUSD: &limit,
|
||||
}
|
||||
@@ -110,6 +112,129 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestAPIKeyAuthSetsGroupContext(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
group := &service.Group{
|
||||
ID: 101,
|
||||
Name: "g1",
|
||||
Status: service.StatusActive,
|
||||
Platform: service.PlatformAnthropic,
|
||||
Hydrated: true,
|
||||
}
|
||||
user := &service.User{
|
||||
ID: 7,
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Balance: 10,
|
||||
Concurrency: 3,
|
||||
}
|
||||
apiKey := &service.APIKey{
|
||||
ID: 100,
|
||||
UserID: user.ID,
|
||||
Key: "test-key",
|
||||
Status: service.StatusActive,
|
||||
User: user,
|
||||
Group: group,
|
||||
}
|
||||
apiKey.GroupID = &group.ID
|
||||
|
||||
apiKeyRepo := &stubApiKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
if key != apiKey.Key {
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
clone := *apiKey
|
||||
return &clone, nil
|
||||
},
|
||||
}
|
||||
|
||||
cfg := &config.Config{RunMode: config.RunModeSimple}
|
||||
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
|
||||
router := gin.New()
|
||||
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, nil, cfg)))
|
||||
router.GET("/t", func(c *gin.Context) {
|
||||
groupFromCtx, ok := c.Request.Context().Value(ctxkey.Group).(*service.Group)
|
||||
if !ok || groupFromCtx == nil || groupFromCtx.ID != group.ID {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"ok": false})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
||||
req.Header.Set("x-api-key", apiKey.Key)
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
func TestAPIKeyAuthOverwritesInvalidContextGroup(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
group := &service.Group{
|
||||
ID: 101,
|
||||
Name: "g1",
|
||||
Status: service.StatusActive,
|
||||
Platform: service.PlatformAnthropic,
|
||||
Hydrated: true,
|
||||
}
|
||||
user := &service.User{
|
||||
ID: 7,
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Balance: 10,
|
||||
Concurrency: 3,
|
||||
}
|
||||
apiKey := &service.APIKey{
|
||||
ID: 100,
|
||||
UserID: user.ID,
|
||||
Key: "test-key",
|
||||
Status: service.StatusActive,
|
||||
User: user,
|
||||
Group: group,
|
||||
}
|
||||
apiKey.GroupID = &group.ID
|
||||
|
||||
apiKeyRepo := &stubApiKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
if key != apiKey.Key {
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
clone := *apiKey
|
||||
return &clone, nil
|
||||
},
|
||||
}
|
||||
|
||||
cfg := &config.Config{RunMode: config.RunModeSimple}
|
||||
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
|
||||
router := gin.New()
|
||||
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, nil, cfg)))
|
||||
|
||||
invalidGroup := &service.Group{
|
||||
ID: group.ID,
|
||||
Platform: group.Platform,
|
||||
Status: group.Status,
|
||||
}
|
||||
router.GET("/t", func(c *gin.Context) {
|
||||
groupFromCtx, ok := c.Request.Context().Value(ctxkey.Group).(*service.Group)
|
||||
if !ok || groupFromCtx == nil || groupFromCtx.ID != group.ID || !groupFromCtx.Hydrated || groupFromCtx == invalidGroup {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"ok": false})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
||||
req.Header.Set("x-api-key", apiKey.Key)
|
||||
req = req.WithContext(context.WithValue(req.Context(), ctxkey.Group, invalidGroup))
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
func newAuthTestRouter(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) *gin.Engine {
|
||||
router := gin.New()
|
||||
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, cfg)))
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user