mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-25 17:14:45 +08:00
Merge branch 'main' of https://github.com/mt21625457/aicodex2api
This commit is contained in:
368
Linux DO Connect.md
Normal file
368
Linux DO Connect.md
Normal file
@@ -0,0 +1,368 @@
|
|||||||
|
# Linux DO Connect
|
||||||
|
|
||||||
|
OAuth(Open Authorization)是一个开放的网络授权标准,目前最新版本为 OAuth 2.0。我们日常使用的第三方登录(如 Google 账号登录)就采用了该标准。OAuth 允许用户授权第三方应用访问存储在其他服务提供商(如 Google)上的信息,无需在不同平台上重复填写注册信息。用户授权后,平台可以直接访问用户的账户信息进行身份验证,而用户无需向第三方应用提供密码。
|
||||||
|
|
||||||
|
目前系统已实现完整的 OAuth2 授权码(code)方式鉴权,但界面等配套功能还在持续完善中。让我们一起打造一个更完善的共享方案。
|
||||||
|
|
||||||
|
## 基本介绍
|
||||||
|
|
||||||
|
这是一套标准的 OAuth2 鉴权系统,可以让开发者共享论坛的用户基本信息。
|
||||||
|
|
||||||
|
- 可获取字段:
|
||||||
|
|
||||||
|
| 参数 | 说明 |
|
||||||
|
| ----------------- | ------------------------------- |
|
||||||
|
| `id` | 用户唯一标识(不可变) |
|
||||||
|
| `username` | 论坛用户名 |
|
||||||
|
| `name` | 论坛用户昵称(可变) |
|
||||||
|
| `avatar_template` | 用户头像模板URL(支持多种尺寸) |
|
||||||
|
| `active` | 账号活跃状态 |
|
||||||
|
| `trust_level` | 信任等级(0-4) |
|
||||||
|
| `silenced` | 禁言状态 |
|
||||||
|
| `external_ids` | 外部ID关联信息 |
|
||||||
|
| `api_key` | API访问密钥 |
|
||||||
|
|
||||||
|
通过这些信息,公益网站/接口可以实现:
|
||||||
|
|
||||||
|
1. 基于 `id` 的服务频率限制
|
||||||
|
2. 基于 `trust_level` 的服务额度分配
|
||||||
|
3. 基于用户信息的滥用举报机制
|
||||||
|
|
||||||
|
## 相关端点
|
||||||
|
|
||||||
|
- Authorize 端点: `https://connect.linux.do/oauth2/authorize`
|
||||||
|
- Token 端点:`https://connect.linux.do/oauth2/token`
|
||||||
|
- 用户信息 端点:`https://connect.linux.do/api/user`
|
||||||
|
|
||||||
|
## 申请使用
|
||||||
|
|
||||||
|
- 访问 [Connect.Linux.Do](https://connect.linux.do/) 申请接入你的应用。
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
- 点击 **`我的应用接入`** - **`申请新接入`**,填写相关信息。其中 **`回调地址`** 是你的应用接收用户信息的地址。
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
- 申请成功后,你将获得 **`Client Id`** 和 **`Client Secret`**,这是你应用的唯一身份凭证。
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
## 接入 Linux Do
|
||||||
|
|
||||||
|
JavaScript
|
||||||
|
```JavaScript
|
||||||
|
// 安装第三方请求库(或使用原生的 Fetch API),本例中使用 axios
|
||||||
|
// npm install axios
|
||||||
|
|
||||||
|
// 通过 OAuth2 获取 Linux Do 用户信息的参考流程
|
||||||
|
const axios = require('axios');
|
||||||
|
const readline = require('readline');
|
||||||
|
|
||||||
|
// 配置信息(建议通过环境变量配置,避免使用硬编码)
|
||||||
|
const CLIENT_ID = '你的 Client ID';
|
||||||
|
const CLIENT_SECRET = '你的 Client Secret';
|
||||||
|
const REDIRECT_URI = '你的回调地址';
|
||||||
|
const AUTH_URL = 'https://connect.linux.do/oauth2/authorize';
|
||||||
|
const TOKEN_URL = 'https://connect.linux.do/oauth2/token';
|
||||||
|
const USER_INFO_URL = 'https://connect.linux.do/api/user';
|
||||||
|
|
||||||
|
// 第一步:生成授权 URL
|
||||||
|
function getAuthUrl() {
|
||||||
|
const params = new URLSearchParams({
|
||||||
|
client_id: CLIENT_ID,
|
||||||
|
redirect_uri: REDIRECT_URI,
|
||||||
|
response_type: 'code',
|
||||||
|
scope: 'user'
|
||||||
|
});
|
||||||
|
|
||||||
|
return `${AUTH_URL}?${params.toString()}`;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 第二步:获取 code 参数
|
||||||
|
function getCode() {
|
||||||
|
return new Promise((resolve) => {
|
||||||
|
// 本例中使用终端输入来模拟流程,仅供本地测试
|
||||||
|
// 请在实际应用中替换为真实的处理逻辑
|
||||||
|
const rl = readline.createInterface({ input: process.stdin, output: process.stdout });
|
||||||
|
rl.question('从回调 URL 中提取出 code,粘贴到此处并按回车:', (answer) => {
|
||||||
|
rl.close();
|
||||||
|
resolve(answer.trim());
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// 第三步:使用 code 参数获取访问令牌
|
||||||
|
async function getAccessToken(code) {
|
||||||
|
try {
|
||||||
|
const form = new URLSearchParams({
|
||||||
|
client_id: CLIENT_ID,
|
||||||
|
client_secret: CLIENT_SECRET,
|
||||||
|
code: code,
|
||||||
|
redirect_uri: REDIRECT_URI,
|
||||||
|
grant_type: 'authorization_code'
|
||||||
|
}).toString();
|
||||||
|
|
||||||
|
const response = await axios.post(TOKEN_URL, form, {
|
||||||
|
// 提醒:需正确配置请求头,否则无法正常获取访问令牌
|
||||||
|
headers: {
|
||||||
|
'Content-Type': 'application/x-www-form-urlencoded',
|
||||||
|
'Accept': 'application/json'
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
return response.data;
|
||||||
|
} catch (error) {
|
||||||
|
console.error(`获取访问令牌失败:${error.response ? JSON.stringify(error.response.data) : error.message}`);
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 第四步:使用访问令牌获取用户信息
|
||||||
|
async function getUserInfo(accessToken) {
|
||||||
|
try {
|
||||||
|
const response = await axios.get(USER_INFO_URL, {
|
||||||
|
headers: {
|
||||||
|
Authorization: `Bearer ${accessToken}`
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
return response.data;
|
||||||
|
} catch (error) {
|
||||||
|
console.error(`获取用户信息失败:${error.response ? JSON.stringify(error.response.data) : error.message}`);
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 主流程
|
||||||
|
async function main() {
|
||||||
|
// 1. 生成授权 URL,前端引导用户访问授权页
|
||||||
|
const authUrl = getAuthUrl();
|
||||||
|
console.log(`请访问此 URL 授权:${authUrl}
|
||||||
|
`);
|
||||||
|
|
||||||
|
// 2. 用户授权后,从回调 URL 获取 code 参数
|
||||||
|
const code = await getCode();
|
||||||
|
|
||||||
|
try {
|
||||||
|
// 3. 使用 code 参数获取访问令牌
|
||||||
|
const tokenData = await getAccessToken(code);
|
||||||
|
const accessToken = tokenData.access_token;
|
||||||
|
|
||||||
|
// 4. 使用访问令牌获取用户信息
|
||||||
|
if (accessToken) {
|
||||||
|
const userInfo = await getUserInfo(accessToken);
|
||||||
|
console.log(`
|
||||||
|
获取用户信息成功:${JSON.stringify(userInfo, null, 2)}`);
|
||||||
|
} else {
|
||||||
|
console.log(`
|
||||||
|
获取访问令牌失败:${JSON.stringify(tokenData)}`);
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error('发生错误:', error);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
Python
|
||||||
|
```python
|
||||||
|
# 安装第三方请求库,本例中使用 requests
|
||||||
|
# pip install requests
|
||||||
|
|
||||||
|
# 通过 OAuth2 获取 Linux Do 用户信息的参考流程
|
||||||
|
import requests
|
||||||
|
import json
|
||||||
|
|
||||||
|
# 配置信息(建议通过环境变量配置,避免使用硬编码)
|
||||||
|
CLIENT_ID = '你的 Client ID'
|
||||||
|
CLIENT_SECRET = '你的 Client Secret'
|
||||||
|
REDIRECT_URI = '你的回调地址'
|
||||||
|
AUTH_URL = 'https://connect.linux.do/oauth2/authorize'
|
||||||
|
TOKEN_URL = 'https://connect.linux.do/oauth2/token'
|
||||||
|
USER_INFO_URL = 'https://connect.linux.do/api/user'
|
||||||
|
|
||||||
|
# 第一步:生成授权 URL
|
||||||
|
def get_auth_url():
|
||||||
|
params = {
|
||||||
|
'client_id': CLIENT_ID,
|
||||||
|
'redirect_uri': REDIRECT_URI,
|
||||||
|
'response_type': 'code',
|
||||||
|
'scope': 'user'
|
||||||
|
}
|
||||||
|
auth_url = f"{AUTH_URL}?{'&'.join(f'{k}={v}' for k, v in params.items())}"
|
||||||
|
return auth_url
|
||||||
|
|
||||||
|
# 第二步:获取 code 参数
|
||||||
|
def get_code():
|
||||||
|
# 本例中使用终端输入来模拟流程,仅供本地测试
|
||||||
|
# 请在实际应用中替换为真实的处理逻辑
|
||||||
|
return input('从回调 URL 中提取出 code,粘贴到此处并按回车:').strip()
|
||||||
|
|
||||||
|
# 第三步:使用 code 参数获取访问令牌
|
||||||
|
def get_access_token(code):
|
||||||
|
try:
|
||||||
|
data = {
|
||||||
|
'client_id': CLIENT_ID,
|
||||||
|
'client_secret': CLIENT_SECRET,
|
||||||
|
'code': code,
|
||||||
|
'redirect_uri': REDIRECT_URI,
|
||||||
|
'grant_type': 'authorization_code'
|
||||||
|
}
|
||||||
|
# 提醒:需正确配置请求头,否则无法正常获取访问令牌
|
||||||
|
headers = {
|
||||||
|
'Content-Type': 'application/x-www-form-urlencoded',
|
||||||
|
'Accept': 'application/json'
|
||||||
|
}
|
||||||
|
response = requests.post(TOKEN_URL, data=data, headers=headers)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
print(f"获取访问令牌失败:{e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 第四步:使用访问令牌获取用户信息
|
||||||
|
def get_user_info(access_token):
|
||||||
|
try:
|
||||||
|
headers = {
|
||||||
|
'Authorization': f'Bearer {access_token}'
|
||||||
|
}
|
||||||
|
response = requests.get(USER_INFO_URL, headers=headers)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
print(f"获取用户信息失败:{e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 主流程
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# 1. 生成授权 URL,前端引导用户访问授权页
|
||||||
|
auth_url = get_auth_url()
|
||||||
|
print(f'请访问此 URL 授权:{auth_url}
|
||||||
|
')
|
||||||
|
|
||||||
|
# 2. 用户授权后,从回调 URL 获取 code 参数
|
||||||
|
code = get_code()
|
||||||
|
|
||||||
|
# 3. 使用 code 参数获取访问令牌
|
||||||
|
token_data = get_access_token(code)
|
||||||
|
if token_data:
|
||||||
|
access_token = token_data.get('access_token')
|
||||||
|
|
||||||
|
# 4. 使用访问令牌获取用户信息
|
||||||
|
if access_token:
|
||||||
|
user_info = get_user_info(access_token)
|
||||||
|
if user_info:
|
||||||
|
print(f"
|
||||||
|
获取用户信息成功:{json.dumps(user_info, indent=2)}")
|
||||||
|
else:
|
||||||
|
print("
|
||||||
|
获取用户信息失败")
|
||||||
|
else:
|
||||||
|
print(f"
|
||||||
|
获取访问令牌失败:{json.dumps(token_data, indent=2)}")
|
||||||
|
else:
|
||||||
|
print("
|
||||||
|
获取访问令牌失败")
|
||||||
|
```
|
||||||
|
PHP
|
||||||
|
```php
|
||||||
|
// 通过 OAuth2 获取 Linux Do 用户信息的参考流程
|
||||||
|
|
||||||
|
// 配置信息
|
||||||
|
$CLIENT_ID = '你的 Client ID';
|
||||||
|
$CLIENT_SECRET = '你的 Client Secret';
|
||||||
|
$REDIRECT_URI = '你的回调地址';
|
||||||
|
$AUTH_URL = 'https://connect.linux.do/oauth2/authorize';
|
||||||
|
$TOKEN_URL = 'https://connect.linux.do/oauth2/token';
|
||||||
|
$USER_INFO_URL = 'https://connect.linux.do/api/user';
|
||||||
|
|
||||||
|
// 生成授权 URL
|
||||||
|
function getAuthUrl($clientId, $redirectUri) {
|
||||||
|
global $AUTH_URL;
|
||||||
|
return $AUTH_URL . '?' . http_build_query([
|
||||||
|
'client_id' => $clientId,
|
||||||
|
'redirect_uri' => $redirectUri,
|
||||||
|
'response_type' => 'code',
|
||||||
|
'scope' => 'user'
|
||||||
|
]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 使用 code 参数获取用户信息(合并获取令牌和获取用户信息的步骤)
|
||||||
|
function getUserInfoWithCode($code, $clientId, $clientSecret, $redirectUri) {
|
||||||
|
global $TOKEN_URL, $USER_INFO_URL;
|
||||||
|
|
||||||
|
// 1. 获取访问令牌
|
||||||
|
$ch = curl_init($TOKEN_URL);
|
||||||
|
curl_setopt($ch, CURLOPT_RETURNTRANSFER, true);
|
||||||
|
curl_setopt($ch, CURLOPT_POST, true);
|
||||||
|
curl_setopt($ch, CURLOPT_POSTFIELDS, http_build_query([
|
||||||
|
'client_id' => $clientId,
|
||||||
|
'client_secret' => $clientSecret,
|
||||||
|
'code' => $code,
|
||||||
|
'redirect_uri' => $redirectUri,
|
||||||
|
'grant_type' => 'authorization_code'
|
||||||
|
]));
|
||||||
|
curl_setopt($ch, CURLOPT_HTTPHEADER, [
|
||||||
|
'Content-Type: application/x-www-form-urlencoded',
|
||||||
|
'Accept: application/json'
|
||||||
|
]);
|
||||||
|
|
||||||
|
$tokenResponse = curl_exec($ch);
|
||||||
|
curl_close($ch);
|
||||||
|
|
||||||
|
$tokenData = json_decode($tokenResponse, true);
|
||||||
|
if (!isset($tokenData['access_token'])) {
|
||||||
|
return ['error' => '获取访问令牌失败', 'details' => $tokenData];
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. 获取用户信息
|
||||||
|
$ch = curl_init($USER_INFO_URL);
|
||||||
|
curl_setopt($ch, CURLOPT_RETURNTRANSFER, true);
|
||||||
|
curl_setopt($ch, CURLOPT_HTTPHEADER, [
|
||||||
|
'Authorization: Bearer ' . $tokenData['access_token']
|
||||||
|
]);
|
||||||
|
|
||||||
|
$userResponse = curl_exec($ch);
|
||||||
|
curl_close($ch);
|
||||||
|
|
||||||
|
return json_decode($userResponse, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 主流程
|
||||||
|
// 1. 生成授权 URL
|
||||||
|
$authUrl = getAuthUrl($CLIENT_ID, $REDIRECT_URI);
|
||||||
|
echo "<a href='$authUrl'>使用 Linux Do 登录</a>";
|
||||||
|
|
||||||
|
// 2. 处理回调并获取用户信息
|
||||||
|
if (isset($_GET['code'])) {
|
||||||
|
$userInfo = getUserInfoWithCode(
|
||||||
|
$_GET['code'],
|
||||||
|
$CLIENT_ID,
|
||||||
|
$CLIENT_SECRET,
|
||||||
|
$REDIRECT_URI
|
||||||
|
);
|
||||||
|
|
||||||
|
if (isset($userInfo['error'])) {
|
||||||
|
echo '错误: ' . $userInfo['error'];
|
||||||
|
} else {
|
||||||
|
echo '欢迎, ' . $userInfo['name'] . '!';
|
||||||
|
// 处理用户登录逻辑...
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## 使用说明
|
||||||
|
|
||||||
|
### 授权流程
|
||||||
|
|
||||||
|
1. 用户点击应用中的’使用 Linux Do 登录’按钮
|
||||||
|
2. 系统将用户重定向至 Linux Do 的授权页面
|
||||||
|
3. 用户完成授权后,系统自动重定向回应用并携带授权码
|
||||||
|
4. 应用使用授权码获取访问令牌
|
||||||
|
5. 使用访问令牌获取用户信息
|
||||||
|
|
||||||
|
### 安全建议
|
||||||
|
|
||||||
|
- 切勿在前端代码中暴露 Client Secret
|
||||||
|
- 对所有用户输入数据进行严格验证
|
||||||
|
- 确保使用 HTTPS 协议传输数据
|
||||||
|
- 定期更新并妥善保管 Client Secret
|
||||||
@@ -1 +1 @@
|
|||||||
0.1.1
|
0.1.46
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
emailQueueService := service.ProvideEmailQueueService(emailService)
|
emailQueueService := service.ProvideEmailQueueService(emailService)
|
||||||
authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService)
|
authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService)
|
||||||
userService := service.NewUserService(userRepository)
|
userService := service.NewUserService(userRepository)
|
||||||
authHandler := handler.NewAuthHandler(configConfig, authService, userService)
|
authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService)
|
||||||
userHandler := handler.NewUserHandler(userService)
|
userHandler := handler.NewUserHandler(userService)
|
||||||
apiKeyRepository := repository.NewAPIKeyRepository(client)
|
apiKeyRepository := repository.NewAPIKeyRepository(client)
|
||||||
groupRepository := repository.NewGroupRepository(client, db)
|
groupRepository := repository.NewGroupRepository(client, db)
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
package ent
|
package ent
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@@ -35,6 +36,10 @@ type APIKey struct {
|
|||||||
GroupID *int64 `json:"group_id,omitempty"`
|
GroupID *int64 `json:"group_id,omitempty"`
|
||||||
// Status holds the value of the "status" field.
|
// Status holds the value of the "status" field.
|
||||||
Status string `json:"status,omitempty"`
|
Status string `json:"status,omitempty"`
|
||||||
|
// Allowed IPs/CIDRs, e.g. ["192.168.1.100", "10.0.0.0/8"]
|
||||||
|
IPWhitelist []string `json:"ip_whitelist,omitempty"`
|
||||||
|
// Blocked IPs/CIDRs
|
||||||
|
IPBlacklist []string `json:"ip_blacklist,omitempty"`
|
||||||
// Edges holds the relations/edges for other nodes in the graph.
|
// Edges holds the relations/edges for other nodes in the graph.
|
||||||
// The values are being populated by the APIKeyQuery when eager-loading is set.
|
// The values are being populated by the APIKeyQuery when eager-loading is set.
|
||||||
Edges APIKeyEdges `json:"edges"`
|
Edges APIKeyEdges `json:"edges"`
|
||||||
@@ -90,6 +95,8 @@ func (*APIKey) scanValues(columns []string) ([]any, error) {
|
|||||||
values := make([]any, len(columns))
|
values := make([]any, len(columns))
|
||||||
for i := range columns {
|
for i := range columns {
|
||||||
switch columns[i] {
|
switch columns[i] {
|
||||||
|
case apikey.FieldIPWhitelist, apikey.FieldIPBlacklist:
|
||||||
|
values[i] = new([]byte)
|
||||||
case apikey.FieldID, apikey.FieldUserID, apikey.FieldGroupID:
|
case apikey.FieldID, apikey.FieldUserID, apikey.FieldGroupID:
|
||||||
values[i] = new(sql.NullInt64)
|
values[i] = new(sql.NullInt64)
|
||||||
case apikey.FieldKey, apikey.FieldName, apikey.FieldStatus:
|
case apikey.FieldKey, apikey.FieldName, apikey.FieldStatus:
|
||||||
@@ -167,6 +174,22 @@ func (_m *APIKey) assignValues(columns []string, values []any) error {
|
|||||||
} else if value.Valid {
|
} else if value.Valid {
|
||||||
_m.Status = value.String
|
_m.Status = value.String
|
||||||
}
|
}
|
||||||
|
case apikey.FieldIPWhitelist:
|
||||||
|
if value, ok := values[i].(*[]byte); !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field ip_whitelist", values[i])
|
||||||
|
} else if value != nil && len(*value) > 0 {
|
||||||
|
if err := json.Unmarshal(*value, &_m.IPWhitelist); err != nil {
|
||||||
|
return fmt.Errorf("unmarshal field ip_whitelist: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case apikey.FieldIPBlacklist:
|
||||||
|
if value, ok := values[i].(*[]byte); !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field ip_blacklist", values[i])
|
||||||
|
} else if value != nil && len(*value) > 0 {
|
||||||
|
if err := json.Unmarshal(*value, &_m.IPBlacklist); err != nil {
|
||||||
|
return fmt.Errorf("unmarshal field ip_blacklist: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
_m.selectValues.Set(columns[i], values[i])
|
_m.selectValues.Set(columns[i], values[i])
|
||||||
}
|
}
|
||||||
@@ -245,6 +268,12 @@ func (_m *APIKey) String() string {
|
|||||||
builder.WriteString(", ")
|
builder.WriteString(", ")
|
||||||
builder.WriteString("status=")
|
builder.WriteString("status=")
|
||||||
builder.WriteString(_m.Status)
|
builder.WriteString(_m.Status)
|
||||||
|
builder.WriteString(", ")
|
||||||
|
builder.WriteString("ip_whitelist=")
|
||||||
|
builder.WriteString(fmt.Sprintf("%v", _m.IPWhitelist))
|
||||||
|
builder.WriteString(", ")
|
||||||
|
builder.WriteString("ip_blacklist=")
|
||||||
|
builder.WriteString(fmt.Sprintf("%v", _m.IPBlacklist))
|
||||||
builder.WriteByte(')')
|
builder.WriteByte(')')
|
||||||
return builder.String()
|
return builder.String()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -31,6 +31,10 @@ const (
|
|||||||
FieldGroupID = "group_id"
|
FieldGroupID = "group_id"
|
||||||
// FieldStatus holds the string denoting the status field in the database.
|
// FieldStatus holds the string denoting the status field in the database.
|
||||||
FieldStatus = "status"
|
FieldStatus = "status"
|
||||||
|
// FieldIPWhitelist holds the string denoting the ip_whitelist field in the database.
|
||||||
|
FieldIPWhitelist = "ip_whitelist"
|
||||||
|
// FieldIPBlacklist holds the string denoting the ip_blacklist field in the database.
|
||||||
|
FieldIPBlacklist = "ip_blacklist"
|
||||||
// EdgeUser holds the string denoting the user edge name in mutations.
|
// EdgeUser holds the string denoting the user edge name in mutations.
|
||||||
EdgeUser = "user"
|
EdgeUser = "user"
|
||||||
// EdgeGroup holds the string denoting the group edge name in mutations.
|
// EdgeGroup holds the string denoting the group edge name in mutations.
|
||||||
@@ -73,6 +77,8 @@ var Columns = []string{
|
|||||||
FieldName,
|
FieldName,
|
||||||
FieldGroupID,
|
FieldGroupID,
|
||||||
FieldStatus,
|
FieldStatus,
|
||||||
|
FieldIPWhitelist,
|
||||||
|
FieldIPBlacklist,
|
||||||
}
|
}
|
||||||
|
|
||||||
// ValidColumn reports if the column name is valid (part of the table columns).
|
// ValidColumn reports if the column name is valid (part of the table columns).
|
||||||
|
|||||||
@@ -470,6 +470,26 @@ func StatusContainsFold(v string) predicate.APIKey {
|
|||||||
return predicate.APIKey(sql.FieldContainsFold(FieldStatus, v))
|
return predicate.APIKey(sql.FieldContainsFold(FieldStatus, v))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IPWhitelistIsNil applies the IsNil predicate on the "ip_whitelist" field.
|
||||||
|
func IPWhitelistIsNil() predicate.APIKey {
|
||||||
|
return predicate.APIKey(sql.FieldIsNull(FieldIPWhitelist))
|
||||||
|
}
|
||||||
|
|
||||||
|
// IPWhitelistNotNil applies the NotNil predicate on the "ip_whitelist" field.
|
||||||
|
func IPWhitelistNotNil() predicate.APIKey {
|
||||||
|
return predicate.APIKey(sql.FieldNotNull(FieldIPWhitelist))
|
||||||
|
}
|
||||||
|
|
||||||
|
// IPBlacklistIsNil applies the IsNil predicate on the "ip_blacklist" field.
|
||||||
|
func IPBlacklistIsNil() predicate.APIKey {
|
||||||
|
return predicate.APIKey(sql.FieldIsNull(FieldIPBlacklist))
|
||||||
|
}
|
||||||
|
|
||||||
|
// IPBlacklistNotNil applies the NotNil predicate on the "ip_blacklist" field.
|
||||||
|
func IPBlacklistNotNil() predicate.APIKey {
|
||||||
|
return predicate.APIKey(sql.FieldNotNull(FieldIPBlacklist))
|
||||||
|
}
|
||||||
|
|
||||||
// HasUser applies the HasEdge predicate on the "user" edge.
|
// HasUser applies the HasEdge predicate on the "user" edge.
|
||||||
func HasUser() predicate.APIKey {
|
func HasUser() predicate.APIKey {
|
||||||
return predicate.APIKey(func(s *sql.Selector) {
|
return predicate.APIKey(func(s *sql.Selector) {
|
||||||
|
|||||||
@@ -113,6 +113,18 @@ func (_c *APIKeyCreate) SetNillableStatus(v *string) *APIKeyCreate {
|
|||||||
return _c
|
return _c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetIPWhitelist sets the "ip_whitelist" field.
|
||||||
|
func (_c *APIKeyCreate) SetIPWhitelist(v []string) *APIKeyCreate {
|
||||||
|
_c.mutation.SetIPWhitelist(v)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetIPBlacklist sets the "ip_blacklist" field.
|
||||||
|
func (_c *APIKeyCreate) SetIPBlacklist(v []string) *APIKeyCreate {
|
||||||
|
_c.mutation.SetIPBlacklist(v)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
// SetUser sets the "user" edge to the User entity.
|
// SetUser sets the "user" edge to the User entity.
|
||||||
func (_c *APIKeyCreate) SetUser(v *User) *APIKeyCreate {
|
func (_c *APIKeyCreate) SetUser(v *User) *APIKeyCreate {
|
||||||
return _c.SetUserID(v.ID)
|
return _c.SetUserID(v.ID)
|
||||||
@@ -285,6 +297,14 @@ func (_c *APIKeyCreate) createSpec() (*APIKey, *sqlgraph.CreateSpec) {
|
|||||||
_spec.SetField(apikey.FieldStatus, field.TypeString, value)
|
_spec.SetField(apikey.FieldStatus, field.TypeString, value)
|
||||||
_node.Status = value
|
_node.Status = value
|
||||||
}
|
}
|
||||||
|
if value, ok := _c.mutation.IPWhitelist(); ok {
|
||||||
|
_spec.SetField(apikey.FieldIPWhitelist, field.TypeJSON, value)
|
||||||
|
_node.IPWhitelist = value
|
||||||
|
}
|
||||||
|
if value, ok := _c.mutation.IPBlacklist(); ok {
|
||||||
|
_spec.SetField(apikey.FieldIPBlacklist, field.TypeJSON, value)
|
||||||
|
_node.IPBlacklist = value
|
||||||
|
}
|
||||||
if nodes := _c.mutation.UserIDs(); len(nodes) > 0 {
|
if nodes := _c.mutation.UserIDs(); len(nodes) > 0 {
|
||||||
edge := &sqlgraph.EdgeSpec{
|
edge := &sqlgraph.EdgeSpec{
|
||||||
Rel: sqlgraph.M2O,
|
Rel: sqlgraph.M2O,
|
||||||
@@ -483,6 +503,42 @@ func (u *APIKeyUpsert) UpdateStatus() *APIKeyUpsert {
|
|||||||
return u
|
return u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetIPWhitelist sets the "ip_whitelist" field.
|
||||||
|
func (u *APIKeyUpsert) SetIPWhitelist(v []string) *APIKeyUpsert {
|
||||||
|
u.Set(apikey.FieldIPWhitelist, v)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateIPWhitelist sets the "ip_whitelist" field to the value that was provided on create.
|
||||||
|
func (u *APIKeyUpsert) UpdateIPWhitelist() *APIKeyUpsert {
|
||||||
|
u.SetExcluded(apikey.FieldIPWhitelist)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearIPWhitelist clears the value of the "ip_whitelist" field.
|
||||||
|
func (u *APIKeyUpsert) ClearIPWhitelist() *APIKeyUpsert {
|
||||||
|
u.SetNull(apikey.FieldIPWhitelist)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetIPBlacklist sets the "ip_blacklist" field.
|
||||||
|
func (u *APIKeyUpsert) SetIPBlacklist(v []string) *APIKeyUpsert {
|
||||||
|
u.Set(apikey.FieldIPBlacklist, v)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateIPBlacklist sets the "ip_blacklist" field to the value that was provided on create.
|
||||||
|
func (u *APIKeyUpsert) UpdateIPBlacklist() *APIKeyUpsert {
|
||||||
|
u.SetExcluded(apikey.FieldIPBlacklist)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearIPBlacklist clears the value of the "ip_blacklist" field.
|
||||||
|
func (u *APIKeyUpsert) ClearIPBlacklist() *APIKeyUpsert {
|
||||||
|
u.SetNull(apikey.FieldIPBlacklist)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
// UpdateNewValues updates the mutable fields using the new values that were set on create.
|
// UpdateNewValues updates the mutable fields using the new values that were set on create.
|
||||||
// Using this option is equivalent to using:
|
// Using this option is equivalent to using:
|
||||||
//
|
//
|
||||||
@@ -640,6 +696,48 @@ func (u *APIKeyUpsertOne) UpdateStatus() *APIKeyUpsertOne {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetIPWhitelist sets the "ip_whitelist" field.
|
||||||
|
func (u *APIKeyUpsertOne) SetIPWhitelist(v []string) *APIKeyUpsertOne {
|
||||||
|
return u.Update(func(s *APIKeyUpsert) {
|
||||||
|
s.SetIPWhitelist(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateIPWhitelist sets the "ip_whitelist" field to the value that was provided on create.
|
||||||
|
func (u *APIKeyUpsertOne) UpdateIPWhitelist() *APIKeyUpsertOne {
|
||||||
|
return u.Update(func(s *APIKeyUpsert) {
|
||||||
|
s.UpdateIPWhitelist()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearIPWhitelist clears the value of the "ip_whitelist" field.
|
||||||
|
func (u *APIKeyUpsertOne) ClearIPWhitelist() *APIKeyUpsertOne {
|
||||||
|
return u.Update(func(s *APIKeyUpsert) {
|
||||||
|
s.ClearIPWhitelist()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetIPBlacklist sets the "ip_blacklist" field.
|
||||||
|
func (u *APIKeyUpsertOne) SetIPBlacklist(v []string) *APIKeyUpsertOne {
|
||||||
|
return u.Update(func(s *APIKeyUpsert) {
|
||||||
|
s.SetIPBlacklist(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateIPBlacklist sets the "ip_blacklist" field to the value that was provided on create.
|
||||||
|
func (u *APIKeyUpsertOne) UpdateIPBlacklist() *APIKeyUpsertOne {
|
||||||
|
return u.Update(func(s *APIKeyUpsert) {
|
||||||
|
s.UpdateIPBlacklist()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearIPBlacklist clears the value of the "ip_blacklist" field.
|
||||||
|
func (u *APIKeyUpsertOne) ClearIPBlacklist() *APIKeyUpsertOne {
|
||||||
|
return u.Update(func(s *APIKeyUpsert) {
|
||||||
|
s.ClearIPBlacklist()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// Exec executes the query.
|
// Exec executes the query.
|
||||||
func (u *APIKeyUpsertOne) Exec(ctx context.Context) error {
|
func (u *APIKeyUpsertOne) Exec(ctx context.Context) error {
|
||||||
if len(u.create.conflict) == 0 {
|
if len(u.create.conflict) == 0 {
|
||||||
@@ -963,6 +1061,48 @@ func (u *APIKeyUpsertBulk) UpdateStatus() *APIKeyUpsertBulk {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetIPWhitelist sets the "ip_whitelist" field.
|
||||||
|
func (u *APIKeyUpsertBulk) SetIPWhitelist(v []string) *APIKeyUpsertBulk {
|
||||||
|
return u.Update(func(s *APIKeyUpsert) {
|
||||||
|
s.SetIPWhitelist(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateIPWhitelist sets the "ip_whitelist" field to the value that was provided on create.
|
||||||
|
func (u *APIKeyUpsertBulk) UpdateIPWhitelist() *APIKeyUpsertBulk {
|
||||||
|
return u.Update(func(s *APIKeyUpsert) {
|
||||||
|
s.UpdateIPWhitelist()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearIPWhitelist clears the value of the "ip_whitelist" field.
|
||||||
|
func (u *APIKeyUpsertBulk) ClearIPWhitelist() *APIKeyUpsertBulk {
|
||||||
|
return u.Update(func(s *APIKeyUpsert) {
|
||||||
|
s.ClearIPWhitelist()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetIPBlacklist sets the "ip_blacklist" field.
|
||||||
|
func (u *APIKeyUpsertBulk) SetIPBlacklist(v []string) *APIKeyUpsertBulk {
|
||||||
|
return u.Update(func(s *APIKeyUpsert) {
|
||||||
|
s.SetIPBlacklist(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateIPBlacklist sets the "ip_blacklist" field to the value that was provided on create.
|
||||||
|
func (u *APIKeyUpsertBulk) UpdateIPBlacklist() *APIKeyUpsertBulk {
|
||||||
|
return u.Update(func(s *APIKeyUpsert) {
|
||||||
|
s.UpdateIPBlacklist()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearIPBlacklist clears the value of the "ip_blacklist" field.
|
||||||
|
func (u *APIKeyUpsertBulk) ClearIPBlacklist() *APIKeyUpsertBulk {
|
||||||
|
return u.Update(func(s *APIKeyUpsert) {
|
||||||
|
s.ClearIPBlacklist()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// Exec executes the query.
|
// Exec executes the query.
|
||||||
func (u *APIKeyUpsertBulk) Exec(ctx context.Context) error {
|
func (u *APIKeyUpsertBulk) Exec(ctx context.Context) error {
|
||||||
if u.create.err != nil {
|
if u.create.err != nil {
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
|
|
||||||
"entgo.io/ent/dialect/sql"
|
"entgo.io/ent/dialect/sql"
|
||||||
"entgo.io/ent/dialect/sql/sqlgraph"
|
"entgo.io/ent/dialect/sql/sqlgraph"
|
||||||
|
"entgo.io/ent/dialect/sql/sqljson"
|
||||||
"entgo.io/ent/schema/field"
|
"entgo.io/ent/schema/field"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/apikey"
|
"github.com/Wei-Shaw/sub2api/ent/apikey"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/group"
|
"github.com/Wei-Shaw/sub2api/ent/group"
|
||||||
@@ -133,6 +134,42 @@ func (_u *APIKeyUpdate) SetNillableStatus(v *string) *APIKeyUpdate {
|
|||||||
return _u
|
return _u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetIPWhitelist sets the "ip_whitelist" field.
|
||||||
|
func (_u *APIKeyUpdate) SetIPWhitelist(v []string) *APIKeyUpdate {
|
||||||
|
_u.mutation.SetIPWhitelist(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// AppendIPWhitelist appends value to the "ip_whitelist" field.
|
||||||
|
func (_u *APIKeyUpdate) AppendIPWhitelist(v []string) *APIKeyUpdate {
|
||||||
|
_u.mutation.AppendIPWhitelist(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearIPWhitelist clears the value of the "ip_whitelist" field.
|
||||||
|
func (_u *APIKeyUpdate) ClearIPWhitelist() *APIKeyUpdate {
|
||||||
|
_u.mutation.ClearIPWhitelist()
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetIPBlacklist sets the "ip_blacklist" field.
|
||||||
|
func (_u *APIKeyUpdate) SetIPBlacklist(v []string) *APIKeyUpdate {
|
||||||
|
_u.mutation.SetIPBlacklist(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// AppendIPBlacklist appends value to the "ip_blacklist" field.
|
||||||
|
func (_u *APIKeyUpdate) AppendIPBlacklist(v []string) *APIKeyUpdate {
|
||||||
|
_u.mutation.AppendIPBlacklist(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearIPBlacklist clears the value of the "ip_blacklist" field.
|
||||||
|
func (_u *APIKeyUpdate) ClearIPBlacklist() *APIKeyUpdate {
|
||||||
|
_u.mutation.ClearIPBlacklist()
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
// SetUser sets the "user" edge to the User entity.
|
// SetUser sets the "user" edge to the User entity.
|
||||||
func (_u *APIKeyUpdate) SetUser(v *User) *APIKeyUpdate {
|
func (_u *APIKeyUpdate) SetUser(v *User) *APIKeyUpdate {
|
||||||
return _u.SetUserID(v.ID)
|
return _u.SetUserID(v.ID)
|
||||||
@@ -291,6 +328,28 @@ func (_u *APIKeyUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
|||||||
if value, ok := _u.mutation.Status(); ok {
|
if value, ok := _u.mutation.Status(); ok {
|
||||||
_spec.SetField(apikey.FieldStatus, field.TypeString, value)
|
_spec.SetField(apikey.FieldStatus, field.TypeString, value)
|
||||||
}
|
}
|
||||||
|
if value, ok := _u.mutation.IPWhitelist(); ok {
|
||||||
|
_spec.SetField(apikey.FieldIPWhitelist, field.TypeJSON, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.AppendedIPWhitelist(); ok {
|
||||||
|
_spec.AddModifier(func(u *sql.UpdateBuilder) {
|
||||||
|
sqljson.Append(u, apikey.FieldIPWhitelist, value)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if _u.mutation.IPWhitelistCleared() {
|
||||||
|
_spec.ClearField(apikey.FieldIPWhitelist, field.TypeJSON)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.IPBlacklist(); ok {
|
||||||
|
_spec.SetField(apikey.FieldIPBlacklist, field.TypeJSON, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.AppendedIPBlacklist(); ok {
|
||||||
|
_spec.AddModifier(func(u *sql.UpdateBuilder) {
|
||||||
|
sqljson.Append(u, apikey.FieldIPBlacklist, value)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if _u.mutation.IPBlacklistCleared() {
|
||||||
|
_spec.ClearField(apikey.FieldIPBlacklist, field.TypeJSON)
|
||||||
|
}
|
||||||
if _u.mutation.UserCleared() {
|
if _u.mutation.UserCleared() {
|
||||||
edge := &sqlgraph.EdgeSpec{
|
edge := &sqlgraph.EdgeSpec{
|
||||||
Rel: sqlgraph.M2O,
|
Rel: sqlgraph.M2O,
|
||||||
@@ -516,6 +575,42 @@ func (_u *APIKeyUpdateOne) SetNillableStatus(v *string) *APIKeyUpdateOne {
|
|||||||
return _u
|
return _u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetIPWhitelist sets the "ip_whitelist" field.
|
||||||
|
func (_u *APIKeyUpdateOne) SetIPWhitelist(v []string) *APIKeyUpdateOne {
|
||||||
|
_u.mutation.SetIPWhitelist(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// AppendIPWhitelist appends value to the "ip_whitelist" field.
|
||||||
|
func (_u *APIKeyUpdateOne) AppendIPWhitelist(v []string) *APIKeyUpdateOne {
|
||||||
|
_u.mutation.AppendIPWhitelist(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearIPWhitelist clears the value of the "ip_whitelist" field.
|
||||||
|
func (_u *APIKeyUpdateOne) ClearIPWhitelist() *APIKeyUpdateOne {
|
||||||
|
_u.mutation.ClearIPWhitelist()
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetIPBlacklist sets the "ip_blacklist" field.
|
||||||
|
func (_u *APIKeyUpdateOne) SetIPBlacklist(v []string) *APIKeyUpdateOne {
|
||||||
|
_u.mutation.SetIPBlacklist(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// AppendIPBlacklist appends value to the "ip_blacklist" field.
|
||||||
|
func (_u *APIKeyUpdateOne) AppendIPBlacklist(v []string) *APIKeyUpdateOne {
|
||||||
|
_u.mutation.AppendIPBlacklist(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearIPBlacklist clears the value of the "ip_blacklist" field.
|
||||||
|
func (_u *APIKeyUpdateOne) ClearIPBlacklist() *APIKeyUpdateOne {
|
||||||
|
_u.mutation.ClearIPBlacklist()
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
// SetUser sets the "user" edge to the User entity.
|
// SetUser sets the "user" edge to the User entity.
|
||||||
func (_u *APIKeyUpdateOne) SetUser(v *User) *APIKeyUpdateOne {
|
func (_u *APIKeyUpdateOne) SetUser(v *User) *APIKeyUpdateOne {
|
||||||
return _u.SetUserID(v.ID)
|
return _u.SetUserID(v.ID)
|
||||||
@@ -704,6 +799,28 @@ func (_u *APIKeyUpdateOne) sqlSave(ctx context.Context) (_node *APIKey, err erro
|
|||||||
if value, ok := _u.mutation.Status(); ok {
|
if value, ok := _u.mutation.Status(); ok {
|
||||||
_spec.SetField(apikey.FieldStatus, field.TypeString, value)
|
_spec.SetField(apikey.FieldStatus, field.TypeString, value)
|
||||||
}
|
}
|
||||||
|
if value, ok := _u.mutation.IPWhitelist(); ok {
|
||||||
|
_spec.SetField(apikey.FieldIPWhitelist, field.TypeJSON, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.AppendedIPWhitelist(); ok {
|
||||||
|
_spec.AddModifier(func(u *sql.UpdateBuilder) {
|
||||||
|
sqljson.Append(u, apikey.FieldIPWhitelist, value)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if _u.mutation.IPWhitelistCleared() {
|
||||||
|
_spec.ClearField(apikey.FieldIPWhitelist, field.TypeJSON)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.IPBlacklist(); ok {
|
||||||
|
_spec.SetField(apikey.FieldIPBlacklist, field.TypeJSON, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.AppendedIPBlacklist(); ok {
|
||||||
|
_spec.AddModifier(func(u *sql.UpdateBuilder) {
|
||||||
|
sqljson.Append(u, apikey.FieldIPBlacklist, value)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if _u.mutation.IPBlacklistCleared() {
|
||||||
|
_spec.ClearField(apikey.FieldIPBlacklist, field.TypeJSON)
|
||||||
|
}
|
||||||
if _u.mutation.UserCleared() {
|
if _u.mutation.UserCleared() {
|
||||||
edge := &sqlgraph.EdgeSpec{
|
edge := &sqlgraph.EdgeSpec{
|
||||||
Rel: sqlgraph.M2O,
|
Rel: sqlgraph.M2O,
|
||||||
|
|||||||
@@ -18,6 +18,8 @@ var (
|
|||||||
{Name: "key", Type: field.TypeString, Unique: true, Size: 128},
|
{Name: "key", Type: field.TypeString, Unique: true, Size: 128},
|
||||||
{Name: "name", Type: field.TypeString, Size: 100},
|
{Name: "name", Type: field.TypeString, Size: 100},
|
||||||
{Name: "status", Type: field.TypeString, Size: 20, Default: "active"},
|
{Name: "status", Type: field.TypeString, Size: 20, Default: "active"},
|
||||||
|
{Name: "ip_whitelist", Type: field.TypeJSON, Nullable: true},
|
||||||
|
{Name: "ip_blacklist", Type: field.TypeJSON, Nullable: true},
|
||||||
{Name: "group_id", Type: field.TypeInt64, Nullable: true},
|
{Name: "group_id", Type: field.TypeInt64, Nullable: true},
|
||||||
{Name: "user_id", Type: field.TypeInt64},
|
{Name: "user_id", Type: field.TypeInt64},
|
||||||
}
|
}
|
||||||
@@ -29,13 +31,13 @@ var (
|
|||||||
ForeignKeys: []*schema.ForeignKey{
|
ForeignKeys: []*schema.ForeignKey{
|
||||||
{
|
{
|
||||||
Symbol: "api_keys_groups_api_keys",
|
Symbol: "api_keys_groups_api_keys",
|
||||||
Columns: []*schema.Column{APIKeysColumns[7]},
|
Columns: []*schema.Column{APIKeysColumns[9]},
|
||||||
RefColumns: []*schema.Column{GroupsColumns[0]},
|
RefColumns: []*schema.Column{GroupsColumns[0]},
|
||||||
OnDelete: schema.SetNull,
|
OnDelete: schema.SetNull,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Symbol: "api_keys_users_api_keys",
|
Symbol: "api_keys_users_api_keys",
|
||||||
Columns: []*schema.Column{APIKeysColumns[8]},
|
Columns: []*schema.Column{APIKeysColumns[10]},
|
||||||
RefColumns: []*schema.Column{UsersColumns[0]},
|
RefColumns: []*schema.Column{UsersColumns[0]},
|
||||||
OnDelete: schema.NoAction,
|
OnDelete: schema.NoAction,
|
||||||
},
|
},
|
||||||
@@ -44,12 +46,12 @@ var (
|
|||||||
{
|
{
|
||||||
Name: "apikey_user_id",
|
Name: "apikey_user_id",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{APIKeysColumns[8]},
|
Columns: []*schema.Column{APIKeysColumns[10]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "apikey_group_id",
|
Name: "apikey_group_id",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{APIKeysColumns[7]},
|
Columns: []*schema.Column{APIKeysColumns[9]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "apikey_status",
|
Name: "apikey_status",
|
||||||
@@ -376,6 +378,7 @@ var (
|
|||||||
{Name: "duration_ms", Type: field.TypeInt, Nullable: true},
|
{Name: "duration_ms", Type: field.TypeInt, Nullable: true},
|
||||||
{Name: "first_token_ms", Type: field.TypeInt, Nullable: true},
|
{Name: "first_token_ms", Type: field.TypeInt, Nullable: true},
|
||||||
{Name: "user_agent", Type: field.TypeString, Nullable: true, Size: 512},
|
{Name: "user_agent", Type: field.TypeString, Nullable: true, Size: 512},
|
||||||
|
{Name: "ip_address", Type: field.TypeString, Nullable: true, Size: 45},
|
||||||
{Name: "image_count", Type: field.TypeInt, Default: 0},
|
{Name: "image_count", Type: field.TypeInt, Default: 0},
|
||||||
{Name: "image_size", Type: field.TypeString, Nullable: true, Size: 10},
|
{Name: "image_size", Type: field.TypeString, Nullable: true, Size: 10},
|
||||||
{Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
{Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
||||||
@@ -393,31 +396,31 @@ var (
|
|||||||
ForeignKeys: []*schema.ForeignKey{
|
ForeignKeys: []*schema.ForeignKey{
|
||||||
{
|
{
|
||||||
Symbol: "usage_logs_api_keys_usage_logs",
|
Symbol: "usage_logs_api_keys_usage_logs",
|
||||||
Columns: []*schema.Column{UsageLogsColumns[24]},
|
Columns: []*schema.Column{UsageLogsColumns[25]},
|
||||||
RefColumns: []*schema.Column{APIKeysColumns[0]},
|
RefColumns: []*schema.Column{APIKeysColumns[0]},
|
||||||
OnDelete: schema.NoAction,
|
OnDelete: schema.NoAction,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Symbol: "usage_logs_accounts_usage_logs",
|
Symbol: "usage_logs_accounts_usage_logs",
|
||||||
Columns: []*schema.Column{UsageLogsColumns[25]},
|
Columns: []*schema.Column{UsageLogsColumns[26]},
|
||||||
RefColumns: []*schema.Column{AccountsColumns[0]},
|
RefColumns: []*schema.Column{AccountsColumns[0]},
|
||||||
OnDelete: schema.NoAction,
|
OnDelete: schema.NoAction,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Symbol: "usage_logs_groups_usage_logs",
|
Symbol: "usage_logs_groups_usage_logs",
|
||||||
Columns: []*schema.Column{UsageLogsColumns[26]},
|
Columns: []*schema.Column{UsageLogsColumns[27]},
|
||||||
RefColumns: []*schema.Column{GroupsColumns[0]},
|
RefColumns: []*schema.Column{GroupsColumns[0]},
|
||||||
OnDelete: schema.SetNull,
|
OnDelete: schema.SetNull,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Symbol: "usage_logs_users_usage_logs",
|
Symbol: "usage_logs_users_usage_logs",
|
||||||
Columns: []*schema.Column{UsageLogsColumns[27]},
|
Columns: []*schema.Column{UsageLogsColumns[28]},
|
||||||
RefColumns: []*schema.Column{UsersColumns[0]},
|
RefColumns: []*schema.Column{UsersColumns[0]},
|
||||||
OnDelete: schema.NoAction,
|
OnDelete: schema.NoAction,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Symbol: "usage_logs_user_subscriptions_usage_logs",
|
Symbol: "usage_logs_user_subscriptions_usage_logs",
|
||||||
Columns: []*schema.Column{UsageLogsColumns[28]},
|
Columns: []*schema.Column{UsageLogsColumns[29]},
|
||||||
RefColumns: []*schema.Column{UserSubscriptionsColumns[0]},
|
RefColumns: []*schema.Column{UserSubscriptionsColumns[0]},
|
||||||
OnDelete: schema.SetNull,
|
OnDelete: schema.SetNull,
|
||||||
},
|
},
|
||||||
@@ -426,32 +429,32 @@ var (
|
|||||||
{
|
{
|
||||||
Name: "usagelog_user_id",
|
Name: "usagelog_user_id",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{UsageLogsColumns[27]},
|
Columns: []*schema.Column{UsageLogsColumns[28]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "usagelog_api_key_id",
|
Name: "usagelog_api_key_id",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{UsageLogsColumns[24]},
|
Columns: []*schema.Column{UsageLogsColumns[25]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "usagelog_account_id",
|
Name: "usagelog_account_id",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{UsageLogsColumns[25]},
|
Columns: []*schema.Column{UsageLogsColumns[26]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "usagelog_group_id",
|
Name: "usagelog_group_id",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{UsageLogsColumns[26]},
|
Columns: []*schema.Column{UsageLogsColumns[27]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "usagelog_subscription_id",
|
Name: "usagelog_subscription_id",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{UsageLogsColumns[28]},
|
Columns: []*schema.Column{UsageLogsColumns[29]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "usagelog_created_at",
|
Name: "usagelog_created_at",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{UsageLogsColumns[23]},
|
Columns: []*schema.Column{UsageLogsColumns[24]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "usagelog_model",
|
Name: "usagelog_model",
|
||||||
@@ -466,12 +469,12 @@ var (
|
|||||||
{
|
{
|
||||||
Name: "usagelog_user_id_created_at",
|
Name: "usagelog_user_id_created_at",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{UsageLogsColumns[27], UsageLogsColumns[23]},
|
Columns: []*schema.Column{UsageLogsColumns[28], UsageLogsColumns[24]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "usagelog_api_key_id_created_at",
|
Name: "usagelog_api_key_id_created_at",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{UsageLogsColumns[24], UsageLogsColumns[23]},
|
Columns: []*schema.Column{UsageLogsColumns[25], UsageLogsColumns[24]},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -54,26 +54,30 @@ const (
|
|||||||
// APIKeyMutation represents an operation that mutates the APIKey nodes in the graph.
|
// APIKeyMutation represents an operation that mutates the APIKey nodes in the graph.
|
||||||
type APIKeyMutation struct {
|
type APIKeyMutation struct {
|
||||||
config
|
config
|
||||||
op Op
|
op Op
|
||||||
typ string
|
typ string
|
||||||
id *int64
|
id *int64
|
||||||
created_at *time.Time
|
created_at *time.Time
|
||||||
updated_at *time.Time
|
updated_at *time.Time
|
||||||
deleted_at *time.Time
|
deleted_at *time.Time
|
||||||
key *string
|
key *string
|
||||||
name *string
|
name *string
|
||||||
status *string
|
status *string
|
||||||
clearedFields map[string]struct{}
|
ip_whitelist *[]string
|
||||||
user *int64
|
appendip_whitelist []string
|
||||||
cleareduser bool
|
ip_blacklist *[]string
|
||||||
group *int64
|
appendip_blacklist []string
|
||||||
clearedgroup bool
|
clearedFields map[string]struct{}
|
||||||
usage_logs map[int64]struct{}
|
user *int64
|
||||||
removedusage_logs map[int64]struct{}
|
cleareduser bool
|
||||||
clearedusage_logs bool
|
group *int64
|
||||||
done bool
|
clearedgroup bool
|
||||||
oldValue func(context.Context) (*APIKey, error)
|
usage_logs map[int64]struct{}
|
||||||
predicates []predicate.APIKey
|
removedusage_logs map[int64]struct{}
|
||||||
|
clearedusage_logs bool
|
||||||
|
done bool
|
||||||
|
oldValue func(context.Context) (*APIKey, error)
|
||||||
|
predicates []predicate.APIKey
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ ent.Mutation = (*APIKeyMutation)(nil)
|
var _ ent.Mutation = (*APIKeyMutation)(nil)
|
||||||
@@ -488,6 +492,136 @@ func (m *APIKeyMutation) ResetStatus() {
|
|||||||
m.status = nil
|
m.status = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetIPWhitelist sets the "ip_whitelist" field.
|
||||||
|
func (m *APIKeyMutation) SetIPWhitelist(s []string) {
|
||||||
|
m.ip_whitelist = &s
|
||||||
|
m.appendip_whitelist = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IPWhitelist returns the value of the "ip_whitelist" field in the mutation.
|
||||||
|
func (m *APIKeyMutation) IPWhitelist() (r []string, exists bool) {
|
||||||
|
v := m.ip_whitelist
|
||||||
|
if v == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return *v, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// OldIPWhitelist returns the old "ip_whitelist" field's value of the APIKey entity.
|
||||||
|
// If the APIKey object wasn't provided to the builder, the object is fetched from the database.
|
||||||
|
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||||
|
func (m *APIKeyMutation) OldIPWhitelist(ctx context.Context) (v []string, err error) {
|
||||||
|
if !m.op.Is(OpUpdateOne) {
|
||||||
|
return v, errors.New("OldIPWhitelist is only allowed on UpdateOne operations")
|
||||||
|
}
|
||||||
|
if m.id == nil || m.oldValue == nil {
|
||||||
|
return v, errors.New("OldIPWhitelist requires an ID field in the mutation")
|
||||||
|
}
|
||||||
|
oldValue, err := m.oldValue(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return v, fmt.Errorf("querying old value for OldIPWhitelist: %w", err)
|
||||||
|
}
|
||||||
|
return oldValue.IPWhitelist, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AppendIPWhitelist adds s to the "ip_whitelist" field.
|
||||||
|
func (m *APIKeyMutation) AppendIPWhitelist(s []string) {
|
||||||
|
m.appendip_whitelist = append(m.appendip_whitelist, s...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AppendedIPWhitelist returns the list of values that were appended to the "ip_whitelist" field in this mutation.
|
||||||
|
func (m *APIKeyMutation) AppendedIPWhitelist() ([]string, bool) {
|
||||||
|
if len(m.appendip_whitelist) == 0 {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
return m.appendip_whitelist, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearIPWhitelist clears the value of the "ip_whitelist" field.
|
||||||
|
func (m *APIKeyMutation) ClearIPWhitelist() {
|
||||||
|
m.ip_whitelist = nil
|
||||||
|
m.appendip_whitelist = nil
|
||||||
|
m.clearedFields[apikey.FieldIPWhitelist] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IPWhitelistCleared returns if the "ip_whitelist" field was cleared in this mutation.
|
||||||
|
func (m *APIKeyMutation) IPWhitelistCleared() bool {
|
||||||
|
_, ok := m.clearedFields[apikey.FieldIPWhitelist]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetIPWhitelist resets all changes to the "ip_whitelist" field.
|
||||||
|
func (m *APIKeyMutation) ResetIPWhitelist() {
|
||||||
|
m.ip_whitelist = nil
|
||||||
|
m.appendip_whitelist = nil
|
||||||
|
delete(m.clearedFields, apikey.FieldIPWhitelist)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetIPBlacklist sets the "ip_blacklist" field.
|
||||||
|
func (m *APIKeyMutation) SetIPBlacklist(s []string) {
|
||||||
|
m.ip_blacklist = &s
|
||||||
|
m.appendip_blacklist = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IPBlacklist returns the value of the "ip_blacklist" field in the mutation.
|
||||||
|
func (m *APIKeyMutation) IPBlacklist() (r []string, exists bool) {
|
||||||
|
v := m.ip_blacklist
|
||||||
|
if v == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return *v, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// OldIPBlacklist returns the old "ip_blacklist" field's value of the APIKey entity.
|
||||||
|
// If the APIKey object wasn't provided to the builder, the object is fetched from the database.
|
||||||
|
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||||
|
func (m *APIKeyMutation) OldIPBlacklist(ctx context.Context) (v []string, err error) {
|
||||||
|
if !m.op.Is(OpUpdateOne) {
|
||||||
|
return v, errors.New("OldIPBlacklist is only allowed on UpdateOne operations")
|
||||||
|
}
|
||||||
|
if m.id == nil || m.oldValue == nil {
|
||||||
|
return v, errors.New("OldIPBlacklist requires an ID field in the mutation")
|
||||||
|
}
|
||||||
|
oldValue, err := m.oldValue(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return v, fmt.Errorf("querying old value for OldIPBlacklist: %w", err)
|
||||||
|
}
|
||||||
|
return oldValue.IPBlacklist, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AppendIPBlacklist adds s to the "ip_blacklist" field.
|
||||||
|
func (m *APIKeyMutation) AppendIPBlacklist(s []string) {
|
||||||
|
m.appendip_blacklist = append(m.appendip_blacklist, s...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AppendedIPBlacklist returns the list of values that were appended to the "ip_blacklist" field in this mutation.
|
||||||
|
func (m *APIKeyMutation) AppendedIPBlacklist() ([]string, bool) {
|
||||||
|
if len(m.appendip_blacklist) == 0 {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
return m.appendip_blacklist, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearIPBlacklist clears the value of the "ip_blacklist" field.
|
||||||
|
func (m *APIKeyMutation) ClearIPBlacklist() {
|
||||||
|
m.ip_blacklist = nil
|
||||||
|
m.appendip_blacklist = nil
|
||||||
|
m.clearedFields[apikey.FieldIPBlacklist] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IPBlacklistCleared returns if the "ip_blacklist" field was cleared in this mutation.
|
||||||
|
func (m *APIKeyMutation) IPBlacklistCleared() bool {
|
||||||
|
_, ok := m.clearedFields[apikey.FieldIPBlacklist]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetIPBlacklist resets all changes to the "ip_blacklist" field.
|
||||||
|
func (m *APIKeyMutation) ResetIPBlacklist() {
|
||||||
|
m.ip_blacklist = nil
|
||||||
|
m.appendip_blacklist = nil
|
||||||
|
delete(m.clearedFields, apikey.FieldIPBlacklist)
|
||||||
|
}
|
||||||
|
|
||||||
// ClearUser clears the "user" edge to the User entity.
|
// ClearUser clears the "user" edge to the User entity.
|
||||||
func (m *APIKeyMutation) ClearUser() {
|
func (m *APIKeyMutation) ClearUser() {
|
||||||
m.cleareduser = true
|
m.cleareduser = true
|
||||||
@@ -630,7 +764,7 @@ func (m *APIKeyMutation) Type() string {
|
|||||||
// order to get all numeric fields that were incremented/decremented, call
|
// order to get all numeric fields that were incremented/decremented, call
|
||||||
// AddedFields().
|
// AddedFields().
|
||||||
func (m *APIKeyMutation) Fields() []string {
|
func (m *APIKeyMutation) Fields() []string {
|
||||||
fields := make([]string, 0, 8)
|
fields := make([]string, 0, 10)
|
||||||
if m.created_at != nil {
|
if m.created_at != nil {
|
||||||
fields = append(fields, apikey.FieldCreatedAt)
|
fields = append(fields, apikey.FieldCreatedAt)
|
||||||
}
|
}
|
||||||
@@ -655,6 +789,12 @@ func (m *APIKeyMutation) Fields() []string {
|
|||||||
if m.status != nil {
|
if m.status != nil {
|
||||||
fields = append(fields, apikey.FieldStatus)
|
fields = append(fields, apikey.FieldStatus)
|
||||||
}
|
}
|
||||||
|
if m.ip_whitelist != nil {
|
||||||
|
fields = append(fields, apikey.FieldIPWhitelist)
|
||||||
|
}
|
||||||
|
if m.ip_blacklist != nil {
|
||||||
|
fields = append(fields, apikey.FieldIPBlacklist)
|
||||||
|
}
|
||||||
return fields
|
return fields
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -679,6 +819,10 @@ func (m *APIKeyMutation) Field(name string) (ent.Value, bool) {
|
|||||||
return m.GroupID()
|
return m.GroupID()
|
||||||
case apikey.FieldStatus:
|
case apikey.FieldStatus:
|
||||||
return m.Status()
|
return m.Status()
|
||||||
|
case apikey.FieldIPWhitelist:
|
||||||
|
return m.IPWhitelist()
|
||||||
|
case apikey.FieldIPBlacklist:
|
||||||
|
return m.IPBlacklist()
|
||||||
}
|
}
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
@@ -704,6 +848,10 @@ func (m *APIKeyMutation) OldField(ctx context.Context, name string) (ent.Value,
|
|||||||
return m.OldGroupID(ctx)
|
return m.OldGroupID(ctx)
|
||||||
case apikey.FieldStatus:
|
case apikey.FieldStatus:
|
||||||
return m.OldStatus(ctx)
|
return m.OldStatus(ctx)
|
||||||
|
case apikey.FieldIPWhitelist:
|
||||||
|
return m.OldIPWhitelist(ctx)
|
||||||
|
case apikey.FieldIPBlacklist:
|
||||||
|
return m.OldIPBlacklist(ctx)
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("unknown APIKey field %s", name)
|
return nil, fmt.Errorf("unknown APIKey field %s", name)
|
||||||
}
|
}
|
||||||
@@ -769,6 +917,20 @@ func (m *APIKeyMutation) SetField(name string, value ent.Value) error {
|
|||||||
}
|
}
|
||||||
m.SetStatus(v)
|
m.SetStatus(v)
|
||||||
return nil
|
return nil
|
||||||
|
case apikey.FieldIPWhitelist:
|
||||||
|
v, ok := value.([]string)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||||
|
}
|
||||||
|
m.SetIPWhitelist(v)
|
||||||
|
return nil
|
||||||
|
case apikey.FieldIPBlacklist:
|
||||||
|
v, ok := value.([]string)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||||
|
}
|
||||||
|
m.SetIPBlacklist(v)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
return fmt.Errorf("unknown APIKey field %s", name)
|
return fmt.Errorf("unknown APIKey field %s", name)
|
||||||
}
|
}
|
||||||
@@ -808,6 +970,12 @@ func (m *APIKeyMutation) ClearedFields() []string {
|
|||||||
if m.FieldCleared(apikey.FieldGroupID) {
|
if m.FieldCleared(apikey.FieldGroupID) {
|
||||||
fields = append(fields, apikey.FieldGroupID)
|
fields = append(fields, apikey.FieldGroupID)
|
||||||
}
|
}
|
||||||
|
if m.FieldCleared(apikey.FieldIPWhitelist) {
|
||||||
|
fields = append(fields, apikey.FieldIPWhitelist)
|
||||||
|
}
|
||||||
|
if m.FieldCleared(apikey.FieldIPBlacklist) {
|
||||||
|
fields = append(fields, apikey.FieldIPBlacklist)
|
||||||
|
}
|
||||||
return fields
|
return fields
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -828,6 +996,12 @@ func (m *APIKeyMutation) ClearField(name string) error {
|
|||||||
case apikey.FieldGroupID:
|
case apikey.FieldGroupID:
|
||||||
m.ClearGroupID()
|
m.ClearGroupID()
|
||||||
return nil
|
return nil
|
||||||
|
case apikey.FieldIPWhitelist:
|
||||||
|
m.ClearIPWhitelist()
|
||||||
|
return nil
|
||||||
|
case apikey.FieldIPBlacklist:
|
||||||
|
m.ClearIPBlacklist()
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
return fmt.Errorf("unknown APIKey nullable field %s", name)
|
return fmt.Errorf("unknown APIKey nullable field %s", name)
|
||||||
}
|
}
|
||||||
@@ -860,6 +1034,12 @@ func (m *APIKeyMutation) ResetField(name string) error {
|
|||||||
case apikey.FieldStatus:
|
case apikey.FieldStatus:
|
||||||
m.ResetStatus()
|
m.ResetStatus()
|
||||||
return nil
|
return nil
|
||||||
|
case apikey.FieldIPWhitelist:
|
||||||
|
m.ResetIPWhitelist()
|
||||||
|
return nil
|
||||||
|
case apikey.FieldIPBlacklist:
|
||||||
|
m.ResetIPBlacklist()
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
return fmt.Errorf("unknown APIKey field %s", name)
|
return fmt.Errorf("unknown APIKey field %s", name)
|
||||||
}
|
}
|
||||||
@@ -8396,6 +8576,7 @@ type UsageLogMutation struct {
|
|||||||
first_token_ms *int
|
first_token_ms *int
|
||||||
addfirst_token_ms *int
|
addfirst_token_ms *int
|
||||||
user_agent *string
|
user_agent *string
|
||||||
|
ip_address *string
|
||||||
image_count *int
|
image_count *int
|
||||||
addimage_count *int
|
addimage_count *int
|
||||||
image_size *string
|
image_size *string
|
||||||
@@ -9801,6 +9982,55 @@ func (m *UsageLogMutation) ResetUserAgent() {
|
|||||||
delete(m.clearedFields, usagelog.FieldUserAgent)
|
delete(m.clearedFields, usagelog.FieldUserAgent)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetIPAddress sets the "ip_address" field.
|
||||||
|
func (m *UsageLogMutation) SetIPAddress(s string) {
|
||||||
|
m.ip_address = &s
|
||||||
|
}
|
||||||
|
|
||||||
|
// IPAddress returns the value of the "ip_address" field in the mutation.
|
||||||
|
func (m *UsageLogMutation) IPAddress() (r string, exists bool) {
|
||||||
|
v := m.ip_address
|
||||||
|
if v == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return *v, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// OldIPAddress returns the old "ip_address" field's value of the UsageLog entity.
|
||||||
|
// If the UsageLog object wasn't provided to the builder, the object is fetched from the database.
|
||||||
|
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||||
|
func (m *UsageLogMutation) OldIPAddress(ctx context.Context) (v *string, err error) {
|
||||||
|
if !m.op.Is(OpUpdateOne) {
|
||||||
|
return v, errors.New("OldIPAddress is only allowed on UpdateOne operations")
|
||||||
|
}
|
||||||
|
if m.id == nil || m.oldValue == nil {
|
||||||
|
return v, errors.New("OldIPAddress requires an ID field in the mutation")
|
||||||
|
}
|
||||||
|
oldValue, err := m.oldValue(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return v, fmt.Errorf("querying old value for OldIPAddress: %w", err)
|
||||||
|
}
|
||||||
|
return oldValue.IPAddress, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearIPAddress clears the value of the "ip_address" field.
|
||||||
|
func (m *UsageLogMutation) ClearIPAddress() {
|
||||||
|
m.ip_address = nil
|
||||||
|
m.clearedFields[usagelog.FieldIPAddress] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IPAddressCleared returns if the "ip_address" field was cleared in this mutation.
|
||||||
|
func (m *UsageLogMutation) IPAddressCleared() bool {
|
||||||
|
_, ok := m.clearedFields[usagelog.FieldIPAddress]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetIPAddress resets all changes to the "ip_address" field.
|
||||||
|
func (m *UsageLogMutation) ResetIPAddress() {
|
||||||
|
m.ip_address = nil
|
||||||
|
delete(m.clearedFields, usagelog.FieldIPAddress)
|
||||||
|
}
|
||||||
|
|
||||||
// SetImageCount sets the "image_count" field.
|
// SetImageCount sets the "image_count" field.
|
||||||
func (m *UsageLogMutation) SetImageCount(i int) {
|
func (m *UsageLogMutation) SetImageCount(i int) {
|
||||||
m.image_count = &i
|
m.image_count = &i
|
||||||
@@ -10111,7 +10341,7 @@ func (m *UsageLogMutation) Type() string {
|
|||||||
// order to get all numeric fields that were incremented/decremented, call
|
// order to get all numeric fields that were incremented/decremented, call
|
||||||
// AddedFields().
|
// AddedFields().
|
||||||
func (m *UsageLogMutation) Fields() []string {
|
func (m *UsageLogMutation) Fields() []string {
|
||||||
fields := make([]string, 0, 28)
|
fields := make([]string, 0, 29)
|
||||||
if m.user != nil {
|
if m.user != nil {
|
||||||
fields = append(fields, usagelog.FieldUserID)
|
fields = append(fields, usagelog.FieldUserID)
|
||||||
}
|
}
|
||||||
@@ -10187,6 +10417,9 @@ func (m *UsageLogMutation) Fields() []string {
|
|||||||
if m.user_agent != nil {
|
if m.user_agent != nil {
|
||||||
fields = append(fields, usagelog.FieldUserAgent)
|
fields = append(fields, usagelog.FieldUserAgent)
|
||||||
}
|
}
|
||||||
|
if m.ip_address != nil {
|
||||||
|
fields = append(fields, usagelog.FieldIPAddress)
|
||||||
|
}
|
||||||
if m.image_count != nil {
|
if m.image_count != nil {
|
||||||
fields = append(fields, usagelog.FieldImageCount)
|
fields = append(fields, usagelog.FieldImageCount)
|
||||||
}
|
}
|
||||||
@@ -10254,6 +10487,8 @@ func (m *UsageLogMutation) Field(name string) (ent.Value, bool) {
|
|||||||
return m.FirstTokenMs()
|
return m.FirstTokenMs()
|
||||||
case usagelog.FieldUserAgent:
|
case usagelog.FieldUserAgent:
|
||||||
return m.UserAgent()
|
return m.UserAgent()
|
||||||
|
case usagelog.FieldIPAddress:
|
||||||
|
return m.IPAddress()
|
||||||
case usagelog.FieldImageCount:
|
case usagelog.FieldImageCount:
|
||||||
return m.ImageCount()
|
return m.ImageCount()
|
||||||
case usagelog.FieldImageSize:
|
case usagelog.FieldImageSize:
|
||||||
@@ -10319,6 +10554,8 @@ func (m *UsageLogMutation) OldField(ctx context.Context, name string) (ent.Value
|
|||||||
return m.OldFirstTokenMs(ctx)
|
return m.OldFirstTokenMs(ctx)
|
||||||
case usagelog.FieldUserAgent:
|
case usagelog.FieldUserAgent:
|
||||||
return m.OldUserAgent(ctx)
|
return m.OldUserAgent(ctx)
|
||||||
|
case usagelog.FieldIPAddress:
|
||||||
|
return m.OldIPAddress(ctx)
|
||||||
case usagelog.FieldImageCount:
|
case usagelog.FieldImageCount:
|
||||||
return m.OldImageCount(ctx)
|
return m.OldImageCount(ctx)
|
||||||
case usagelog.FieldImageSize:
|
case usagelog.FieldImageSize:
|
||||||
@@ -10509,6 +10746,13 @@ func (m *UsageLogMutation) SetField(name string, value ent.Value) error {
|
|||||||
}
|
}
|
||||||
m.SetUserAgent(v)
|
m.SetUserAgent(v)
|
||||||
return nil
|
return nil
|
||||||
|
case usagelog.FieldIPAddress:
|
||||||
|
v, ok := value.(string)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||||
|
}
|
||||||
|
m.SetIPAddress(v)
|
||||||
|
return nil
|
||||||
case usagelog.FieldImageCount:
|
case usagelog.FieldImageCount:
|
||||||
v, ok := value.(int)
|
v, ok := value.(int)
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -10782,6 +11026,9 @@ func (m *UsageLogMutation) ClearedFields() []string {
|
|||||||
if m.FieldCleared(usagelog.FieldUserAgent) {
|
if m.FieldCleared(usagelog.FieldUserAgent) {
|
||||||
fields = append(fields, usagelog.FieldUserAgent)
|
fields = append(fields, usagelog.FieldUserAgent)
|
||||||
}
|
}
|
||||||
|
if m.FieldCleared(usagelog.FieldIPAddress) {
|
||||||
|
fields = append(fields, usagelog.FieldIPAddress)
|
||||||
|
}
|
||||||
if m.FieldCleared(usagelog.FieldImageSize) {
|
if m.FieldCleared(usagelog.FieldImageSize) {
|
||||||
fields = append(fields, usagelog.FieldImageSize)
|
fields = append(fields, usagelog.FieldImageSize)
|
||||||
}
|
}
|
||||||
@@ -10814,6 +11061,9 @@ func (m *UsageLogMutation) ClearField(name string) error {
|
|||||||
case usagelog.FieldUserAgent:
|
case usagelog.FieldUserAgent:
|
||||||
m.ClearUserAgent()
|
m.ClearUserAgent()
|
||||||
return nil
|
return nil
|
||||||
|
case usagelog.FieldIPAddress:
|
||||||
|
m.ClearIPAddress()
|
||||||
|
return nil
|
||||||
case usagelog.FieldImageSize:
|
case usagelog.FieldImageSize:
|
||||||
m.ClearImageSize()
|
m.ClearImageSize()
|
||||||
return nil
|
return nil
|
||||||
@@ -10900,6 +11150,9 @@ func (m *UsageLogMutation) ResetField(name string) error {
|
|||||||
case usagelog.FieldUserAgent:
|
case usagelog.FieldUserAgent:
|
||||||
m.ResetUserAgent()
|
m.ResetUserAgent()
|
||||||
return nil
|
return nil
|
||||||
|
case usagelog.FieldIPAddress:
|
||||||
|
m.ResetIPAddress()
|
||||||
|
return nil
|
||||||
case usagelog.FieldImageCount:
|
case usagelog.FieldImageCount:
|
||||||
m.ResetImageCount()
|
m.ResetImageCount()
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -533,16 +533,20 @@ func init() {
|
|||||||
usagelogDescUserAgent := usagelogFields[24].Descriptor()
|
usagelogDescUserAgent := usagelogFields[24].Descriptor()
|
||||||
// usagelog.UserAgentValidator is a validator for the "user_agent" field. It is called by the builders before save.
|
// usagelog.UserAgentValidator is a validator for the "user_agent" field. It is called by the builders before save.
|
||||||
usagelog.UserAgentValidator = usagelogDescUserAgent.Validators[0].(func(string) error)
|
usagelog.UserAgentValidator = usagelogDescUserAgent.Validators[0].(func(string) error)
|
||||||
|
// usagelogDescIPAddress is the schema descriptor for ip_address field.
|
||||||
|
usagelogDescIPAddress := usagelogFields[25].Descriptor()
|
||||||
|
// usagelog.IPAddressValidator is a validator for the "ip_address" field. It is called by the builders before save.
|
||||||
|
usagelog.IPAddressValidator = usagelogDescIPAddress.Validators[0].(func(string) error)
|
||||||
// usagelogDescImageCount is the schema descriptor for image_count field.
|
// usagelogDescImageCount is the schema descriptor for image_count field.
|
||||||
usagelogDescImageCount := usagelogFields[25].Descriptor()
|
usagelogDescImageCount := usagelogFields[26].Descriptor()
|
||||||
// usagelog.DefaultImageCount holds the default value on creation for the image_count field.
|
// usagelog.DefaultImageCount holds the default value on creation for the image_count field.
|
||||||
usagelog.DefaultImageCount = usagelogDescImageCount.Default.(int)
|
usagelog.DefaultImageCount = usagelogDescImageCount.Default.(int)
|
||||||
// usagelogDescImageSize is the schema descriptor for image_size field.
|
// usagelogDescImageSize is the schema descriptor for image_size field.
|
||||||
usagelogDescImageSize := usagelogFields[26].Descriptor()
|
usagelogDescImageSize := usagelogFields[27].Descriptor()
|
||||||
// usagelog.ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save.
|
// usagelog.ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save.
|
||||||
usagelog.ImageSizeValidator = usagelogDescImageSize.Validators[0].(func(string) error)
|
usagelog.ImageSizeValidator = usagelogDescImageSize.Validators[0].(func(string) error)
|
||||||
// usagelogDescCreatedAt is the schema descriptor for created_at field.
|
// usagelogDescCreatedAt is the schema descriptor for created_at field.
|
||||||
usagelogDescCreatedAt := usagelogFields[27].Descriptor()
|
usagelogDescCreatedAt := usagelogFields[28].Descriptor()
|
||||||
// usagelog.DefaultCreatedAt holds the default value on creation for the created_at field.
|
// usagelog.DefaultCreatedAt holds the default value on creation for the created_at field.
|
||||||
usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time)
|
usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time)
|
||||||
userMixin := schema.User{}.Mixin()
|
userMixin := schema.User{}.Mixin()
|
||||||
|
|||||||
@@ -46,6 +46,12 @@ func (APIKey) Fields() []ent.Field {
|
|||||||
field.String("status").
|
field.String("status").
|
||||||
MaxLen(20).
|
MaxLen(20).
|
||||||
Default(service.StatusActive),
|
Default(service.StatusActive),
|
||||||
|
field.JSON("ip_whitelist", []string{}).
|
||||||
|
Optional().
|
||||||
|
Comment("Allowed IPs/CIDRs, e.g. [\"192.168.1.100\", \"10.0.0.0/8\"]"),
|
||||||
|
field.JSON("ip_blacklist", []string{}).
|
||||||
|
Optional().
|
||||||
|
Comment("Blocked IPs/CIDRs"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -100,6 +100,10 @@ func (UsageLog) Fields() []ent.Field {
|
|||||||
MaxLen(512).
|
MaxLen(512).
|
||||||
Optional().
|
Optional().
|
||||||
Nillable(),
|
Nillable(),
|
||||||
|
field.String("ip_address").
|
||||||
|
MaxLen(45). // 支持 IPv6
|
||||||
|
Optional().
|
||||||
|
Nillable(),
|
||||||
|
|
||||||
// 图片生成字段(仅 gemini-3-pro-image 等图片模型使用)
|
// 图片生成字段(仅 gemini-3-pro-image 等图片模型使用)
|
||||||
field.Int("image_count").
|
field.Int("image_count").
|
||||||
|
|||||||
@@ -72,6 +72,8 @@ type UsageLog struct {
|
|||||||
FirstTokenMs *int `json:"first_token_ms,omitempty"`
|
FirstTokenMs *int `json:"first_token_ms,omitempty"`
|
||||||
// UserAgent holds the value of the "user_agent" field.
|
// UserAgent holds the value of the "user_agent" field.
|
||||||
UserAgent *string `json:"user_agent,omitempty"`
|
UserAgent *string `json:"user_agent,omitempty"`
|
||||||
|
// IPAddress holds the value of the "ip_address" field.
|
||||||
|
IPAddress *string `json:"ip_address,omitempty"`
|
||||||
// ImageCount holds the value of the "image_count" field.
|
// ImageCount holds the value of the "image_count" field.
|
||||||
ImageCount int `json:"image_count,omitempty"`
|
ImageCount int `json:"image_count,omitempty"`
|
||||||
// ImageSize holds the value of the "image_size" field.
|
// ImageSize holds the value of the "image_size" field.
|
||||||
@@ -167,7 +169,7 @@ func (*UsageLog) scanValues(columns []string) ([]any, error) {
|
|||||||
values[i] = new(sql.NullFloat64)
|
values[i] = new(sql.NullFloat64)
|
||||||
case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount:
|
case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount:
|
||||||
values[i] = new(sql.NullInt64)
|
values[i] = new(sql.NullInt64)
|
||||||
case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldUserAgent, usagelog.FieldImageSize:
|
case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize:
|
||||||
values[i] = new(sql.NullString)
|
values[i] = new(sql.NullString)
|
||||||
case usagelog.FieldCreatedAt:
|
case usagelog.FieldCreatedAt:
|
||||||
values[i] = new(sql.NullTime)
|
values[i] = new(sql.NullTime)
|
||||||
@@ -347,6 +349,13 @@ func (_m *UsageLog) assignValues(columns []string, values []any) error {
|
|||||||
_m.UserAgent = new(string)
|
_m.UserAgent = new(string)
|
||||||
*_m.UserAgent = value.String
|
*_m.UserAgent = value.String
|
||||||
}
|
}
|
||||||
|
case usagelog.FieldIPAddress:
|
||||||
|
if value, ok := values[i].(*sql.NullString); !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field ip_address", values[i])
|
||||||
|
} else if value.Valid {
|
||||||
|
_m.IPAddress = new(string)
|
||||||
|
*_m.IPAddress = value.String
|
||||||
|
}
|
||||||
case usagelog.FieldImageCount:
|
case usagelog.FieldImageCount:
|
||||||
if value, ok := values[i].(*sql.NullInt64); !ok {
|
if value, ok := values[i].(*sql.NullInt64); !ok {
|
||||||
return fmt.Errorf("unexpected type %T for field image_count", values[i])
|
return fmt.Errorf("unexpected type %T for field image_count", values[i])
|
||||||
@@ -512,6 +521,11 @@ func (_m *UsageLog) String() string {
|
|||||||
builder.WriteString(*v)
|
builder.WriteString(*v)
|
||||||
}
|
}
|
||||||
builder.WriteString(", ")
|
builder.WriteString(", ")
|
||||||
|
if v := _m.IPAddress; v != nil {
|
||||||
|
builder.WriteString("ip_address=")
|
||||||
|
builder.WriteString(*v)
|
||||||
|
}
|
||||||
|
builder.WriteString(", ")
|
||||||
builder.WriteString("image_count=")
|
builder.WriteString("image_count=")
|
||||||
builder.WriteString(fmt.Sprintf("%v", _m.ImageCount))
|
builder.WriteString(fmt.Sprintf("%v", _m.ImageCount))
|
||||||
builder.WriteString(", ")
|
builder.WriteString(", ")
|
||||||
|
|||||||
@@ -64,6 +64,8 @@ const (
|
|||||||
FieldFirstTokenMs = "first_token_ms"
|
FieldFirstTokenMs = "first_token_ms"
|
||||||
// FieldUserAgent holds the string denoting the user_agent field in the database.
|
// FieldUserAgent holds the string denoting the user_agent field in the database.
|
||||||
FieldUserAgent = "user_agent"
|
FieldUserAgent = "user_agent"
|
||||||
|
// FieldIPAddress holds the string denoting the ip_address field in the database.
|
||||||
|
FieldIPAddress = "ip_address"
|
||||||
// FieldImageCount holds the string denoting the image_count field in the database.
|
// FieldImageCount holds the string denoting the image_count field in the database.
|
||||||
FieldImageCount = "image_count"
|
FieldImageCount = "image_count"
|
||||||
// FieldImageSize holds the string denoting the image_size field in the database.
|
// FieldImageSize holds the string denoting the image_size field in the database.
|
||||||
@@ -147,6 +149,7 @@ var Columns = []string{
|
|||||||
FieldDurationMs,
|
FieldDurationMs,
|
||||||
FieldFirstTokenMs,
|
FieldFirstTokenMs,
|
||||||
FieldUserAgent,
|
FieldUserAgent,
|
||||||
|
FieldIPAddress,
|
||||||
FieldImageCount,
|
FieldImageCount,
|
||||||
FieldImageSize,
|
FieldImageSize,
|
||||||
FieldCreatedAt,
|
FieldCreatedAt,
|
||||||
@@ -199,6 +202,8 @@ var (
|
|||||||
DefaultStream bool
|
DefaultStream bool
|
||||||
// UserAgentValidator is a validator for the "user_agent" field. It is called by the builders before save.
|
// UserAgentValidator is a validator for the "user_agent" field. It is called by the builders before save.
|
||||||
UserAgentValidator func(string) error
|
UserAgentValidator func(string) error
|
||||||
|
// IPAddressValidator is a validator for the "ip_address" field. It is called by the builders before save.
|
||||||
|
IPAddressValidator func(string) error
|
||||||
// DefaultImageCount holds the default value on creation for the "image_count" field.
|
// DefaultImageCount holds the default value on creation for the "image_count" field.
|
||||||
DefaultImageCount int
|
DefaultImageCount int
|
||||||
// ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save.
|
// ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save.
|
||||||
@@ -340,6 +345,11 @@ func ByUserAgent(opts ...sql.OrderTermOption) OrderOption {
|
|||||||
return sql.OrderByField(FieldUserAgent, opts...).ToFunc()
|
return sql.OrderByField(FieldUserAgent, opts...).ToFunc()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ByIPAddress orders the results by the ip_address field.
|
||||||
|
func ByIPAddress(opts ...sql.OrderTermOption) OrderOption {
|
||||||
|
return sql.OrderByField(FieldIPAddress, opts...).ToFunc()
|
||||||
|
}
|
||||||
|
|
||||||
// ByImageCount orders the results by the image_count field.
|
// ByImageCount orders the results by the image_count field.
|
||||||
func ByImageCount(opts ...sql.OrderTermOption) OrderOption {
|
func ByImageCount(opts ...sql.OrderTermOption) OrderOption {
|
||||||
return sql.OrderByField(FieldImageCount, opts...).ToFunc()
|
return sql.OrderByField(FieldImageCount, opts...).ToFunc()
|
||||||
|
|||||||
@@ -180,6 +180,11 @@ func UserAgent(v string) predicate.UsageLog {
|
|||||||
return predicate.UsageLog(sql.FieldEQ(FieldUserAgent, v))
|
return predicate.UsageLog(sql.FieldEQ(FieldUserAgent, v))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IPAddress applies equality check predicate on the "ip_address" field. It's identical to IPAddressEQ.
|
||||||
|
func IPAddress(v string) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldEQ(FieldIPAddress, v))
|
||||||
|
}
|
||||||
|
|
||||||
// ImageCount applies equality check predicate on the "image_count" field. It's identical to ImageCountEQ.
|
// ImageCount applies equality check predicate on the "image_count" field. It's identical to ImageCountEQ.
|
||||||
func ImageCount(v int) predicate.UsageLog {
|
func ImageCount(v int) predicate.UsageLog {
|
||||||
return predicate.UsageLog(sql.FieldEQ(FieldImageCount, v))
|
return predicate.UsageLog(sql.FieldEQ(FieldImageCount, v))
|
||||||
@@ -1190,6 +1195,81 @@ func UserAgentContainsFold(v string) predicate.UsageLog {
|
|||||||
return predicate.UsageLog(sql.FieldContainsFold(FieldUserAgent, v))
|
return predicate.UsageLog(sql.FieldContainsFold(FieldUserAgent, v))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IPAddressEQ applies the EQ predicate on the "ip_address" field.
|
||||||
|
func IPAddressEQ(v string) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldEQ(FieldIPAddress, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// IPAddressNEQ applies the NEQ predicate on the "ip_address" field.
|
||||||
|
func IPAddressNEQ(v string) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldNEQ(FieldIPAddress, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// IPAddressIn applies the In predicate on the "ip_address" field.
|
||||||
|
func IPAddressIn(vs ...string) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldIn(FieldIPAddress, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// IPAddressNotIn applies the NotIn predicate on the "ip_address" field.
|
||||||
|
func IPAddressNotIn(vs ...string) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldNotIn(FieldIPAddress, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// IPAddressGT applies the GT predicate on the "ip_address" field.
|
||||||
|
func IPAddressGT(v string) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldGT(FieldIPAddress, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// IPAddressGTE applies the GTE predicate on the "ip_address" field.
|
||||||
|
func IPAddressGTE(v string) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldGTE(FieldIPAddress, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// IPAddressLT applies the LT predicate on the "ip_address" field.
|
||||||
|
func IPAddressLT(v string) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldLT(FieldIPAddress, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// IPAddressLTE applies the LTE predicate on the "ip_address" field.
|
||||||
|
func IPAddressLTE(v string) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldLTE(FieldIPAddress, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// IPAddressContains applies the Contains predicate on the "ip_address" field.
|
||||||
|
func IPAddressContains(v string) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldContains(FieldIPAddress, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// IPAddressHasPrefix applies the HasPrefix predicate on the "ip_address" field.
|
||||||
|
func IPAddressHasPrefix(v string) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldHasPrefix(FieldIPAddress, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// IPAddressHasSuffix applies the HasSuffix predicate on the "ip_address" field.
|
||||||
|
func IPAddressHasSuffix(v string) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldHasSuffix(FieldIPAddress, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// IPAddressIsNil applies the IsNil predicate on the "ip_address" field.
|
||||||
|
func IPAddressIsNil() predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldIsNull(FieldIPAddress))
|
||||||
|
}
|
||||||
|
|
||||||
|
// IPAddressNotNil applies the NotNil predicate on the "ip_address" field.
|
||||||
|
func IPAddressNotNil() predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldNotNull(FieldIPAddress))
|
||||||
|
}
|
||||||
|
|
||||||
|
// IPAddressEqualFold applies the EqualFold predicate on the "ip_address" field.
|
||||||
|
func IPAddressEqualFold(v string) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldEqualFold(FieldIPAddress, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// IPAddressContainsFold applies the ContainsFold predicate on the "ip_address" field.
|
||||||
|
func IPAddressContainsFold(v string) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldContainsFold(FieldIPAddress, v))
|
||||||
|
}
|
||||||
|
|
||||||
// ImageCountEQ applies the EQ predicate on the "image_count" field.
|
// ImageCountEQ applies the EQ predicate on the "image_count" field.
|
||||||
func ImageCountEQ(v int) predicate.UsageLog {
|
func ImageCountEQ(v int) predicate.UsageLog {
|
||||||
return predicate.UsageLog(sql.FieldEQ(FieldImageCount, v))
|
return predicate.UsageLog(sql.FieldEQ(FieldImageCount, v))
|
||||||
|
|||||||
@@ -337,6 +337,20 @@ func (_c *UsageLogCreate) SetNillableUserAgent(v *string) *UsageLogCreate {
|
|||||||
return _c
|
return _c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetIPAddress sets the "ip_address" field.
|
||||||
|
func (_c *UsageLogCreate) SetIPAddress(v string) *UsageLogCreate {
|
||||||
|
_c.mutation.SetIPAddress(v)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableIPAddress sets the "ip_address" field if the given value is not nil.
|
||||||
|
func (_c *UsageLogCreate) SetNillableIPAddress(v *string) *UsageLogCreate {
|
||||||
|
if v != nil {
|
||||||
|
_c.SetIPAddress(*v)
|
||||||
|
}
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
// SetImageCount sets the "image_count" field.
|
// SetImageCount sets the "image_count" field.
|
||||||
func (_c *UsageLogCreate) SetImageCount(v int) *UsageLogCreate {
|
func (_c *UsageLogCreate) SetImageCount(v int) *UsageLogCreate {
|
||||||
_c.mutation.SetImageCount(v)
|
_c.mutation.SetImageCount(v)
|
||||||
@@ -586,6 +600,11 @@ func (_c *UsageLogCreate) check() error {
|
|||||||
return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)}
|
return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if v, ok := _c.mutation.IPAddress(); ok {
|
||||||
|
if err := usagelog.IPAddressValidator(v); err != nil {
|
||||||
|
return &ValidationError{Name: "ip_address", err: fmt.Errorf(`ent: validator failed for field "UsageLog.ip_address": %w`, err)}
|
||||||
|
}
|
||||||
|
}
|
||||||
if _, ok := _c.mutation.ImageCount(); !ok {
|
if _, ok := _c.mutation.ImageCount(); !ok {
|
||||||
return &ValidationError{Name: "image_count", err: errors.New(`ent: missing required field "UsageLog.image_count"`)}
|
return &ValidationError{Name: "image_count", err: errors.New(`ent: missing required field "UsageLog.image_count"`)}
|
||||||
}
|
}
|
||||||
@@ -713,6 +732,10 @@ func (_c *UsageLogCreate) createSpec() (*UsageLog, *sqlgraph.CreateSpec) {
|
|||||||
_spec.SetField(usagelog.FieldUserAgent, field.TypeString, value)
|
_spec.SetField(usagelog.FieldUserAgent, field.TypeString, value)
|
||||||
_node.UserAgent = &value
|
_node.UserAgent = &value
|
||||||
}
|
}
|
||||||
|
if value, ok := _c.mutation.IPAddress(); ok {
|
||||||
|
_spec.SetField(usagelog.FieldIPAddress, field.TypeString, value)
|
||||||
|
_node.IPAddress = &value
|
||||||
|
}
|
||||||
if value, ok := _c.mutation.ImageCount(); ok {
|
if value, ok := _c.mutation.ImageCount(); ok {
|
||||||
_spec.SetField(usagelog.FieldImageCount, field.TypeInt, value)
|
_spec.SetField(usagelog.FieldImageCount, field.TypeInt, value)
|
||||||
_node.ImageCount = value
|
_node.ImageCount = value
|
||||||
@@ -1288,6 +1311,24 @@ func (u *UsageLogUpsert) ClearUserAgent() *UsageLogUpsert {
|
|||||||
return u
|
return u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetIPAddress sets the "ip_address" field.
|
||||||
|
func (u *UsageLogUpsert) SetIPAddress(v string) *UsageLogUpsert {
|
||||||
|
u.Set(usagelog.FieldIPAddress, v)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateIPAddress sets the "ip_address" field to the value that was provided on create.
|
||||||
|
func (u *UsageLogUpsert) UpdateIPAddress() *UsageLogUpsert {
|
||||||
|
u.SetExcluded(usagelog.FieldIPAddress)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearIPAddress clears the value of the "ip_address" field.
|
||||||
|
func (u *UsageLogUpsert) ClearIPAddress() *UsageLogUpsert {
|
||||||
|
u.SetNull(usagelog.FieldIPAddress)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
// SetImageCount sets the "image_count" field.
|
// SetImageCount sets the "image_count" field.
|
||||||
func (u *UsageLogUpsert) SetImageCount(v int) *UsageLogUpsert {
|
func (u *UsageLogUpsert) SetImageCount(v int) *UsageLogUpsert {
|
||||||
u.Set(usagelog.FieldImageCount, v)
|
u.Set(usagelog.FieldImageCount, v)
|
||||||
@@ -1866,6 +1907,27 @@ func (u *UsageLogUpsertOne) ClearUserAgent() *UsageLogUpsertOne {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetIPAddress sets the "ip_address" field.
|
||||||
|
func (u *UsageLogUpsertOne) SetIPAddress(v string) *UsageLogUpsertOne {
|
||||||
|
return u.Update(func(s *UsageLogUpsert) {
|
||||||
|
s.SetIPAddress(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateIPAddress sets the "ip_address" field to the value that was provided on create.
|
||||||
|
func (u *UsageLogUpsertOne) UpdateIPAddress() *UsageLogUpsertOne {
|
||||||
|
return u.Update(func(s *UsageLogUpsert) {
|
||||||
|
s.UpdateIPAddress()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearIPAddress clears the value of the "ip_address" field.
|
||||||
|
func (u *UsageLogUpsertOne) ClearIPAddress() *UsageLogUpsertOne {
|
||||||
|
return u.Update(func(s *UsageLogUpsert) {
|
||||||
|
s.ClearIPAddress()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// SetImageCount sets the "image_count" field.
|
// SetImageCount sets the "image_count" field.
|
||||||
func (u *UsageLogUpsertOne) SetImageCount(v int) *UsageLogUpsertOne {
|
func (u *UsageLogUpsertOne) SetImageCount(v int) *UsageLogUpsertOne {
|
||||||
return u.Update(func(s *UsageLogUpsert) {
|
return u.Update(func(s *UsageLogUpsert) {
|
||||||
@@ -2616,6 +2678,27 @@ func (u *UsageLogUpsertBulk) ClearUserAgent() *UsageLogUpsertBulk {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetIPAddress sets the "ip_address" field.
|
||||||
|
func (u *UsageLogUpsertBulk) SetIPAddress(v string) *UsageLogUpsertBulk {
|
||||||
|
return u.Update(func(s *UsageLogUpsert) {
|
||||||
|
s.SetIPAddress(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateIPAddress sets the "ip_address" field to the value that was provided on create.
|
||||||
|
func (u *UsageLogUpsertBulk) UpdateIPAddress() *UsageLogUpsertBulk {
|
||||||
|
return u.Update(func(s *UsageLogUpsert) {
|
||||||
|
s.UpdateIPAddress()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearIPAddress clears the value of the "ip_address" field.
|
||||||
|
func (u *UsageLogUpsertBulk) ClearIPAddress() *UsageLogUpsertBulk {
|
||||||
|
return u.Update(func(s *UsageLogUpsert) {
|
||||||
|
s.ClearIPAddress()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// SetImageCount sets the "image_count" field.
|
// SetImageCount sets the "image_count" field.
|
||||||
func (u *UsageLogUpsertBulk) SetImageCount(v int) *UsageLogUpsertBulk {
|
func (u *UsageLogUpsertBulk) SetImageCount(v int) *UsageLogUpsertBulk {
|
||||||
return u.Update(func(s *UsageLogUpsert) {
|
return u.Update(func(s *UsageLogUpsert) {
|
||||||
|
|||||||
@@ -524,6 +524,26 @@ func (_u *UsageLogUpdate) ClearUserAgent() *UsageLogUpdate {
|
|||||||
return _u
|
return _u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetIPAddress sets the "ip_address" field.
|
||||||
|
func (_u *UsageLogUpdate) SetIPAddress(v string) *UsageLogUpdate {
|
||||||
|
_u.mutation.SetIPAddress(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableIPAddress sets the "ip_address" field if the given value is not nil.
|
||||||
|
func (_u *UsageLogUpdate) SetNillableIPAddress(v *string) *UsageLogUpdate {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetIPAddress(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearIPAddress clears the value of the "ip_address" field.
|
||||||
|
func (_u *UsageLogUpdate) ClearIPAddress() *UsageLogUpdate {
|
||||||
|
_u.mutation.ClearIPAddress()
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
// SetImageCount sets the "image_count" field.
|
// SetImageCount sets the "image_count" field.
|
||||||
func (_u *UsageLogUpdate) SetImageCount(v int) *UsageLogUpdate {
|
func (_u *UsageLogUpdate) SetImageCount(v int) *UsageLogUpdate {
|
||||||
_u.mutation.ResetImageCount()
|
_u.mutation.ResetImageCount()
|
||||||
@@ -669,6 +689,11 @@ func (_u *UsageLogUpdate) check() error {
|
|||||||
return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)}
|
return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if v, ok := _u.mutation.IPAddress(); ok {
|
||||||
|
if err := usagelog.IPAddressValidator(v); err != nil {
|
||||||
|
return &ValidationError{Name: "ip_address", err: fmt.Errorf(`ent: validator failed for field "UsageLog.ip_address": %w`, err)}
|
||||||
|
}
|
||||||
|
}
|
||||||
if v, ok := _u.mutation.ImageSize(); ok {
|
if v, ok := _u.mutation.ImageSize(); ok {
|
||||||
if err := usagelog.ImageSizeValidator(v); err != nil {
|
if err := usagelog.ImageSizeValidator(v); err != nil {
|
||||||
return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)}
|
return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)}
|
||||||
@@ -815,6 +840,12 @@ func (_u *UsageLogUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
|||||||
if _u.mutation.UserAgentCleared() {
|
if _u.mutation.UserAgentCleared() {
|
||||||
_spec.ClearField(usagelog.FieldUserAgent, field.TypeString)
|
_spec.ClearField(usagelog.FieldUserAgent, field.TypeString)
|
||||||
}
|
}
|
||||||
|
if value, ok := _u.mutation.IPAddress(); ok {
|
||||||
|
_spec.SetField(usagelog.FieldIPAddress, field.TypeString, value)
|
||||||
|
}
|
||||||
|
if _u.mutation.IPAddressCleared() {
|
||||||
|
_spec.ClearField(usagelog.FieldIPAddress, field.TypeString)
|
||||||
|
}
|
||||||
if value, ok := _u.mutation.ImageCount(); ok {
|
if value, ok := _u.mutation.ImageCount(); ok {
|
||||||
_spec.SetField(usagelog.FieldImageCount, field.TypeInt, value)
|
_spec.SetField(usagelog.FieldImageCount, field.TypeInt, value)
|
||||||
}
|
}
|
||||||
@@ -1484,6 +1515,26 @@ func (_u *UsageLogUpdateOne) ClearUserAgent() *UsageLogUpdateOne {
|
|||||||
return _u
|
return _u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetIPAddress sets the "ip_address" field.
|
||||||
|
func (_u *UsageLogUpdateOne) SetIPAddress(v string) *UsageLogUpdateOne {
|
||||||
|
_u.mutation.SetIPAddress(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableIPAddress sets the "ip_address" field if the given value is not nil.
|
||||||
|
func (_u *UsageLogUpdateOne) SetNillableIPAddress(v *string) *UsageLogUpdateOne {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetIPAddress(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearIPAddress clears the value of the "ip_address" field.
|
||||||
|
func (_u *UsageLogUpdateOne) ClearIPAddress() *UsageLogUpdateOne {
|
||||||
|
_u.mutation.ClearIPAddress()
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
// SetImageCount sets the "image_count" field.
|
// SetImageCount sets the "image_count" field.
|
||||||
func (_u *UsageLogUpdateOne) SetImageCount(v int) *UsageLogUpdateOne {
|
func (_u *UsageLogUpdateOne) SetImageCount(v int) *UsageLogUpdateOne {
|
||||||
_u.mutation.ResetImageCount()
|
_u.mutation.ResetImageCount()
|
||||||
@@ -1642,6 +1693,11 @@ func (_u *UsageLogUpdateOne) check() error {
|
|||||||
return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)}
|
return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if v, ok := _u.mutation.IPAddress(); ok {
|
||||||
|
if err := usagelog.IPAddressValidator(v); err != nil {
|
||||||
|
return &ValidationError{Name: "ip_address", err: fmt.Errorf(`ent: validator failed for field "UsageLog.ip_address": %w`, err)}
|
||||||
|
}
|
||||||
|
}
|
||||||
if v, ok := _u.mutation.ImageSize(); ok {
|
if v, ok := _u.mutation.ImageSize(); ok {
|
||||||
if err := usagelog.ImageSizeValidator(v); err != nil {
|
if err := usagelog.ImageSizeValidator(v); err != nil {
|
||||||
return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)}
|
return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)}
|
||||||
@@ -1805,6 +1861,12 @@ func (_u *UsageLogUpdateOne) sqlSave(ctx context.Context) (_node *UsageLog, err
|
|||||||
if _u.mutation.UserAgentCleared() {
|
if _u.mutation.UserAgentCleared() {
|
||||||
_spec.ClearField(usagelog.FieldUserAgent, field.TypeString)
|
_spec.ClearField(usagelog.FieldUserAgent, field.TypeString)
|
||||||
}
|
}
|
||||||
|
if value, ok := _u.mutation.IPAddress(); ok {
|
||||||
|
_spec.SetField(usagelog.FieldIPAddress, field.TypeString, value)
|
||||||
|
}
|
||||||
|
if _u.mutation.IPAddressCleared() {
|
||||||
|
_spec.ClearField(usagelog.FieldIPAddress, field.TypeString)
|
||||||
|
}
|
||||||
if value, ok := _u.mutation.ImageCount(); ok {
|
if value, ok := _u.mutation.ImageCount(); ok {
|
||||||
_spec.SetField(usagelog.FieldImageCount, field.TypeInt, value)
|
_spec.SetField(usagelog.FieldImageCount, field.TypeInt, value)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@@ -35,24 +36,25 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
Server ServerConfig `mapstructure:"server"`
|
Server ServerConfig `mapstructure:"server"`
|
||||||
CORS CORSConfig `mapstructure:"cors"`
|
CORS CORSConfig `mapstructure:"cors"`
|
||||||
Security SecurityConfig `mapstructure:"security"`
|
Security SecurityConfig `mapstructure:"security"`
|
||||||
Billing BillingConfig `mapstructure:"billing"`
|
Billing BillingConfig `mapstructure:"billing"`
|
||||||
Turnstile TurnstileConfig `mapstructure:"turnstile"`
|
Turnstile TurnstileConfig `mapstructure:"turnstile"`
|
||||||
Database DatabaseConfig `mapstructure:"database"`
|
Database DatabaseConfig `mapstructure:"database"`
|
||||||
Redis RedisConfig `mapstructure:"redis"`
|
Redis RedisConfig `mapstructure:"redis"`
|
||||||
JWT JWTConfig `mapstructure:"jwt"`
|
JWT JWTConfig `mapstructure:"jwt"`
|
||||||
Default DefaultConfig `mapstructure:"default"`
|
LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"`
|
||||||
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
|
Default DefaultConfig `mapstructure:"default"`
|
||||||
Pricing PricingConfig `mapstructure:"pricing"`
|
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
|
||||||
Gateway GatewayConfig `mapstructure:"gateway"`
|
Pricing PricingConfig `mapstructure:"pricing"`
|
||||||
Concurrency ConcurrencyConfig `mapstructure:"concurrency"`
|
Gateway GatewayConfig `mapstructure:"gateway"`
|
||||||
TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"`
|
Concurrency ConcurrencyConfig `mapstructure:"concurrency"`
|
||||||
RunMode string `mapstructure:"run_mode" yaml:"run_mode"`
|
TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"`
|
||||||
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
|
RunMode string `mapstructure:"run_mode" yaml:"run_mode"`
|
||||||
Gemini GeminiConfig `mapstructure:"gemini"`
|
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
|
||||||
Update UpdateConfig `mapstructure:"update"`
|
Gemini GeminiConfig `mapstructure:"gemini"`
|
||||||
|
Update UpdateConfig `mapstructure:"update"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateConfig 在线更新相关配置
|
// UpdateConfig 在线更新相关配置
|
||||||
@@ -322,6 +324,30 @@ type TurnstileConfig struct {
|
|||||||
Required bool `mapstructure:"required"`
|
Required bool `mapstructure:"required"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// LinuxDoConnectConfig 用于 LinuxDo Connect OAuth 登录(终端用户 SSO)。
|
||||||
|
//
|
||||||
|
// 注意:这与上游账号的 OAuth(例如 OpenAI/Gemini 账号接入)不是一回事。
|
||||||
|
// 这里是用于登录 Sub2API 本身的用户体系。
|
||||||
|
type LinuxDoConnectConfig struct {
|
||||||
|
Enabled bool `mapstructure:"enabled"`
|
||||||
|
ClientID string `mapstructure:"client_id"`
|
||||||
|
ClientSecret string `mapstructure:"client_secret"`
|
||||||
|
AuthorizeURL string `mapstructure:"authorize_url"`
|
||||||
|
TokenURL string `mapstructure:"token_url"`
|
||||||
|
UserInfoURL string `mapstructure:"userinfo_url"`
|
||||||
|
Scopes string `mapstructure:"scopes"`
|
||||||
|
RedirectURL string `mapstructure:"redirect_url"` // 后端回调地址(需在提供方后台登记)
|
||||||
|
FrontendRedirectURL string `mapstructure:"frontend_redirect_url"` // 前端接收 token 的路由(默认:/auth/linuxdo/callback)
|
||||||
|
TokenAuthMethod string `mapstructure:"token_auth_method"` // client_secret_post / client_secret_basic / none
|
||||||
|
UsePKCE bool `mapstructure:"use_pkce"`
|
||||||
|
|
||||||
|
// 可选:用于从 userinfo JSON 中提取字段的 gjson 路径。
|
||||||
|
// 为空时,服务端会尝试一组常见字段名。
|
||||||
|
UserInfoEmailPath string `mapstructure:"userinfo_email_path"`
|
||||||
|
UserInfoIDPath string `mapstructure:"userinfo_id_path"`
|
||||||
|
UserInfoUsernamePath string `mapstructure:"userinfo_username_path"`
|
||||||
|
}
|
||||||
|
|
||||||
type DefaultConfig struct {
|
type DefaultConfig struct {
|
||||||
AdminEmail string `mapstructure:"admin_email"`
|
AdminEmail string `mapstructure:"admin_email"`
|
||||||
AdminPassword string `mapstructure:"admin_password"`
|
AdminPassword string `mapstructure:"admin_password"`
|
||||||
@@ -388,6 +414,18 @@ func Load() (*Config, error) {
|
|||||||
cfg.Server.Mode = "debug"
|
cfg.Server.Mode = "debug"
|
||||||
}
|
}
|
||||||
cfg.JWT.Secret = strings.TrimSpace(cfg.JWT.Secret)
|
cfg.JWT.Secret = strings.TrimSpace(cfg.JWT.Secret)
|
||||||
|
cfg.LinuxDo.ClientID = strings.TrimSpace(cfg.LinuxDo.ClientID)
|
||||||
|
cfg.LinuxDo.ClientSecret = strings.TrimSpace(cfg.LinuxDo.ClientSecret)
|
||||||
|
cfg.LinuxDo.AuthorizeURL = strings.TrimSpace(cfg.LinuxDo.AuthorizeURL)
|
||||||
|
cfg.LinuxDo.TokenURL = strings.TrimSpace(cfg.LinuxDo.TokenURL)
|
||||||
|
cfg.LinuxDo.UserInfoURL = strings.TrimSpace(cfg.LinuxDo.UserInfoURL)
|
||||||
|
cfg.LinuxDo.Scopes = strings.TrimSpace(cfg.LinuxDo.Scopes)
|
||||||
|
cfg.LinuxDo.RedirectURL = strings.TrimSpace(cfg.LinuxDo.RedirectURL)
|
||||||
|
cfg.LinuxDo.FrontendRedirectURL = strings.TrimSpace(cfg.LinuxDo.FrontendRedirectURL)
|
||||||
|
cfg.LinuxDo.TokenAuthMethod = strings.ToLower(strings.TrimSpace(cfg.LinuxDo.TokenAuthMethod))
|
||||||
|
cfg.LinuxDo.UserInfoEmailPath = strings.TrimSpace(cfg.LinuxDo.UserInfoEmailPath)
|
||||||
|
cfg.LinuxDo.UserInfoIDPath = strings.TrimSpace(cfg.LinuxDo.UserInfoIDPath)
|
||||||
|
cfg.LinuxDo.UserInfoUsernamePath = strings.TrimSpace(cfg.LinuxDo.UserInfoUsernamePath)
|
||||||
cfg.CORS.AllowedOrigins = normalizeStringSlice(cfg.CORS.AllowedOrigins)
|
cfg.CORS.AllowedOrigins = normalizeStringSlice(cfg.CORS.AllowedOrigins)
|
||||||
cfg.Security.ResponseHeaders.AdditionalAllowed = normalizeStringSlice(cfg.Security.ResponseHeaders.AdditionalAllowed)
|
cfg.Security.ResponseHeaders.AdditionalAllowed = normalizeStringSlice(cfg.Security.ResponseHeaders.AdditionalAllowed)
|
||||||
cfg.Security.ResponseHeaders.ForceRemove = normalizeStringSlice(cfg.Security.ResponseHeaders.ForceRemove)
|
cfg.Security.ResponseHeaders.ForceRemove = normalizeStringSlice(cfg.Security.ResponseHeaders.ForceRemove)
|
||||||
@@ -426,6 +464,81 @@ func Load() (*Config, error) {
|
|||||||
return &cfg, nil
|
return &cfg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ValidateAbsoluteHTTPURL 校验一个绝对 http(s) URL(禁止 fragment)。
|
||||||
|
func ValidateAbsoluteHTTPURL(raw string) error {
|
||||||
|
raw = strings.TrimSpace(raw)
|
||||||
|
if raw == "" {
|
||||||
|
return fmt.Errorf("empty url")
|
||||||
|
}
|
||||||
|
u, err := url.Parse(raw)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !u.IsAbs() {
|
||||||
|
return fmt.Errorf("must be absolute")
|
||||||
|
}
|
||||||
|
if !isHTTPScheme(u.Scheme) {
|
||||||
|
return fmt.Errorf("unsupported scheme: %s", u.Scheme)
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(u.Host) == "" {
|
||||||
|
return fmt.Errorf("missing host")
|
||||||
|
}
|
||||||
|
if u.Fragment != "" {
|
||||||
|
return fmt.Errorf("must not include fragment")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateFrontendRedirectURL 校验前端回调地址:
|
||||||
|
// - 允许同源相对路径(以 / 开头)
|
||||||
|
// - 或绝对 http(s) URL(禁止 fragment)
|
||||||
|
func ValidateFrontendRedirectURL(raw string) error {
|
||||||
|
raw = strings.TrimSpace(raw)
|
||||||
|
if raw == "" {
|
||||||
|
return fmt.Errorf("empty url")
|
||||||
|
}
|
||||||
|
if strings.ContainsAny(raw, "\r\n") {
|
||||||
|
return fmt.Errorf("contains invalid characters")
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(raw, "/") {
|
||||||
|
if strings.HasPrefix(raw, "//") {
|
||||||
|
return fmt.Errorf("must not start with //")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
u, err := url.Parse(raw)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !u.IsAbs() {
|
||||||
|
return fmt.Errorf("must be absolute http(s) url or relative path")
|
||||||
|
}
|
||||||
|
if !isHTTPScheme(u.Scheme) {
|
||||||
|
return fmt.Errorf("unsupported scheme: %s", u.Scheme)
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(u.Host) == "" {
|
||||||
|
return fmt.Errorf("missing host")
|
||||||
|
}
|
||||||
|
if u.Fragment != "" {
|
||||||
|
return fmt.Errorf("must not include fragment")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isHTTPScheme(scheme string) bool {
|
||||||
|
return strings.EqualFold(scheme, "http") || strings.EqualFold(scheme, "https")
|
||||||
|
}
|
||||||
|
|
||||||
|
func warnIfInsecureURL(field, raw string) {
|
||||||
|
u, err := url.Parse(strings.TrimSpace(raw))
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if strings.EqualFold(u.Scheme, "http") {
|
||||||
|
log.Printf("Warning: %s uses http scheme; use https in production to avoid token leakage.", field)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func setDefaults() {
|
func setDefaults() {
|
||||||
viper.SetDefault("run_mode", RunModeStandard)
|
viper.SetDefault("run_mode", RunModeStandard)
|
||||||
|
|
||||||
@@ -475,6 +588,22 @@ func setDefaults() {
|
|||||||
// Turnstile
|
// Turnstile
|
||||||
viper.SetDefault("turnstile.required", false)
|
viper.SetDefault("turnstile.required", false)
|
||||||
|
|
||||||
|
// LinuxDo Connect OAuth 登录(终端用户 SSO)
|
||||||
|
viper.SetDefault("linuxdo_connect.enabled", false)
|
||||||
|
viper.SetDefault("linuxdo_connect.client_id", "")
|
||||||
|
viper.SetDefault("linuxdo_connect.client_secret", "")
|
||||||
|
viper.SetDefault("linuxdo_connect.authorize_url", "https://connect.linux.do/oauth2/authorize")
|
||||||
|
viper.SetDefault("linuxdo_connect.token_url", "https://connect.linux.do/oauth2/token")
|
||||||
|
viper.SetDefault("linuxdo_connect.userinfo_url", "https://connect.linux.do/api/user")
|
||||||
|
viper.SetDefault("linuxdo_connect.scopes", "user")
|
||||||
|
viper.SetDefault("linuxdo_connect.redirect_url", "")
|
||||||
|
viper.SetDefault("linuxdo_connect.frontend_redirect_url", "/auth/linuxdo/callback")
|
||||||
|
viper.SetDefault("linuxdo_connect.token_auth_method", "client_secret_post")
|
||||||
|
viper.SetDefault("linuxdo_connect.use_pkce", false)
|
||||||
|
viper.SetDefault("linuxdo_connect.userinfo_email_path", "")
|
||||||
|
viper.SetDefault("linuxdo_connect.userinfo_id_path", "")
|
||||||
|
viper.SetDefault("linuxdo_connect.userinfo_username_path", "")
|
||||||
|
|
||||||
// Database
|
// Database
|
||||||
viper.SetDefault("database.host", "localhost")
|
viper.SetDefault("database.host", "localhost")
|
||||||
viper.SetDefault("database.port", 5432)
|
viper.SetDefault("database.port", 5432)
|
||||||
@@ -544,7 +673,7 @@ func setDefaults() {
|
|||||||
viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 30) // 并发槽位过期时间(支持超长请求)
|
viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 30) // 并发槽位过期时间(支持超长请求)
|
||||||
viper.SetDefault("gateway.stream_data_interval_timeout", 180)
|
viper.SetDefault("gateway.stream_data_interval_timeout", 180)
|
||||||
viper.SetDefault("gateway.stream_keepalive_interval", 10)
|
viper.SetDefault("gateway.stream_keepalive_interval", 10)
|
||||||
viper.SetDefault("gateway.max_line_size", 10*1024*1024)
|
viper.SetDefault("gateway.max_line_size", 40*1024*1024)
|
||||||
viper.SetDefault("gateway.scheduling.sticky_session_max_waiting", 3)
|
viper.SetDefault("gateway.scheduling.sticky_session_max_waiting", 3)
|
||||||
viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 45*time.Second)
|
viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 45*time.Second)
|
||||||
viper.SetDefault("gateway.scheduling.fallback_wait_timeout", 30*time.Second)
|
viper.SetDefault("gateway.scheduling.fallback_wait_timeout", 30*time.Second)
|
||||||
@@ -586,6 +715,60 @@ func (c *Config) Validate() error {
|
|||||||
if c.Security.CSP.Enabled && strings.TrimSpace(c.Security.CSP.Policy) == "" {
|
if c.Security.CSP.Enabled && strings.TrimSpace(c.Security.CSP.Policy) == "" {
|
||||||
return fmt.Errorf("security.csp.policy is required when CSP is enabled")
|
return fmt.Errorf("security.csp.policy is required when CSP is enabled")
|
||||||
}
|
}
|
||||||
|
if c.LinuxDo.Enabled {
|
||||||
|
if strings.TrimSpace(c.LinuxDo.ClientID) == "" {
|
||||||
|
return fmt.Errorf("linuxdo_connect.client_id is required when linuxdo_connect.enabled=true")
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(c.LinuxDo.AuthorizeURL) == "" {
|
||||||
|
return fmt.Errorf("linuxdo_connect.authorize_url is required when linuxdo_connect.enabled=true")
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(c.LinuxDo.TokenURL) == "" {
|
||||||
|
return fmt.Errorf("linuxdo_connect.token_url is required when linuxdo_connect.enabled=true")
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(c.LinuxDo.UserInfoURL) == "" {
|
||||||
|
return fmt.Errorf("linuxdo_connect.userinfo_url is required when linuxdo_connect.enabled=true")
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(c.LinuxDo.RedirectURL) == "" {
|
||||||
|
return fmt.Errorf("linuxdo_connect.redirect_url is required when linuxdo_connect.enabled=true")
|
||||||
|
}
|
||||||
|
method := strings.ToLower(strings.TrimSpace(c.LinuxDo.TokenAuthMethod))
|
||||||
|
switch method {
|
||||||
|
case "", "client_secret_post", "client_secret_basic", "none":
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("linuxdo_connect.token_auth_method must be one of: client_secret_post/client_secret_basic/none")
|
||||||
|
}
|
||||||
|
if method == "none" && !c.LinuxDo.UsePKCE {
|
||||||
|
return fmt.Errorf("linuxdo_connect.use_pkce must be true when linuxdo_connect.token_auth_method=none")
|
||||||
|
}
|
||||||
|
if (method == "" || method == "client_secret_post" || method == "client_secret_basic") && strings.TrimSpace(c.LinuxDo.ClientSecret) == "" {
|
||||||
|
return fmt.Errorf("linuxdo_connect.client_secret is required when linuxdo_connect.enabled=true and token_auth_method is client_secret_post/client_secret_basic")
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(c.LinuxDo.FrontendRedirectURL) == "" {
|
||||||
|
return fmt.Errorf("linuxdo_connect.frontend_redirect_url is required when linuxdo_connect.enabled=true")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := ValidateAbsoluteHTTPURL(c.LinuxDo.AuthorizeURL); err != nil {
|
||||||
|
return fmt.Errorf("linuxdo_connect.authorize_url invalid: %w", err)
|
||||||
|
}
|
||||||
|
if err := ValidateAbsoluteHTTPURL(c.LinuxDo.TokenURL); err != nil {
|
||||||
|
return fmt.Errorf("linuxdo_connect.token_url invalid: %w", err)
|
||||||
|
}
|
||||||
|
if err := ValidateAbsoluteHTTPURL(c.LinuxDo.UserInfoURL); err != nil {
|
||||||
|
return fmt.Errorf("linuxdo_connect.userinfo_url invalid: %w", err)
|
||||||
|
}
|
||||||
|
if err := ValidateAbsoluteHTTPURL(c.LinuxDo.RedirectURL); err != nil {
|
||||||
|
return fmt.Errorf("linuxdo_connect.redirect_url invalid: %w", err)
|
||||||
|
}
|
||||||
|
if err := ValidateFrontendRedirectURL(c.LinuxDo.FrontendRedirectURL); err != nil {
|
||||||
|
return fmt.Errorf("linuxdo_connect.frontend_redirect_url invalid: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
warnIfInsecureURL("linuxdo_connect.authorize_url", c.LinuxDo.AuthorizeURL)
|
||||||
|
warnIfInsecureURL("linuxdo_connect.token_url", c.LinuxDo.TokenURL)
|
||||||
|
warnIfInsecureURL("linuxdo_connect.userinfo_url", c.LinuxDo.UserInfoURL)
|
||||||
|
warnIfInsecureURL("linuxdo_connect.redirect_url", c.LinuxDo.RedirectURL)
|
||||||
|
warnIfInsecureURL("linuxdo_connect.frontend_redirect_url", c.LinuxDo.FrontendRedirectURL)
|
||||||
|
}
|
||||||
if c.Billing.CircuitBreaker.Enabled {
|
if c.Billing.CircuitBreaker.Enabled {
|
||||||
if c.Billing.CircuitBreaker.FailureThreshold <= 0 {
|
if c.Billing.CircuitBreaker.FailureThreshold <= 0 {
|
||||||
return fmt.Errorf("billing.circuit_breaker.failure_threshold must be positive")
|
return fmt.Errorf("billing.circuit_breaker.failure_threshold must be positive")
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package config
|
package config
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -90,3 +91,53 @@ func TestLoadDefaultSecurityToggles(t *testing.T) {
|
|||||||
t.Fatalf("ResponseHeaders.Enabled = true, want false")
|
t.Fatalf("ResponseHeaders.Enabled = true, want false")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestValidateLinuxDoFrontendRedirectURL(t *testing.T) {
|
||||||
|
viper.Reset()
|
||||||
|
|
||||||
|
cfg, err := Load()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Load() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg.LinuxDo.Enabled = true
|
||||||
|
cfg.LinuxDo.ClientID = "test-client"
|
||||||
|
cfg.LinuxDo.ClientSecret = "test-secret"
|
||||||
|
cfg.LinuxDo.RedirectURL = "https://example.com/api/v1/auth/oauth/linuxdo/callback"
|
||||||
|
cfg.LinuxDo.TokenAuthMethod = "client_secret_post"
|
||||||
|
cfg.LinuxDo.UsePKCE = false
|
||||||
|
|
||||||
|
cfg.LinuxDo.FrontendRedirectURL = "javascript:alert(1)"
|
||||||
|
err = cfg.Validate()
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("Validate() expected error for javascript scheme, got nil")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "linuxdo_connect.frontend_redirect_url") {
|
||||||
|
t.Fatalf("Validate() expected frontend_redirect_url error, got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateLinuxDoPKCERequiredForPublicClient(t *testing.T) {
|
||||||
|
viper.Reset()
|
||||||
|
|
||||||
|
cfg, err := Load()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Load() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg.LinuxDo.Enabled = true
|
||||||
|
cfg.LinuxDo.ClientID = "test-client"
|
||||||
|
cfg.LinuxDo.ClientSecret = ""
|
||||||
|
cfg.LinuxDo.RedirectURL = "https://example.com/api/v1/auth/oauth/linuxdo/callback"
|
||||||
|
cfg.LinuxDo.FrontendRedirectURL = "/auth/linuxdo/callback"
|
||||||
|
cfg.LinuxDo.TokenAuthMethod = "none"
|
||||||
|
cfg.LinuxDo.UsePKCE = false
|
||||||
|
|
||||||
|
err = cfg.Validate()
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("Validate() expected error when token_auth_method=none and use_pkce=false, got nil")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "linuxdo_connect.use_pkce") {
|
||||||
|
t.Fatalf("Validate() expected use_pkce error, got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -116,6 +116,7 @@ type BulkUpdateAccountsRequest struct {
|
|||||||
Concurrency *int `json:"concurrency"`
|
Concurrency *int `json:"concurrency"`
|
||||||
Priority *int `json:"priority"`
|
Priority *int `json:"priority"`
|
||||||
Status string `json:"status" binding:"omitempty,oneof=active inactive error"`
|
Status string `json:"status" binding:"omitempty,oneof=active inactive error"`
|
||||||
|
Schedulable *bool `json:"schedulable"`
|
||||||
GroupIDs *[]int64 `json:"group_ids"`
|
GroupIDs *[]int64 `json:"group_ids"`
|
||||||
Credentials map[string]any `json:"credentials"`
|
Credentials map[string]any `json:"credentials"`
|
||||||
Extra map[string]any `json:"extra"`
|
Extra map[string]any `json:"extra"`
|
||||||
@@ -136,6 +137,11 @@ func (h *AccountHandler) List(c *gin.Context) {
|
|||||||
accountType := c.Query("type")
|
accountType := c.Query("type")
|
||||||
status := c.Query("status")
|
status := c.Query("status")
|
||||||
search := c.Query("search")
|
search := c.Query("search")
|
||||||
|
// 标准化和验证 search 参数
|
||||||
|
search = strings.TrimSpace(search)
|
||||||
|
if len(search) > 100 {
|
||||||
|
search = search[:100]
|
||||||
|
}
|
||||||
|
|
||||||
accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search)
|
accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -655,6 +661,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
|
|||||||
req.Concurrency != nil ||
|
req.Concurrency != nil ||
|
||||||
req.Priority != nil ||
|
req.Priority != nil ||
|
||||||
req.Status != "" ||
|
req.Status != "" ||
|
||||||
|
req.Schedulable != nil ||
|
||||||
req.GroupIDs != nil ||
|
req.GroupIDs != nil ||
|
||||||
len(req.Credentials) > 0 ||
|
len(req.Credentials) > 0 ||
|
||||||
len(req.Extra) > 0
|
len(req.Extra) > 0
|
||||||
@@ -671,6 +678,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
|
|||||||
Concurrency: req.Concurrency,
|
Concurrency: req.Concurrency,
|
||||||
Priority: req.Priority,
|
Priority: req.Priority,
|
||||||
Status: req.Status,
|
Status: req.Status,
|
||||||
|
Schedulable: req.Schedulable,
|
||||||
GroupIDs: req.GroupIDs,
|
GroupIDs: req.GroupIDs,
|
||||||
Credentials: req.Credentials,
|
Credentials: req.Credentials,
|
||||||
Extra: req.Extra,
|
Extra: req.Extra,
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package admin
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||||
@@ -67,6 +68,12 @@ func (h *GroupHandler) List(c *gin.Context) {
|
|||||||
page, pageSize := response.ParsePagination(c)
|
page, pageSize := response.ParsePagination(c)
|
||||||
platform := c.Query("platform")
|
platform := c.Query("platform")
|
||||||
status := c.Query("status")
|
status := c.Query("status")
|
||||||
|
search := c.Query("search")
|
||||||
|
// 标准化和验证 search 参数
|
||||||
|
search = strings.TrimSpace(search)
|
||||||
|
if len(search) > 100 {
|
||||||
|
search = search[:100]
|
||||||
|
}
|
||||||
isExclusiveStr := c.Query("is_exclusive")
|
isExclusiveStr := c.Query("is_exclusive")
|
||||||
|
|
||||||
var isExclusive *bool
|
var isExclusive *bool
|
||||||
@@ -75,7 +82,7 @@ func (h *GroupHandler) List(c *gin.Context) {
|
|||||||
isExclusive = &val
|
isExclusive = &val
|
||||||
}
|
}
|
||||||
|
|
||||||
groups, total, err := h.adminService.ListGroups(c.Request.Context(), page, pageSize, platform, status, isExclusive)
|
groups, total, err := h.adminService.ListGroups(c.Request.Context(), page, pageSize, platform, status, search, isExclusive)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -51,6 +51,11 @@ func (h *ProxyHandler) List(c *gin.Context) {
|
|||||||
protocol := c.Query("protocol")
|
protocol := c.Query("protocol")
|
||||||
status := c.Query("status")
|
status := c.Query("status")
|
||||||
search := c.Query("search")
|
search := c.Query("search")
|
||||||
|
// 标准化和验证 search 参数
|
||||||
|
search = strings.TrimSpace(search)
|
||||||
|
if len(search) > 100 {
|
||||||
|
search = search[:100]
|
||||||
|
}
|
||||||
|
|
||||||
proxies, total, err := h.adminService.ListProxiesWithAccountCount(c.Request.Context(), page, pageSize, protocol, status, search)
|
proxies, total, err := h.adminService.ListProxiesWithAccountCount(c.Request.Context(), page, pageSize, protocol, status, search)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"encoding/csv"
|
"encoding/csv"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||||
@@ -41,6 +42,11 @@ func (h *RedeemHandler) List(c *gin.Context) {
|
|||||||
codeType := c.Query("type")
|
codeType := c.Query("type")
|
||||||
status := c.Query("status")
|
status := c.Query("status")
|
||||||
search := c.Query("search")
|
search := c.Query("search")
|
||||||
|
// 标准化和验证 search 参数
|
||||||
|
search = strings.TrimSpace(search)
|
||||||
|
if len(search) > 100 {
|
||||||
|
search = search[:100]
|
||||||
|
}
|
||||||
|
|
||||||
codes, total, err := h.adminService.ListRedeemCodes(c.Request.Context(), page, pageSize, codeType, status, search)
|
codes, total, err := h.adminService.ListRedeemCodes(c.Request.Context(), page, pageSize, codeType, status, search)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -2,8 +2,10 @@ package admin
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"log"
|
"log"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||||
@@ -38,33 +40,37 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
response.Success(c, dto.SystemSettings{
|
response.Success(c, dto.SystemSettings{
|
||||||
RegistrationEnabled: settings.RegistrationEnabled,
|
RegistrationEnabled: settings.RegistrationEnabled,
|
||||||
EmailVerifyEnabled: settings.EmailVerifyEnabled,
|
EmailVerifyEnabled: settings.EmailVerifyEnabled,
|
||||||
SMTPHost: settings.SMTPHost,
|
SMTPHost: settings.SMTPHost,
|
||||||
SMTPPort: settings.SMTPPort,
|
SMTPPort: settings.SMTPPort,
|
||||||
SMTPUsername: settings.SMTPUsername,
|
SMTPUsername: settings.SMTPUsername,
|
||||||
SMTPPasswordConfigured: settings.SMTPPasswordConfigured,
|
SMTPPasswordConfigured: settings.SMTPPasswordConfigured,
|
||||||
SMTPFrom: settings.SMTPFrom,
|
SMTPFrom: settings.SMTPFrom,
|
||||||
SMTPFromName: settings.SMTPFromName,
|
SMTPFromName: settings.SMTPFromName,
|
||||||
SMTPUseTLS: settings.SMTPUseTLS,
|
SMTPUseTLS: settings.SMTPUseTLS,
|
||||||
TurnstileEnabled: settings.TurnstileEnabled,
|
TurnstileEnabled: settings.TurnstileEnabled,
|
||||||
TurnstileSiteKey: settings.TurnstileSiteKey,
|
TurnstileSiteKey: settings.TurnstileSiteKey,
|
||||||
TurnstileSecretKeyConfigured: settings.TurnstileSecretKeyConfigured,
|
TurnstileSecretKeyConfigured: settings.TurnstileSecretKeyConfigured,
|
||||||
SiteName: settings.SiteName,
|
LinuxDoConnectEnabled: settings.LinuxDoConnectEnabled,
|
||||||
SiteLogo: settings.SiteLogo,
|
LinuxDoConnectClientID: settings.LinuxDoConnectClientID,
|
||||||
SiteSubtitle: settings.SiteSubtitle,
|
LinuxDoConnectClientSecretConfigured: settings.LinuxDoConnectClientSecretConfigured,
|
||||||
APIBaseURL: settings.APIBaseURL,
|
LinuxDoConnectRedirectURL: settings.LinuxDoConnectRedirectURL,
|
||||||
ContactInfo: settings.ContactInfo,
|
SiteName: settings.SiteName,
|
||||||
DocURL: settings.DocURL,
|
SiteLogo: settings.SiteLogo,
|
||||||
DefaultConcurrency: settings.DefaultConcurrency,
|
SiteSubtitle: settings.SiteSubtitle,
|
||||||
DefaultBalance: settings.DefaultBalance,
|
APIBaseURL: settings.APIBaseURL,
|
||||||
EnableModelFallback: settings.EnableModelFallback,
|
ContactInfo: settings.ContactInfo,
|
||||||
FallbackModelAnthropic: settings.FallbackModelAnthropic,
|
DocURL: settings.DocURL,
|
||||||
FallbackModelOpenAI: settings.FallbackModelOpenAI,
|
DefaultConcurrency: settings.DefaultConcurrency,
|
||||||
FallbackModelGemini: settings.FallbackModelGemini,
|
DefaultBalance: settings.DefaultBalance,
|
||||||
FallbackModelAntigravity: settings.FallbackModelAntigravity,
|
EnableModelFallback: settings.EnableModelFallback,
|
||||||
EnableIdentityPatch: settings.EnableIdentityPatch,
|
FallbackModelAnthropic: settings.FallbackModelAnthropic,
|
||||||
IdentityPatchPrompt: settings.IdentityPatchPrompt,
|
FallbackModelOpenAI: settings.FallbackModelOpenAI,
|
||||||
|
FallbackModelGemini: settings.FallbackModelGemini,
|
||||||
|
FallbackModelAntigravity: settings.FallbackModelAntigravity,
|
||||||
|
EnableIdentityPatch: settings.EnableIdentityPatch,
|
||||||
|
IdentityPatchPrompt: settings.IdentityPatchPrompt,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -88,6 +94,12 @@ type UpdateSettingsRequest struct {
|
|||||||
TurnstileSiteKey string `json:"turnstile_site_key"`
|
TurnstileSiteKey string `json:"turnstile_site_key"`
|
||||||
TurnstileSecretKey string `json:"turnstile_secret_key"`
|
TurnstileSecretKey string `json:"turnstile_secret_key"`
|
||||||
|
|
||||||
|
// LinuxDo Connect OAuth 登录(终端用户 SSO)
|
||||||
|
LinuxDoConnectEnabled bool `json:"linuxdo_connect_enabled"`
|
||||||
|
LinuxDoConnectClientID string `json:"linuxdo_connect_client_id"`
|
||||||
|
LinuxDoConnectClientSecret string `json:"linuxdo_connect_client_secret"`
|
||||||
|
LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"`
|
||||||
|
|
||||||
// OEM设置
|
// OEM设置
|
||||||
SiteName string `json:"site_name"`
|
SiteName string `json:"site_name"`
|
||||||
SiteLogo string `json:"site_logo"`
|
SiteLogo string `json:"site_logo"`
|
||||||
@@ -165,34 +177,67 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// LinuxDo Connect 参数验证
|
||||||
|
if req.LinuxDoConnectEnabled {
|
||||||
|
req.LinuxDoConnectClientID = strings.TrimSpace(req.LinuxDoConnectClientID)
|
||||||
|
req.LinuxDoConnectClientSecret = strings.TrimSpace(req.LinuxDoConnectClientSecret)
|
||||||
|
req.LinuxDoConnectRedirectURL = strings.TrimSpace(req.LinuxDoConnectRedirectURL)
|
||||||
|
|
||||||
|
if req.LinuxDoConnectClientID == "" {
|
||||||
|
response.BadRequest(c, "LinuxDo Client ID is required when enabled")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if req.LinuxDoConnectRedirectURL == "" {
|
||||||
|
response.BadRequest(c, "LinuxDo Redirect URL is required when enabled")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := config.ValidateAbsoluteHTTPURL(req.LinuxDoConnectRedirectURL); err != nil {
|
||||||
|
response.BadRequest(c, "LinuxDo Redirect URL must be an absolute http(s) URL")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果未提供 client_secret,则保留现有值(如有)。
|
||||||
|
if req.LinuxDoConnectClientSecret == "" {
|
||||||
|
if previousSettings.LinuxDoConnectClientSecret == "" {
|
||||||
|
response.BadRequest(c, "LinuxDo Client Secret is required when enabled")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
req.LinuxDoConnectClientSecret = previousSettings.LinuxDoConnectClientSecret
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
settings := &service.SystemSettings{
|
settings := &service.SystemSettings{
|
||||||
RegistrationEnabled: req.RegistrationEnabled,
|
RegistrationEnabled: req.RegistrationEnabled,
|
||||||
EmailVerifyEnabled: req.EmailVerifyEnabled,
|
EmailVerifyEnabled: req.EmailVerifyEnabled,
|
||||||
SMTPHost: req.SMTPHost,
|
SMTPHost: req.SMTPHost,
|
||||||
SMTPPort: req.SMTPPort,
|
SMTPPort: req.SMTPPort,
|
||||||
SMTPUsername: req.SMTPUsername,
|
SMTPUsername: req.SMTPUsername,
|
||||||
SMTPPassword: req.SMTPPassword,
|
SMTPPassword: req.SMTPPassword,
|
||||||
SMTPFrom: req.SMTPFrom,
|
SMTPFrom: req.SMTPFrom,
|
||||||
SMTPFromName: req.SMTPFromName,
|
SMTPFromName: req.SMTPFromName,
|
||||||
SMTPUseTLS: req.SMTPUseTLS,
|
SMTPUseTLS: req.SMTPUseTLS,
|
||||||
TurnstileEnabled: req.TurnstileEnabled,
|
TurnstileEnabled: req.TurnstileEnabled,
|
||||||
TurnstileSiteKey: req.TurnstileSiteKey,
|
TurnstileSiteKey: req.TurnstileSiteKey,
|
||||||
TurnstileSecretKey: req.TurnstileSecretKey,
|
TurnstileSecretKey: req.TurnstileSecretKey,
|
||||||
SiteName: req.SiteName,
|
LinuxDoConnectEnabled: req.LinuxDoConnectEnabled,
|
||||||
SiteLogo: req.SiteLogo,
|
LinuxDoConnectClientID: req.LinuxDoConnectClientID,
|
||||||
SiteSubtitle: req.SiteSubtitle,
|
LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret,
|
||||||
APIBaseURL: req.APIBaseURL,
|
LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL,
|
||||||
ContactInfo: req.ContactInfo,
|
SiteName: req.SiteName,
|
||||||
DocURL: req.DocURL,
|
SiteLogo: req.SiteLogo,
|
||||||
DefaultConcurrency: req.DefaultConcurrency,
|
SiteSubtitle: req.SiteSubtitle,
|
||||||
DefaultBalance: req.DefaultBalance,
|
APIBaseURL: req.APIBaseURL,
|
||||||
EnableModelFallback: req.EnableModelFallback,
|
ContactInfo: req.ContactInfo,
|
||||||
FallbackModelAnthropic: req.FallbackModelAnthropic,
|
DocURL: req.DocURL,
|
||||||
FallbackModelOpenAI: req.FallbackModelOpenAI,
|
DefaultConcurrency: req.DefaultConcurrency,
|
||||||
FallbackModelGemini: req.FallbackModelGemini,
|
DefaultBalance: req.DefaultBalance,
|
||||||
FallbackModelAntigravity: req.FallbackModelAntigravity,
|
EnableModelFallback: req.EnableModelFallback,
|
||||||
EnableIdentityPatch: req.EnableIdentityPatch,
|
FallbackModelAnthropic: req.FallbackModelAnthropic,
|
||||||
IdentityPatchPrompt: req.IdentityPatchPrompt,
|
FallbackModelOpenAI: req.FallbackModelOpenAI,
|
||||||
|
FallbackModelGemini: req.FallbackModelGemini,
|
||||||
|
FallbackModelAntigravity: req.FallbackModelAntigravity,
|
||||||
|
EnableIdentityPatch: req.EnableIdentityPatch,
|
||||||
|
IdentityPatchPrompt: req.IdentityPatchPrompt,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.settingService.UpdateSettings(c.Request.Context(), settings); err != nil {
|
if err := h.settingService.UpdateSettings(c.Request.Context(), settings); err != nil {
|
||||||
@@ -210,33 +255,37 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
response.Success(c, dto.SystemSettings{
|
response.Success(c, dto.SystemSettings{
|
||||||
RegistrationEnabled: updatedSettings.RegistrationEnabled,
|
RegistrationEnabled: updatedSettings.RegistrationEnabled,
|
||||||
EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
|
EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
|
||||||
SMTPHost: updatedSettings.SMTPHost,
|
SMTPHost: updatedSettings.SMTPHost,
|
||||||
SMTPPort: updatedSettings.SMTPPort,
|
SMTPPort: updatedSettings.SMTPPort,
|
||||||
SMTPUsername: updatedSettings.SMTPUsername,
|
SMTPUsername: updatedSettings.SMTPUsername,
|
||||||
SMTPPasswordConfigured: updatedSettings.SMTPPasswordConfigured,
|
SMTPPasswordConfigured: updatedSettings.SMTPPasswordConfigured,
|
||||||
SMTPFrom: updatedSettings.SMTPFrom,
|
SMTPFrom: updatedSettings.SMTPFrom,
|
||||||
SMTPFromName: updatedSettings.SMTPFromName,
|
SMTPFromName: updatedSettings.SMTPFromName,
|
||||||
SMTPUseTLS: updatedSettings.SMTPUseTLS,
|
SMTPUseTLS: updatedSettings.SMTPUseTLS,
|
||||||
TurnstileEnabled: updatedSettings.TurnstileEnabled,
|
TurnstileEnabled: updatedSettings.TurnstileEnabled,
|
||||||
TurnstileSiteKey: updatedSettings.TurnstileSiteKey,
|
TurnstileSiteKey: updatedSettings.TurnstileSiteKey,
|
||||||
TurnstileSecretKeyConfigured: updatedSettings.TurnstileSecretKeyConfigured,
|
TurnstileSecretKeyConfigured: updatedSettings.TurnstileSecretKeyConfigured,
|
||||||
SiteName: updatedSettings.SiteName,
|
LinuxDoConnectEnabled: updatedSettings.LinuxDoConnectEnabled,
|
||||||
SiteLogo: updatedSettings.SiteLogo,
|
LinuxDoConnectClientID: updatedSettings.LinuxDoConnectClientID,
|
||||||
SiteSubtitle: updatedSettings.SiteSubtitle,
|
LinuxDoConnectClientSecretConfigured: updatedSettings.LinuxDoConnectClientSecretConfigured,
|
||||||
APIBaseURL: updatedSettings.APIBaseURL,
|
LinuxDoConnectRedirectURL: updatedSettings.LinuxDoConnectRedirectURL,
|
||||||
ContactInfo: updatedSettings.ContactInfo,
|
SiteName: updatedSettings.SiteName,
|
||||||
DocURL: updatedSettings.DocURL,
|
SiteLogo: updatedSettings.SiteLogo,
|
||||||
DefaultConcurrency: updatedSettings.DefaultConcurrency,
|
SiteSubtitle: updatedSettings.SiteSubtitle,
|
||||||
DefaultBalance: updatedSettings.DefaultBalance,
|
APIBaseURL: updatedSettings.APIBaseURL,
|
||||||
EnableModelFallback: updatedSettings.EnableModelFallback,
|
ContactInfo: updatedSettings.ContactInfo,
|
||||||
FallbackModelAnthropic: updatedSettings.FallbackModelAnthropic,
|
DocURL: updatedSettings.DocURL,
|
||||||
FallbackModelOpenAI: updatedSettings.FallbackModelOpenAI,
|
DefaultConcurrency: updatedSettings.DefaultConcurrency,
|
||||||
FallbackModelGemini: updatedSettings.FallbackModelGemini,
|
DefaultBalance: updatedSettings.DefaultBalance,
|
||||||
FallbackModelAntigravity: updatedSettings.FallbackModelAntigravity,
|
EnableModelFallback: updatedSettings.EnableModelFallback,
|
||||||
EnableIdentityPatch: updatedSettings.EnableIdentityPatch,
|
FallbackModelAnthropic: updatedSettings.FallbackModelAnthropic,
|
||||||
IdentityPatchPrompt: updatedSettings.IdentityPatchPrompt,
|
FallbackModelOpenAI: updatedSettings.FallbackModelOpenAI,
|
||||||
|
FallbackModelGemini: updatedSettings.FallbackModelGemini,
|
||||||
|
FallbackModelAntigravity: updatedSettings.FallbackModelAntigravity,
|
||||||
|
EnableIdentityPatch: updatedSettings.EnableIdentityPatch,
|
||||||
|
IdentityPatchPrompt: updatedSettings.IdentityPatchPrompt,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -298,6 +347,18 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
|||||||
if req.TurnstileSecretKey != "" {
|
if req.TurnstileSecretKey != "" {
|
||||||
changed = append(changed, "turnstile_secret_key")
|
changed = append(changed, "turnstile_secret_key")
|
||||||
}
|
}
|
||||||
|
if before.LinuxDoConnectEnabled != after.LinuxDoConnectEnabled {
|
||||||
|
changed = append(changed, "linuxdo_connect_enabled")
|
||||||
|
}
|
||||||
|
if before.LinuxDoConnectClientID != after.LinuxDoConnectClientID {
|
||||||
|
changed = append(changed, "linuxdo_connect_client_id")
|
||||||
|
}
|
||||||
|
if req.LinuxDoConnectClientSecret != "" {
|
||||||
|
changed = append(changed, "linuxdo_connect_client_secret")
|
||||||
|
}
|
||||||
|
if before.LinuxDoConnectRedirectURL != after.LinuxDoConnectRedirectURL {
|
||||||
|
changed = append(changed, "linuxdo_connect_redirect_url")
|
||||||
|
}
|
||||||
if before.SiteName != after.SiteName {
|
if before.SiteName != after.SiteName {
|
||||||
changed = append(changed, "site_name")
|
changed = append(changed, "site_name")
|
||||||
}
|
}
|
||||||
@@ -337,6 +398,12 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
|||||||
if before.FallbackModelAntigravity != after.FallbackModelAntigravity {
|
if before.FallbackModelAntigravity != after.FallbackModelAntigravity {
|
||||||
changed = append(changed, "fallback_model_antigravity")
|
changed = append(changed, "fallback_model_antigravity")
|
||||||
}
|
}
|
||||||
|
if before.EnableIdentityPatch != after.EnableIdentityPatch {
|
||||||
|
changed = append(changed, "enable_identity_patch")
|
||||||
|
}
|
||||||
|
if before.IdentityPatchPrompt != after.IdentityPatchPrompt {
|
||||||
|
changed = append(changed, "identity_patch_prompt")
|
||||||
|
}
|
||||||
return changed
|
return changed
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package admin
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||||
@@ -63,10 +64,17 @@ type UpdateBalanceRequest struct {
|
|||||||
func (h *UserHandler) List(c *gin.Context) {
|
func (h *UserHandler) List(c *gin.Context) {
|
||||||
page, pageSize := response.ParsePagination(c)
|
page, pageSize := response.ParsePagination(c)
|
||||||
|
|
||||||
|
search := c.Query("search")
|
||||||
|
// 标准化和验证 search 参数
|
||||||
|
search = strings.TrimSpace(search)
|
||||||
|
if len(search) > 100 {
|
||||||
|
search = search[:100]
|
||||||
|
}
|
||||||
|
|
||||||
filters := service.UserListFilters{
|
filters := service.UserListFilters{
|
||||||
Status: c.Query("status"),
|
Status: c.Query("status"),
|
||||||
Role: c.Query("role"),
|
Role: c.Query("role"),
|
||||||
Search: c.Query("search"),
|
Search: search,
|
||||||
Attributes: parseAttributeFilters(c),
|
Attributes: parseAttributeFilters(c),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -27,16 +27,20 @@ func NewAPIKeyHandler(apiKeyService *service.APIKeyService) *APIKeyHandler {
|
|||||||
|
|
||||||
// CreateAPIKeyRequest represents the create API key request payload
|
// CreateAPIKeyRequest represents the create API key request payload
|
||||||
type CreateAPIKeyRequest struct {
|
type CreateAPIKeyRequest struct {
|
||||||
Name string `json:"name" binding:"required"`
|
Name string `json:"name" binding:"required"`
|
||||||
GroupID *int64 `json:"group_id"` // nullable
|
GroupID *int64 `json:"group_id"` // nullable
|
||||||
CustomKey *string `json:"custom_key"` // 可选的自定义key
|
CustomKey *string `json:"custom_key"` // 可选的自定义key
|
||||||
|
IPWhitelist []string `json:"ip_whitelist"` // IP 白名单
|
||||||
|
IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateAPIKeyRequest represents the update API key request payload
|
// UpdateAPIKeyRequest represents the update API key request payload
|
||||||
type UpdateAPIKeyRequest struct {
|
type UpdateAPIKeyRequest struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
GroupID *int64 `json:"group_id"`
|
GroupID *int64 `json:"group_id"`
|
||||||
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
|
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
|
||||||
|
IPWhitelist []string `json:"ip_whitelist"` // IP 白名单
|
||||||
|
IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单
|
||||||
}
|
}
|
||||||
|
|
||||||
// List handles listing user's API keys with pagination
|
// List handles listing user's API keys with pagination
|
||||||
@@ -110,9 +114,11 @@ func (h *APIKeyHandler) Create(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
svcReq := service.CreateAPIKeyRequest{
|
svcReq := service.CreateAPIKeyRequest{
|
||||||
Name: req.Name,
|
Name: req.Name,
|
||||||
GroupID: req.GroupID,
|
GroupID: req.GroupID,
|
||||||
CustomKey: req.CustomKey,
|
CustomKey: req.CustomKey,
|
||||||
|
IPWhitelist: req.IPWhitelist,
|
||||||
|
IPBlacklist: req.IPBlacklist,
|
||||||
}
|
}
|
||||||
key, err := h.apiKeyService.Create(c.Request.Context(), subject.UserID, svcReq)
|
key, err := h.apiKeyService.Create(c.Request.Context(), subject.UserID, svcReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -144,7 +150,10 @@ func (h *APIKeyHandler) Update(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
svcReq := service.UpdateAPIKeyRequest{}
|
svcReq := service.UpdateAPIKeyRequest{
|
||||||
|
IPWhitelist: req.IPWhitelist,
|
||||||
|
IPBlacklist: req.IPBlacklist,
|
||||||
|
}
|
||||||
if req.Name != "" {
|
if req.Name != "" {
|
||||||
svcReq.Name = &req.Name
|
svcReq.Name = &req.Name
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,14 +15,16 @@ type AuthHandler struct {
|
|||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
authService *service.AuthService
|
authService *service.AuthService
|
||||||
userService *service.UserService
|
userService *service.UserService
|
||||||
|
settingSvc *service.SettingService
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAuthHandler creates a new AuthHandler
|
// NewAuthHandler creates a new AuthHandler
|
||||||
func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService) *AuthHandler {
|
func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService, settingService *service.SettingService) *AuthHandler {
|
||||||
return &AuthHandler{
|
return &AuthHandler{
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
authService: authService,
|
authService: authService,
|
||||||
userService: userService,
|
userService: userService,
|
||||||
|
settingSvc: settingService,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
679
backend/internal/handler/auth_linuxdo_oauth.go
Normal file
679
backend/internal/handler/auth_linuxdo_oauth.go
Normal file
@@ -0,0 +1,679 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
"unicode/utf8"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/imroc/req/v3"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
linuxDoOAuthCookiePath = "/api/v1/auth/oauth/linuxdo"
|
||||||
|
linuxDoOAuthStateCookieName = "linuxdo_oauth_state"
|
||||||
|
linuxDoOAuthVerifierCookie = "linuxdo_oauth_verifier"
|
||||||
|
linuxDoOAuthRedirectCookie = "linuxdo_oauth_redirect"
|
||||||
|
linuxDoOAuthCookieMaxAgeSec = 10 * 60 // 10 minutes
|
||||||
|
linuxDoOAuthDefaultRedirectTo = "/dashboard"
|
||||||
|
linuxDoOAuthDefaultFrontendCB = "/auth/linuxdo/callback"
|
||||||
|
|
||||||
|
linuxDoOAuthMaxRedirectLen = 2048
|
||||||
|
linuxDoOAuthMaxFragmentValueLen = 512
|
||||||
|
linuxDoOAuthMaxSubjectLen = 64 - len("linuxdo-")
|
||||||
|
)
|
||||||
|
|
||||||
|
type linuxDoTokenResponse struct {
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
TokenType string `json:"token_type"`
|
||||||
|
ExpiresIn int64 `json:"expires_in"`
|
||||||
|
RefreshToken string `json:"refresh_token,omitempty"`
|
||||||
|
Scope string `json:"scope,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type linuxDoTokenExchangeError struct {
|
||||||
|
StatusCode int
|
||||||
|
ProviderError string
|
||||||
|
ProviderDescription string
|
||||||
|
Body string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *linuxDoTokenExchangeError) Error() string {
|
||||||
|
if e == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
parts := []string{fmt.Sprintf("token exchange status=%d", e.StatusCode)}
|
||||||
|
if strings.TrimSpace(e.ProviderError) != "" {
|
||||||
|
parts = append(parts, "error="+strings.TrimSpace(e.ProviderError))
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(e.ProviderDescription) != "" {
|
||||||
|
parts = append(parts, "error_description="+strings.TrimSpace(e.ProviderDescription))
|
||||||
|
}
|
||||||
|
return strings.Join(parts, " ")
|
||||||
|
}
|
||||||
|
|
||||||
|
// LinuxDoOAuthStart 启动 LinuxDo Connect OAuth 登录流程。
|
||||||
|
// GET /api/v1/auth/oauth/linuxdo/start?redirect=/dashboard
|
||||||
|
func (h *AuthHandler) LinuxDoOAuthStart(c *gin.Context) {
|
||||||
|
cfg, err := h.getLinuxDoOAuthConfig(c.Request.Context())
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
state, err := oauth.GenerateState()
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_STATE_GEN_FAILED", "failed to generate oauth state").WithCause(err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
redirectTo := sanitizeFrontendRedirectPath(c.Query("redirect"))
|
||||||
|
if redirectTo == "" {
|
||||||
|
redirectTo = linuxDoOAuthDefaultRedirectTo
|
||||||
|
}
|
||||||
|
|
||||||
|
secureCookie := isRequestHTTPS(c)
|
||||||
|
setCookie(c, linuxDoOAuthStateCookieName, encodeCookieValue(state), linuxDoOAuthCookieMaxAgeSec, secureCookie)
|
||||||
|
setCookie(c, linuxDoOAuthRedirectCookie, encodeCookieValue(redirectTo), linuxDoOAuthCookieMaxAgeSec, secureCookie)
|
||||||
|
|
||||||
|
codeChallenge := ""
|
||||||
|
if cfg.UsePKCE {
|
||||||
|
verifier, err := oauth.GenerateCodeVerifier()
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_PKCE_GEN_FAILED", "failed to generate pkce verifier").WithCause(err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
codeChallenge = oauth.GenerateCodeChallenge(verifier)
|
||||||
|
setCookie(c, linuxDoOAuthVerifierCookie, encodeCookieValue(verifier), linuxDoOAuthCookieMaxAgeSec, secureCookie)
|
||||||
|
}
|
||||||
|
|
||||||
|
redirectURI := strings.TrimSpace(cfg.RedirectURL)
|
||||||
|
if redirectURI == "" {
|
||||||
|
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth redirect url not configured"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
authURL, err := buildLinuxDoAuthorizeURL(cfg, state, codeChallenge, redirectURI)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BUILD_URL_FAILED", "failed to build oauth authorization url").WithCause(err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Redirect(http.StatusFound, authURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
// LinuxDoOAuthCallback 处理 OAuth 回调:创建/登录用户,然后重定向到前端。
|
||||||
|
// GET /api/v1/auth/oauth/linuxdo/callback?code=...&state=...
|
||||||
|
func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
|
||||||
|
cfg, cfgErr := h.getLinuxDoOAuthConfig(c.Request.Context())
|
||||||
|
if cfgErr != nil {
|
||||||
|
response.ErrorFrom(c, cfgErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
frontendCallback := strings.TrimSpace(cfg.FrontendRedirectURL)
|
||||||
|
if frontendCallback == "" {
|
||||||
|
frontendCallback = linuxDoOAuthDefaultFrontendCB
|
||||||
|
}
|
||||||
|
|
||||||
|
if providerErr := strings.TrimSpace(c.Query("error")); providerErr != "" {
|
||||||
|
redirectOAuthError(c, frontendCallback, "provider_error", providerErr, c.Query("error_description"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
code := strings.TrimSpace(c.Query("code"))
|
||||||
|
state := strings.TrimSpace(c.Query("state"))
|
||||||
|
if code == "" || state == "" {
|
||||||
|
redirectOAuthError(c, frontendCallback, "missing_params", "missing code/state", "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
secureCookie := isRequestHTTPS(c)
|
||||||
|
defer func() {
|
||||||
|
clearCookie(c, linuxDoOAuthStateCookieName, secureCookie)
|
||||||
|
clearCookie(c, linuxDoOAuthVerifierCookie, secureCookie)
|
||||||
|
clearCookie(c, linuxDoOAuthRedirectCookie, secureCookie)
|
||||||
|
}()
|
||||||
|
|
||||||
|
expectedState, err := readCookieDecoded(c, linuxDoOAuthStateCookieName)
|
||||||
|
if err != nil || expectedState == "" || state != expectedState {
|
||||||
|
redirectOAuthError(c, frontendCallback, "invalid_state", "invalid oauth state", "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
redirectTo, _ := readCookieDecoded(c, linuxDoOAuthRedirectCookie)
|
||||||
|
redirectTo = sanitizeFrontendRedirectPath(redirectTo)
|
||||||
|
if redirectTo == "" {
|
||||||
|
redirectTo = linuxDoOAuthDefaultRedirectTo
|
||||||
|
}
|
||||||
|
|
||||||
|
codeVerifier := ""
|
||||||
|
if cfg.UsePKCE {
|
||||||
|
codeVerifier, _ = readCookieDecoded(c, linuxDoOAuthVerifierCookie)
|
||||||
|
if codeVerifier == "" {
|
||||||
|
redirectOAuthError(c, frontendCallback, "missing_verifier", "missing pkce verifier", "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
redirectURI := strings.TrimSpace(cfg.RedirectURL)
|
||||||
|
if redirectURI == "" {
|
||||||
|
redirectOAuthError(c, frontendCallback, "config_error", "oauth redirect url not configured", "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenResp, err := linuxDoExchangeCode(c.Request.Context(), cfg, code, redirectURI, codeVerifier)
|
||||||
|
if err != nil {
|
||||||
|
description := ""
|
||||||
|
var exchangeErr *linuxDoTokenExchangeError
|
||||||
|
if errors.As(err, &exchangeErr) && exchangeErr != nil {
|
||||||
|
log.Printf(
|
||||||
|
"[LinuxDo OAuth] token exchange failed: status=%d provider_error=%q provider_description=%q body=%s",
|
||||||
|
exchangeErr.StatusCode,
|
||||||
|
exchangeErr.ProviderError,
|
||||||
|
exchangeErr.ProviderDescription,
|
||||||
|
truncateLogValue(exchangeErr.Body, 2048),
|
||||||
|
)
|
||||||
|
description = exchangeErr.Error()
|
||||||
|
} else {
|
||||||
|
log.Printf("[LinuxDo OAuth] token exchange failed: %v", err)
|
||||||
|
description = err.Error()
|
||||||
|
}
|
||||||
|
redirectOAuthError(c, frontendCallback, "token_exchange_failed", "failed to exchange oauth code", singleLine(description))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
email, username, subject, err := linuxDoFetchUserInfo(c.Request.Context(), cfg, tokenResp)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[LinuxDo OAuth] userinfo fetch failed: %v", err)
|
||||||
|
redirectOAuthError(c, frontendCallback, "userinfo_failed", "failed to fetch user info", "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 安全考虑:不要把第三方返回的 email 直接映射到本地账号(可能与本地邮箱用户冲突导致账号被接管)。
|
||||||
|
// 统一使用基于 subject 的稳定合成邮箱来做账号绑定。
|
||||||
|
if subject != "" {
|
||||||
|
email = linuxDoSyntheticEmail(subject)
|
||||||
|
}
|
||||||
|
|
||||||
|
jwtToken, _, err := h.authService.LoginOrRegisterOAuth(c.Request.Context(), email, username)
|
||||||
|
if err != nil {
|
||||||
|
// 避免把内部细节泄露给客户端;给前端保留结构化原因与提示信息即可。
|
||||||
|
redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
fragment := url.Values{}
|
||||||
|
fragment.Set("access_token", jwtToken)
|
||||||
|
fragment.Set("token_type", "Bearer")
|
||||||
|
fragment.Set("redirect", redirectTo)
|
||||||
|
redirectWithFragment(c, frontendCallback, fragment)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *AuthHandler) getLinuxDoOAuthConfig(ctx context.Context) (config.LinuxDoConnectConfig, error) {
|
||||||
|
if h != nil && h.settingSvc != nil {
|
||||||
|
return h.settingSvc.GetLinuxDoConnectOAuthConfig(ctx)
|
||||||
|
}
|
||||||
|
if h == nil || h.cfg == nil {
|
||||||
|
return config.LinuxDoConnectConfig{}, infraerrors.ServiceUnavailable("CONFIG_NOT_READY", "config not loaded")
|
||||||
|
}
|
||||||
|
if !h.cfg.LinuxDo.Enabled {
|
||||||
|
return config.LinuxDoConnectConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "oauth login is disabled")
|
||||||
|
}
|
||||||
|
return h.cfg.LinuxDo, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func linuxDoExchangeCode(
|
||||||
|
ctx context.Context,
|
||||||
|
cfg config.LinuxDoConnectConfig,
|
||||||
|
code string,
|
||||||
|
redirectURI string,
|
||||||
|
codeVerifier string,
|
||||||
|
) (*linuxDoTokenResponse, error) {
|
||||||
|
client := req.C().SetTimeout(30 * time.Second)
|
||||||
|
|
||||||
|
form := url.Values{}
|
||||||
|
form.Set("grant_type", "authorization_code")
|
||||||
|
form.Set("client_id", cfg.ClientID)
|
||||||
|
form.Set("code", code)
|
||||||
|
form.Set("redirect_uri", redirectURI)
|
||||||
|
if cfg.UsePKCE {
|
||||||
|
form.Set("code_verifier", codeVerifier)
|
||||||
|
}
|
||||||
|
|
||||||
|
r := client.R().
|
||||||
|
SetContext(ctx).
|
||||||
|
SetHeader("Accept", "application/json")
|
||||||
|
|
||||||
|
switch strings.ToLower(strings.TrimSpace(cfg.TokenAuthMethod)) {
|
||||||
|
case "", "client_secret_post":
|
||||||
|
form.Set("client_secret", cfg.ClientSecret)
|
||||||
|
case "client_secret_basic":
|
||||||
|
r.SetBasicAuth(cfg.ClientID, cfg.ClientSecret)
|
||||||
|
case "none":
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unsupported token_auth_method: %s", cfg.TokenAuthMethod)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := r.SetFormDataFromValues(form).Post(cfg.TokenURL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("request token: %w", err)
|
||||||
|
}
|
||||||
|
body := strings.TrimSpace(resp.String())
|
||||||
|
if !resp.IsSuccessState() {
|
||||||
|
providerErr, providerDesc := parseOAuthProviderError(body)
|
||||||
|
return nil, &linuxDoTokenExchangeError{
|
||||||
|
StatusCode: resp.StatusCode,
|
||||||
|
ProviderError: providerErr,
|
||||||
|
ProviderDescription: providerDesc,
|
||||||
|
Body: body,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenResp, ok := parseLinuxDoTokenResponse(body)
|
||||||
|
if !ok || strings.TrimSpace(tokenResp.AccessToken) == "" {
|
||||||
|
return nil, &linuxDoTokenExchangeError{
|
||||||
|
StatusCode: resp.StatusCode,
|
||||||
|
Body: body,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(tokenResp.TokenType) == "" {
|
||||||
|
tokenResp.TokenType = "Bearer"
|
||||||
|
}
|
||||||
|
return tokenResp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func linuxDoFetchUserInfo(
|
||||||
|
ctx context.Context,
|
||||||
|
cfg config.LinuxDoConnectConfig,
|
||||||
|
token *linuxDoTokenResponse,
|
||||||
|
) (email string, username string, subject string, err error) {
|
||||||
|
client := req.C().SetTimeout(30 * time.Second)
|
||||||
|
authorization, err := buildBearerAuthorization(token.TokenType, token.AccessToken)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", "", fmt.Errorf("invalid token for userinfo request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := client.R().
|
||||||
|
SetContext(ctx).
|
||||||
|
SetHeader("Accept", "application/json").
|
||||||
|
SetHeader("Authorization", authorization).
|
||||||
|
Get(cfg.UserInfoURL)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", "", fmt.Errorf("request userinfo: %w", err)
|
||||||
|
}
|
||||||
|
if !resp.IsSuccessState() {
|
||||||
|
return "", "", "", fmt.Errorf("userinfo status=%d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
return linuxDoParseUserInfo(resp.String(), cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func linuxDoParseUserInfo(body string, cfg config.LinuxDoConnectConfig) (email string, username string, subject string, err error) {
|
||||||
|
email = firstNonEmpty(
|
||||||
|
getGJSON(body, cfg.UserInfoEmailPath),
|
||||||
|
getGJSON(body, "email"),
|
||||||
|
getGJSON(body, "user.email"),
|
||||||
|
getGJSON(body, "data.email"),
|
||||||
|
getGJSON(body, "attributes.email"),
|
||||||
|
)
|
||||||
|
username = firstNonEmpty(
|
||||||
|
getGJSON(body, cfg.UserInfoUsernamePath),
|
||||||
|
getGJSON(body, "username"),
|
||||||
|
getGJSON(body, "preferred_username"),
|
||||||
|
getGJSON(body, "name"),
|
||||||
|
getGJSON(body, "user.username"),
|
||||||
|
getGJSON(body, "user.name"),
|
||||||
|
)
|
||||||
|
subject = firstNonEmpty(
|
||||||
|
getGJSON(body, cfg.UserInfoIDPath),
|
||||||
|
getGJSON(body, "sub"),
|
||||||
|
getGJSON(body, "id"),
|
||||||
|
getGJSON(body, "user_id"),
|
||||||
|
getGJSON(body, "uid"),
|
||||||
|
getGJSON(body, "user.id"),
|
||||||
|
)
|
||||||
|
|
||||||
|
subject = strings.TrimSpace(subject)
|
||||||
|
if subject == "" {
|
||||||
|
return "", "", "", errors.New("userinfo missing id field")
|
||||||
|
}
|
||||||
|
if !isSafeLinuxDoSubject(subject) {
|
||||||
|
return "", "", "", errors.New("userinfo returned invalid id field")
|
||||||
|
}
|
||||||
|
|
||||||
|
email = strings.TrimSpace(email)
|
||||||
|
if email == "" {
|
||||||
|
// LinuxDo Connect 的 userinfo 可能不提供 email。为兼容现有用户模型(email 必填且唯一),使用稳定的合成邮箱。
|
||||||
|
email = linuxDoSyntheticEmail(subject)
|
||||||
|
}
|
||||||
|
|
||||||
|
username = strings.TrimSpace(username)
|
||||||
|
if username == "" {
|
||||||
|
username = "linuxdo_" + subject
|
||||||
|
}
|
||||||
|
|
||||||
|
return email, username, subject, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildLinuxDoAuthorizeURL(cfg config.LinuxDoConnectConfig, state string, codeChallenge string, redirectURI string) (string, error) {
|
||||||
|
u, err := url.Parse(cfg.AuthorizeURL)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("parse authorize_url: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
q := u.Query()
|
||||||
|
q.Set("response_type", "code")
|
||||||
|
q.Set("client_id", cfg.ClientID)
|
||||||
|
q.Set("redirect_uri", redirectURI)
|
||||||
|
if strings.TrimSpace(cfg.Scopes) != "" {
|
||||||
|
q.Set("scope", cfg.Scopes)
|
||||||
|
}
|
||||||
|
q.Set("state", state)
|
||||||
|
if cfg.UsePKCE {
|
||||||
|
q.Set("code_challenge", codeChallenge)
|
||||||
|
q.Set("code_challenge_method", "S256")
|
||||||
|
}
|
||||||
|
|
||||||
|
u.RawQuery = q.Encode()
|
||||||
|
return u.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func redirectOAuthError(c *gin.Context, frontendCallback string, code string, message string, description string) {
|
||||||
|
fragment := url.Values{}
|
||||||
|
fragment.Set("error", truncateFragmentValue(code))
|
||||||
|
if strings.TrimSpace(message) != "" {
|
||||||
|
fragment.Set("error_message", truncateFragmentValue(message))
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(description) != "" {
|
||||||
|
fragment.Set("error_description", truncateFragmentValue(description))
|
||||||
|
}
|
||||||
|
redirectWithFragment(c, frontendCallback, fragment)
|
||||||
|
}
|
||||||
|
|
||||||
|
func redirectWithFragment(c *gin.Context, frontendCallback string, fragment url.Values) {
|
||||||
|
u, err := url.Parse(frontendCallback)
|
||||||
|
if err != nil {
|
||||||
|
// 兜底:尽力跳转到默认页面,避免卡死在回调页。
|
||||||
|
c.Redirect(http.StatusFound, linuxDoOAuthDefaultRedirectTo)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if u.Scheme != "" && !strings.EqualFold(u.Scheme, "http") && !strings.EqualFold(u.Scheme, "https") {
|
||||||
|
c.Redirect(http.StatusFound, linuxDoOAuthDefaultRedirectTo)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
u.Fragment = fragment.Encode()
|
||||||
|
c.Header("Cache-Control", "no-store")
|
||||||
|
c.Header("Pragma", "no-cache")
|
||||||
|
c.Redirect(http.StatusFound, u.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func firstNonEmpty(values ...string) string {
|
||||||
|
for _, v := range values {
|
||||||
|
v = strings.TrimSpace(v)
|
||||||
|
if v != "" {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseOAuthProviderError(body string) (providerErr string, providerDesc string) {
|
||||||
|
body = strings.TrimSpace(body)
|
||||||
|
if body == "" {
|
||||||
|
return "", ""
|
||||||
|
}
|
||||||
|
|
||||||
|
providerErr = firstNonEmpty(
|
||||||
|
getGJSON(body, "error"),
|
||||||
|
getGJSON(body, "code"),
|
||||||
|
getGJSON(body, "error.code"),
|
||||||
|
)
|
||||||
|
providerDesc = firstNonEmpty(
|
||||||
|
getGJSON(body, "error_description"),
|
||||||
|
getGJSON(body, "error.message"),
|
||||||
|
getGJSON(body, "message"),
|
||||||
|
getGJSON(body, "detail"),
|
||||||
|
)
|
||||||
|
|
||||||
|
if providerErr != "" || providerDesc != "" {
|
||||||
|
return providerErr, providerDesc
|
||||||
|
}
|
||||||
|
|
||||||
|
values, err := url.ParseQuery(body)
|
||||||
|
if err != nil {
|
||||||
|
return "", ""
|
||||||
|
}
|
||||||
|
providerErr = firstNonEmpty(values.Get("error"), values.Get("code"))
|
||||||
|
providerDesc = firstNonEmpty(values.Get("error_description"), values.Get("error_message"), values.Get("message"))
|
||||||
|
return providerErr, providerDesc
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseLinuxDoTokenResponse(body string) (*linuxDoTokenResponse, bool) {
|
||||||
|
body = strings.TrimSpace(body)
|
||||||
|
if body == "" {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
accessToken := strings.TrimSpace(getGJSON(body, "access_token"))
|
||||||
|
if accessToken != "" {
|
||||||
|
tokenType := strings.TrimSpace(getGJSON(body, "token_type"))
|
||||||
|
refreshToken := strings.TrimSpace(getGJSON(body, "refresh_token"))
|
||||||
|
scope := strings.TrimSpace(getGJSON(body, "scope"))
|
||||||
|
expiresIn := gjson.Get(body, "expires_in").Int()
|
||||||
|
return &linuxDoTokenResponse{
|
||||||
|
AccessToken: accessToken,
|
||||||
|
TokenType: tokenType,
|
||||||
|
ExpiresIn: expiresIn,
|
||||||
|
RefreshToken: refreshToken,
|
||||||
|
Scope: scope,
|
||||||
|
}, true
|
||||||
|
}
|
||||||
|
|
||||||
|
values, err := url.ParseQuery(body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
accessToken = strings.TrimSpace(values.Get("access_token"))
|
||||||
|
if accessToken == "" {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
expiresIn := int64(0)
|
||||||
|
if raw := strings.TrimSpace(values.Get("expires_in")); raw != "" {
|
||||||
|
if v, err := strconv.ParseInt(raw, 10, 64); err == nil {
|
||||||
|
expiresIn = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &linuxDoTokenResponse{
|
||||||
|
AccessToken: accessToken,
|
||||||
|
TokenType: strings.TrimSpace(values.Get("token_type")),
|
||||||
|
ExpiresIn: expiresIn,
|
||||||
|
RefreshToken: strings.TrimSpace(values.Get("refresh_token")),
|
||||||
|
Scope: strings.TrimSpace(values.Get("scope")),
|
||||||
|
}, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func getGJSON(body string, path string) string {
|
||||||
|
path = strings.TrimSpace(path)
|
||||||
|
if path == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
res := gjson.Get(body, path)
|
||||||
|
if !res.Exists() {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return res.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func truncateLogValue(value string, maxLen int) string {
|
||||||
|
value = strings.TrimSpace(value)
|
||||||
|
if value == "" || maxLen <= 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if len(value) <= maxLen {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
value = value[:maxLen]
|
||||||
|
for !utf8.ValidString(value) {
|
||||||
|
value = value[:len(value)-1]
|
||||||
|
}
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
||||||
|
func singleLine(value string) string {
|
||||||
|
value = strings.TrimSpace(value)
|
||||||
|
if value == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return strings.Join(strings.Fields(value), " ")
|
||||||
|
}
|
||||||
|
|
||||||
|
func sanitizeFrontendRedirectPath(path string) string {
|
||||||
|
path = strings.TrimSpace(path)
|
||||||
|
if path == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if len(path) > linuxDoOAuthMaxRedirectLen {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
// 只允许同源相对路径(避免开放重定向)。
|
||||||
|
if !strings.HasPrefix(path, "/") {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(path, "//") {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if strings.Contains(path, "://") {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if strings.ContainsAny(path, "\r\n") {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return path
|
||||||
|
}
|
||||||
|
|
||||||
|
func isRequestHTTPS(c *gin.Context) bool {
|
||||||
|
if c.Request.TLS != nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
proto := strings.ToLower(strings.TrimSpace(c.GetHeader("X-Forwarded-Proto")))
|
||||||
|
return proto == "https"
|
||||||
|
}
|
||||||
|
|
||||||
|
func encodeCookieValue(value string) string {
|
||||||
|
return base64.RawURLEncoding.EncodeToString([]byte(value))
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeCookieValue(value string) (string, error) {
|
||||||
|
raw, err := base64.RawURLEncoding.DecodeString(value)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return string(raw), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func readCookieDecoded(c *gin.Context, name string) (string, error) {
|
||||||
|
ck, err := c.Request.Cookie(name)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return decodeCookieValue(ck.Value)
|
||||||
|
}
|
||||||
|
|
||||||
|
func setCookie(c *gin.Context, name string, value string, maxAgeSec int, secure bool) {
|
||||||
|
http.SetCookie(c.Writer, &http.Cookie{
|
||||||
|
Name: name,
|
||||||
|
Value: value,
|
||||||
|
Path: linuxDoOAuthCookiePath,
|
||||||
|
MaxAge: maxAgeSec,
|
||||||
|
HttpOnly: true,
|
||||||
|
Secure: secure,
|
||||||
|
SameSite: http.SameSiteLaxMode,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func clearCookie(c *gin.Context, name string, secure bool) {
|
||||||
|
http.SetCookie(c.Writer, &http.Cookie{
|
||||||
|
Name: name,
|
||||||
|
Value: "",
|
||||||
|
Path: linuxDoOAuthCookiePath,
|
||||||
|
MaxAge: -1,
|
||||||
|
HttpOnly: true,
|
||||||
|
Secure: secure,
|
||||||
|
SameSite: http.SameSiteLaxMode,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func truncateFragmentValue(value string) string {
|
||||||
|
value = strings.TrimSpace(value)
|
||||||
|
if value == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if len(value) > linuxDoOAuthMaxFragmentValueLen {
|
||||||
|
value = value[:linuxDoOAuthMaxFragmentValueLen]
|
||||||
|
for !utf8.ValidString(value) {
|
||||||
|
value = value[:len(value)-1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildBearerAuthorization(tokenType, accessToken string) (string, error) {
|
||||||
|
tokenType = strings.TrimSpace(tokenType)
|
||||||
|
if tokenType == "" {
|
||||||
|
tokenType = "Bearer"
|
||||||
|
}
|
||||||
|
if !strings.EqualFold(tokenType, "Bearer") {
|
||||||
|
return "", fmt.Errorf("unsupported token_type: %s", tokenType)
|
||||||
|
}
|
||||||
|
|
||||||
|
accessToken = strings.TrimSpace(accessToken)
|
||||||
|
if accessToken == "" {
|
||||||
|
return "", errors.New("missing access_token")
|
||||||
|
}
|
||||||
|
if strings.ContainsAny(accessToken, " \t\r\n") {
|
||||||
|
return "", errors.New("access_token contains whitespace")
|
||||||
|
}
|
||||||
|
return "Bearer " + accessToken, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isSafeLinuxDoSubject(subject string) bool {
|
||||||
|
subject = strings.TrimSpace(subject)
|
||||||
|
if subject == "" || len(subject) > linuxDoOAuthMaxSubjectLen {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, r := range subject {
|
||||||
|
switch {
|
||||||
|
case r >= '0' && r <= '9':
|
||||||
|
case r >= 'a' && r <= 'z':
|
||||||
|
case r >= 'A' && r <= 'Z':
|
||||||
|
case r == '_' || r == '-':
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func linuxDoSyntheticEmail(subject string) string {
|
||||||
|
subject = strings.TrimSpace(subject)
|
||||||
|
if subject == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return "linuxdo-" + subject + service.LinuxDoConnectSyntheticEmailDomain
|
||||||
|
}
|
||||||
108
backend/internal/handler/auth_linuxdo_oauth_test.go
Normal file
108
backend/internal/handler/auth_linuxdo_oauth_test.go
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSanitizeFrontendRedirectPath(t *testing.T) {
|
||||||
|
require.Equal(t, "/dashboard", sanitizeFrontendRedirectPath("/dashboard"))
|
||||||
|
require.Equal(t, "/dashboard", sanitizeFrontendRedirectPath(" /dashboard "))
|
||||||
|
require.Equal(t, "", sanitizeFrontendRedirectPath("dashboard"))
|
||||||
|
require.Equal(t, "", sanitizeFrontendRedirectPath("//evil.com"))
|
||||||
|
require.Equal(t, "", sanitizeFrontendRedirectPath("https://evil.com"))
|
||||||
|
require.Equal(t, "", sanitizeFrontendRedirectPath("/\nfoo"))
|
||||||
|
|
||||||
|
long := "/" + strings.Repeat("a", linuxDoOAuthMaxRedirectLen)
|
||||||
|
require.Equal(t, "", sanitizeFrontendRedirectPath(long))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildBearerAuthorization(t *testing.T) {
|
||||||
|
auth, err := buildBearerAuthorization("", "token123")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "Bearer token123", auth)
|
||||||
|
|
||||||
|
auth, err = buildBearerAuthorization("bearer", "token123")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "Bearer token123", auth)
|
||||||
|
|
||||||
|
_, err = buildBearerAuthorization("MAC", "token123")
|
||||||
|
require.Error(t, err)
|
||||||
|
|
||||||
|
_, err = buildBearerAuthorization("Bearer", "token 123")
|
||||||
|
require.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLinuxDoParseUserInfoParsesIDAndUsername(t *testing.T) {
|
||||||
|
cfg := config.LinuxDoConnectConfig{
|
||||||
|
UserInfoURL: "https://connect.linux.do/api/user",
|
||||||
|
}
|
||||||
|
|
||||||
|
email, username, subject, err := linuxDoParseUserInfo(`{"id":123,"username":"alice"}`, cfg)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "123", subject)
|
||||||
|
require.Equal(t, "alice", username)
|
||||||
|
require.Equal(t, "linuxdo-123@linuxdo-connect.invalid", email)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLinuxDoParseUserInfoDefaultsUsername(t *testing.T) {
|
||||||
|
cfg := config.LinuxDoConnectConfig{
|
||||||
|
UserInfoURL: "https://connect.linux.do/api/user",
|
||||||
|
}
|
||||||
|
|
||||||
|
email, username, subject, err := linuxDoParseUserInfo(`{"id":"123"}`, cfg)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "123", subject)
|
||||||
|
require.Equal(t, "linuxdo_123", username)
|
||||||
|
require.Equal(t, "linuxdo-123@linuxdo-connect.invalid", email)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLinuxDoParseUserInfoRejectsUnsafeSubject(t *testing.T) {
|
||||||
|
cfg := config.LinuxDoConnectConfig{
|
||||||
|
UserInfoURL: "https://connect.linux.do/api/user",
|
||||||
|
}
|
||||||
|
|
||||||
|
_, _, _, err := linuxDoParseUserInfo(`{"id":"123@456"}`, cfg)
|
||||||
|
require.Error(t, err)
|
||||||
|
|
||||||
|
tooLong := strings.Repeat("a", linuxDoOAuthMaxSubjectLen+1)
|
||||||
|
_, _, _, err = linuxDoParseUserInfo(`{"id":"`+tooLong+`"}`, cfg)
|
||||||
|
require.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseOAuthProviderErrorJSON(t *testing.T) {
|
||||||
|
code, desc := parseOAuthProviderError(`{"error":"invalid_client","error_description":"bad secret"}`)
|
||||||
|
require.Equal(t, "invalid_client", code)
|
||||||
|
require.Equal(t, "bad secret", desc)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseOAuthProviderErrorForm(t *testing.T) {
|
||||||
|
code, desc := parseOAuthProviderError("error=invalid_request&error_description=Missing+code_verifier")
|
||||||
|
require.Equal(t, "invalid_request", code)
|
||||||
|
require.Equal(t, "Missing code_verifier", desc)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseLinuxDoTokenResponseJSON(t *testing.T) {
|
||||||
|
token, ok := parseLinuxDoTokenResponse(`{"access_token":"t1","token_type":"Bearer","expires_in":3600,"scope":"user"}`)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, "t1", token.AccessToken)
|
||||||
|
require.Equal(t, "Bearer", token.TokenType)
|
||||||
|
require.Equal(t, int64(3600), token.ExpiresIn)
|
||||||
|
require.Equal(t, "user", token.Scope)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseLinuxDoTokenResponseForm(t *testing.T) {
|
||||||
|
token, ok := parseLinuxDoTokenResponse("access_token=t2&token_type=bearer&expires_in=60")
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, "t2", token.AccessToken)
|
||||||
|
require.Equal(t, "bearer", token.TokenType)
|
||||||
|
require.Equal(t, int64(60), token.ExpiresIn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSingleLineStripsWhitespace(t *testing.T) {
|
||||||
|
require.Equal(t, "hello world", singleLine("hello\r\nworld"))
|
||||||
|
require.Equal(t, "", singleLine("\n\t\r"))
|
||||||
|
}
|
||||||
@@ -53,16 +53,18 @@ func APIKeyFromService(k *service.APIKey) *APIKey {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return &APIKey{
|
return &APIKey{
|
||||||
ID: k.ID,
|
ID: k.ID,
|
||||||
UserID: k.UserID,
|
UserID: k.UserID,
|
||||||
Key: k.Key,
|
Key: k.Key,
|
||||||
Name: k.Name,
|
Name: k.Name,
|
||||||
GroupID: k.GroupID,
|
GroupID: k.GroupID,
|
||||||
Status: k.Status,
|
Status: k.Status,
|
||||||
CreatedAt: k.CreatedAt,
|
IPWhitelist: k.IPWhitelist,
|
||||||
UpdatedAt: k.UpdatedAt,
|
IPBlacklist: k.IPBlacklist,
|
||||||
User: UserFromServiceShallow(k.User),
|
CreatedAt: k.CreatedAt,
|
||||||
Group: GroupFromServiceShallow(k.Group),
|
UpdatedAt: k.UpdatedAt,
|
||||||
|
User: UserFromServiceShallow(k.User),
|
||||||
|
Group: GroupFromServiceShallow(k.Group),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -250,11 +252,12 @@ func AccountSummaryFromService(a *service.Account) *AccountSummary {
|
|||||||
|
|
||||||
// usageLogFromServiceBase is a helper that converts service UsageLog to DTO.
|
// usageLogFromServiceBase is a helper that converts service UsageLog to DTO.
|
||||||
// The account parameter allows caller to control what Account info is included.
|
// The account parameter allows caller to control what Account info is included.
|
||||||
func usageLogFromServiceBase(l *service.UsageLog, account *AccountSummary) *UsageLog {
|
// The includeIPAddress parameter controls whether to include the IP address (admin-only).
|
||||||
|
func usageLogFromServiceBase(l *service.UsageLog, account *AccountSummary, includeIPAddress bool) *UsageLog {
|
||||||
if l == nil {
|
if l == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return &UsageLog{
|
result := &UsageLog{
|
||||||
ID: l.ID,
|
ID: l.ID,
|
||||||
UserID: l.UserID,
|
UserID: l.UserID,
|
||||||
APIKeyID: l.APIKeyID,
|
APIKeyID: l.APIKeyID,
|
||||||
@@ -290,21 +293,26 @@ func usageLogFromServiceBase(l *service.UsageLog, account *AccountSummary) *Usag
|
|||||||
Group: GroupFromServiceShallow(l.Group),
|
Group: GroupFromServiceShallow(l.Group),
|
||||||
Subscription: UserSubscriptionFromService(l.Subscription),
|
Subscription: UserSubscriptionFromService(l.Subscription),
|
||||||
}
|
}
|
||||||
|
// IP 地址仅对管理员可见
|
||||||
|
if includeIPAddress {
|
||||||
|
result.IPAddress = l.IPAddress
|
||||||
|
}
|
||||||
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
// UsageLogFromService converts a service UsageLog to DTO for regular users.
|
// UsageLogFromService converts a service UsageLog to DTO for regular users.
|
||||||
// It excludes Account details - users should not see account information.
|
// It excludes Account details and IP address - users should not see these.
|
||||||
func UsageLogFromService(l *service.UsageLog) *UsageLog {
|
func UsageLogFromService(l *service.UsageLog) *UsageLog {
|
||||||
return usageLogFromServiceBase(l, nil)
|
return usageLogFromServiceBase(l, nil, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
// UsageLogFromServiceAdmin converts a service UsageLog to DTO for admin users.
|
// UsageLogFromServiceAdmin converts a service UsageLog to DTO for admin users.
|
||||||
// It includes minimal Account info (ID, Name only).
|
// It includes minimal Account info (ID, Name only) and IP address.
|
||||||
func UsageLogFromServiceAdmin(l *service.UsageLog) *UsageLog {
|
func UsageLogFromServiceAdmin(l *service.UsageLog) *UsageLog {
|
||||||
if l == nil {
|
if l == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return usageLogFromServiceBase(l, AccountSummaryFromService(l.Account))
|
return usageLogFromServiceBase(l, AccountSummaryFromService(l.Account), true)
|
||||||
}
|
}
|
||||||
|
|
||||||
func SettingFromService(s *service.Setting) *Setting {
|
func SettingFromService(s *service.Setting) *Setting {
|
||||||
|
|||||||
@@ -17,6 +17,11 @@ type SystemSettings struct {
|
|||||||
TurnstileSiteKey string `json:"turnstile_site_key"`
|
TurnstileSiteKey string `json:"turnstile_site_key"`
|
||||||
TurnstileSecretKeyConfigured bool `json:"turnstile_secret_key_configured"`
|
TurnstileSecretKeyConfigured bool `json:"turnstile_secret_key_configured"`
|
||||||
|
|
||||||
|
LinuxDoConnectEnabled bool `json:"linuxdo_connect_enabled"`
|
||||||
|
LinuxDoConnectClientID string `json:"linuxdo_connect_client_id"`
|
||||||
|
LinuxDoConnectClientSecretConfigured bool `json:"linuxdo_connect_client_secret_configured"`
|
||||||
|
LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"`
|
||||||
|
|
||||||
SiteName string `json:"site_name"`
|
SiteName string `json:"site_name"`
|
||||||
SiteLogo string `json:"site_logo"`
|
SiteLogo string `json:"site_logo"`
|
||||||
SiteSubtitle string `json:"site_subtitle"`
|
SiteSubtitle string `json:"site_subtitle"`
|
||||||
@@ -50,5 +55,6 @@ type PublicSettings struct {
|
|||||||
APIBaseURL string `json:"api_base_url"`
|
APIBaseURL string `json:"api_base_url"`
|
||||||
ContactInfo string `json:"contact_info"`
|
ContactInfo string `json:"contact_info"`
|
||||||
DocURL string `json:"doc_url"`
|
DocURL string `json:"doc_url"`
|
||||||
|
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
||||||
Version string `json:"version"`
|
Version string `json:"version"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -20,14 +20,16 @@ type User struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type APIKey struct {
|
type APIKey struct {
|
||||||
ID int64 `json:"id"`
|
ID int64 `json:"id"`
|
||||||
UserID int64 `json:"user_id"`
|
UserID int64 `json:"user_id"`
|
||||||
Key string `json:"key"`
|
Key string `json:"key"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
GroupID *int64 `json:"group_id"`
|
GroupID *int64 `json:"group_id"`
|
||||||
Status string `json:"status"`
|
Status string `json:"status"`
|
||||||
CreatedAt time.Time `json:"created_at"`
|
IPWhitelist []string `json:"ip_whitelist"`
|
||||||
UpdatedAt time.Time `json:"updated_at"`
|
IPBlacklist []string `json:"ip_blacklist"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
|
|
||||||
User *User `json:"user,omitempty"`
|
User *User `json:"user,omitempty"`
|
||||||
Group *Group `json:"group,omitempty"`
|
Group *Group `json:"group,omitempty"`
|
||||||
@@ -187,6 +189,9 @@ type UsageLog struct {
|
|||||||
// User-Agent
|
// User-Agent
|
||||||
UserAgent *string `json:"user_agent"`
|
UserAgent *string `json:"user_agent"`
|
||||||
|
|
||||||
|
// IP 地址(仅管理员可见)
|
||||||
|
IPAddress *string `json:"ip_address,omitempty"`
|
||||||
|
|
||||||
CreatedAt time.Time `json:"created_at"`
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
|
||||||
User *User `json:"user,omitempty"`
|
User *User `json:"user,omitempty"`
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import (
|
|||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||||
pkgerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
pkgerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
@@ -114,6 +115,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
// 获取 User-Agent
|
// 获取 User-Agent
|
||||||
userAgent := c.Request.UserAgent()
|
userAgent := c.Request.UserAgent()
|
||||||
|
|
||||||
|
// 获取客户端 IP
|
||||||
|
clientIP := ip.GetClientIP(c)
|
||||||
|
|
||||||
// 0. 检查wait队列是否已满
|
// 0. 检查wait队列是否已满
|
||||||
maxWait := service.CalculateMaxWait(subject.Concurrency)
|
maxWait := service.CalculateMaxWait(subject.Concurrency)
|
||||||
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
|
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
|
||||||
@@ -273,7 +277,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 异步记录使用量(subscription已在函数开头获取)
|
// 异步记录使用量(subscription已在函数开头获取)
|
||||||
go func(result *service.ForwardResult, usedAccount *service.Account, ua string) {
|
go func(result *service.ForwardResult, usedAccount *service.Account, ua string, cip string) {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||||
@@ -283,10 +287,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
Account: usedAccount,
|
Account: usedAccount,
|
||||||
Subscription: subscription,
|
Subscription: subscription,
|
||||||
UserAgent: ua,
|
UserAgent: ua,
|
||||||
|
IPAddress: cip,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
log.Printf("Record usage failed: %v", err)
|
log.Printf("Record usage failed: %v", err)
|
||||||
}
|
}
|
||||||
}(result, account, userAgent)
|
}(result, account, userAgent, clientIP)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -401,7 +406,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 异步记录使用量(subscription已在函数开头获取)
|
// 异步记录使用量(subscription已在函数开头获取)
|
||||||
go func(result *service.ForwardResult, usedAccount *service.Account, ua string) {
|
go func(result *service.ForwardResult, usedAccount *service.Account, ua string, cip string) {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||||
@@ -411,10 +416,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
Account: usedAccount,
|
Account: usedAccount,
|
||||||
Subscription: subscription,
|
Subscription: subscription,
|
||||||
UserAgent: ua,
|
UserAgent: ua,
|
||||||
|
IPAddress: cip,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
log.Printf("Record usage failed: %v", err)
|
log.Printf("Record usage failed: %v", err)
|
||||||
}
|
}
|
||||||
}(result, account, userAgent)
|
}(result, account, userAgent, clientIP)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/gemini"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/gemini"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
|
||||||
@@ -167,6 +168,9 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
// 获取 User-Agent
|
// 获取 User-Agent
|
||||||
userAgent := c.Request.UserAgent()
|
userAgent := c.Request.UserAgent()
|
||||||
|
|
||||||
|
// 获取客户端 IP
|
||||||
|
clientIP := ip.GetClientIP(c)
|
||||||
|
|
||||||
// For Gemini native API, do not send Claude-style ping frames.
|
// For Gemini native API, do not send Claude-style ping frames.
|
||||||
geminiConcurrency := NewConcurrencyHelper(h.concurrencyHelper.concurrencyService, SSEPingFormatNone, 0)
|
geminiConcurrency := NewConcurrencyHelper(h.concurrencyHelper.concurrencyService, SSEPingFormatNone, 0)
|
||||||
|
|
||||||
@@ -307,7 +311,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 6) record usage async
|
// 6) record usage async
|
||||||
go func(result *service.ForwardResult, usedAccount *service.Account, ua string) {
|
go func(result *service.ForwardResult, usedAccount *service.Account, ua string, cip string) {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||||
@@ -317,10 +321,11 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
Account: usedAccount,
|
Account: usedAccount,
|
||||||
Subscription: subscription,
|
Subscription: subscription,
|
||||||
UserAgent: ua,
|
UserAgent: ua,
|
||||||
|
IPAddress: cip,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
log.Printf("Record usage failed: %v", err)
|
log.Printf("Record usage failed: %v", err)
|
||||||
}
|
}
|
||||||
}(result, account, userAgent)
|
}(result, account, userAgent, clientIP)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
@@ -94,6 +95,10 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
|||||||
|
|
||||||
// For non-Codex CLI requests, set default instructions
|
// For non-Codex CLI requests, set default instructions
|
||||||
userAgent := c.GetHeader("User-Agent")
|
userAgent := c.GetHeader("User-Agent")
|
||||||
|
|
||||||
|
// 获取客户端 IP
|
||||||
|
clientIP := ip.GetClientIP(c)
|
||||||
|
|
||||||
if !openai.IsCodexCLIRequest(userAgent) {
|
if !openai.IsCodexCLIRequest(userAgent) {
|
||||||
reqBody["instructions"] = openai.DefaultInstructions
|
reqBody["instructions"] = openai.DefaultInstructions
|
||||||
// Re-serialize body
|
// Re-serialize body
|
||||||
@@ -242,7 +247,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Async record usage
|
// Async record usage
|
||||||
go func(result *service.OpenAIForwardResult, usedAccount *service.Account, ua string) {
|
go func(result *service.OpenAIForwardResult, usedAccount *service.Account, ua string, cip string) {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||||
@@ -252,10 +257,11 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
|||||||
Account: usedAccount,
|
Account: usedAccount,
|
||||||
Subscription: subscription,
|
Subscription: subscription,
|
||||||
UserAgent: ua,
|
UserAgent: ua,
|
||||||
|
IPAddress: cip,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
log.Printf("Record usage failed: %v", err)
|
log.Printf("Record usage failed: %v", err)
|
||||||
}
|
}
|
||||||
}(result, account, userAgent)
|
}(result, account, userAgent, clientIP)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -42,6 +42,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
|
|||||||
APIBaseURL: settings.APIBaseURL,
|
APIBaseURL: settings.APIBaseURL,
|
||||||
ContactInfo: settings.ContactInfo,
|
ContactInfo: settings.ContactInfo,
|
||||||
DocURL: settings.DocURL,
|
DocURL: settings.DocURL,
|
||||||
|
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
|
||||||
Version: h.version,
|
Version: h.version,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
168
backend/internal/pkg/ip/ip.go
Normal file
168
backend/internal/pkg/ip/ip.go
Normal file
@@ -0,0 +1,168 @@
|
|||||||
|
// Package ip 提供客户端 IP 地址提取工具。
|
||||||
|
package ip
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetClientIP 从 Gin Context 中提取客户端真实 IP 地址。
|
||||||
|
// 按以下优先级检查 Header:
|
||||||
|
// 1. CF-Connecting-IP (Cloudflare)
|
||||||
|
// 2. X-Real-IP (Nginx)
|
||||||
|
// 3. X-Forwarded-For (取第一个非私有 IP)
|
||||||
|
// 4. c.ClientIP() (Gin 内置方法)
|
||||||
|
func GetClientIP(c *gin.Context) string {
|
||||||
|
// 1. Cloudflare
|
||||||
|
if ip := c.GetHeader("CF-Connecting-IP"); ip != "" {
|
||||||
|
return normalizeIP(ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Nginx X-Real-IP
|
||||||
|
if ip := c.GetHeader("X-Real-IP"); ip != "" {
|
||||||
|
return normalizeIP(ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. X-Forwarded-For (多个 IP 时取第一个公网 IP)
|
||||||
|
if xff := c.GetHeader("X-Forwarded-For"); xff != "" {
|
||||||
|
ips := strings.Split(xff, ",")
|
||||||
|
for _, ip := range ips {
|
||||||
|
ip = strings.TrimSpace(ip)
|
||||||
|
if ip != "" && !isPrivateIP(ip) {
|
||||||
|
return normalizeIP(ip)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 如果都是私有 IP,返回第一个
|
||||||
|
if len(ips) > 0 {
|
||||||
|
return normalizeIP(strings.TrimSpace(ips[0]))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. Gin 内置方法
|
||||||
|
return normalizeIP(c.ClientIP())
|
||||||
|
}
|
||||||
|
|
||||||
|
// normalizeIP 规范化 IP 地址,去除端口号和空格。
|
||||||
|
func normalizeIP(ip string) string {
|
||||||
|
ip = strings.TrimSpace(ip)
|
||||||
|
// 移除端口号(如 "192.168.1.1:8080" -> "192.168.1.1")
|
||||||
|
if host, _, err := net.SplitHostPort(ip); err == nil {
|
||||||
|
return host
|
||||||
|
}
|
||||||
|
return ip
|
||||||
|
}
|
||||||
|
|
||||||
|
// isPrivateIP 检查 IP 是否为私有地址。
|
||||||
|
func isPrivateIP(ipStr string) bool {
|
||||||
|
ip := net.ParseIP(ipStr)
|
||||||
|
if ip == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 私有 IP 范围
|
||||||
|
privateBlocks := []string{
|
||||||
|
"10.0.0.0/8",
|
||||||
|
"172.16.0.0/12",
|
||||||
|
"192.168.0.0/16",
|
||||||
|
"127.0.0.0/8",
|
||||||
|
"::1/128",
|
||||||
|
"fc00::/7",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, block := range privateBlocks {
|
||||||
|
_, cidr, err := net.ParseCIDR(block)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if cidr.Contains(ip) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// MatchesPattern 检查 IP 是否匹配指定的模式(支持单个 IP 或 CIDR)。
|
||||||
|
// pattern 可以是:
|
||||||
|
// - 单个 IP: "192.168.1.100"
|
||||||
|
// - CIDR 范围: "192.168.1.0/24"
|
||||||
|
func MatchesPattern(clientIP, pattern string) bool {
|
||||||
|
ip := net.ParseIP(clientIP)
|
||||||
|
if ip == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 尝试解析为 CIDR
|
||||||
|
if strings.Contains(pattern, "/") {
|
||||||
|
_, cidr, err := net.ParseCIDR(pattern)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return cidr.Contains(ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 作为单个 IP 处理
|
||||||
|
patternIP := net.ParseIP(pattern)
|
||||||
|
if patternIP == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return ip.Equal(patternIP)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MatchesAnyPattern 检查 IP 是否匹配任意一个模式。
|
||||||
|
func MatchesAnyPattern(clientIP string, patterns []string) bool {
|
||||||
|
for _, pattern := range patterns {
|
||||||
|
if MatchesPattern(clientIP, pattern) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// CheckIPRestriction 检查 IP 是否被 API Key 的 IP 限制允许。
|
||||||
|
// 返回值:(是否允许, 拒绝原因)
|
||||||
|
// 逻辑:
|
||||||
|
// 1. 先检查黑名单,如果在黑名单中则直接拒绝
|
||||||
|
// 2. 如果白名单不为空,IP 必须在白名单中
|
||||||
|
// 3. 如果白名单为空,允许访问(除非被黑名单拒绝)
|
||||||
|
func CheckIPRestriction(clientIP string, whitelist, blacklist []string) (bool, string) {
|
||||||
|
// 规范化 IP
|
||||||
|
clientIP = normalizeIP(clientIP)
|
||||||
|
if clientIP == "" {
|
||||||
|
return false, "access denied"
|
||||||
|
}
|
||||||
|
|
||||||
|
// 1. 检查黑名单
|
||||||
|
if len(blacklist) > 0 && MatchesAnyPattern(clientIP, blacklist) {
|
||||||
|
return false, "access denied"
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. 检查白名单(如果设置了白名单,IP 必须在其中)
|
||||||
|
if len(whitelist) > 0 && !MatchesAnyPattern(clientIP, whitelist) {
|
||||||
|
return false, "access denied"
|
||||||
|
}
|
||||||
|
|
||||||
|
return true, ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateIPPattern 验证 IP 或 CIDR 格式是否有效。
|
||||||
|
func ValidateIPPattern(pattern string) bool {
|
||||||
|
if strings.Contains(pattern, "/") {
|
||||||
|
_, _, err := net.ParseCIDR(pattern)
|
||||||
|
return err == nil
|
||||||
|
}
|
||||||
|
return net.ParseIP(pattern) != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateIPPatterns 验证多个 IP 或 CIDR 格式。
|
||||||
|
// 返回无效的模式列表。
|
||||||
|
func ValidateIPPatterns(patterns []string) []string {
|
||||||
|
var invalid []string
|
||||||
|
for _, p := range patterns {
|
||||||
|
if !ValidateIPPattern(p) {
|
||||||
|
invalid = append(invalid, p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return invalid
|
||||||
|
}
|
||||||
@@ -675,6 +675,40 @@ func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetA
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *accountRepository) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope service.AntigravityQuotaScope, resetAt time.Time) error {
|
||||||
|
now := time.Now().UTC()
|
||||||
|
payload := map[string]string{
|
||||||
|
"rate_limited_at": now.Format(time.RFC3339),
|
||||||
|
"rate_limit_reset_at": resetAt.UTC().Format(time.RFC3339),
|
||||||
|
}
|
||||||
|
raw, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
path := "{antigravity_quota_scopes," + string(scope) + "}"
|
||||||
|
client := clientFromContext(ctx, r.client)
|
||||||
|
result, err := client.ExecContext(
|
||||||
|
ctx,
|
||||||
|
"UPDATE accounts SET extra = jsonb_set(COALESCE(extra, '{}'::jsonb), $1::text[], $2::jsonb, true), updated_at = NOW() WHERE id = $3 AND deleted_at IS NULL",
|
||||||
|
path,
|
||||||
|
raw,
|
||||||
|
id,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
affected, err := result.RowsAffected()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if affected == 0 {
|
||||||
|
return service.ErrAccountNotFound
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (r *accountRepository) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
|
func (r *accountRepository) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
|
||||||
_, err := r.client.Account.Update().
|
_, err := r.client.Account.Update().
|
||||||
Where(dbaccount.IDEQ(id)).
|
Where(dbaccount.IDEQ(id)).
|
||||||
@@ -718,6 +752,27 @@ func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *accountRepository) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error {
|
||||||
|
client := clientFromContext(ctx, r.client)
|
||||||
|
result, err := client.ExecContext(
|
||||||
|
ctx,
|
||||||
|
"UPDATE accounts SET extra = COALESCE(extra, '{}'::jsonb) - 'antigravity_quota_scopes', updated_at = NOW() WHERE id = $1 AND deleted_at IS NULL",
|
||||||
|
id,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
affected, err := result.RowsAffected()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if affected == 0 {
|
||||||
|
return service.ErrAccountNotFound
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (r *accountRepository) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
|
func (r *accountRepository) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
|
||||||
builder := r.client.Account.Update().
|
builder := r.client.Account.Update().
|
||||||
Where(dbaccount.IDEQ(id)).
|
Where(dbaccount.IDEQ(id)).
|
||||||
@@ -831,6 +886,11 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
|
|||||||
args = append(args, *updates.Status)
|
args = append(args, *updates.Status)
|
||||||
idx++
|
idx++
|
||||||
}
|
}
|
||||||
|
if updates.Schedulable != nil {
|
||||||
|
setClauses = append(setClauses, "schedulable = $"+itoa(idx))
|
||||||
|
args = append(args, *updates.Schedulable)
|
||||||
|
idx++
|
||||||
|
}
|
||||||
// JSONB 需要合并而非覆盖,使用 raw SQL 保持旧行为。
|
// JSONB 需要合并而非覆盖,使用 raw SQL 保持旧行为。
|
||||||
if len(updates.Credentials) > 0 {
|
if len(updates.Credentials) > 0 {
|
||||||
payload, err := json.Marshal(updates.Credentials)
|
payload, err := json.Marshal(updates.Credentials)
|
||||||
|
|||||||
@@ -26,13 +26,21 @@ func (r *apiKeyRepository) activeQuery() *dbent.APIKeyQuery {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *apiKeyRepository) Create(ctx context.Context, key *service.APIKey) error {
|
func (r *apiKeyRepository) Create(ctx context.Context, key *service.APIKey) error {
|
||||||
created, err := r.client.APIKey.Create().
|
builder := r.client.APIKey.Create().
|
||||||
SetUserID(key.UserID).
|
SetUserID(key.UserID).
|
||||||
SetKey(key.Key).
|
SetKey(key.Key).
|
||||||
SetName(key.Name).
|
SetName(key.Name).
|
||||||
SetStatus(key.Status).
|
SetStatus(key.Status).
|
||||||
SetNillableGroupID(key.GroupID).
|
SetNillableGroupID(key.GroupID)
|
||||||
Save(ctx)
|
|
||||||
|
if len(key.IPWhitelist) > 0 {
|
||||||
|
builder.SetIPWhitelist(key.IPWhitelist)
|
||||||
|
}
|
||||||
|
if len(key.IPBlacklist) > 0 {
|
||||||
|
builder.SetIPBlacklist(key.IPBlacklist)
|
||||||
|
}
|
||||||
|
|
||||||
|
created, err := builder.Save(ctx)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
key.ID = created.ID
|
key.ID = created.ID
|
||||||
key.CreatedAt = created.CreatedAt
|
key.CreatedAt = created.CreatedAt
|
||||||
@@ -108,6 +116,18 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) erro
|
|||||||
builder.ClearGroupID()
|
builder.ClearGroupID()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IP 限制字段
|
||||||
|
if len(key.IPWhitelist) > 0 {
|
||||||
|
builder.SetIPWhitelist(key.IPWhitelist)
|
||||||
|
} else {
|
||||||
|
builder.ClearIPWhitelist()
|
||||||
|
}
|
||||||
|
if len(key.IPBlacklist) > 0 {
|
||||||
|
builder.SetIPBlacklist(key.IPBlacklist)
|
||||||
|
} else {
|
||||||
|
builder.ClearIPBlacklist()
|
||||||
|
}
|
||||||
|
|
||||||
affected, err := builder.Save(ctx)
|
affected, err := builder.Save(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -268,14 +288,16 @@ func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
out := &service.APIKey{
|
out := &service.APIKey{
|
||||||
ID: m.ID,
|
ID: m.ID,
|
||||||
UserID: m.UserID,
|
UserID: m.UserID,
|
||||||
Key: m.Key,
|
Key: m.Key,
|
||||||
Name: m.Name,
|
Name: m.Name,
|
||||||
Status: m.Status,
|
Status: m.Status,
|
||||||
CreatedAt: m.CreatedAt,
|
IPWhitelist: m.IPWhitelist,
|
||||||
UpdatedAt: m.UpdatedAt,
|
IPBlacklist: m.IPBlacklist,
|
||||||
GroupID: m.GroupID,
|
CreatedAt: m.CreatedAt,
|
||||||
|
UpdatedAt: m.UpdatedAt,
|
||||||
|
GroupID: m.GroupID,
|
||||||
}
|
}
|
||||||
if m.Edges.User != nil {
|
if m.Edges.User != nil {
|
||||||
out.User = userEntityToService(m.Edges.User)
|
out.User = userEntityToService(m.Edges.User)
|
||||||
|
|||||||
@@ -112,10 +112,10 @@ func (r *groupRepository) Delete(ctx context.Context, id int64) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *groupRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Group, *pagination.PaginationResult, error) {
|
func (r *groupRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Group, *pagination.PaginationResult, error) {
|
||||||
return r.ListWithFilters(ctx, params, "", "", nil)
|
return r.ListWithFilters(ctx, params, "", "", "", nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]service.Group, *pagination.PaginationResult, error) {
|
func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]service.Group, *pagination.PaginationResult, error) {
|
||||||
q := r.client.Group.Query()
|
q := r.client.Group.Query()
|
||||||
|
|
||||||
if platform != "" {
|
if platform != "" {
|
||||||
@@ -124,6 +124,12 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination
|
|||||||
if status != "" {
|
if status != "" {
|
||||||
q = q.Where(group.StatusEQ(status))
|
q = q.Where(group.StatusEQ(status))
|
||||||
}
|
}
|
||||||
|
if search != "" {
|
||||||
|
q = q.Where(group.Or(
|
||||||
|
group.NameContainsFold(search),
|
||||||
|
group.DescriptionContainsFold(search),
|
||||||
|
))
|
||||||
|
}
|
||||||
if isExclusive != nil {
|
if isExclusive != nil {
|
||||||
q = q.Where(group.IsExclusiveEQ(*isExclusive))
|
q = q.Where(group.IsExclusiveEQ(*isExclusive))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -131,6 +131,7 @@ func (s *GroupRepoSuite) TestListWithFilters_Platform() {
|
|||||||
pagination.PaginationParams{Page: 1, PageSize: 10},
|
pagination.PaginationParams{Page: 1, PageSize: 10},
|
||||||
service.PlatformOpenAI,
|
service.PlatformOpenAI,
|
||||||
"",
|
"",
|
||||||
|
"",
|
||||||
nil,
|
nil,
|
||||||
)
|
)
|
||||||
s.Require().NoError(err, "ListWithFilters base")
|
s.Require().NoError(err, "ListWithFilters base")
|
||||||
@@ -152,7 +153,7 @@ func (s *GroupRepoSuite) TestListWithFilters_Platform() {
|
|||||||
SubscriptionType: service.SubscriptionTypeStandard,
|
SubscriptionType: service.SubscriptionTypeStandard,
|
||||||
}))
|
}))
|
||||||
|
|
||||||
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformOpenAI, "", nil)
|
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformOpenAI, "", "", nil)
|
||||||
s.Require().NoError(err)
|
s.Require().NoError(err)
|
||||||
s.Require().Len(groups, len(baseGroups)+1)
|
s.Require().Len(groups, len(baseGroups)+1)
|
||||||
// Verify all groups are OpenAI platform
|
// Verify all groups are OpenAI platform
|
||||||
@@ -179,7 +180,7 @@ func (s *GroupRepoSuite) TestListWithFilters_Status() {
|
|||||||
SubscriptionType: service.SubscriptionTypeStandard,
|
SubscriptionType: service.SubscriptionTypeStandard,
|
||||||
}))
|
}))
|
||||||
|
|
||||||
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.StatusDisabled, nil)
|
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.StatusDisabled, "", nil)
|
||||||
s.Require().NoError(err)
|
s.Require().NoError(err)
|
||||||
s.Require().Len(groups, 1)
|
s.Require().Len(groups, 1)
|
||||||
s.Require().Equal(service.StatusDisabled, groups[0].Status)
|
s.Require().Equal(service.StatusDisabled, groups[0].Status)
|
||||||
@@ -204,12 +205,117 @@ func (s *GroupRepoSuite) TestListWithFilters_IsExclusive() {
|
|||||||
}))
|
}))
|
||||||
|
|
||||||
isExclusive := true
|
isExclusive := true
|
||||||
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", &isExclusive)
|
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "", &isExclusive)
|
||||||
s.Require().NoError(err)
|
s.Require().NoError(err)
|
||||||
s.Require().Len(groups, 1)
|
s.Require().Len(groups, 1)
|
||||||
s.Require().True(groups[0].IsExclusive)
|
s.Require().True(groups[0].IsExclusive)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *GroupRepoSuite) TestListWithFilters_Search() {
|
||||||
|
newRepo := func() (*groupRepository, context.Context) {
|
||||||
|
tx := testEntTx(s.T())
|
||||||
|
return newGroupRepositoryWithSQL(tx.Client(), tx), context.Background()
|
||||||
|
}
|
||||||
|
|
||||||
|
containsID := func(groups []service.Group, id int64) bool {
|
||||||
|
for i := range groups {
|
||||||
|
if groups[i].ID == id {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
mustCreate := func(repo *groupRepository, ctx context.Context, g *service.Group) *service.Group {
|
||||||
|
s.Require().NoError(repo.Create(ctx, g))
|
||||||
|
s.Require().NotZero(g.ID)
|
||||||
|
return g
|
||||||
|
}
|
||||||
|
|
||||||
|
newGroup := func(name string) *service.Group {
|
||||||
|
return &service.Group{
|
||||||
|
Name: name,
|
||||||
|
Platform: service.PlatformAnthropic,
|
||||||
|
RateMultiplier: 1.0,
|
||||||
|
IsExclusive: false,
|
||||||
|
Status: service.StatusActive,
|
||||||
|
SubscriptionType: service.SubscriptionTypeStandard,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
s.Run("search_name_should_match", func() {
|
||||||
|
repo, ctx := newRepo()
|
||||||
|
|
||||||
|
target := mustCreate(repo, ctx, newGroup("it-group-search-name-target"))
|
||||||
|
other := mustCreate(repo, ctx, newGroup("it-group-search-name-other"))
|
||||||
|
|
||||||
|
groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "name-target", nil)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().True(containsID(groups, target.ID), "expected target group to match by name")
|
||||||
|
s.Require().False(containsID(groups, other.ID), "expected other group to be filtered out")
|
||||||
|
})
|
||||||
|
|
||||||
|
s.Run("search_description_should_match", func() {
|
||||||
|
repo, ctx := newRepo()
|
||||||
|
|
||||||
|
target := newGroup("it-group-search-desc-target")
|
||||||
|
target.Description = "something about desc-needle in here"
|
||||||
|
target = mustCreate(repo, ctx, target)
|
||||||
|
|
||||||
|
other := newGroup("it-group-search-desc-other")
|
||||||
|
other.Description = "nothing to see here"
|
||||||
|
other = mustCreate(repo, ctx, other)
|
||||||
|
|
||||||
|
groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "desc-needle", nil)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().True(containsID(groups, target.ID), "expected target group to match by description")
|
||||||
|
s.Require().False(containsID(groups, other.ID), "expected other group to be filtered out")
|
||||||
|
})
|
||||||
|
|
||||||
|
s.Run("search_nonexistent_should_return_empty", func() {
|
||||||
|
repo, ctx := newRepo()
|
||||||
|
|
||||||
|
_ = mustCreate(repo, ctx, newGroup("it-group-search-nonexistent-baseline"))
|
||||||
|
|
||||||
|
search := s.T().Name() + "__no_such_group__"
|
||||||
|
groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", search, nil)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Empty(groups)
|
||||||
|
})
|
||||||
|
|
||||||
|
s.Run("search_should_be_case_insensitive", func() {
|
||||||
|
repo, ctx := newRepo()
|
||||||
|
|
||||||
|
target := mustCreate(repo, ctx, newGroup("MiXeDCaSe-Needle"))
|
||||||
|
other := mustCreate(repo, ctx, newGroup("it-group-search-case-other"))
|
||||||
|
|
||||||
|
groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "mixedcase-needle", nil)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().True(containsID(groups, target.ID), "expected case-insensitive match")
|
||||||
|
s.Require().False(containsID(groups, other.ID), "expected other group to be filtered out")
|
||||||
|
})
|
||||||
|
|
||||||
|
s.Run("search_should_escape_like_wildcards", func() {
|
||||||
|
repo, ctx := newRepo()
|
||||||
|
|
||||||
|
percentTarget := mustCreate(repo, ctx, newGroup("it-group-search-100%-target"))
|
||||||
|
percentOther := mustCreate(repo, ctx, newGroup("it-group-search-100X-other"))
|
||||||
|
|
||||||
|
groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "100%", nil)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().True(containsID(groups, percentTarget.ID), "expected literal %% match")
|
||||||
|
s.Require().False(containsID(groups, percentOther.ID), "expected %% not to act as wildcard")
|
||||||
|
|
||||||
|
underscoreTarget := mustCreate(repo, ctx, newGroup("it-group-search-ab_cd-target"))
|
||||||
|
underscoreOther := mustCreate(repo, ctx, newGroup("it-group-search-abXcd-other"))
|
||||||
|
|
||||||
|
groups, _, err = repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "ab_cd", nil)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().True(containsID(groups, underscoreTarget.ID), "expected literal _ match")
|
||||||
|
s.Require().False(containsID(groups, underscoreOther.ID), "expected _ not to act as wildcard")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func (s *GroupRepoSuite) TestListWithFilters_AccountCount() {
|
func (s *GroupRepoSuite) TestListWithFilters_AccountCount() {
|
||||||
g1 := &service.Group{
|
g1 := &service.Group{
|
||||||
Name: "g1",
|
Name: "g1",
|
||||||
@@ -244,7 +350,7 @@ func (s *GroupRepoSuite) TestListWithFilters_AccountCount() {
|
|||||||
s.Require().NoError(err)
|
s.Require().NoError(err)
|
||||||
|
|
||||||
isExclusive := true
|
isExclusive := true
|
||||||
groups, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformAnthropic, service.StatusActive, &isExclusive)
|
groups, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformAnthropic, service.StatusActive, "", &isExclusive)
|
||||||
s.Require().NoError(err, "ListWithFilters")
|
s.Require().NoError(err, "ListWithFilters")
|
||||||
s.Require().Equal(int64(1), page.Total)
|
s.Require().Equal(int64(1), page.Total)
|
||||||
s.Require().Len(groups, 1)
|
s.Require().Len(groups, 1)
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ import (
|
|||||||
"github.com/lib/pq"
|
"github.com/lib/pq"
|
||||||
)
|
)
|
||||||
|
|
||||||
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, image_count, image_size, created_at"
|
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, created_at"
|
||||||
|
|
||||||
type usageLogRepository struct {
|
type usageLogRepository struct {
|
||||||
client *dbent.Client
|
client *dbent.Client
|
||||||
@@ -110,6 +110,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
|||||||
duration_ms,
|
duration_ms,
|
||||||
first_token_ms,
|
first_token_ms,
|
||||||
user_agent,
|
user_agent,
|
||||||
|
ip_address,
|
||||||
image_count,
|
image_count,
|
||||||
image_size,
|
image_size,
|
||||||
created_at
|
created_at
|
||||||
@@ -119,7 +120,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
|||||||
$8, $9, $10, $11,
|
$8, $9, $10, $11,
|
||||||
$12, $13,
|
$12, $13,
|
||||||
$14, $15, $16, $17, $18, $19,
|
$14, $15, $16, $17, $18, $19,
|
||||||
$20, $21, $22, $23, $24, $25, $26, $27, $28
|
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29
|
||||||
)
|
)
|
||||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||||
RETURNING id, created_at
|
RETURNING id, created_at
|
||||||
@@ -130,6 +131,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
|||||||
duration := nullInt(log.DurationMs)
|
duration := nullInt(log.DurationMs)
|
||||||
firstToken := nullInt(log.FirstTokenMs)
|
firstToken := nullInt(log.FirstTokenMs)
|
||||||
userAgent := nullString(log.UserAgent)
|
userAgent := nullString(log.UserAgent)
|
||||||
|
ipAddress := nullString(log.IPAddress)
|
||||||
imageSize := nullString(log.ImageSize)
|
imageSize := nullString(log.ImageSize)
|
||||||
|
|
||||||
var requestIDArg any
|
var requestIDArg any
|
||||||
@@ -163,6 +165,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
|||||||
duration,
|
duration,
|
||||||
firstToken,
|
firstToken,
|
||||||
userAgent,
|
userAgent,
|
||||||
|
ipAddress,
|
||||||
log.ImageCount,
|
log.ImageCount,
|
||||||
imageSize,
|
imageSize,
|
||||||
createdAt,
|
createdAt,
|
||||||
@@ -1873,6 +1876,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
|||||||
durationMs sql.NullInt64
|
durationMs sql.NullInt64
|
||||||
firstTokenMs sql.NullInt64
|
firstTokenMs sql.NullInt64
|
||||||
userAgent sql.NullString
|
userAgent sql.NullString
|
||||||
|
ipAddress sql.NullString
|
||||||
imageCount int
|
imageCount int
|
||||||
imageSize sql.NullString
|
imageSize sql.NullString
|
||||||
createdAt time.Time
|
createdAt time.Time
|
||||||
@@ -1905,6 +1909,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
|||||||
&durationMs,
|
&durationMs,
|
||||||
&firstTokenMs,
|
&firstTokenMs,
|
||||||
&userAgent,
|
&userAgent,
|
||||||
|
&ipAddress,
|
||||||
&imageCount,
|
&imageCount,
|
||||||
&imageSize,
|
&imageSize,
|
||||||
&createdAt,
|
&createdAt,
|
||||||
@@ -1959,6 +1964,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
|||||||
if userAgent.Valid {
|
if userAgent.Valid {
|
||||||
log.UserAgent = &userAgent.String
|
log.UserAgent = &userAgent.String
|
||||||
}
|
}
|
||||||
|
if ipAddress.Valid {
|
||||||
|
log.IPAddress = &ipAddress.String
|
||||||
|
}
|
||||||
if imageSize.Valid {
|
if imageSize.Valid {
|
||||||
log.ImageSize = &imageSize.String
|
log.ImageSize = &imageSize.String
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -82,6 +82,8 @@ func TestAPIContracts(t *testing.T) {
|
|||||||
"name": "Key One",
|
"name": "Key One",
|
||||||
"group_id": null,
|
"group_id": null,
|
||||||
"status": "active",
|
"status": "active",
|
||||||
|
"ip_whitelist": null,
|
||||||
|
"ip_blacklist": null,
|
||||||
"created_at": "2025-01-02T03:04:05Z",
|
"created_at": "2025-01-02T03:04:05Z",
|
||||||
"updated_at": "2025-01-02T03:04:05Z"
|
"updated_at": "2025-01-02T03:04:05Z"
|
||||||
}
|
}
|
||||||
@@ -116,6 +118,8 @@ func TestAPIContracts(t *testing.T) {
|
|||||||
"name": "Key One",
|
"name": "Key One",
|
||||||
"group_id": null,
|
"group_id": null,
|
||||||
"status": "active",
|
"status": "active",
|
||||||
|
"ip_whitelist": null,
|
||||||
|
"ip_blacklist": null,
|
||||||
"created_at": "2025-01-02T03:04:05Z",
|
"created_at": "2025-01-02T03:04:05Z",
|
||||||
"updated_at": "2025-01-02T03:04:05Z"
|
"updated_at": "2025-01-02T03:04:05Z"
|
||||||
}
|
}
|
||||||
@@ -304,6 +308,10 @@ func TestAPIContracts(t *testing.T) {
|
|||||||
"turnstile_enabled": true,
|
"turnstile_enabled": true,
|
||||||
"turnstile_site_key": "site-key",
|
"turnstile_site_key": "site-key",
|
||||||
"turnstile_secret_key_configured": true,
|
"turnstile_secret_key_configured": true,
|
||||||
|
"linuxdo_connect_enabled": false,
|
||||||
|
"linuxdo_connect_client_id": "",
|
||||||
|
"linuxdo_connect_client_secret_configured": false,
|
||||||
|
"linuxdo_connect_redirect_url": "",
|
||||||
"site_name": "Sub2API",
|
"site_name": "Sub2API",
|
||||||
"site_logo": "",
|
"site_logo": "",
|
||||||
"site_subtitle": "Subtitle",
|
"site_subtitle": "Subtitle",
|
||||||
@@ -390,7 +398,7 @@ func newContractDeps(t *testing.T) *contractDeps {
|
|||||||
settingRepo := newStubSettingRepo()
|
settingRepo := newStubSettingRepo()
|
||||||
settingService := service.NewSettingService(settingRepo, cfg)
|
settingService := service.NewSettingService(settingRepo, cfg)
|
||||||
|
|
||||||
authHandler := handler.NewAuthHandler(cfg, nil, userService)
|
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService)
|
||||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||||
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
||||||
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil)
|
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil)
|
||||||
@@ -583,7 +591,7 @@ func (stubGroupRepo) List(ctx context.Context, params pagination.PaginationParam
|
|||||||
return nil, nil, errors.New("not implemented")
|
return nil, nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (stubGroupRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]service.Group, *pagination.PaginationResult, error) {
|
func (stubGroupRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]service.Group, *pagination.PaginationResult, error) {
|
||||||
return nil, nil, errors.New("not implemented")
|
return nil, nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -71,6 +72,17 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 检查 IP 限制(白名单/黑名单)
|
||||||
|
// 注意:错误信息故意模糊,避免暴露具体的 IP 限制机制
|
||||||
|
if len(apiKey.IPWhitelist) > 0 || len(apiKey.IPBlacklist) > 0 {
|
||||||
|
clientIP := ip.GetClientIP(c)
|
||||||
|
allowed, _ := ip.CheckIPRestriction(clientIP, apiKey.IPWhitelist, apiKey.IPBlacklist)
|
||||||
|
if !allowed {
|
||||||
|
AbortWithError(c, 403, "ACCESS_DENIED", "Access denied")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 检查关联的用户
|
// 检查关联的用户
|
||||||
if apiKey.User == nil {
|
if apiKey.User == nil {
|
||||||
AbortWithError(c, 401, "USER_NOT_FOUND", "User associated with API key not found")
|
AbortWithError(c, 401, "USER_NOT_FOUND", "User associated with API key not found")
|
||||||
|
|||||||
@@ -19,6 +19,8 @@ func RegisterAuthRoutes(
|
|||||||
auth.POST("/register", h.Auth.Register)
|
auth.POST("/register", h.Auth.Register)
|
||||||
auth.POST("/login", h.Auth.Login)
|
auth.POST("/login", h.Auth.Login)
|
||||||
auth.POST("/send-verify-code", h.Auth.SendVerifyCode)
|
auth.POST("/send-verify-code", h.Auth.SendVerifyCode)
|
||||||
|
auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart)
|
||||||
|
auth.GET("/oauth/linuxdo/callback", h.Auth.LinuxDoOAuthCallback)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 公开设置(无需认证)
|
// 公开设置(无需认证)
|
||||||
|
|||||||
@@ -49,10 +49,12 @@ type AccountRepository interface {
|
|||||||
ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error)
|
ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error)
|
||||||
|
|
||||||
SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error
|
SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error
|
||||||
|
SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error
|
||||||
SetOverloaded(ctx context.Context, id int64, until time.Time) error
|
SetOverloaded(ctx context.Context, id int64, until time.Time) error
|
||||||
SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error
|
SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error
|
||||||
ClearTempUnschedulable(ctx context.Context, id int64) error
|
ClearTempUnschedulable(ctx context.Context, id int64) error
|
||||||
ClearRateLimit(ctx context.Context, id int64) error
|
ClearRateLimit(ctx context.Context, id int64) error
|
||||||
|
ClearAntigravityQuotaScopes(ctx context.Context, id int64) error
|
||||||
UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error
|
UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error
|
||||||
UpdateExtra(ctx context.Context, id int64, updates map[string]any) error
|
UpdateExtra(ctx context.Context, id int64, updates map[string]any) error
|
||||||
BulkUpdate(ctx context.Context, ids []int64, updates AccountBulkUpdate) (int64, error)
|
BulkUpdate(ctx context.Context, ids []int64, updates AccountBulkUpdate) (int64, error)
|
||||||
@@ -66,6 +68,7 @@ type AccountBulkUpdate struct {
|
|||||||
Concurrency *int
|
Concurrency *int
|
||||||
Priority *int
|
Priority *int
|
||||||
Status *string
|
Status *string
|
||||||
|
Schedulable *bool
|
||||||
Credentials map[string]any
|
Credentials map[string]any
|
||||||
Extra map[string]any
|
Extra map[string]any
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -139,6 +139,10 @@ func (s *accountRepoStub) SetRateLimited(ctx context.Context, id int64, resetAt
|
|||||||
panic("unexpected SetRateLimited call")
|
panic("unexpected SetRateLimited call")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *accountRepoStub) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error {
|
||||||
|
panic("unexpected SetAntigravityQuotaScopeLimit call")
|
||||||
|
}
|
||||||
|
|
||||||
func (s *accountRepoStub) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
|
func (s *accountRepoStub) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
|
||||||
panic("unexpected SetOverloaded call")
|
panic("unexpected SetOverloaded call")
|
||||||
}
|
}
|
||||||
@@ -155,6 +159,10 @@ func (s *accountRepoStub) ClearRateLimit(ctx context.Context, id int64) error {
|
|||||||
panic("unexpected ClearRateLimit call")
|
panic("unexpected ClearRateLimit call")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *accountRepoStub) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error {
|
||||||
|
panic("unexpected ClearAntigravityQuotaScopes call")
|
||||||
|
}
|
||||||
|
|
||||||
func (s *accountRepoStub) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
|
func (s *accountRepoStub) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
|
||||||
panic("unexpected UpdateSessionWindow call")
|
panic("unexpected UpdateSessionWindow call")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ type AdminService interface {
|
|||||||
GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error)
|
GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error)
|
||||||
|
|
||||||
// Group management
|
// Group management
|
||||||
ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]Group, int64, error)
|
ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool) ([]Group, int64, error)
|
||||||
GetAllGroups(ctx context.Context) ([]Group, error)
|
GetAllGroups(ctx context.Context) ([]Group, error)
|
||||||
GetAllGroupsByPlatform(ctx context.Context, platform string) ([]Group, error)
|
GetAllGroupsByPlatform(ctx context.Context, platform string) ([]Group, error)
|
||||||
GetGroup(ctx context.Context, id int64) (*Group, error)
|
GetGroup(ctx context.Context, id int64) (*Group, error)
|
||||||
@@ -168,6 +168,7 @@ type BulkUpdateAccountsInput struct {
|
|||||||
Concurrency *int
|
Concurrency *int
|
||||||
Priority *int
|
Priority *int
|
||||||
Status string
|
Status string
|
||||||
|
Schedulable *bool
|
||||||
GroupIDs *[]int64
|
GroupIDs *[]int64
|
||||||
Credentials map[string]any
|
Credentials map[string]any
|
||||||
Extra map[string]any
|
Extra map[string]any
|
||||||
@@ -478,9 +479,9 @@ func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Group management implementations
|
// Group management implementations
|
||||||
func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]Group, int64, error) {
|
func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool) ([]Group, int64, error) {
|
||||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||||
groups, result, err := s.groupRepo.ListWithFilters(ctx, params, platform, status, isExclusive)
|
groups, result, err := s.groupRepo.ListWithFilters(ctx, params, platform, status, search, isExclusive)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
}
|
}
|
||||||
@@ -910,6 +911,9 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
|
|||||||
if input.Status != "" {
|
if input.Status != "" {
|
||||||
repoUpdates.Status = &input.Status
|
repoUpdates.Status = &input.Status
|
||||||
}
|
}
|
||||||
|
if input.Schedulable != nil {
|
||||||
|
repoUpdates.Schedulable = input.Schedulable
|
||||||
|
}
|
||||||
|
|
||||||
// Run bulk update for column/jsonb fields first.
|
// Run bulk update for column/jsonb fields first.
|
||||||
if _, err := s.accountRepo.BulkUpdate(ctx, input.AccountIDs, repoUpdates); err != nil {
|
if _, err := s.accountRepo.BulkUpdate(ctx, input.AccountIDs, repoUpdates); err != nil {
|
||||||
|
|||||||
@@ -124,7 +124,7 @@ func (s *groupRepoStub) List(ctx context.Context, params pagination.PaginationPa
|
|||||||
panic("unexpected List call")
|
panic("unexpected List call")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *groupRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) {
|
func (s *groupRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) {
|
||||||
panic("unexpected ListWithFilters call")
|
panic("unexpected ListWithFilters call")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -16,6 +16,16 @@ type groupRepoStubForAdmin struct {
|
|||||||
updated *Group // 记录 Update 调用的参数
|
updated *Group // 记录 Update 调用的参数
|
||||||
getByID *Group // GetByID 返回值
|
getByID *Group // GetByID 返回值
|
||||||
getErr error // GetByID 返回的错误
|
getErr error // GetByID 返回的错误
|
||||||
|
|
||||||
|
listWithFiltersCalls int
|
||||||
|
listWithFiltersParams pagination.PaginationParams
|
||||||
|
listWithFiltersPlatform string
|
||||||
|
listWithFiltersStatus string
|
||||||
|
listWithFiltersSearch string
|
||||||
|
listWithFiltersIsExclusive *bool
|
||||||
|
listWithFiltersGroups []Group
|
||||||
|
listWithFiltersResult *pagination.PaginationResult
|
||||||
|
listWithFiltersErr error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *groupRepoStubForAdmin) Create(_ context.Context, g *Group) error {
|
func (s *groupRepoStubForAdmin) Create(_ context.Context, g *Group) error {
|
||||||
@@ -47,8 +57,28 @@ func (s *groupRepoStubForAdmin) List(_ context.Context, _ pagination.PaginationP
|
|||||||
panic("unexpected List call")
|
panic("unexpected List call")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *groupRepoStubForAdmin) ListWithFilters(_ context.Context, _ pagination.PaginationParams, _, _ string, _ *bool) ([]Group, *pagination.PaginationResult, error) {
|
func (s *groupRepoStubForAdmin) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) {
|
||||||
panic("unexpected ListWithFilters call")
|
s.listWithFiltersCalls++
|
||||||
|
s.listWithFiltersParams = params
|
||||||
|
s.listWithFiltersPlatform = platform
|
||||||
|
s.listWithFiltersStatus = status
|
||||||
|
s.listWithFiltersSearch = search
|
||||||
|
s.listWithFiltersIsExclusive = isExclusive
|
||||||
|
|
||||||
|
if s.listWithFiltersErr != nil {
|
||||||
|
return nil, nil, s.listWithFiltersErr
|
||||||
|
}
|
||||||
|
|
||||||
|
result := s.listWithFiltersResult
|
||||||
|
if result == nil {
|
||||||
|
result = &pagination.PaginationResult{
|
||||||
|
Total: int64(len(s.listWithFiltersGroups)),
|
||||||
|
Page: params.Page,
|
||||||
|
PageSize: params.PageSize,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.listWithFiltersGroups, result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *groupRepoStubForAdmin) ListActive(_ context.Context) ([]Group, error) {
|
func (s *groupRepoStubForAdmin) ListActive(_ context.Context) ([]Group, error) {
|
||||||
@@ -195,3 +225,68 @@ func TestAdminService_UpdateGroup_PartialImagePricing(t *testing.T) {
|
|||||||
require.InDelta(t, 0.15, *repo.updated.ImagePrice2K, 0.0001) // 原值保持
|
require.InDelta(t, 0.15, *repo.updated.ImagePrice2K, 0.0001) // 原值保持
|
||||||
require.Nil(t, repo.updated.ImagePrice4K)
|
require.Nil(t, repo.updated.ImagePrice4K)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAdminService_ListGroups_WithSearch(t *testing.T) {
|
||||||
|
// 测试:
|
||||||
|
// 1. search 参数正常传递到 repository 层
|
||||||
|
// 2. search 为空字符串时的行为
|
||||||
|
// 3. search 与其他过滤条件组合使用
|
||||||
|
|
||||||
|
t.Run("search 参数正常传递到 repository 层", func(t *testing.T) {
|
||||||
|
repo := &groupRepoStubForAdmin{
|
||||||
|
listWithFiltersGroups: []Group{{ID: 1, Name: "alpha"}},
|
||||||
|
listWithFiltersResult: &pagination.PaginationResult{Total: 1},
|
||||||
|
}
|
||||||
|
svc := &adminServiceImpl{groupRepo: repo}
|
||||||
|
|
||||||
|
groups, total, err := svc.ListGroups(context.Background(), 1, 20, "", "", "alpha", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, int64(1), total)
|
||||||
|
require.Equal(t, []Group{{ID: 1, Name: "alpha"}}, groups)
|
||||||
|
|
||||||
|
require.Equal(t, 1, repo.listWithFiltersCalls)
|
||||||
|
require.Equal(t, pagination.PaginationParams{Page: 1, PageSize: 20}, repo.listWithFiltersParams)
|
||||||
|
require.Equal(t, "alpha", repo.listWithFiltersSearch)
|
||||||
|
require.Nil(t, repo.listWithFiltersIsExclusive)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("search 为空字符串时传递空字符串", func(t *testing.T) {
|
||||||
|
repo := &groupRepoStubForAdmin{
|
||||||
|
listWithFiltersGroups: []Group{},
|
||||||
|
listWithFiltersResult: &pagination.PaginationResult{Total: 0},
|
||||||
|
}
|
||||||
|
svc := &adminServiceImpl{groupRepo: repo}
|
||||||
|
|
||||||
|
groups, total, err := svc.ListGroups(context.Background(), 2, 10, "", "", "", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Empty(t, groups)
|
||||||
|
require.Equal(t, int64(0), total)
|
||||||
|
|
||||||
|
require.Equal(t, 1, repo.listWithFiltersCalls)
|
||||||
|
require.Equal(t, pagination.PaginationParams{Page: 2, PageSize: 10}, repo.listWithFiltersParams)
|
||||||
|
require.Equal(t, "", repo.listWithFiltersSearch)
|
||||||
|
require.Nil(t, repo.listWithFiltersIsExclusive)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("search 与其他过滤条件组合使用", func(t *testing.T) {
|
||||||
|
isExclusive := true
|
||||||
|
repo := &groupRepoStubForAdmin{
|
||||||
|
listWithFiltersGroups: []Group{{ID: 2, Name: "beta"}},
|
||||||
|
listWithFiltersResult: &pagination.PaginationResult{Total: 42},
|
||||||
|
}
|
||||||
|
svc := &adminServiceImpl{groupRepo: repo}
|
||||||
|
|
||||||
|
groups, total, err := svc.ListGroups(context.Background(), 3, 50, PlatformAntigravity, StatusActive, "beta", &isExclusive)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, int64(42), total)
|
||||||
|
require.Equal(t, []Group{{ID: 2, Name: "beta"}}, groups)
|
||||||
|
|
||||||
|
require.Equal(t, 1, repo.listWithFiltersCalls)
|
||||||
|
require.Equal(t, pagination.PaginationParams{Page: 3, PageSize: 50}, repo.listWithFiltersParams)
|
||||||
|
require.Equal(t, PlatformAntigravity, repo.listWithFiltersPlatform)
|
||||||
|
require.Equal(t, StatusActive, repo.listWithFiltersStatus)
|
||||||
|
require.Equal(t, "beta", repo.listWithFiltersSearch)
|
||||||
|
require.NotNil(t, repo.listWithFiltersIsExclusive)
|
||||||
|
require.True(t, *repo.listWithFiltersIsExclusive)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
238
backend/internal/service/admin_service_search_test.go
Normal file
238
backend/internal/service/admin_service_search_test.go
Normal file
@@ -0,0 +1,238 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type accountRepoStubForAdminList struct {
|
||||||
|
accountRepoStub
|
||||||
|
|
||||||
|
listWithFiltersCalls int
|
||||||
|
listWithFiltersParams pagination.PaginationParams
|
||||||
|
listWithFiltersPlatform string
|
||||||
|
listWithFiltersType string
|
||||||
|
listWithFiltersStatus string
|
||||||
|
listWithFiltersSearch string
|
||||||
|
listWithFiltersAccounts []Account
|
||||||
|
listWithFiltersResult *pagination.PaginationResult
|
||||||
|
listWithFiltersErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *accountRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) {
|
||||||
|
s.listWithFiltersCalls++
|
||||||
|
s.listWithFiltersParams = params
|
||||||
|
s.listWithFiltersPlatform = platform
|
||||||
|
s.listWithFiltersType = accountType
|
||||||
|
s.listWithFiltersStatus = status
|
||||||
|
s.listWithFiltersSearch = search
|
||||||
|
|
||||||
|
if s.listWithFiltersErr != nil {
|
||||||
|
return nil, nil, s.listWithFiltersErr
|
||||||
|
}
|
||||||
|
|
||||||
|
result := s.listWithFiltersResult
|
||||||
|
if result == nil {
|
||||||
|
result = &pagination.PaginationResult{
|
||||||
|
Total: int64(len(s.listWithFiltersAccounts)),
|
||||||
|
Page: params.Page,
|
||||||
|
PageSize: params.PageSize,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.listWithFiltersAccounts, result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type proxyRepoStubForAdminList struct {
|
||||||
|
proxyRepoStub
|
||||||
|
|
||||||
|
listWithFiltersCalls int
|
||||||
|
listWithFiltersParams pagination.PaginationParams
|
||||||
|
listWithFiltersProtocol string
|
||||||
|
listWithFiltersStatus string
|
||||||
|
listWithFiltersSearch string
|
||||||
|
listWithFiltersProxies []Proxy
|
||||||
|
listWithFiltersResult *pagination.PaginationResult
|
||||||
|
listWithFiltersErr error
|
||||||
|
|
||||||
|
listWithFiltersAndAccountCountCalls int
|
||||||
|
listWithFiltersAndAccountCountParams pagination.PaginationParams
|
||||||
|
listWithFiltersAndAccountCountProtocol string
|
||||||
|
listWithFiltersAndAccountCountStatus string
|
||||||
|
listWithFiltersAndAccountCountSearch string
|
||||||
|
listWithFiltersAndAccountCountProxies []ProxyWithAccountCount
|
||||||
|
listWithFiltersAndAccountCountResult *pagination.PaginationResult
|
||||||
|
listWithFiltersAndAccountCountErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *proxyRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, protocol, status, search string) ([]Proxy, *pagination.PaginationResult, error) {
|
||||||
|
s.listWithFiltersCalls++
|
||||||
|
s.listWithFiltersParams = params
|
||||||
|
s.listWithFiltersProtocol = protocol
|
||||||
|
s.listWithFiltersStatus = status
|
||||||
|
s.listWithFiltersSearch = search
|
||||||
|
|
||||||
|
if s.listWithFiltersErr != nil {
|
||||||
|
return nil, nil, s.listWithFiltersErr
|
||||||
|
}
|
||||||
|
|
||||||
|
result := s.listWithFiltersResult
|
||||||
|
if result == nil {
|
||||||
|
result = &pagination.PaginationResult{
|
||||||
|
Total: int64(len(s.listWithFiltersProxies)),
|
||||||
|
Page: params.Page,
|
||||||
|
PageSize: params.PageSize,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.listWithFiltersProxies, result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *proxyRepoStubForAdminList) ListWithFiltersAndAccountCount(_ context.Context, params pagination.PaginationParams, protocol, status, search string) ([]ProxyWithAccountCount, *pagination.PaginationResult, error) {
|
||||||
|
s.listWithFiltersAndAccountCountCalls++
|
||||||
|
s.listWithFiltersAndAccountCountParams = params
|
||||||
|
s.listWithFiltersAndAccountCountProtocol = protocol
|
||||||
|
s.listWithFiltersAndAccountCountStatus = status
|
||||||
|
s.listWithFiltersAndAccountCountSearch = search
|
||||||
|
|
||||||
|
if s.listWithFiltersAndAccountCountErr != nil {
|
||||||
|
return nil, nil, s.listWithFiltersAndAccountCountErr
|
||||||
|
}
|
||||||
|
|
||||||
|
result := s.listWithFiltersAndAccountCountResult
|
||||||
|
if result == nil {
|
||||||
|
result = &pagination.PaginationResult{
|
||||||
|
Total: int64(len(s.listWithFiltersAndAccountCountProxies)),
|
||||||
|
Page: params.Page,
|
||||||
|
PageSize: params.PageSize,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.listWithFiltersAndAccountCountProxies, result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type redeemRepoStubForAdminList struct {
|
||||||
|
redeemRepoStub
|
||||||
|
|
||||||
|
listWithFiltersCalls int
|
||||||
|
listWithFiltersParams pagination.PaginationParams
|
||||||
|
listWithFiltersType string
|
||||||
|
listWithFiltersStatus string
|
||||||
|
listWithFiltersSearch string
|
||||||
|
listWithFiltersCodes []RedeemCode
|
||||||
|
listWithFiltersResult *pagination.PaginationResult
|
||||||
|
listWithFiltersErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *redeemRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, codeType, status, search string) ([]RedeemCode, *pagination.PaginationResult, error) {
|
||||||
|
s.listWithFiltersCalls++
|
||||||
|
s.listWithFiltersParams = params
|
||||||
|
s.listWithFiltersType = codeType
|
||||||
|
s.listWithFiltersStatus = status
|
||||||
|
s.listWithFiltersSearch = search
|
||||||
|
|
||||||
|
if s.listWithFiltersErr != nil {
|
||||||
|
return nil, nil, s.listWithFiltersErr
|
||||||
|
}
|
||||||
|
|
||||||
|
result := s.listWithFiltersResult
|
||||||
|
if result == nil {
|
||||||
|
result = &pagination.PaginationResult{
|
||||||
|
Total: int64(len(s.listWithFiltersCodes)),
|
||||||
|
Page: params.Page,
|
||||||
|
PageSize: params.PageSize,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.listWithFiltersCodes, result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminService_ListAccounts_WithSearch(t *testing.T) {
|
||||||
|
t.Run("search 参数正常传递到 repository 层", func(t *testing.T) {
|
||||||
|
repo := &accountRepoStubForAdminList{
|
||||||
|
listWithFiltersAccounts: []Account{{ID: 1, Name: "acc"}},
|
||||||
|
listWithFiltersResult: &pagination.PaginationResult{Total: 10},
|
||||||
|
}
|
||||||
|
svc := &adminServiceImpl{accountRepo: repo}
|
||||||
|
|
||||||
|
accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformGemini, AccountTypeOAuth, StatusActive, "acc")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, int64(10), total)
|
||||||
|
require.Equal(t, []Account{{ID: 1, Name: "acc"}}, accounts)
|
||||||
|
|
||||||
|
require.Equal(t, 1, repo.listWithFiltersCalls)
|
||||||
|
require.Equal(t, pagination.PaginationParams{Page: 1, PageSize: 20}, repo.listWithFiltersParams)
|
||||||
|
require.Equal(t, PlatformGemini, repo.listWithFiltersPlatform)
|
||||||
|
require.Equal(t, AccountTypeOAuth, repo.listWithFiltersType)
|
||||||
|
require.Equal(t, StatusActive, repo.listWithFiltersStatus)
|
||||||
|
require.Equal(t, "acc", repo.listWithFiltersSearch)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminService_ListProxies_WithSearch(t *testing.T) {
|
||||||
|
t.Run("search 参数正常传递到 repository 层", func(t *testing.T) {
|
||||||
|
repo := &proxyRepoStubForAdminList{
|
||||||
|
listWithFiltersProxies: []Proxy{{ID: 2, Name: "p1"}},
|
||||||
|
listWithFiltersResult: &pagination.PaginationResult{Total: 7},
|
||||||
|
}
|
||||||
|
svc := &adminServiceImpl{proxyRepo: repo}
|
||||||
|
|
||||||
|
proxies, total, err := svc.ListProxies(context.Background(), 3, 50, "http", StatusActive, "p1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, int64(7), total)
|
||||||
|
require.Equal(t, []Proxy{{ID: 2, Name: "p1"}}, proxies)
|
||||||
|
|
||||||
|
require.Equal(t, 1, repo.listWithFiltersCalls)
|
||||||
|
require.Equal(t, pagination.PaginationParams{Page: 3, PageSize: 50}, repo.listWithFiltersParams)
|
||||||
|
require.Equal(t, "http", repo.listWithFiltersProtocol)
|
||||||
|
require.Equal(t, StatusActive, repo.listWithFiltersStatus)
|
||||||
|
require.Equal(t, "p1", repo.listWithFiltersSearch)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminService_ListProxiesWithAccountCount_WithSearch(t *testing.T) {
|
||||||
|
t.Run("search 参数正常传递到 repository 层", func(t *testing.T) {
|
||||||
|
repo := &proxyRepoStubForAdminList{
|
||||||
|
listWithFiltersAndAccountCountProxies: []ProxyWithAccountCount{{Proxy: Proxy{ID: 3, Name: "p2"}, AccountCount: 5}},
|
||||||
|
listWithFiltersAndAccountCountResult: &pagination.PaginationResult{Total: 9},
|
||||||
|
}
|
||||||
|
svc := &adminServiceImpl{proxyRepo: repo}
|
||||||
|
|
||||||
|
proxies, total, err := svc.ListProxiesWithAccountCount(context.Background(), 2, 10, "socks5", StatusDisabled, "p2")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, int64(9), total)
|
||||||
|
require.Equal(t, []ProxyWithAccountCount{{Proxy: Proxy{ID: 3, Name: "p2"}, AccountCount: 5}}, proxies)
|
||||||
|
|
||||||
|
require.Equal(t, 1, repo.listWithFiltersAndAccountCountCalls)
|
||||||
|
require.Equal(t, pagination.PaginationParams{Page: 2, PageSize: 10}, repo.listWithFiltersAndAccountCountParams)
|
||||||
|
require.Equal(t, "socks5", repo.listWithFiltersAndAccountCountProtocol)
|
||||||
|
require.Equal(t, StatusDisabled, repo.listWithFiltersAndAccountCountStatus)
|
||||||
|
require.Equal(t, "p2", repo.listWithFiltersAndAccountCountSearch)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminService_ListRedeemCodes_WithSearch(t *testing.T) {
|
||||||
|
t.Run("search 参数正常传递到 repository 层", func(t *testing.T) {
|
||||||
|
repo := &redeemRepoStubForAdminList{
|
||||||
|
listWithFiltersCodes: []RedeemCode{{ID: 4, Code: "ABC"}},
|
||||||
|
listWithFiltersResult: &pagination.PaginationResult{Total: 3},
|
||||||
|
}
|
||||||
|
svc := &adminServiceImpl{redeemCodeRepo: repo}
|
||||||
|
|
||||||
|
codes, total, err := svc.ListRedeemCodes(context.Background(), 1, 20, RedeemTypeBalance, StatusUnused, "ABC")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, int64(3), total)
|
||||||
|
require.Equal(t, []RedeemCode{{ID: 4, Code: "ABC"}}, codes)
|
||||||
|
|
||||||
|
require.Equal(t, 1, repo.listWithFiltersCalls)
|
||||||
|
require.Equal(t, pagination.PaginationParams{Page: 1, PageSize: 20}, repo.listWithFiltersParams)
|
||||||
|
require.Equal(t, RedeemTypeBalance, repo.listWithFiltersType)
|
||||||
|
require.Equal(t, StatusUnused, repo.listWithFiltersStatus)
|
||||||
|
require.Equal(t, "ABC", repo.listWithFiltersSearch)
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -93,6 +93,7 @@ var antigravityPrefixMapping = []struct {
|
|||||||
// 长前缀优先
|
// 长前缀优先
|
||||||
{"gemini-2.5-flash-image", "gemini-3-pro-image"}, // gemini-2.5-flash-image → 3-pro-image
|
{"gemini-2.5-flash-image", "gemini-3-pro-image"}, // gemini-2.5-flash-image → 3-pro-image
|
||||||
{"gemini-3-pro-image", "gemini-3-pro-image"}, // gemini-3-pro-image-preview 等
|
{"gemini-3-pro-image", "gemini-3-pro-image"}, // gemini-3-pro-image-preview 等
|
||||||
|
{"gemini-3-flash", "gemini-3-flash"}, // gemini-3-flash-preview 等 → gemini-3-flash
|
||||||
{"claude-3-5-sonnet", "claude-sonnet-4-5"}, // 旧版 claude-3-5-sonnet-xxx
|
{"claude-3-5-sonnet", "claude-sonnet-4-5"}, // 旧版 claude-3-5-sonnet-xxx
|
||||||
{"claude-sonnet-4-5", "claude-sonnet-4-5"}, // claude-sonnet-4-5-xxx
|
{"claude-sonnet-4-5", "claude-sonnet-4-5"}, // claude-sonnet-4-5-xxx
|
||||||
{"claude-haiku-4-5", "claude-sonnet-4-5"}, // claude-haiku-4-5-xxx → sonnet
|
{"claude-haiku-4-5", "claude-sonnet-4-5"}, // claude-haiku-4-5-xxx → sonnet
|
||||||
@@ -502,6 +503,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
|||||||
|
|
||||||
originalModel := claudeReq.Model
|
originalModel := claudeReq.Model
|
||||||
mappedModel := s.getMappedModel(account, claudeReq.Model)
|
mappedModel := s.getMappedModel(account, claudeReq.Model)
|
||||||
|
quotaScope, _ := resolveAntigravityQuotaScope(originalModel)
|
||||||
|
|
||||||
// 获取 access_token
|
// 获取 access_token
|
||||||
if s.tokenProvider == nil {
|
if s.tokenProvider == nil {
|
||||||
@@ -603,7 +605,7 @@ urlFallbackLoop:
|
|||||||
}
|
}
|
||||||
// 所有重试都失败,标记限流状态
|
// 所有重试都失败,标记限流状态
|
||||||
if resp.StatusCode == 429 {
|
if resp.StatusCode == 429 {
|
||||||
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody)
|
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope)
|
||||||
}
|
}
|
||||||
// 最后一次尝试也失败
|
// 最后一次尝试也失败
|
||||||
resp = &http.Response{
|
resp = &http.Response{
|
||||||
@@ -696,7 +698,7 @@ urlFallbackLoop:
|
|||||||
|
|
||||||
// 处理错误响应(重试后仍失败或不触发重试)
|
// 处理错误响应(重试后仍失败或不触发重试)
|
||||||
if resp.StatusCode >= 400 {
|
if resp.StatusCode >= 400 {
|
||||||
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody)
|
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope)
|
||||||
|
|
||||||
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
||||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||||
@@ -1021,6 +1023,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
|||||||
if len(body) == 0 {
|
if len(body) == 0 {
|
||||||
return nil, s.writeGoogleError(c, http.StatusBadRequest, "Request body is empty")
|
return nil, s.writeGoogleError(c, http.StatusBadRequest, "Request body is empty")
|
||||||
}
|
}
|
||||||
|
quotaScope, _ := resolveAntigravityQuotaScope(originalModel)
|
||||||
|
|
||||||
// 解析请求以获取 image_size(用于图片计费)
|
// 解析请求以获取 image_size(用于图片计费)
|
||||||
imageSize := s.extractImageSize(body)
|
imageSize := s.extractImageSize(body)
|
||||||
@@ -1146,7 +1149,7 @@ urlFallbackLoop:
|
|||||||
}
|
}
|
||||||
// 所有重试都失败,标记限流状态
|
// 所有重试都失败,标记限流状态
|
||||||
if resp.StatusCode == 429 {
|
if resp.StatusCode == 429 {
|
||||||
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody)
|
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope)
|
||||||
}
|
}
|
||||||
resp = &http.Response{
|
resp = &http.Response{
|
||||||
StatusCode: resp.StatusCode,
|
StatusCode: resp.StatusCode,
|
||||||
@@ -1200,7 +1203,7 @@ urlFallbackLoop:
|
|||||||
goto handleSuccess
|
goto handleSuccess
|
||||||
}
|
}
|
||||||
|
|
||||||
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody)
|
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope)
|
||||||
|
|
||||||
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
||||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||||
@@ -1314,7 +1317,7 @@ func sleepAntigravityBackoffWithContext(ctx context.Context, attempt int) bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte) {
|
func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope) {
|
||||||
// 429 使用 Gemini 格式解析(从 body 解析重置时间)
|
// 429 使用 Gemini 格式解析(从 body 解析重置时间)
|
||||||
if statusCode == 429 {
|
if statusCode == 429 {
|
||||||
resetAt := ParseGeminiRateLimitResetTime(body)
|
resetAt := ParseGeminiRateLimitResetTime(body)
|
||||||
@@ -1325,13 +1328,23 @@ func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, pre
|
|||||||
defaultDur = 5 * time.Minute
|
defaultDur = 5 * time.Minute
|
||||||
}
|
}
|
||||||
ra := time.Now().Add(defaultDur)
|
ra := time.Now().Add(defaultDur)
|
||||||
log.Printf("%s status=429 rate_limited reset_in=%v (fallback)", prefix, defaultDur)
|
log.Printf("%s status=429 rate_limited scope=%s reset_in=%v (fallback)", prefix, quotaScope, defaultDur)
|
||||||
_ = s.accountRepo.SetRateLimited(ctx, account.ID, ra)
|
if quotaScope == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := s.accountRepo.SetAntigravityQuotaScopeLimit(ctx, account.ID, quotaScope, ra); err != nil {
|
||||||
|
log.Printf("%s status=429 rate_limit_set_failed scope=%s error=%v", prefix, quotaScope, err)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
resetTime := time.Unix(*resetAt, 0)
|
resetTime := time.Unix(*resetAt, 0)
|
||||||
log.Printf("%s status=429 rate_limited reset_at=%v reset_in=%v", prefix, resetTime.Format("15:04:05"), time.Until(resetTime).Truncate(time.Second))
|
log.Printf("%s status=429 rate_limited scope=%s reset_at=%v reset_in=%v", prefix, quotaScope, resetTime.Format("15:04:05"), time.Until(resetTime).Truncate(time.Second))
|
||||||
_ = s.accountRepo.SetRateLimited(ctx, account.ID, resetTime)
|
if quotaScope == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := s.accountRepo.SetAntigravityQuotaScopeLimit(ctx, account.ID, quotaScope, resetTime); err != nil {
|
||||||
|
log.Printf("%s status=429 rate_limit_set_failed scope=%s error=%v", prefix, quotaScope, err)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// 其他错误码继续使用 rateLimitService
|
// 其他错误码继续使用 rateLimitService
|
||||||
|
|||||||
88
backend/internal/service/antigravity_quota_scope.go
Normal file
88
backend/internal/service/antigravity_quota_scope.go
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const antigravityQuotaScopesKey = "antigravity_quota_scopes"
|
||||||
|
|
||||||
|
// AntigravityQuotaScope 表示 Antigravity 的配额域
|
||||||
|
type AntigravityQuotaScope string
|
||||||
|
|
||||||
|
const (
|
||||||
|
AntigravityQuotaScopeClaude AntigravityQuotaScope = "claude"
|
||||||
|
AntigravityQuotaScopeGeminiText AntigravityQuotaScope = "gemini_text"
|
||||||
|
AntigravityQuotaScopeGeminiImage AntigravityQuotaScope = "gemini_image"
|
||||||
|
)
|
||||||
|
|
||||||
|
// resolveAntigravityQuotaScope 根据模型名称解析配额域
|
||||||
|
func resolveAntigravityQuotaScope(requestedModel string) (AntigravityQuotaScope, bool) {
|
||||||
|
model := normalizeAntigravityModelName(requestedModel)
|
||||||
|
if model == "" {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
switch {
|
||||||
|
case strings.HasPrefix(model, "claude-"):
|
||||||
|
return AntigravityQuotaScopeClaude, true
|
||||||
|
case strings.HasPrefix(model, "gemini-"):
|
||||||
|
if isImageGenerationModel(model) {
|
||||||
|
return AntigravityQuotaScopeGeminiImage, true
|
||||||
|
}
|
||||||
|
return AntigravityQuotaScopeGeminiText, true
|
||||||
|
default:
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeAntigravityModelName(model string) string {
|
||||||
|
normalized := strings.ToLower(strings.TrimSpace(model))
|
||||||
|
normalized = strings.TrimPrefix(normalized, "models/")
|
||||||
|
return normalized
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsSchedulableForModel 结合 Antigravity 配额域限流判断是否可调度
|
||||||
|
func (a *Account) IsSchedulableForModel(requestedModel string) bool {
|
||||||
|
if a == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if !a.IsSchedulable() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if a.Platform != PlatformAntigravity {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
scope, ok := resolveAntigravityQuotaScope(requestedModel)
|
||||||
|
if !ok {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
resetAt := a.antigravityQuotaScopeResetAt(scope)
|
||||||
|
if resetAt == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
now := time.Now()
|
||||||
|
return !now.Before(*resetAt)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Account) antigravityQuotaScopeResetAt(scope AntigravityQuotaScope) *time.Time {
|
||||||
|
if a == nil || a.Extra == nil || scope == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
rawScopes, ok := a.Extra[antigravityQuotaScopesKey].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
rawScope, ok := rawScopes[string(scope)].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
resetAtRaw, ok := rawScope["rate_limit_reset_at"].(string)
|
||||||
|
if !ok || strings.TrimSpace(resetAtRaw) == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
resetAt, err := time.Parse(time.RFC3339, resetAtRaw)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &resetAt
|
||||||
|
}
|
||||||
@@ -3,16 +3,18 @@ package service
|
|||||||
import "time"
|
import "time"
|
||||||
|
|
||||||
type APIKey struct {
|
type APIKey struct {
|
||||||
ID int64
|
ID int64
|
||||||
UserID int64
|
UserID int64
|
||||||
Key string
|
Key string
|
||||||
Name string
|
Name string
|
||||||
GroupID *int64
|
GroupID *int64
|
||||||
Status string
|
Status string
|
||||||
CreatedAt time.Time
|
IPWhitelist []string
|
||||||
UpdatedAt time.Time
|
IPBlacklist []string
|
||||||
User *User
|
CreatedAt time.Time
|
||||||
Group *Group
|
UpdatedAt time.Time
|
||||||
|
User *User
|
||||||
|
Group *Group
|
||||||
}
|
}
|
||||||
|
|
||||||
func (k *APIKey) IsActive() bool {
|
func (k *APIKey) IsActive() bool {
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||||
)
|
)
|
||||||
@@ -20,6 +21,7 @@ var (
|
|||||||
ErrAPIKeyTooShort = infraerrors.BadRequest("API_KEY_TOO_SHORT", "api key must be at least 16 characters")
|
ErrAPIKeyTooShort = infraerrors.BadRequest("API_KEY_TOO_SHORT", "api key must be at least 16 characters")
|
||||||
ErrAPIKeyInvalidChars = infraerrors.BadRequest("API_KEY_INVALID_CHARS", "api key can only contain letters, numbers, underscores, and hyphens")
|
ErrAPIKeyInvalidChars = infraerrors.BadRequest("API_KEY_INVALID_CHARS", "api key can only contain letters, numbers, underscores, and hyphens")
|
||||||
ErrAPIKeyRateLimited = infraerrors.TooManyRequests("API_KEY_RATE_LIMITED", "too many failed attempts, please try again later")
|
ErrAPIKeyRateLimited = infraerrors.TooManyRequests("API_KEY_RATE_LIMITED", "too many failed attempts, please try again later")
|
||||||
|
ErrInvalidIPPattern = infraerrors.BadRequest("INVALID_IP_PATTERN", "invalid IP or CIDR pattern")
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -57,16 +59,20 @@ type APIKeyCache interface {
|
|||||||
|
|
||||||
// CreateAPIKeyRequest 创建API Key请求
|
// CreateAPIKeyRequest 创建API Key请求
|
||||||
type CreateAPIKeyRequest struct {
|
type CreateAPIKeyRequest struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
GroupID *int64 `json:"group_id"`
|
GroupID *int64 `json:"group_id"`
|
||||||
CustomKey *string `json:"custom_key"` // 可选的自定义key
|
CustomKey *string `json:"custom_key"` // 可选的自定义key
|
||||||
|
IPWhitelist []string `json:"ip_whitelist"` // IP 白名单
|
||||||
|
IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateAPIKeyRequest 更新API Key请求
|
// UpdateAPIKeyRequest 更新API Key请求
|
||||||
type UpdateAPIKeyRequest struct {
|
type UpdateAPIKeyRequest struct {
|
||||||
Name *string `json:"name"`
|
Name *string `json:"name"`
|
||||||
GroupID *int64 `json:"group_id"`
|
GroupID *int64 `json:"group_id"`
|
||||||
Status *string `json:"status"`
|
Status *string `json:"status"`
|
||||||
|
IPWhitelist []string `json:"ip_whitelist"` // IP 白名单(空数组清空)
|
||||||
|
IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单(空数组清空)
|
||||||
}
|
}
|
||||||
|
|
||||||
// APIKeyService API Key服务
|
// APIKeyService API Key服务
|
||||||
@@ -186,6 +192,20 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK
|
|||||||
return nil, fmt.Errorf("get user: %w", err)
|
return nil, fmt.Errorf("get user: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 验证 IP 白名单格式
|
||||||
|
if len(req.IPWhitelist) > 0 {
|
||||||
|
if invalid := ip.ValidateIPPatterns(req.IPWhitelist); len(invalid) > 0 {
|
||||||
|
return nil, fmt.Errorf("%w: %v", ErrInvalidIPPattern, invalid)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证 IP 黑名单格式
|
||||||
|
if len(req.IPBlacklist) > 0 {
|
||||||
|
if invalid := ip.ValidateIPPatterns(req.IPBlacklist); len(invalid) > 0 {
|
||||||
|
return nil, fmt.Errorf("%w: %v", ErrInvalidIPPattern, invalid)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 验证分组权限(如果指定了分组)
|
// 验证分组权限(如果指定了分组)
|
||||||
if req.GroupID != nil {
|
if req.GroupID != nil {
|
||||||
group, err := s.groupRepo.GetByID(ctx, *req.GroupID)
|
group, err := s.groupRepo.GetByID(ctx, *req.GroupID)
|
||||||
@@ -236,11 +256,13 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK
|
|||||||
|
|
||||||
// 创建API Key记录
|
// 创建API Key记录
|
||||||
apiKey := &APIKey{
|
apiKey := &APIKey{
|
||||||
UserID: userID,
|
UserID: userID,
|
||||||
Key: key,
|
Key: key,
|
||||||
Name: req.Name,
|
Name: req.Name,
|
||||||
GroupID: req.GroupID,
|
GroupID: req.GroupID,
|
||||||
Status: StatusActive,
|
Status: StatusActive,
|
||||||
|
IPWhitelist: req.IPWhitelist,
|
||||||
|
IPBlacklist: req.IPBlacklist,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.apiKeyRepo.Create(ctx, apiKey); err != nil {
|
if err := s.apiKeyRepo.Create(ctx, apiKey); err != nil {
|
||||||
@@ -312,6 +334,20 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req
|
|||||||
return nil, ErrInsufficientPerms
|
return nil, ErrInsufficientPerms
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 验证 IP 白名单格式
|
||||||
|
if len(req.IPWhitelist) > 0 {
|
||||||
|
if invalid := ip.ValidateIPPatterns(req.IPWhitelist); len(invalid) > 0 {
|
||||||
|
return nil, fmt.Errorf("%w: %v", ErrInvalidIPPattern, invalid)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证 IP 黑名单格式
|
||||||
|
if len(req.IPBlacklist) > 0 {
|
||||||
|
if invalid := ip.ValidateIPPatterns(req.IPBlacklist); len(invalid) > 0 {
|
||||||
|
return nil, fmt.Errorf("%w: %v", ErrInvalidIPPattern, invalid)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 更新字段
|
// 更新字段
|
||||||
if req.Name != nil {
|
if req.Name != nil {
|
||||||
apiKey.Name = *req.Name
|
apiKey.Name = *req.Name
|
||||||
@@ -344,6 +380,10 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 更新 IP 限制(空数组会清空设置)
|
||||||
|
apiKey.IPWhitelist = req.IPWhitelist
|
||||||
|
apiKey.IPBlacklist = req.IPBlacklist
|
||||||
|
|
||||||
if err := s.apiKeyRepo.Update(ctx, apiKey); err != nil {
|
if err := s.apiKeyRepo.Update(ctx, apiKey); err != nil {
|
||||||
return nil, fmt.Errorf("update api key: %w", err)
|
return nil, fmt.Errorf("update api key: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,9 +2,13 @@ package service
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/hex"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
|
"net/mail"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
@@ -18,6 +22,7 @@ var (
|
|||||||
ErrInvalidCredentials = infraerrors.Unauthorized("INVALID_CREDENTIALS", "invalid email or password")
|
ErrInvalidCredentials = infraerrors.Unauthorized("INVALID_CREDENTIALS", "invalid email or password")
|
||||||
ErrUserNotActive = infraerrors.Forbidden("USER_NOT_ACTIVE", "user is not active")
|
ErrUserNotActive = infraerrors.Forbidden("USER_NOT_ACTIVE", "user is not active")
|
||||||
ErrEmailExists = infraerrors.Conflict("EMAIL_EXISTS", "email already exists")
|
ErrEmailExists = infraerrors.Conflict("EMAIL_EXISTS", "email already exists")
|
||||||
|
ErrEmailReserved = infraerrors.BadRequest("EMAIL_RESERVED", "email is reserved")
|
||||||
ErrInvalidToken = infraerrors.Unauthorized("INVALID_TOKEN", "invalid token")
|
ErrInvalidToken = infraerrors.Unauthorized("INVALID_TOKEN", "invalid token")
|
||||||
ErrTokenExpired = infraerrors.Unauthorized("TOKEN_EXPIRED", "token has expired")
|
ErrTokenExpired = infraerrors.Unauthorized("TOKEN_EXPIRED", "token has expired")
|
||||||
ErrTokenTooLarge = infraerrors.BadRequest("TOKEN_TOO_LARGE", "token too large")
|
ErrTokenTooLarge = infraerrors.BadRequest("TOKEN_TOO_LARGE", "token too large")
|
||||||
@@ -80,6 +85,11 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
|
|||||||
return "", nil, ErrRegDisabled
|
return "", nil, ErrRegDisabled
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 防止用户注册 LinuxDo OAuth 合成邮箱,避免第三方登录与本地账号发生碰撞。
|
||||||
|
if isReservedEmail(email) {
|
||||||
|
return "", nil, ErrEmailReserved
|
||||||
|
}
|
||||||
|
|
||||||
// 检查是否需要邮件验证
|
// 检查是否需要邮件验证
|
||||||
if s.settingService != nil && s.settingService.IsEmailVerifyEnabled(ctx) {
|
if s.settingService != nil && s.settingService.IsEmailVerifyEnabled(ctx) {
|
||||||
// 如果邮件验证已开启但邮件服务未配置,拒绝注册
|
// 如果邮件验证已开启但邮件服务未配置,拒绝注册
|
||||||
@@ -161,6 +171,10 @@ func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error {
|
|||||||
return ErrRegDisabled
|
return ErrRegDisabled
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if isReservedEmail(email) {
|
||||||
|
return ErrEmailReserved
|
||||||
|
}
|
||||||
|
|
||||||
// 检查邮箱是否已存在
|
// 检查邮箱是否已存在
|
||||||
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
|
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -195,6 +209,10 @@ func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*S
|
|||||||
return nil, ErrRegDisabled
|
return nil, ErrRegDisabled
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if isReservedEmail(email) {
|
||||||
|
return nil, ErrEmailReserved
|
||||||
|
}
|
||||||
|
|
||||||
// 检查邮箱是否已存在
|
// 检查邮箱是否已存在
|
||||||
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
|
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -319,6 +337,102 @@ func (s *AuthService) Login(ctx context.Context, email, password string) (string
|
|||||||
return token, user, nil
|
return token, user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// LoginOrRegisterOAuth 用于第三方 OAuth/SSO 登录:
|
||||||
|
// - 如果邮箱已存在:直接登录(不需要本地密码)
|
||||||
|
// - 如果邮箱不存在:创建新用户并登录
|
||||||
|
//
|
||||||
|
// 注意:该函数用于“终端用户登录 Sub2API 本身”的场景(不同于上游账号的 OAuth,例如 OpenAI/Gemini)。
|
||||||
|
// 为了满足现有数据库约束(需要密码哈希),新用户会生成随机密码并进行哈希保存。
|
||||||
|
func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username string) (string, *User, error) {
|
||||||
|
email = strings.TrimSpace(email)
|
||||||
|
if email == "" || len(email) > 255 {
|
||||||
|
return "", nil, infraerrors.BadRequest("INVALID_EMAIL", "invalid email")
|
||||||
|
}
|
||||||
|
if _, err := mail.ParseAddress(email); err != nil {
|
||||||
|
return "", nil, infraerrors.BadRequest("INVALID_EMAIL", "invalid email")
|
||||||
|
}
|
||||||
|
|
||||||
|
username = strings.TrimSpace(username)
|
||||||
|
if len([]rune(username)) > 100 {
|
||||||
|
username = string([]rune(username)[:100])
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := s.userRepo.GetByEmail(ctx, email)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, ErrUserNotFound) {
|
||||||
|
// OAuth 首次登录视为注册。
|
||||||
|
if s.settingService != nil && !s.settingService.IsRegistrationEnabled(ctx) {
|
||||||
|
return "", nil, ErrRegDisabled
|
||||||
|
}
|
||||||
|
|
||||||
|
randomPassword, err := randomHexString(32)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[Auth] Failed to generate random password for oauth signup: %v", err)
|
||||||
|
return "", nil, ErrServiceUnavailable
|
||||||
|
}
|
||||||
|
hashedPassword, err := s.HashPassword(randomPassword)
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, fmt.Errorf("hash password: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 新用户默认值。
|
||||||
|
defaultBalance := s.cfg.Default.UserBalance
|
||||||
|
defaultConcurrency := s.cfg.Default.UserConcurrency
|
||||||
|
if s.settingService != nil {
|
||||||
|
defaultBalance = s.settingService.GetDefaultBalance(ctx)
|
||||||
|
defaultConcurrency = s.settingService.GetDefaultConcurrency(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
newUser := &User{
|
||||||
|
Email: email,
|
||||||
|
Username: username,
|
||||||
|
PasswordHash: hashedPassword,
|
||||||
|
Role: RoleUser,
|
||||||
|
Balance: defaultBalance,
|
||||||
|
Concurrency: defaultConcurrency,
|
||||||
|
Status: StatusActive,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.userRepo.Create(ctx, newUser); err != nil {
|
||||||
|
if errors.Is(err, ErrEmailExists) {
|
||||||
|
// 并发场景:GetByEmail 与 Create 之间用户被创建。
|
||||||
|
user, err = s.userRepo.GetByEmail(ctx, email)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[Auth] Database error getting user after conflict: %v", err)
|
||||||
|
return "", nil, ErrServiceUnavailable
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
log.Printf("[Auth] Database error creating oauth user: %v", err)
|
||||||
|
return "", nil, ErrServiceUnavailable
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
user = newUser
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
log.Printf("[Auth] Database error during oauth login: %v", err)
|
||||||
|
return "", nil, ErrServiceUnavailable
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !user.IsActive() {
|
||||||
|
return "", nil, ErrUserNotActive
|
||||||
|
}
|
||||||
|
|
||||||
|
// 尽力补全:当用户名为空时,使用第三方返回的用户名回填。
|
||||||
|
if user.Username == "" && username != "" {
|
||||||
|
user.Username = username
|
||||||
|
if err := s.userRepo.Update(ctx, user); err != nil {
|
||||||
|
log.Printf("[Auth] Failed to update username after oauth login: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := s.GenerateToken(user)
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, fmt.Errorf("generate token: %w", err)
|
||||||
|
}
|
||||||
|
return token, user, nil
|
||||||
|
}
|
||||||
|
|
||||||
// ValidateToken 验证JWT token并返回用户声明
|
// ValidateToken 验证JWT token并返回用户声明
|
||||||
func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) {
|
func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) {
|
||||||
// 先做长度校验,尽早拒绝异常超长 token,降低 DoS 风险。
|
// 先做长度校验,尽早拒绝异常超长 token,降低 DoS 风险。
|
||||||
@@ -361,6 +475,22 @@ func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) {
|
|||||||
return nil, ErrInvalidToken
|
return nil, ErrInvalidToken
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func randomHexString(byteLength int) (string, error) {
|
||||||
|
if byteLength <= 0 {
|
||||||
|
byteLength = 16
|
||||||
|
}
|
||||||
|
buf := make([]byte, byteLength)
|
||||||
|
if _, err := rand.Read(buf); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return hex.EncodeToString(buf), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isReservedEmail(email string) bool {
|
||||||
|
normalized := strings.ToLower(strings.TrimSpace(email))
|
||||||
|
return strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain)
|
||||||
|
}
|
||||||
|
|
||||||
// GenerateToken 生成JWT token
|
// GenerateToken 生成JWT token
|
||||||
func (s *AuthService) GenerateToken(user *User) (string, error) {
|
func (s *AuthService) GenerateToken(user *User) (string, error) {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
|||||||
@@ -182,6 +182,16 @@ func TestAuthService_Register_CheckEmailError(t *testing.T) {
|
|||||||
require.ErrorIs(t, err, ErrServiceUnavailable)
|
require.ErrorIs(t, err, ErrServiceUnavailable)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAuthService_Register_ReservedEmail(t *testing.T) {
|
||||||
|
repo := &userRepoStub{}
|
||||||
|
service := newAuthService(repo, map[string]string{
|
||||||
|
SettingKeyRegistrationEnabled: "true",
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
_, _, err := service.Register(context.Background(), "linuxdo-123@linuxdo-connect.invalid", "password")
|
||||||
|
require.ErrorIs(t, err, ErrEmailReserved)
|
||||||
|
}
|
||||||
|
|
||||||
func TestAuthService_Register_CreateError(t *testing.T) {
|
func TestAuthService_Register_CreateError(t *testing.T) {
|
||||||
repo := &userRepoStub{createErr: errors.New("create failed")}
|
repo := &userRepoStub{createErr: errors.New("create failed")}
|
||||||
service := newAuthService(repo, map[string]string{
|
service := newAuthService(repo, map[string]string{
|
||||||
|
|||||||
@@ -105,7 +105,17 @@ const (
|
|||||||
// Request identity patch (Claude -> Gemini systemInstruction injection)
|
// Request identity patch (Claude -> Gemini systemInstruction injection)
|
||||||
SettingKeyEnableIdentityPatch = "enable_identity_patch"
|
SettingKeyEnableIdentityPatch = "enable_identity_patch"
|
||||||
SettingKeyIdentityPatchPrompt = "identity_patch_prompt"
|
SettingKeyIdentityPatchPrompt = "identity_patch_prompt"
|
||||||
|
|
||||||
|
// LinuxDo Connect OAuth 登录(终端用户 SSO)
|
||||||
|
SettingKeyLinuxDoConnectEnabled = "linuxdo_connect_enabled"
|
||||||
|
SettingKeyLinuxDoConnectClientID = "linuxdo_connect_client_id"
|
||||||
|
SettingKeyLinuxDoConnectClientSecret = "linuxdo_connect_client_secret"
|
||||||
|
SettingKeyLinuxDoConnectRedirectURL = "linuxdo_connect_redirect_url"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// LinuxDoConnectSyntheticEmailDomain 是 LinuxDo Connect 用户的合成邮箱后缀(RFC 保留域名)。
|
||||||
|
// 目的:避免第三方登录返回的用户标识与本地真实邮箱发生碰撞,进而造成账号被接管的风险。
|
||||||
|
const LinuxDoConnectSyntheticEmailDomain = "@linuxdo-connect.invalid"
|
||||||
|
|
||||||
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys).
|
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys).
|
||||||
const AdminAPIKeyPrefix = "admin-"
|
const AdminAPIKeyPrefix = "admin-"
|
||||||
|
|||||||
@@ -136,6 +136,9 @@ func (m *mockAccountRepoForPlatform) ListSchedulableByGroupIDAndPlatforms(ctx co
|
|||||||
func (m *mockAccountRepoForPlatform) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
func (m *mockAccountRepoForPlatform) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
func (m *mockAccountRepoForPlatform) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
func (m *mockAccountRepoForPlatform) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
|
func (m *mockAccountRepoForPlatform) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -148,6 +151,9 @@ func (m *mockAccountRepoForPlatform) ClearTempUnschedulable(ctx context.Context,
|
|||||||
func (m *mockAccountRepoForPlatform) ClearRateLimit(ctx context.Context, id int64) error {
|
func (m *mockAccountRepoForPlatform) ClearRateLimit(ctx context.Context, id int64) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
func (m *mockAccountRepoForPlatform) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
func (m *mockAccountRepoForPlatform) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
|
func (m *mockAccountRepoForPlatform) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ const (
|
|||||||
claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true"
|
claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true"
|
||||||
claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true"
|
claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true"
|
||||||
stickySessionTTL = time.Hour // 粘性会话TTL
|
stickySessionTTL = time.Hour // 粘性会话TTL
|
||||||
defaultMaxLineSize = 10 * 1024 * 1024
|
defaultMaxLineSize = 40 * 1024 * 1024
|
||||||
claudeCodeSystemPrompt = "You are Claude Code, Anthropic's official CLI for Claude."
|
claudeCodeSystemPrompt = "You are Claude Code, Anthropic's official CLI for Claude."
|
||||||
maxCacheControlBlocks = 4 // Anthropic API 允许的最大 cache_control 块数量
|
maxCacheControlBlocks = 4 // Anthropic API 允许的最大 cache_control 块数量
|
||||||
)
|
)
|
||||||
@@ -481,7 +481,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
|||||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||||
if err == nil && s.isAccountInGroup(account, groupID) &&
|
if err == nil && s.isAccountInGroup(account, groupID) &&
|
||||||
s.isAccountAllowedForPlatform(account, platform, useMixed) &&
|
s.isAccountAllowedForPlatform(account, platform, useMixed) &&
|
||||||
account.IsSchedulable() &&
|
account.IsSchedulableForModel(requestedModel) &&
|
||||||
(requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
(requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||||
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
|
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
|
||||||
if err == nil && result.Acquired {
|
if err == nil && result.Acquired {
|
||||||
@@ -519,6 +519,9 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
|||||||
if !s.isAccountAllowedForPlatform(acc, platform, useMixed) {
|
if !s.isAccountAllowedForPlatform(acc, platform, useMixed) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
if !acc.IsSchedulableForModel(requestedModel) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
|
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -812,7 +815,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
|||||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||||
// 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台)
|
// 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台)
|
||||||
if err == nil && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulable() && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
if err == nil && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||||
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
|
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
|
||||||
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
|
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
|
||||||
}
|
}
|
||||||
@@ -844,6 +847,9 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
|||||||
if _, excluded := excludedIDs[acc.ID]; excluded {
|
if _, excluded := excludedIDs[acc.ID]; excluded {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
if !acc.IsSchedulableForModel(requestedModel) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
|
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -901,7 +907,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
|||||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||||
// 检查账号分组归属和有效性:原生平台直接匹配,antigravity 需要启用混合调度
|
// 检查账号分组归属和有效性:原生平台直接匹配,antigravity 需要启用混合调度
|
||||||
if err == nil && s.isAccountInGroup(account, groupID) && account.IsSchedulable() && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
if err == nil && s.isAccountInGroup(account, groupID) && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||||
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
|
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
|
||||||
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
|
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
|
||||||
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
|
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
|
||||||
@@ -936,6 +942,9 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
|||||||
if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
|
if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
if !acc.IsSchedulableForModel(requestedModel) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
|
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -2247,6 +2256,7 @@ type RecordUsageInput struct {
|
|||||||
Account *Account
|
Account *Account
|
||||||
Subscription *UserSubscription // 可选:订阅信息
|
Subscription *UserSubscription // 可选:订阅信息
|
||||||
UserAgent string // 请求的 User-Agent
|
UserAgent string // 请求的 User-Agent
|
||||||
|
IPAddress string // 请求的客户端 IP 地址
|
||||||
}
|
}
|
||||||
|
|
||||||
// RecordUsage 记录使用量并扣费(或更新订阅用量)
|
// RecordUsage 记录使用量并扣费(或更新订阅用量)
|
||||||
@@ -2337,6 +2347,11 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
|||||||
usageLog.UserAgent = &input.UserAgent
|
usageLog.UserAgent = &input.UserAgent
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 添加 IPAddress
|
||||||
|
if input.IPAddress != "" {
|
||||||
|
usageLog.IPAddress = &input.IPAddress
|
||||||
|
}
|
||||||
|
|
||||||
// 添加分组和订阅关联
|
// 添加分组和订阅关联
|
||||||
if apiKey.GroupID != nil {
|
if apiKey.GroupID != nil {
|
||||||
usageLog.GroupID = apiKey.GroupID
|
usageLog.GroupID = apiKey.GroupID
|
||||||
|
|||||||
@@ -114,7 +114,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
|
|||||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||||
// 检查账号是否有效:原生平台直接匹配,antigravity 需要启用混合调度
|
// 检查账号是否有效:原生平台直接匹配,antigravity 需要启用混合调度
|
||||||
if err == nil && account.IsSchedulable() && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
if err == nil && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||||
valid := false
|
valid := false
|
||||||
if account.Platform == platform {
|
if account.Platform == platform {
|
||||||
valid = true
|
valid = true
|
||||||
@@ -172,6 +172,9 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
|
|||||||
if useMixedScheduling && acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
|
if useMixedScheduling && acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
if !acc.IsSchedulableForModel(requestedModel) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
|
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -121,6 +121,9 @@ func (m *mockAccountRepoForGemini) ListSchedulableByGroupIDAndPlatforms(ctx cont
|
|||||||
func (m *mockAccountRepoForGemini) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
func (m *mockAccountRepoForGemini) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
func (m *mockAccountRepoForGemini) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
func (m *mockAccountRepoForGemini) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
|
func (m *mockAccountRepoForGemini) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -131,6 +134,9 @@ func (m *mockAccountRepoForGemini) ClearTempUnschedulable(ctx context.Context, i
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
func (m *mockAccountRepoForGemini) ClearRateLimit(ctx context.Context, id int64) error { return nil }
|
func (m *mockAccountRepoForGemini) ClearRateLimit(ctx context.Context, id int64) error { return nil }
|
||||||
|
func (m *mockAccountRepoForGemini) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
func (m *mockAccountRepoForGemini) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
|
func (m *mockAccountRepoForGemini) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -166,7 +172,7 @@ func (m *mockGroupRepoForGemini) DeleteCascade(ctx context.Context, id int64) ([
|
|||||||
func (m *mockGroupRepoForGemini) List(ctx context.Context, params pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
|
func (m *mockGroupRepoForGemini) List(ctx context.Context, params pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
|
||||||
return nil, nil, nil
|
return nil, nil, nil
|
||||||
}
|
}
|
||||||
func (m *mockGroupRepoForGemini) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) {
|
func (m *mockGroupRepoForGemini) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) {
|
||||||
return nil, nil, nil
|
return nil, nil, nil
|
||||||
}
|
}
|
||||||
func (m *mockGroupRepoForGemini) ListActive(ctx context.Context) ([]Group, error) { return nil, nil }
|
func (m *mockGroupRepoForGemini) ListActive(ctx context.Context) ([]Group, error) { return nil, nil }
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ type GroupRepository interface {
|
|||||||
DeleteCascade(ctx context.Context, id int64) ([]int64, error)
|
DeleteCascade(ctx context.Context, id int64) ([]int64, error)
|
||||||
|
|
||||||
List(ctx context.Context, params pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error)
|
List(ctx context.Context, params pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error)
|
||||||
ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error)
|
ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error)
|
||||||
ListActive(ctx context.Context) ([]Group, error)
|
ListActive(ctx context.Context) ([]Group, error)
|
||||||
ListActiveByPlatform(ctx context.Context, platform string) ([]Group, error)
|
ListActiveByPlatform(ctx context.Context, platform string) ([]Group, error)
|
||||||
|
|
||||||
|
|||||||
@@ -1197,6 +1197,7 @@ type OpenAIRecordUsageInput struct {
|
|||||||
Account *Account
|
Account *Account
|
||||||
Subscription *UserSubscription
|
Subscription *UserSubscription
|
||||||
UserAgent string // 请求的 User-Agent
|
UserAgent string // 请求的 User-Agent
|
||||||
|
IPAddress string // 请求的客户端 IP 地址
|
||||||
}
|
}
|
||||||
|
|
||||||
// RecordUsage records usage and deducts balance
|
// RecordUsage records usage and deducts balance
|
||||||
@@ -1271,6 +1272,11 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
|||||||
usageLog.UserAgent = &input.UserAgent
|
usageLog.UserAgent = &input.UserAgent
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 添加 IPAddress
|
||||||
|
if input.IPAddress != "" {
|
||||||
|
usageLog.IPAddress = &input.IPAddress
|
||||||
|
}
|
||||||
|
|
||||||
if apiKey.GroupID != nil {
|
if apiKey.GroupID != nil {
|
||||||
usageLog.GroupID = apiKey.GroupID
|
usageLog.GroupID = apiKey.GroupID
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -345,7 +345,7 @@ func (s *RateLimitService) UpdateSessionWindow(ctx context.Context, account *Acc
|
|||||||
|
|
||||||
// 如果状态为allowed且之前有限流,说明窗口已重置,清除限流状态
|
// 如果状态为allowed且之前有限流,说明窗口已重置,清除限流状态
|
||||||
if status == "allowed" && account.IsRateLimited() {
|
if status == "allowed" && account.IsRateLimited() {
|
||||||
if err := s.accountRepo.ClearRateLimit(ctx, account.ID); err != nil {
|
if err := s.ClearRateLimit(ctx, account.ID); err != nil {
|
||||||
log.Printf("ClearRateLimit failed for account %d: %v", account.ID, err)
|
log.Printf("ClearRateLimit failed for account %d: %v", account.ID, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -353,7 +353,10 @@ func (s *RateLimitService) UpdateSessionWindow(ctx context.Context, account *Acc
|
|||||||
|
|
||||||
// ClearRateLimit 清除账号的限流状态
|
// ClearRateLimit 清除账号的限流状态
|
||||||
func (s *RateLimitService) ClearRateLimit(ctx context.Context, accountID int64) error {
|
func (s *RateLimitService) ClearRateLimit(ctx context.Context, accountID int64) error {
|
||||||
return s.accountRepo.ClearRateLimit(ctx, accountID)
|
if err := s.accountRepo.ClearRateLimit(ctx, accountID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return s.accountRepo.ClearAntigravityQuotaScopes(ctx, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *RateLimitService) ClearTempUnschedulable(ctx context.Context, accountID int64) error {
|
func (s *RateLimitService) ClearTempUnschedulable(ctx context.Context, accountID int64) error {
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||||
@@ -64,6 +65,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
|||||||
SettingKeyAPIBaseURL,
|
SettingKeyAPIBaseURL,
|
||||||
SettingKeyContactInfo,
|
SettingKeyContactInfo,
|
||||||
SettingKeyDocURL,
|
SettingKeyDocURL,
|
||||||
|
SettingKeyLinuxDoConnectEnabled,
|
||||||
}
|
}
|
||||||
|
|
||||||
settings, err := s.settingRepo.GetMultiple(ctx, keys)
|
settings, err := s.settingRepo.GetMultiple(ctx, keys)
|
||||||
@@ -71,6 +73,13 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
|||||||
return nil, fmt.Errorf("get public settings: %w", err)
|
return nil, fmt.Errorf("get public settings: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
linuxDoEnabled := false
|
||||||
|
if raw, ok := settings[SettingKeyLinuxDoConnectEnabled]; ok {
|
||||||
|
linuxDoEnabled = raw == "true"
|
||||||
|
} else {
|
||||||
|
linuxDoEnabled = s.cfg != nil && s.cfg.LinuxDo.Enabled
|
||||||
|
}
|
||||||
|
|
||||||
return &PublicSettings{
|
return &PublicSettings{
|
||||||
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
|
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
|
||||||
EmailVerifyEnabled: settings[SettingKeyEmailVerifyEnabled] == "true",
|
EmailVerifyEnabled: settings[SettingKeyEmailVerifyEnabled] == "true",
|
||||||
@@ -82,6 +91,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
|||||||
APIBaseURL: settings[SettingKeyAPIBaseURL],
|
APIBaseURL: settings[SettingKeyAPIBaseURL],
|
||||||
ContactInfo: settings[SettingKeyContactInfo],
|
ContactInfo: settings[SettingKeyContactInfo],
|
||||||
DocURL: settings[SettingKeyDocURL],
|
DocURL: settings[SettingKeyDocURL],
|
||||||
|
LinuxDoOAuthEnabled: linuxDoEnabled,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -111,6 +121,14 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
|
|||||||
updates[SettingKeyTurnstileSecretKey] = settings.TurnstileSecretKey
|
updates[SettingKeyTurnstileSecretKey] = settings.TurnstileSecretKey
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// LinuxDo Connect OAuth 登录(终端用户 SSO)
|
||||||
|
updates[SettingKeyLinuxDoConnectEnabled] = strconv.FormatBool(settings.LinuxDoConnectEnabled)
|
||||||
|
updates[SettingKeyLinuxDoConnectClientID] = settings.LinuxDoConnectClientID
|
||||||
|
updates[SettingKeyLinuxDoConnectRedirectURL] = settings.LinuxDoConnectRedirectURL
|
||||||
|
if settings.LinuxDoConnectClientSecret != "" {
|
||||||
|
updates[SettingKeyLinuxDoConnectClientSecret] = settings.LinuxDoConnectClientSecret
|
||||||
|
}
|
||||||
|
|
||||||
// OEM设置
|
// OEM设置
|
||||||
updates[SettingKeySiteName] = settings.SiteName
|
updates[SettingKeySiteName] = settings.SiteName
|
||||||
updates[SettingKeySiteLogo] = settings.SiteLogo
|
updates[SettingKeySiteLogo] = settings.SiteLogo
|
||||||
@@ -271,6 +289,38 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
|
|||||||
result.SMTPPassword = settings[SettingKeySMTPPassword]
|
result.SMTPPassword = settings[SettingKeySMTPPassword]
|
||||||
result.TurnstileSecretKey = settings[SettingKeyTurnstileSecretKey]
|
result.TurnstileSecretKey = settings[SettingKeyTurnstileSecretKey]
|
||||||
|
|
||||||
|
// LinuxDo Connect 设置:
|
||||||
|
// - 兼容 config.yaml/env(避免老部署因为未迁移到数据库设置而被意外关闭)
|
||||||
|
// - 支持在后台“系统设置”中覆盖并持久化(存储于 DB)
|
||||||
|
linuxDoBase := config.LinuxDoConnectConfig{}
|
||||||
|
if s.cfg != nil {
|
||||||
|
linuxDoBase = s.cfg.LinuxDo
|
||||||
|
}
|
||||||
|
|
||||||
|
if raw, ok := settings[SettingKeyLinuxDoConnectEnabled]; ok {
|
||||||
|
result.LinuxDoConnectEnabled = raw == "true"
|
||||||
|
} else {
|
||||||
|
result.LinuxDoConnectEnabled = linuxDoBase.Enabled
|
||||||
|
}
|
||||||
|
|
||||||
|
if v, ok := settings[SettingKeyLinuxDoConnectClientID]; ok && strings.TrimSpace(v) != "" {
|
||||||
|
result.LinuxDoConnectClientID = strings.TrimSpace(v)
|
||||||
|
} else {
|
||||||
|
result.LinuxDoConnectClientID = linuxDoBase.ClientID
|
||||||
|
}
|
||||||
|
|
||||||
|
if v, ok := settings[SettingKeyLinuxDoConnectRedirectURL]; ok && strings.TrimSpace(v) != "" {
|
||||||
|
result.LinuxDoConnectRedirectURL = strings.TrimSpace(v)
|
||||||
|
} else {
|
||||||
|
result.LinuxDoConnectRedirectURL = linuxDoBase.RedirectURL
|
||||||
|
}
|
||||||
|
|
||||||
|
result.LinuxDoConnectClientSecret = strings.TrimSpace(settings[SettingKeyLinuxDoConnectClientSecret])
|
||||||
|
if result.LinuxDoConnectClientSecret == "" {
|
||||||
|
result.LinuxDoConnectClientSecret = strings.TrimSpace(linuxDoBase.ClientSecret)
|
||||||
|
}
|
||||||
|
result.LinuxDoConnectClientSecretConfigured = result.LinuxDoConnectClientSecret != ""
|
||||||
|
|
||||||
// Model fallback settings
|
// Model fallback settings
|
||||||
result.EnableModelFallback = settings[SettingKeyEnableModelFallback] == "true"
|
result.EnableModelFallback = settings[SettingKeyEnableModelFallback] == "true"
|
||||||
result.FallbackModelAnthropic = s.getStringOrDefault(settings, SettingKeyFallbackModelAnthropic, "claude-3-5-sonnet-20241022")
|
result.FallbackModelAnthropic = s.getStringOrDefault(settings, SettingKeyFallbackModelAnthropic, "claude-3-5-sonnet-20241022")
|
||||||
@@ -289,6 +339,99 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
|
|||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetLinuxDoConnectOAuthConfig 返回用于登录的“最终生效” LinuxDo Connect 配置。
|
||||||
|
//
|
||||||
|
// 优先级:
|
||||||
|
// - 若对应系统设置键存在,则覆盖 config.yaml/env 的值
|
||||||
|
// - 否则回退到 config.yaml/env 的值
|
||||||
|
func (s *SettingService) GetLinuxDoConnectOAuthConfig(ctx context.Context) (config.LinuxDoConnectConfig, error) {
|
||||||
|
if s == nil || s.cfg == nil {
|
||||||
|
return config.LinuxDoConnectConfig{}, infraerrors.ServiceUnavailable("CONFIG_NOT_READY", "config not loaded")
|
||||||
|
}
|
||||||
|
|
||||||
|
effective := s.cfg.LinuxDo
|
||||||
|
|
||||||
|
keys := []string{
|
||||||
|
SettingKeyLinuxDoConnectEnabled,
|
||||||
|
SettingKeyLinuxDoConnectClientID,
|
||||||
|
SettingKeyLinuxDoConnectClientSecret,
|
||||||
|
SettingKeyLinuxDoConnectRedirectURL,
|
||||||
|
}
|
||||||
|
settings, err := s.settingRepo.GetMultiple(ctx, keys)
|
||||||
|
if err != nil {
|
||||||
|
return config.LinuxDoConnectConfig{}, fmt.Errorf("get linuxdo connect settings: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if raw, ok := settings[SettingKeyLinuxDoConnectEnabled]; ok {
|
||||||
|
effective.Enabled = raw == "true"
|
||||||
|
}
|
||||||
|
if v, ok := settings[SettingKeyLinuxDoConnectClientID]; ok && strings.TrimSpace(v) != "" {
|
||||||
|
effective.ClientID = strings.TrimSpace(v)
|
||||||
|
}
|
||||||
|
if v, ok := settings[SettingKeyLinuxDoConnectClientSecret]; ok && strings.TrimSpace(v) != "" {
|
||||||
|
effective.ClientSecret = strings.TrimSpace(v)
|
||||||
|
}
|
||||||
|
if v, ok := settings[SettingKeyLinuxDoConnectRedirectURL]; ok && strings.TrimSpace(v) != "" {
|
||||||
|
effective.RedirectURL = strings.TrimSpace(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !effective.Enabled {
|
||||||
|
return config.LinuxDoConnectConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "oauth login is disabled")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 基础健壮性校验(避免把用户重定向到一个必然失败或不安全的 OAuth 流程里)。
|
||||||
|
if strings.TrimSpace(effective.ClientID) == "" {
|
||||||
|
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth client id not configured")
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(effective.AuthorizeURL) == "" {
|
||||||
|
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth authorize url not configured")
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(effective.TokenURL) == "" {
|
||||||
|
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth token url not configured")
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(effective.UserInfoURL) == "" {
|
||||||
|
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth userinfo url not configured")
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(effective.RedirectURL) == "" {
|
||||||
|
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth redirect url not configured")
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(effective.FrontendRedirectURL) == "" {
|
||||||
|
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth frontend redirect url not configured")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := config.ValidateAbsoluteHTTPURL(effective.AuthorizeURL); err != nil {
|
||||||
|
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth authorize url invalid")
|
||||||
|
}
|
||||||
|
if err := config.ValidateAbsoluteHTTPURL(effective.TokenURL); err != nil {
|
||||||
|
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth token url invalid")
|
||||||
|
}
|
||||||
|
if err := config.ValidateAbsoluteHTTPURL(effective.UserInfoURL); err != nil {
|
||||||
|
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth userinfo url invalid")
|
||||||
|
}
|
||||||
|
if err := config.ValidateAbsoluteHTTPURL(effective.RedirectURL); err != nil {
|
||||||
|
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth redirect url invalid")
|
||||||
|
}
|
||||||
|
if err := config.ValidateFrontendRedirectURL(effective.FrontendRedirectURL); err != nil {
|
||||||
|
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth frontend redirect url invalid")
|
||||||
|
}
|
||||||
|
|
||||||
|
method := strings.ToLower(strings.TrimSpace(effective.TokenAuthMethod))
|
||||||
|
switch method {
|
||||||
|
case "", "client_secret_post", "client_secret_basic":
|
||||||
|
if strings.TrimSpace(effective.ClientSecret) == "" {
|
||||||
|
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth client secret not configured")
|
||||||
|
}
|
||||||
|
case "none":
|
||||||
|
if !effective.UsePKCE {
|
||||||
|
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth pkce must be enabled when token_auth_method=none")
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth token_auth_method invalid")
|
||||||
|
}
|
||||||
|
|
||||||
|
return effective, nil
|
||||||
|
}
|
||||||
|
|
||||||
// getStringOrDefault 获取字符串值或默认值
|
// getStringOrDefault 获取字符串值或默认值
|
||||||
func (s *SettingService) getStringOrDefault(settings map[string]string, key, defaultValue string) string {
|
func (s *SettingService) getStringOrDefault(settings map[string]string, key, defaultValue string) string {
|
||||||
if value, ok := settings[key]; ok && value != "" {
|
if value, ok := settings[key]; ok && value != "" {
|
||||||
|
|||||||
@@ -18,6 +18,13 @@ type SystemSettings struct {
|
|||||||
TurnstileSecretKey string
|
TurnstileSecretKey string
|
||||||
TurnstileSecretKeyConfigured bool
|
TurnstileSecretKeyConfigured bool
|
||||||
|
|
||||||
|
// LinuxDo Connect OAuth 登录(终端用户 SSO)
|
||||||
|
LinuxDoConnectEnabled bool
|
||||||
|
LinuxDoConnectClientID string
|
||||||
|
LinuxDoConnectClientSecret string
|
||||||
|
LinuxDoConnectClientSecretConfigured bool
|
||||||
|
LinuxDoConnectRedirectURL string
|
||||||
|
|
||||||
SiteName string
|
SiteName string
|
||||||
SiteLogo string
|
SiteLogo string
|
||||||
SiteSubtitle string
|
SiteSubtitle string
|
||||||
@@ -51,5 +58,6 @@ type PublicSettings struct {
|
|||||||
APIBaseURL string
|
APIBaseURL string
|
||||||
ContactInfo string
|
ContactInfo string
|
||||||
DocURL string
|
DocURL string
|
||||||
|
LinuxDoOAuthEnabled bool
|
||||||
Version string
|
Version string
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ type UsageLog struct {
|
|||||||
DurationMs *int
|
DurationMs *int
|
||||||
FirstTokenMs *int
|
FirstTokenMs *int
|
||||||
UserAgent *string
|
UserAgent *string
|
||||||
|
IPAddress *string
|
||||||
|
|
||||||
// 图片生成字段
|
// 图片生成字段
|
||||||
ImageCount int
|
ImageCount int
|
||||||
|
|||||||
5
backend/migrations/031_add_ip_address.sql
Normal file
5
backend/migrations/031_add_ip_address.sql
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
-- Add IP address field to usage_logs table for request tracking (admin-only visibility)
|
||||||
|
ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS ip_address VARCHAR(45);
|
||||||
|
|
||||||
|
-- Create index for IP address queries
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_usage_logs_ip_address ON usage_logs(ip_address);
|
||||||
9
backend/migrations/032_add_api_key_ip_restriction.sql
Normal file
9
backend/migrations/032_add_api_key_ip_restriction.sql
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
-- Add IP restriction fields to api_keys table
|
||||||
|
-- ip_whitelist: JSON array of allowed IPs/CIDRs (if set, only these IPs can use the key)
|
||||||
|
-- ip_blacklist: JSON array of blocked IPs/CIDRs (these IPs are always blocked)
|
||||||
|
|
||||||
|
ALTER TABLE api_keys ADD COLUMN IF NOT EXISTS ip_whitelist JSONB DEFAULT NULL;
|
||||||
|
ALTER TABLE api_keys ADD COLUMN IF NOT EXISTS ip_blacklist JSONB DEFAULT NULL;
|
||||||
|
|
||||||
|
COMMENT ON COLUMN api_keys.ip_whitelist IS 'JSON array of allowed IPs/CIDRs, e.g. ["192.168.1.100", "10.0.0.0/8"]';
|
||||||
|
COMMENT ON COLUMN api_keys.ip_blacklist IS 'JSON array of blocked IPs/CIDRs, e.g. ["1.2.3.4", "5.6.0.0/16"]';
|
||||||
BIN
backend/repository.test
Executable file
BIN
backend/repository.test
Executable file
Binary file not shown.
@@ -154,9 +154,9 @@ gateway:
|
|||||||
# Stream keepalive interval (seconds), 0=disable
|
# Stream keepalive interval (seconds), 0=disable
|
||||||
# 流式 keepalive 间隔(秒),0=禁用
|
# 流式 keepalive 间隔(秒),0=禁用
|
||||||
stream_keepalive_interval: 10
|
stream_keepalive_interval: 10
|
||||||
# SSE max line size in bytes (default: 10MB)
|
# SSE max line size in bytes (default: 40MB)
|
||||||
# SSE 单行最大字节数(默认 10MB)
|
# SSE 单行最大字节数(默认 40MB)
|
||||||
max_line_size: 10485760
|
max_line_size: 41943040
|
||||||
# Log upstream error response body summary (safe/truncated; does not log request content)
|
# Log upstream error response body summary (safe/truncated; does not log request content)
|
||||||
# 记录上游错误响应体摘要(安全/截断;不记录请求内容)
|
# 记录上游错误响应体摘要(安全/截断;不记录请求内容)
|
||||||
log_upstream_error_body: false
|
log_upstream_error_body: false
|
||||||
|
|||||||
@@ -154,9 +154,9 @@ gateway:
|
|||||||
# Stream keepalive interval (seconds), 0=disable
|
# Stream keepalive interval (seconds), 0=disable
|
||||||
# 流式 keepalive 间隔(秒),0=禁用
|
# 流式 keepalive 间隔(秒),0=禁用
|
||||||
stream_keepalive_interval: 10
|
stream_keepalive_interval: 10
|
||||||
# SSE max line size in bytes (default: 10MB)
|
# SSE max line size in bytes (default: 40MB)
|
||||||
# SSE 单行最大字节数(默认 10MB)
|
# SSE 单行最大字节数(默认 40MB)
|
||||||
max_line_size: 10485760
|
max_line_size: 41943040
|
||||||
# Log upstream error response body summary (safe/truncated; does not log request content)
|
# Log upstream error response body summary (safe/truncated; does not log request content)
|
||||||
# 记录上游错误响应体摘要(安全/截断;不记录请求内容)
|
# 记录上游错误响应体摘要(安全/截断;不记录请求内容)
|
||||||
log_upstream_error_body: false
|
log_upstream_error_body: false
|
||||||
@@ -234,6 +234,31 @@ jwt:
|
|||||||
# 令牌过期时间(小时,最大 24)
|
# 令牌过期时间(小时,最大 24)
|
||||||
expire_hour: 24
|
expire_hour: 24
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# LinuxDo Connect OAuth Login (SSO)
|
||||||
|
# LinuxDo Connect OAuth 登录(用于 Sub2API 用户登录)
|
||||||
|
# =============================================================================
|
||||||
|
linuxdo_connect:
|
||||||
|
enabled: false
|
||||||
|
client_id: ""
|
||||||
|
client_secret: ""
|
||||||
|
authorize_url: "https://connect.linux.do/oauth2/authorize"
|
||||||
|
token_url: "https://connect.linux.do/oauth2/token"
|
||||||
|
userinfo_url: "https://connect.linux.do/api/user"
|
||||||
|
scopes: "user"
|
||||||
|
# 示例: "https://your-domain.com/api/v1/auth/oauth/linuxdo/callback"
|
||||||
|
redirect_url: ""
|
||||||
|
# 安全提示:
|
||||||
|
# - 建议使用同源相对路径(以 / 开头),避免把 token 重定向到意外的第三方域名
|
||||||
|
# - 该地址不应包含 #fragment(本实现使用 URL fragment 传递 access_token)
|
||||||
|
frontend_redirect_url: "/auth/linuxdo/callback"
|
||||||
|
token_auth_method: "client_secret_post" # client_secret_post | client_secret_basic | none
|
||||||
|
# 注意:当 token_auth_method=none(public client)时,必须启用 PKCE
|
||||||
|
use_pkce: false
|
||||||
|
userinfo_email_path: ""
|
||||||
|
userinfo_id_path: ""
|
||||||
|
userinfo_username_path: ""
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Default Settings
|
# Default Settings
|
||||||
# 默认设置
|
# 默认设置
|
||||||
|
|||||||
93
deploy/docker-compose.standalone.yml
Normal file
93
deploy/docker-compose.standalone.yml
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
# =============================================================================
|
||||||
|
# Sub2API Docker Compose - Standalone Configuration
|
||||||
|
# =============================================================================
|
||||||
|
# This configuration runs only the Sub2API application.
|
||||||
|
# PostgreSQL and Redis must be provided externally.
|
||||||
|
#
|
||||||
|
# Usage:
|
||||||
|
# 1. Copy .env.example to .env and configure database/redis connection
|
||||||
|
# 2. docker-compose -f docker-compose.standalone.yml up -d
|
||||||
|
# 3. Access: http://localhost:8080
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
services:
|
||||||
|
sub2api:
|
||||||
|
image: weishaw/sub2api:latest
|
||||||
|
container_name: sub2api
|
||||||
|
restart: unless-stopped
|
||||||
|
ulimits:
|
||||||
|
nofile:
|
||||||
|
soft: 100000
|
||||||
|
hard: 100000
|
||||||
|
ports:
|
||||||
|
- "${BIND_HOST:-0.0.0.0}:${SERVER_PORT:-8080}:8080"
|
||||||
|
volumes:
|
||||||
|
- sub2api_data:/app/data
|
||||||
|
extra_hosts:
|
||||||
|
- "host.docker.internal:host-gateway"
|
||||||
|
environment:
|
||||||
|
# =======================================================================
|
||||||
|
# Auto Setup
|
||||||
|
# =======================================================================
|
||||||
|
- AUTO_SETUP=true
|
||||||
|
|
||||||
|
# =======================================================================
|
||||||
|
# Server Configuration
|
||||||
|
# =======================================================================
|
||||||
|
- SERVER_HOST=0.0.0.0
|
||||||
|
- SERVER_PORT=8080
|
||||||
|
- SERVER_MODE=${SERVER_MODE:-release}
|
||||||
|
- RUN_MODE=${RUN_MODE:-standard}
|
||||||
|
|
||||||
|
# =======================================================================
|
||||||
|
# Database Configuration (PostgreSQL) - Required
|
||||||
|
# =======================================================================
|
||||||
|
- DATABASE_HOST=${DATABASE_HOST:?DATABASE_HOST is required}
|
||||||
|
- DATABASE_PORT=${DATABASE_PORT:-5432}
|
||||||
|
- DATABASE_USER=${DATABASE_USER:-sub2api}
|
||||||
|
- DATABASE_PASSWORD=${DATABASE_PASSWORD:?DATABASE_PASSWORD is required}
|
||||||
|
- DATABASE_DBNAME=${DATABASE_DBNAME:-sub2api}
|
||||||
|
- DATABASE_SSLMODE=${DATABASE_SSLMODE:-disable}
|
||||||
|
|
||||||
|
# =======================================================================
|
||||||
|
# Redis Configuration - Required
|
||||||
|
# =======================================================================
|
||||||
|
- REDIS_HOST=${REDIS_HOST:?REDIS_HOST is required}
|
||||||
|
- REDIS_PORT=${REDIS_PORT:-6379}
|
||||||
|
- REDIS_PASSWORD=${REDIS_PASSWORD:-}
|
||||||
|
- REDIS_DB=${REDIS_DB:-0}
|
||||||
|
|
||||||
|
# =======================================================================
|
||||||
|
# Admin Account (auto-created on first run)
|
||||||
|
# =======================================================================
|
||||||
|
- ADMIN_EMAIL=${ADMIN_EMAIL:-admin@sub2api.local}
|
||||||
|
- ADMIN_PASSWORD=${ADMIN_PASSWORD:-}
|
||||||
|
|
||||||
|
# =======================================================================
|
||||||
|
# JWT Configuration
|
||||||
|
# =======================================================================
|
||||||
|
- JWT_SECRET=${JWT_SECRET:-}
|
||||||
|
- JWT_EXPIRE_HOUR=${JWT_EXPIRE_HOUR:-24}
|
||||||
|
|
||||||
|
# =======================================================================
|
||||||
|
# Timezone Configuration
|
||||||
|
# =======================================================================
|
||||||
|
- TZ=${TZ:-Asia/Shanghai}
|
||||||
|
|
||||||
|
# =======================================================================
|
||||||
|
# Gemini OAuth Configuration (optional)
|
||||||
|
# =======================================================================
|
||||||
|
- GEMINI_OAUTH_CLIENT_ID=${GEMINI_OAUTH_CLIENT_ID:-}
|
||||||
|
- GEMINI_OAUTH_CLIENT_SECRET=${GEMINI_OAUTH_CLIENT_SECRET:-}
|
||||||
|
- GEMINI_OAUTH_SCOPES=${GEMINI_OAUTH_SCOPES:-}
|
||||||
|
- GEMINI_QUOTA_POLICY=${GEMINI_QUOTA_POLICY:-}
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD", "curl", "-f", "http://localhost:8080/health"]
|
||||||
|
interval: 30s
|
||||||
|
timeout: 10s
|
||||||
|
retries: 3
|
||||||
|
start_period: 30s
|
||||||
|
|
||||||
|
volumes:
|
||||||
|
sub2api_data:
|
||||||
|
driver: local
|
||||||
@@ -173,11 +173,12 @@ services:
|
|||||||
volumes:
|
volumes:
|
||||||
- redis_data:/data
|
- redis_data:/data
|
||||||
command: >
|
command: >
|
||||||
redis-server
|
sh -c '
|
||||||
--save 60 1
|
redis-server
|
||||||
--appendonly yes
|
--save 60 1
|
||||||
--appendfsync everysec
|
--appendonly yes
|
||||||
${REDIS_PASSWORD:+--requirepass ${REDIS_PASSWORD}}
|
--appendfsync everysec
|
||||||
|
${REDIS_PASSWORD:+--requirepass "$REDIS_PASSWORD"}'
|
||||||
environment:
|
environment:
|
||||||
- TZ=${TZ:-Asia/Shanghai}
|
- TZ=${TZ:-Asia/Shanghai}
|
||||||
# REDISCLI_AUTH is used by redis-cli for authentication (safer than -a flag)
|
# REDISCLI_AUTH is used by redis-cli for authentication (safer than -a flag)
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ import type {
|
|||||||
* List all groups with pagination
|
* List all groups with pagination
|
||||||
* @param page - Page number (default: 1)
|
* @param page - Page number (default: 1)
|
||||||
* @param pageSize - Items per page (default: 20)
|
* @param pageSize - Items per page (default: 20)
|
||||||
* @param filters - Optional filters (platform, status, is_exclusive)
|
* @param filters - Optional filters (platform, status, is_exclusive, search)
|
||||||
* @returns Paginated list of groups
|
* @returns Paginated list of groups
|
||||||
*/
|
*/
|
||||||
export async function list(
|
export async function list(
|
||||||
@@ -26,6 +26,7 @@ export async function list(
|
|||||||
platform?: GroupPlatform
|
platform?: GroupPlatform
|
||||||
status?: 'active' | 'inactive'
|
status?: 'active' | 'inactive'
|
||||||
is_exclusive?: boolean
|
is_exclusive?: boolean
|
||||||
|
search?: string
|
||||||
},
|
},
|
||||||
options?: {
|
options?: {
|
||||||
signal?: AbortSignal
|
signal?: AbortSignal
|
||||||
|
|||||||
@@ -34,6 +34,11 @@ export interface SystemSettings {
|
|||||||
turnstile_enabled: boolean
|
turnstile_enabled: boolean
|
||||||
turnstile_site_key: string
|
turnstile_site_key: string
|
||||||
turnstile_secret_key_configured: boolean
|
turnstile_secret_key_configured: boolean
|
||||||
|
// LinuxDo Connect OAuth 登录(终端用户 SSO)
|
||||||
|
linuxdo_connect_enabled: boolean
|
||||||
|
linuxdo_connect_client_id: string
|
||||||
|
linuxdo_connect_client_secret_configured: boolean
|
||||||
|
linuxdo_connect_redirect_url: string
|
||||||
// Identity patch configuration (Claude -> Gemini)
|
// Identity patch configuration (Claude -> Gemini)
|
||||||
enable_identity_patch: boolean
|
enable_identity_patch: boolean
|
||||||
identity_patch_prompt: string
|
identity_patch_prompt: string
|
||||||
@@ -60,6 +65,10 @@ export interface UpdateSettingsRequest {
|
|||||||
turnstile_enabled?: boolean
|
turnstile_enabled?: boolean
|
||||||
turnstile_site_key?: string
|
turnstile_site_key?: string
|
||||||
turnstile_secret_key?: string
|
turnstile_secret_key?: string
|
||||||
|
linuxdo_connect_enabled?: boolean
|
||||||
|
linuxdo_connect_client_id?: string
|
||||||
|
linuxdo_connect_client_secret?: string
|
||||||
|
linuxdo_connect_redirect_url?: string
|
||||||
enable_identity_patch?: boolean
|
enable_identity_patch?: boolean
|
||||||
identity_patch_prompt?: string
|
identity_patch_prompt?: string
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -64,7 +64,6 @@ export async function getStats(params: {
|
|||||||
group_id?: number
|
group_id?: number
|
||||||
model?: string
|
model?: string
|
||||||
stream?: boolean
|
stream?: boolean
|
||||||
billing_type?: number
|
|
||||||
period?: string
|
period?: string
|
||||||
start_date?: string
|
start_date?: string
|
||||||
end_date?: string
|
end_date?: string
|
||||||
|
|||||||
@@ -42,12 +42,16 @@ export async function getById(id: number): Promise<ApiKey> {
|
|||||||
* @param name - Key name
|
* @param name - Key name
|
||||||
* @param groupId - Optional group ID
|
* @param groupId - Optional group ID
|
||||||
* @param customKey - Optional custom key value
|
* @param customKey - Optional custom key value
|
||||||
|
* @param ipWhitelist - Optional IP whitelist
|
||||||
|
* @param ipBlacklist - Optional IP blacklist
|
||||||
* @returns Created API key
|
* @returns Created API key
|
||||||
*/
|
*/
|
||||||
export async function create(
|
export async function create(
|
||||||
name: string,
|
name: string,
|
||||||
groupId?: number | null,
|
groupId?: number | null,
|
||||||
customKey?: string
|
customKey?: string,
|
||||||
|
ipWhitelist?: string[],
|
||||||
|
ipBlacklist?: string[]
|
||||||
): Promise<ApiKey> {
|
): Promise<ApiKey> {
|
||||||
const payload: CreateApiKeyRequest = { name }
|
const payload: CreateApiKeyRequest = { name }
|
||||||
if (groupId !== undefined) {
|
if (groupId !== undefined) {
|
||||||
@@ -56,6 +60,12 @@ export async function create(
|
|||||||
if (customKey) {
|
if (customKey) {
|
||||||
payload.custom_key = customKey
|
payload.custom_key = customKey
|
||||||
}
|
}
|
||||||
|
if (ipWhitelist && ipWhitelist.length > 0) {
|
||||||
|
payload.ip_whitelist = ipWhitelist
|
||||||
|
}
|
||||||
|
if (ipBlacklist && ipBlacklist.length > 0) {
|
||||||
|
payload.ip_blacklist = ipBlacklist
|
||||||
|
}
|
||||||
|
|
||||||
const { data } = await apiClient.post<ApiKey>('/keys', payload)
|
const { data } = await apiClient.post<ApiKey>('/keys', payload)
|
||||||
return data
|
return data
|
||||||
|
|||||||
@@ -1,8 +1,27 @@
|
|||||||
<template>
|
<template>
|
||||||
<div v-if="selectedIds.length > 0" class="mb-4 flex items-center justify-between p-3 bg-primary-50 rounded-lg">
|
<div v-if="selectedIds.length > 0" class="mb-4 flex items-center justify-between p-3 bg-primary-50 rounded-lg dark:bg-primary-900/20">
|
||||||
<span class="text-sm font-medium">{{ t('admin.accounts.bulkActions.selected', { count: selectedIds.length }) }}</span>
|
<div class="flex flex-wrap items-center gap-2">
|
||||||
|
<span class="text-sm font-medium text-primary-900 dark:text-primary-100">
|
||||||
|
{{ t('admin.accounts.bulkActions.selected', { count: selectedIds.length }) }}
|
||||||
|
</span>
|
||||||
|
<button
|
||||||
|
@click="$emit('select-page')"
|
||||||
|
class="text-xs font-medium text-primary-700 hover:text-primary-800 dark:text-primary-300 dark:hover:text-primary-200"
|
||||||
|
>
|
||||||
|
{{ t('admin.accounts.bulkActions.selectCurrentPage') }}
|
||||||
|
</button>
|
||||||
|
<span class="text-gray-300 dark:text-primary-800">•</span>
|
||||||
|
<button
|
||||||
|
@click="$emit('clear')"
|
||||||
|
class="text-xs font-medium text-primary-700 hover:text-primary-800 dark:text-primary-300 dark:hover:text-primary-200"
|
||||||
|
>
|
||||||
|
{{ t('admin.accounts.bulkActions.clear') }}
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
<div class="flex gap-2">
|
<div class="flex gap-2">
|
||||||
<button @click="$emit('delete')" class="btn btn-danger btn-sm">{{ t('admin.accounts.bulkActions.delete') }}</button>
|
<button @click="$emit('delete')" class="btn btn-danger btn-sm">{{ t('admin.accounts.bulkActions.delete') }}</button>
|
||||||
|
<button @click="$emit('toggle-schedulable', true)" class="btn btn-success btn-sm">{{ t('admin.accounts.bulkActions.enableScheduling') }}</button>
|
||||||
|
<button @click="$emit('toggle-schedulable', false)" class="btn btn-warning btn-sm">{{ t('admin.accounts.bulkActions.disableScheduling') }}</button>
|
||||||
<button @click="$emit('edit')" class="btn btn-primary btn-sm">{{ t('admin.accounts.bulkActions.edit') }}</button>
|
<button @click="$emit('edit')" class="btn btn-primary btn-sm">{{ t('admin.accounts.bulkActions.edit') }}</button>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
@@ -10,5 +29,5 @@
|
|||||||
|
|
||||||
<script setup lang="ts">
|
<script setup lang="ts">
|
||||||
import { useI18n } from 'vue-i18n'
|
import { useI18n } from 'vue-i18n'
|
||||||
defineProps(['selectedIds']); defineEmits(['delete', 'edit']); const { t } = useI18n()
|
defineProps(['selectedIds']); defineEmits(['delete', 'edit', 'clear', 'select-page', 'toggle-schedulable']); const { t } = useI18n()
|
||||||
</script>
|
</script>
|
||||||
@@ -127,12 +127,6 @@
|
|||||||
<Select v-model="filters.stream" :options="streamTypeOptions" @change="emitChange" />
|
<Select v-model="filters.stream" :options="streamTypeOptions" @change="emitChange" />
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<!-- Billing Type Filter -->
|
|
||||||
<div class="w-full sm:w-auto sm:min-w-[180px]">
|
|
||||||
<label class="input-label">{{ t('usage.billingType') }}</label>
|
|
||||||
<Select v-model="filters.billing_type" :options="billingTypeOptions" @change="emitChange" />
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<!-- Group Filter -->
|
<!-- Group Filter -->
|
||||||
<div class="w-full sm:w-auto sm:min-w-[200px]">
|
<div class="w-full sm:w-auto sm:min-w-[200px]">
|
||||||
<label class="input-label">{{ t('admin.usage.group') }}</label>
|
<label class="input-label">{{ t('admin.usage.group') }}</label>
|
||||||
@@ -227,12 +221,6 @@ const streamTypeOptions = ref<SelectOption[]>([
|
|||||||
{ value: false, label: t('usage.sync') }
|
{ value: false, label: t('usage.sync') }
|
||||||
])
|
])
|
||||||
|
|
||||||
const billingTypeOptions = ref<SelectOption[]>([
|
|
||||||
{ value: null, label: t('admin.usage.allBillingTypes') },
|
|
||||||
{ value: 1, label: t('usage.subscription') },
|
|
||||||
{ value: 0, label: t('usage.balance') }
|
|
||||||
])
|
|
||||||
|
|
||||||
const emitChange = () => emit('change')
|
const emitChange = () => emit('change')
|
||||||
|
|
||||||
const updateStartDate = (value: string) => {
|
const updateStartDate = (value: string) => {
|
||||||
|
|||||||
@@ -96,12 +96,6 @@
|
|||||||
</div>
|
</div>
|
||||||
</template>
|
</template>
|
||||||
|
|
||||||
<template #cell-billing_type="{ row }">
|
|
||||||
<span class="inline-flex items-center rounded px-2 py-0.5 text-xs font-medium" :class="row.billing_type === 1 ? 'bg-purple-100 text-purple-800 dark:bg-purple-900 dark:text-purple-200' : 'bg-emerald-100 text-emerald-800 dark:bg-emerald-900 dark:text-emerald-200'">
|
|
||||||
{{ row.billing_type === 1 ? t('usage.subscription') : t('usage.balance') }}
|
|
||||||
</span>
|
|
||||||
</template>
|
|
||||||
|
|
||||||
<template #cell-first_token="{ row }">
|
<template #cell-first_token="{ row }">
|
||||||
<span v-if="row.first_token_ms != null" class="text-sm text-gray-600 dark:text-gray-400">{{ formatDuration(row.first_token_ms) }}</span>
|
<span v-if="row.first_token_ms != null" class="text-sm text-gray-600 dark:text-gray-400">{{ formatDuration(row.first_token_ms) }}</span>
|
||||||
<span v-else class="text-sm text-gray-400 dark:text-gray-500">-</span>
|
<span v-else class="text-sm text-gray-400 dark:text-gray-500">-</span>
|
||||||
@@ -120,6 +114,11 @@
|
|||||||
<span v-else class="text-sm text-gray-400 dark:text-gray-500">-</span>
|
<span v-else class="text-sm text-gray-400 dark:text-gray-500">-</span>
|
||||||
</template>
|
</template>
|
||||||
|
|
||||||
|
<template #cell-ip_address="{ row }">
|
||||||
|
<span v-if="row.ip_address" class="text-sm font-mono text-gray-600 dark:text-gray-400">{{ row.ip_address }}</span>
|
||||||
|
<span v-else class="text-sm text-gray-400 dark:text-gray-500">-</span>
|
||||||
|
</template>
|
||||||
|
|
||||||
<template #empty><EmptyState :message="t('usage.noRecords')" /></template>
|
<template #empty><EmptyState :message="t('usage.noRecords')" /></template>
|
||||||
</DataTable>
|
</DataTable>
|
||||||
</div>
|
</div>
|
||||||
@@ -249,11 +248,11 @@ const cols = computed(() => [
|
|||||||
{ key: 'stream', label: t('usage.type'), sortable: false },
|
{ key: 'stream', label: t('usage.type'), sortable: false },
|
||||||
{ key: 'tokens', label: t('usage.tokens'), sortable: false },
|
{ key: 'tokens', label: t('usage.tokens'), sortable: false },
|
||||||
{ key: 'cost', label: t('usage.cost'), sortable: false },
|
{ key: 'cost', label: t('usage.cost'), sortable: false },
|
||||||
{ key: 'billing_type', label: t('usage.billingType'), sortable: false },
|
|
||||||
{ key: 'first_token', label: t('usage.firstToken'), sortable: false },
|
{ key: 'first_token', label: t('usage.firstToken'), sortable: false },
|
||||||
{ key: 'duration', label: t('usage.duration'), sortable: false },
|
{ key: 'duration', label: t('usage.duration'), sortable: false },
|
||||||
{ key: 'created_at', label: t('usage.time'), sortable: true },
|
{ key: 'created_at', label: t('usage.time'), sortable: true },
|
||||||
{ key: 'user_agent', label: t('usage.userAgent'), sortable: false }
|
{ key: 'user_agent', label: t('usage.userAgent'), sortable: false },
|
||||||
|
{ key: 'ip_address', label: t('admin.usage.ipAddress'), sortable: false }
|
||||||
])
|
])
|
||||||
|
|
||||||
const formatCacheTokens = (tokens: number): string => {
|
const formatCacheTokens = (tokens: number): string => {
|
||||||
|
|||||||
61
frontend/src/components/auth/LinuxDoOAuthSection.vue
Normal file
61
frontend/src/components/auth/LinuxDoOAuthSection.vue
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
<template>
|
||||||
|
<div class="space-y-4">
|
||||||
|
<button type="button" :disabled="disabled" class="btn btn-secondary w-full" @click="startLogin">
|
||||||
|
<svg
|
||||||
|
class="icon mr-2"
|
||||||
|
viewBox="0 0 16 16"
|
||||||
|
version="1.1"
|
||||||
|
xmlns="http://www.w3.org/2000/svg"
|
||||||
|
width="1em"
|
||||||
|
height="1em"
|
||||||
|
style="color: rgb(233, 84, 32); width: 20px; height: 20px"
|
||||||
|
aria-hidden="true"
|
||||||
|
>
|
||||||
|
<g id="linuxdo_icon" data-name="linuxdo_icon">
|
||||||
|
<path
|
||||||
|
d="m7.44,0s.09,0,.13,0c.09,0,.19,0,.28,0,.14,0,.29,0,.43,0,.09,0,.18,0,.27,0q.12,0,.25,0t.26.08c.15.03.29.06.44.08,1.97.38,3.78,1.47,4.95,3.11.04.06.09.12.13.18.67.96,1.15,2.11,1.3,3.28q0,.19.09.26c0,.15,0,.29,0,.44,0,.04,0,.09,0,.13,0,.09,0,.19,0,.28,0,.14,0,.29,0,.43,0,.09,0,.18,0,.27,0,.08,0,.17,0,.25q0,.19-.08.26c-.03.15-.06.29-.08.44-.38,1.97-1.47,3.78-3.11,4.95-.06.04-.12.09-.18.13-.96.67-2.11,1.15-3.28,1.3q-.19,0-.26.09c-.15,0-.29,0-.44,0-.04,0-.09,0-.13,0-.09,0-.19,0-.28,0-.14,0-.29,0-.43,0-.09,0-.18,0-.27,0-.08,0-.17,0-.25,0q-.19,0-.26-.08c-.15-.03-.29-.06-.44-.08-1.97-.38-3.78-1.47-4.95-3.11q-.07-.09-.13-.18c-.67-.96-1.15-2.11-1.3-3.28q0-.19-.09-.26c0-.15,0-.29,0-.44,0-.04,0-.09,0-.13,0-.09,0-.19,0-.28,0-.14,0-.29,0-.43,0-.09,0-.18,0-.27,0-.08,0-.17,0-.25q0-.19.08-.26c.03-.15.06-.29.08-.44.38-1.97,1.47-3.78,3.11-4.95.06-.04.12-.09.18-.13C4.42.73,5.57.26,6.74.1,7,.07,7.15,0,7.44,0Z"
|
||||||
|
fill="#EFEFEF"
|
||||||
|
></path>
|
||||||
|
<path
|
||||||
|
d="m1.27,11.33h13.45c-.94,1.89-2.51,3.21-4.51,3.88-1.99.59-3.96.37-5.8-.57-1.25-.7-2.67-1.9-3.14-3.3Z"
|
||||||
|
fill="#FEB005"
|
||||||
|
></path>
|
||||||
|
<path
|
||||||
|
d="m12.54,1.99c.87.7,1.82,1.59,2.18,2.68H1.27c.87-1.74,2.33-3.13,4.2-3.78,2.44-.79,5-.47,7.07,1.1Z"
|
||||||
|
fill="#1D1D1F"
|
||||||
|
></path>
|
||||||
|
</g>
|
||||||
|
</svg>
|
||||||
|
{{ t('auth.linuxdo.signIn') }}
|
||||||
|
</button>
|
||||||
|
|
||||||
|
<div class="flex items-center gap-3">
|
||||||
|
<div class="h-px flex-1 bg-gray-200 dark:bg-dark-700"></div>
|
||||||
|
<span class="text-xs text-gray-500 dark:text-dark-400">
|
||||||
|
{{ t('auth.linuxdo.orContinue') }}
|
||||||
|
</span>
|
||||||
|
<div class="h-px flex-1 bg-gray-200 dark:bg-dark-700"></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</template>
|
||||||
|
|
||||||
|
<script setup lang="ts">
|
||||||
|
import { useRoute } from 'vue-router'
|
||||||
|
import { useI18n } from 'vue-i18n'
|
||||||
|
|
||||||
|
defineProps<{
|
||||||
|
disabled?: boolean
|
||||||
|
}>()
|
||||||
|
|
||||||
|
const route = useRoute()
|
||||||
|
const { t } = useI18n()
|
||||||
|
|
||||||
|
function startLogin(): void {
|
||||||
|
const redirectTo = (route.query.redirect as string) || '/dashboard'
|
||||||
|
const apiBase = (import.meta.env.VITE_API_BASE_URL as string | undefined) || '/api/v1'
|
||||||
|
const normalized = apiBase.replace(/\/$/, '')
|
||||||
|
const startURL = `${normalized}/auth/oauth/linuxdo/start?redirect=${encodeURIComponent(redirectTo)}`
|
||||||
|
window.location.href = startURL
|
||||||
|
}
|
||||||
|
</script>
|
||||||
|
|
||||||
@@ -43,7 +43,8 @@ export function useTableLoader<T, P extends Record<string, any>>(options: TableL
|
|||||||
if (abortController) {
|
if (abortController) {
|
||||||
abortController.abort()
|
abortController.abort()
|
||||||
}
|
}
|
||||||
abortController = new AbortController()
|
const currentController = new AbortController()
|
||||||
|
abortController = currentController
|
||||||
loading.value = true
|
loading.value = true
|
||||||
|
|
||||||
try {
|
try {
|
||||||
@@ -51,7 +52,7 @@ export function useTableLoader<T, P extends Record<string, any>>(options: TableL
|
|||||||
pagination.page,
|
pagination.page,
|
||||||
pagination.page_size,
|
pagination.page_size,
|
||||||
toRaw(params) as P,
|
toRaw(params) as P,
|
||||||
{ signal: abortController.signal }
|
{ signal: currentController.signal }
|
||||||
)
|
)
|
||||||
|
|
||||||
items.value = response.items || []
|
items.value = response.items || []
|
||||||
@@ -63,7 +64,7 @@ export function useTableLoader<T, P extends Record<string, any>>(options: TableL
|
|||||||
throw error
|
throw error
|
||||||
}
|
}
|
||||||
} finally {
|
} finally {
|
||||||
if (abortController && !abortController.signal.aborted) {
|
if (abortController === currentController) {
|
||||||
loading.value = false
|
loading.value = false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -77,7 +78,9 @@ export function useTableLoader<T, P extends Record<string, any>>(options: TableL
|
|||||||
const debouncedReload = useDebounceFn(reload, debounceMs)
|
const debouncedReload = useDebounceFn(reload, debounceMs)
|
||||||
|
|
||||||
const handlePageChange = (page: number) => {
|
const handlePageChange = (page: number) => {
|
||||||
pagination.page = page
|
// 确保页码在有效范围内
|
||||||
|
const validPage = Math.max(1, Math.min(page, pagination.pages || 1))
|
||||||
|
pagination.page = validPage
|
||||||
load()
|
load()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -229,6 +229,15 @@ export default {
|
|||||||
sendingCode: 'Sending...',
|
sendingCode: 'Sending...',
|
||||||
clickToResend: 'Click to resend code',
|
clickToResend: 'Click to resend code',
|
||||||
resendCode: 'Resend verification code',
|
resendCode: 'Resend verification code',
|
||||||
|
linuxdo: {
|
||||||
|
signIn: 'Continue with Linux.do',
|
||||||
|
orContinue: 'or continue with email',
|
||||||
|
callbackTitle: 'Signing you in',
|
||||||
|
callbackProcessing: 'Completing login, please wait...',
|
||||||
|
callbackHint: 'If you are not redirected automatically, go back to the login page and try again.',
|
||||||
|
callbackMissingToken: 'Missing login token, please try again.',
|
||||||
|
backToLogin: 'Back to Login'
|
||||||
|
},
|
||||||
oauth: {
|
oauth: {
|
||||||
code: 'Code',
|
code: 'Code',
|
||||||
state: 'State',
|
state: 'State',
|
||||||
@@ -361,6 +370,14 @@ export default {
|
|||||||
customKeyTooShort: 'Custom key must be at least 16 characters',
|
customKeyTooShort: 'Custom key must be at least 16 characters',
|
||||||
customKeyInvalidChars: 'Custom key can only contain letters, numbers, underscores, and hyphens',
|
customKeyInvalidChars: 'Custom key can only contain letters, numbers, underscores, and hyphens',
|
||||||
customKeyRequired: 'Please enter a custom key',
|
customKeyRequired: 'Please enter a custom key',
|
||||||
|
ipRestriction: 'IP Restriction',
|
||||||
|
ipWhitelist: 'IP Whitelist',
|
||||||
|
ipWhitelistPlaceholder: '192.168.1.100\n10.0.0.0/8',
|
||||||
|
ipWhitelistHint: 'One IP or CIDR per line. Only these IPs can use this key when set.',
|
||||||
|
ipBlacklist: 'IP Blacklist',
|
||||||
|
ipBlacklistPlaceholder: '1.2.3.4\n5.6.0.0/16',
|
||||||
|
ipBlacklistHint: 'One IP or CIDR per line. These IPs will be blocked from using this key.',
|
||||||
|
ipRestrictionEnabled: 'IP restriction enabled',
|
||||||
ccSwitchNotInstalled: 'CC-Switch is not installed or the protocol handler is not registered. Please install CC-Switch first or manually copy the API key.',
|
ccSwitchNotInstalled: 'CC-Switch is not installed or the protocol handler is not registered. Please install CC-Switch first or manually copy the API key.',
|
||||||
ccsClientSelect: {
|
ccsClientSelect: {
|
||||||
title: 'Select Client',
|
title: 'Select Client',
|
||||||
@@ -421,9 +438,6 @@ export default {
|
|||||||
exportFailed: 'Failed to export usage data',
|
exportFailed: 'Failed to export usage data',
|
||||||
exportExcelSuccess: 'Usage data exported successfully (Excel format)',
|
exportExcelSuccess: 'Usage data exported successfully (Excel format)',
|
||||||
exportExcelFailed: 'Failed to export usage data',
|
exportExcelFailed: 'Failed to export usage data',
|
||||||
billingType: 'Billing',
|
|
||||||
balance: 'Balance',
|
|
||||||
subscription: 'Subscription',
|
|
||||||
imageUnit: ' images',
|
imageUnit: ' images',
|
||||||
userAgent: 'User-Agent'
|
userAgent: 'User-Agent'
|
||||||
},
|
},
|
||||||
@@ -1076,12 +1090,16 @@ export default {
|
|||||||
tokenRefreshed: 'Token refreshed successfully',
|
tokenRefreshed: 'Token refreshed successfully',
|
||||||
accountDeleted: 'Account deleted successfully',
|
accountDeleted: 'Account deleted successfully',
|
||||||
rateLimitCleared: 'Rate limit cleared successfully',
|
rateLimitCleared: 'Rate limit cleared successfully',
|
||||||
|
bulkSchedulableEnabled: 'Successfully enabled scheduling for {count} account(s)',
|
||||||
|
bulkSchedulableDisabled: 'Successfully disabled scheduling for {count} account(s)',
|
||||||
bulkActions: {
|
bulkActions: {
|
||||||
selected: '{count} account(s) selected',
|
selected: '{count} account(s) selected',
|
||||||
selectCurrentPage: 'Select this page',
|
selectCurrentPage: 'Select this page',
|
||||||
clear: 'Clear selection',
|
clear: 'Clear selection',
|
||||||
edit: 'Bulk Edit',
|
edit: 'Bulk Edit',
|
||||||
delete: 'Bulk Delete'
|
delete: 'Bulk Delete',
|
||||||
|
enableScheduling: 'Enable Scheduling',
|
||||||
|
disableScheduling: 'Disable Scheduling'
|
||||||
},
|
},
|
||||||
bulkEdit: {
|
bulkEdit: {
|
||||||
title: 'Bulk Edit Accounts',
|
title: 'Bulk Edit Accounts',
|
||||||
@@ -1486,6 +1504,7 @@ export default {
|
|||||||
testing: 'Testing...',
|
testing: 'Testing...',
|
||||||
retry: 'Retry',
|
retry: 'Retry',
|
||||||
copyOutput: 'Copy output',
|
copyOutput: 'Copy output',
|
||||||
|
outputCopied: 'Output copied',
|
||||||
startingTestForAccount: 'Starting test for account: {name}',
|
startingTestForAccount: 'Starting test for account: {name}',
|
||||||
testAccountTypeLabel: 'Account type: {type}',
|
testAccountTypeLabel: 'Account type: {type}',
|
||||||
selectTestModel: 'Select Test Model',
|
selectTestModel: 'Select Test Model',
|
||||||
@@ -1721,7 +1740,6 @@ export default {
|
|||||||
allAccounts: 'All Accounts',
|
allAccounts: 'All Accounts',
|
||||||
allGroups: 'All Groups',
|
allGroups: 'All Groups',
|
||||||
allTypes: 'All Types',
|
allTypes: 'All Types',
|
||||||
allBillingTypes: 'All Billing',
|
|
||||||
inputCost: 'Input Cost',
|
inputCost: 'Input Cost',
|
||||||
outputCost: 'Output Cost',
|
outputCost: 'Output Cost',
|
||||||
cacheCreationCost: 'Cache Creation Cost',
|
cacheCreationCost: 'Cache Creation Cost',
|
||||||
@@ -1730,7 +1748,8 @@ export default {
|
|||||||
outputTokens: 'Output Tokens',
|
outputTokens: 'Output Tokens',
|
||||||
cacheCreationTokens: 'Cache Creation Tokens',
|
cacheCreationTokens: 'Cache Creation Tokens',
|
||||||
cacheReadTokens: 'Cache Read Tokens',
|
cacheReadTokens: 'Cache Read Tokens',
|
||||||
failedToLoad: 'Failed to load usage records'
|
failedToLoad: 'Failed to load usage records',
|
||||||
|
ipAddress: 'IP'
|
||||||
},
|
},
|
||||||
|
|
||||||
// Settings
|
// Settings
|
||||||
@@ -1756,6 +1775,26 @@ export default {
|
|||||||
cloudflareDashboard: 'Cloudflare Dashboard',
|
cloudflareDashboard: 'Cloudflare Dashboard',
|
||||||
secretKeyHint: 'Server-side verification key (keep this secret)',
|
secretKeyHint: 'Server-side verification key (keep this secret)',
|
||||||
secretKeyConfiguredHint: 'Secret key configured. Leave empty to keep the current value.' },
|
secretKeyConfiguredHint: 'Secret key configured. Leave empty to keep the current value.' },
|
||||||
|
linuxdo: {
|
||||||
|
title: 'LinuxDo Connect Login',
|
||||||
|
description: 'Configure LinuxDo Connect OAuth for Sub2API end-user login',
|
||||||
|
enable: 'Enable LinuxDo Login',
|
||||||
|
enableHint: 'Show LinuxDo login on the login/register pages',
|
||||||
|
clientId: 'Client ID',
|
||||||
|
clientIdPlaceholder: 'e.g., hprJ5pC3...',
|
||||||
|
clientIdHint: 'Get this from Connect.Linux.Do',
|
||||||
|
clientSecret: 'Client Secret',
|
||||||
|
clientSecretPlaceholder: '********',
|
||||||
|
clientSecretHint: 'Used by backend to exchange tokens (keep it secret)',
|
||||||
|
clientSecretConfiguredPlaceholder: '********',
|
||||||
|
clientSecretConfiguredHint: 'Secret configured. Leave empty to keep the current value.',
|
||||||
|
redirectUrl: 'Redirect URL',
|
||||||
|
redirectUrlPlaceholder: 'https://your-domain.com/api/v1/auth/oauth/linuxdo/callback',
|
||||||
|
redirectUrlHint:
|
||||||
|
'Must match the redirect URL configured in Connect.Linux.Do (must be an absolute http(s) URL)',
|
||||||
|
quickSetCopy: 'Generate & Copy (current site)',
|
||||||
|
redirectUrlSetAndCopied: 'Redirect URL generated and copied to clipboard'
|
||||||
|
},
|
||||||
defaults: {
|
defaults: {
|
||||||
title: 'Default User Settings',
|
title: 'Default User Settings',
|
||||||
description: 'Default values for new users',
|
description: 'Default values for new users',
|
||||||
|
|||||||
@@ -227,6 +227,15 @@ export default {
|
|||||||
sendingCode: '发送中...',
|
sendingCode: '发送中...',
|
||||||
clickToResend: '点击重新发送验证码',
|
clickToResend: '点击重新发送验证码',
|
||||||
resendCode: '重新发送验证码',
|
resendCode: '重新发送验证码',
|
||||||
|
linuxdo: {
|
||||||
|
signIn: '使用 Linux.do 登录',
|
||||||
|
orContinue: '或使用邮箱密码继续',
|
||||||
|
callbackTitle: '正在完成登录',
|
||||||
|
callbackProcessing: '正在验证登录信息,请稍候...',
|
||||||
|
callbackHint: '如果页面未自动跳转,请返回登录页重试。',
|
||||||
|
callbackMissingToken: '登录信息缺失,请返回重试。',
|
||||||
|
backToLogin: '返回登录'
|
||||||
|
},
|
||||||
oauth: {
|
oauth: {
|
||||||
code: '授权码',
|
code: '授权码',
|
||||||
state: '状态',
|
state: '状态',
|
||||||
@@ -358,6 +367,14 @@ export default {
|
|||||||
customKeyTooShort: '自定义密钥至少需要16个字符',
|
customKeyTooShort: '自定义密钥至少需要16个字符',
|
||||||
customKeyInvalidChars: '自定义密钥只能包含字母、数字、下划线和连字符',
|
customKeyInvalidChars: '自定义密钥只能包含字母、数字、下划线和连字符',
|
||||||
customKeyRequired: '请输入自定义密钥',
|
customKeyRequired: '请输入自定义密钥',
|
||||||
|
ipRestriction: 'IP 限制',
|
||||||
|
ipWhitelist: 'IP 白名单',
|
||||||
|
ipWhitelistPlaceholder: '192.168.1.100\n10.0.0.0/8',
|
||||||
|
ipWhitelistHint: '每行一个 IP 或 CIDR,设置后仅允许这些 IP 使用此密钥',
|
||||||
|
ipBlacklist: 'IP 黑名单',
|
||||||
|
ipBlacklistPlaceholder: '1.2.3.4\n5.6.0.0/16',
|
||||||
|
ipBlacklistHint: '每行一个 IP 或 CIDR,这些 IP 将被禁止使用此密钥',
|
||||||
|
ipRestrictionEnabled: '已配置 IP 限制',
|
||||||
ccSwitchNotInstalled: 'CC-Switch 未安装或协议处理程序未注册。请先安装 CC-Switch 或手动复制 API 密钥。',
|
ccSwitchNotInstalled: 'CC-Switch 未安装或协议处理程序未注册。请先安装 CC-Switch 或手动复制 API 密钥。',
|
||||||
ccsClientSelect: {
|
ccsClientSelect: {
|
||||||
title: '选择客户端',
|
title: '选择客户端',
|
||||||
@@ -418,9 +435,6 @@ export default {
|
|||||||
exportFailed: '使用数据导出失败',
|
exportFailed: '使用数据导出失败',
|
||||||
exportExcelSuccess: '使用数据导出成功(Excel格式)',
|
exportExcelSuccess: '使用数据导出成功(Excel格式)',
|
||||||
exportExcelFailed: '使用数据导出失败',
|
exportExcelFailed: '使用数据导出失败',
|
||||||
billingType: '消费类型',
|
|
||||||
balance: '余额',
|
|
||||||
subscription: '订阅',
|
|
||||||
imageUnit: '张',
|
imageUnit: '张',
|
||||||
userAgent: 'User-Agent'
|
userAgent: 'User-Agent'
|
||||||
},
|
},
|
||||||
@@ -1212,12 +1226,16 @@ export default {
|
|||||||
accountCreatedSuccess: '账号添加成功',
|
accountCreatedSuccess: '账号添加成功',
|
||||||
accountUpdatedSuccess: '账号更新成功',
|
accountUpdatedSuccess: '账号更新成功',
|
||||||
accountDeletedSuccess: '账号删除成功',
|
accountDeletedSuccess: '账号删除成功',
|
||||||
|
bulkSchedulableEnabled: '成功启用 {count} 个账号的调度',
|
||||||
|
bulkSchedulableDisabled: '成功停止 {count} 个账号的调度',
|
||||||
bulkActions: {
|
bulkActions: {
|
||||||
selected: '已选择 {count} 个账号',
|
selected: '已选择 {count} 个账号',
|
||||||
selectCurrentPage: '本页全选',
|
selectCurrentPage: '本页全选',
|
||||||
clear: '清除选择',
|
clear: '清除选择',
|
||||||
edit: '批量编辑账号',
|
edit: '批量编辑账号',
|
||||||
delete: '批量删除'
|
delete: '批量删除',
|
||||||
|
enableScheduling: '批量启用调度',
|
||||||
|
disableScheduling: '批量停止调度'
|
||||||
},
|
},
|
||||||
bulkEdit: {
|
bulkEdit: {
|
||||||
title: '批量编辑账号',
|
title: '批量编辑账号',
|
||||||
@@ -1601,6 +1619,7 @@ export default {
|
|||||||
startTest: '开始测试',
|
startTest: '开始测试',
|
||||||
retry: '重试',
|
retry: '重试',
|
||||||
copyOutput: '复制输出',
|
copyOutput: '复制输出',
|
||||||
|
outputCopied: '输出已复制',
|
||||||
startingTestForAccount: '开始测试账号:{name}',
|
startingTestForAccount: '开始测试账号:{name}',
|
||||||
testAccountTypeLabel: '账号类型:{type}',
|
testAccountTypeLabel: '账号类型:{type}',
|
||||||
selectTestModel: '选择测试模型',
|
selectTestModel: '选择测试模型',
|
||||||
@@ -1866,7 +1885,6 @@ export default {
|
|||||||
allAccounts: '全部账户',
|
allAccounts: '全部账户',
|
||||||
allGroups: '全部分组',
|
allGroups: '全部分组',
|
||||||
allTypes: '全部类型',
|
allTypes: '全部类型',
|
||||||
allBillingTypes: '全部计费',
|
|
||||||
inputCost: '输入成本',
|
inputCost: '输入成本',
|
||||||
outputCost: '输出成本',
|
outputCost: '输出成本',
|
||||||
cacheCreationCost: '缓存创建成本',
|
cacheCreationCost: '缓存创建成本',
|
||||||
@@ -1875,7 +1893,8 @@ export default {
|
|||||||
outputTokens: '输出 Token',
|
outputTokens: '输出 Token',
|
||||||
cacheCreationTokens: '缓存创建 Token',
|
cacheCreationTokens: '缓存创建 Token',
|
||||||
cacheReadTokens: '缓存读取 Token',
|
cacheReadTokens: '缓存读取 Token',
|
||||||
failedToLoad: '加载使用记录失败'
|
failedToLoad: '加载使用记录失败',
|
||||||
|
ipAddress: 'IP'
|
||||||
},
|
},
|
||||||
|
|
||||||
// Settings
|
// Settings
|
||||||
@@ -1901,6 +1920,25 @@ export default {
|
|||||||
cloudflareDashboard: 'Cloudflare Dashboard',
|
cloudflareDashboard: 'Cloudflare Dashboard',
|
||||||
secretKeyHint: '服务端验证密钥(请保密)',
|
secretKeyHint: '服务端验证密钥(请保密)',
|
||||||
secretKeyConfiguredHint: '密钥已配置,留空以保留当前值。' },
|
secretKeyConfiguredHint: '密钥已配置,留空以保留当前值。' },
|
||||||
|
linuxdo: {
|
||||||
|
title: 'LinuxDo Connect 登录',
|
||||||
|
description: '配置 LinuxDo Connect OAuth,用于 Sub2API 用户登录',
|
||||||
|
enable: '启用 LinuxDo 登录',
|
||||||
|
enableHint: '在登录/注册页面显示 LinuxDo 登录入口',
|
||||||
|
clientId: 'Client ID',
|
||||||
|
clientIdPlaceholder: '例如:hprJ5pC3...',
|
||||||
|
clientIdHint: '从 Connect.Linux.Do 后台获取',
|
||||||
|
clientSecret: 'Client Secret',
|
||||||
|
clientSecretPlaceholder: '********',
|
||||||
|
clientSecretHint: '用于后端交换 token(请保密)',
|
||||||
|
clientSecretConfiguredPlaceholder: '********',
|
||||||
|
clientSecretConfiguredHint: '密钥已配置,留空以保留当前值。',
|
||||||
|
redirectUrl: '回调地址(Redirect URL)',
|
||||||
|
redirectUrlPlaceholder: 'https://your-domain.com/api/v1/auth/oauth/linuxdo/callback',
|
||||||
|
redirectUrlHint: '需与 Connect.Linux.Do 中配置的回调地址一致(必须是 http(s) 完整 URL)',
|
||||||
|
quickSetCopy: '使用当前站点生成并复制',
|
||||||
|
redirectUrlSetAndCopied: '已使用当前站点生成回调地址并复制到剪贴板'
|
||||||
|
},
|
||||||
defaults: {
|
defaults: {
|
||||||
title: '用户默认设置',
|
title: '用户默认设置',
|
||||||
description: '新用户的默认值',
|
description: '新用户的默认值',
|
||||||
|
|||||||
@@ -67,6 +67,15 @@ const routes: RouteRecordRaw[] = [
|
|||||||
title: 'OAuth Callback'
|
title: 'OAuth Callback'
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
path: '/auth/linuxdo/callback',
|
||||||
|
name: 'LinuxDoOAuthCallback',
|
||||||
|
component: () => import('@/views/auth/LinuxDoCallbackView.vue'),
|
||||||
|
meta: {
|
||||||
|
requiresAuth: false,
|
||||||
|
title: 'LinuxDo OAuth Callback'
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
// ==================== User Routes ====================
|
// ==================== User Routes ====================
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ export const useAppStore = defineStore('app', () => {
|
|||||||
const contactInfo = ref<string>('')
|
const contactInfo = ref<string>('')
|
||||||
const apiBaseUrl = ref<string>('')
|
const apiBaseUrl = ref<string>('')
|
||||||
const docUrl = ref<string>('')
|
const docUrl = ref<string>('')
|
||||||
|
const cachedPublicSettings = ref<PublicSettings | null>(null)
|
||||||
|
|
||||||
// Version cache state
|
// Version cache state
|
||||||
const versionLoaded = ref<boolean>(false)
|
const versionLoaded = ref<boolean>(false)
|
||||||
@@ -285,6 +286,9 @@ export const useAppStore = defineStore('app', () => {
|
|||||||
async function fetchPublicSettings(force = false): Promise<PublicSettings | null> {
|
async function fetchPublicSettings(force = false): Promise<PublicSettings | null> {
|
||||||
// Return cached data if available and not forcing refresh
|
// Return cached data if available and not forcing refresh
|
||||||
if (publicSettingsLoaded.value && !force) {
|
if (publicSettingsLoaded.value && !force) {
|
||||||
|
if (cachedPublicSettings.value) {
|
||||||
|
return { ...cachedPublicSettings.value }
|
||||||
|
}
|
||||||
return {
|
return {
|
||||||
registration_enabled: false,
|
registration_enabled: false,
|
||||||
email_verify_enabled: false,
|
email_verify_enabled: false,
|
||||||
@@ -296,6 +300,7 @@ export const useAppStore = defineStore('app', () => {
|
|||||||
api_base_url: apiBaseUrl.value,
|
api_base_url: apiBaseUrl.value,
|
||||||
contact_info: contactInfo.value,
|
contact_info: contactInfo.value,
|
||||||
doc_url: docUrl.value,
|
doc_url: docUrl.value,
|
||||||
|
linuxdo_oauth_enabled: false,
|
||||||
version: siteVersion.value
|
version: siteVersion.value
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -308,6 +313,7 @@ export const useAppStore = defineStore('app', () => {
|
|||||||
publicSettingsLoading.value = true
|
publicSettingsLoading.value = true
|
||||||
try {
|
try {
|
||||||
const data = await fetchPublicSettingsAPI()
|
const data = await fetchPublicSettingsAPI()
|
||||||
|
cachedPublicSettings.value = data
|
||||||
siteName.value = data.site_name || 'Sub2API'
|
siteName.value = data.site_name || 'Sub2API'
|
||||||
siteLogo.value = data.site_logo || ''
|
siteLogo.value = data.site_logo || ''
|
||||||
siteVersion.value = data.version || ''
|
siteVersion.value = data.version || ''
|
||||||
@@ -329,6 +335,7 @@ export const useAppStore = defineStore('app', () => {
|
|||||||
*/
|
*/
|
||||||
function clearPublicSettingsCache(): void {
|
function clearPublicSettingsCache(): void {
|
||||||
publicSettingsLoaded.value = false
|
publicSettingsLoaded.value = false
|
||||||
|
cachedPublicSettings.value = null
|
||||||
}
|
}
|
||||||
|
|
||||||
// ==================== Return Store API ====================
|
// ==================== Return Store API ====================
|
||||||
|
|||||||
@@ -159,6 +159,27 @@ export const useAuthStore = defineStore('auth', () => {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 直接设置 token(用于 OAuth/SSO 回调),并加载当前用户信息。
|
||||||
|
* @param newToken - 后端签发的 JWT access token
|
||||||
|
*/
|
||||||
|
async function setToken(newToken: string): Promise<User> {
|
||||||
|
// Clear any previous state first (avoid mixing sessions)
|
||||||
|
clearAuth()
|
||||||
|
|
||||||
|
token.value = newToken
|
||||||
|
localStorage.setItem(AUTH_TOKEN_KEY, newToken)
|
||||||
|
|
||||||
|
try {
|
||||||
|
const userData = await refreshUser()
|
||||||
|
startAutoRefresh()
|
||||||
|
return userData
|
||||||
|
} catch (error) {
|
||||||
|
clearAuth()
|
||||||
|
throw error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* User logout
|
* User logout
|
||||||
* Clears all authentication state and persisted data
|
* Clears all authentication state and persisted data
|
||||||
@@ -233,6 +254,7 @@ export const useAuthStore = defineStore('auth', () => {
|
|||||||
// Actions
|
// Actions
|
||||||
login,
|
login,
|
||||||
register,
|
register,
|
||||||
|
setToken,
|
||||||
logout,
|
logout,
|
||||||
checkAuth,
|
checkAuth,
|
||||||
refreshUser
|
refreshUser
|
||||||
|
|||||||
@@ -73,6 +73,7 @@ export interface PublicSettings {
|
|||||||
api_base_url: string
|
api_base_url: string
|
||||||
contact_info: string
|
contact_info: string
|
||||||
doc_url: string
|
doc_url: string
|
||||||
|
linuxdo_oauth_enabled: boolean
|
||||||
version: string
|
version: string
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -278,6 +279,8 @@ export interface ApiKey {
|
|||||||
name: string
|
name: string
|
||||||
group_id: number | null
|
group_id: number | null
|
||||||
status: 'active' | 'inactive'
|
status: 'active' | 'inactive'
|
||||||
|
ip_whitelist: string[]
|
||||||
|
ip_blacklist: string[]
|
||||||
created_at: string
|
created_at: string
|
||||||
updated_at: string
|
updated_at: string
|
||||||
group?: Group
|
group?: Group
|
||||||
@@ -287,12 +290,16 @@ export interface CreateApiKeyRequest {
|
|||||||
name: string
|
name: string
|
||||||
group_id?: number | null
|
group_id?: number | null
|
||||||
custom_key?: string // Optional custom API Key
|
custom_key?: string // Optional custom API Key
|
||||||
|
ip_whitelist?: string[]
|
||||||
|
ip_blacklist?: string[]
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface UpdateApiKeyRequest {
|
export interface UpdateApiKeyRequest {
|
||||||
name?: string
|
name?: string
|
||||||
group_id?: number | null
|
group_id?: number | null
|
||||||
status?: 'active' | 'inactive'
|
status?: 'active' | 'inactive'
|
||||||
|
ip_whitelist?: string[]
|
||||||
|
ip_blacklist?: string[]
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface CreateGroupRequest {
|
export interface CreateGroupRequest {
|
||||||
@@ -559,9 +566,6 @@ export interface UpdateProxyRequest {
|
|||||||
|
|
||||||
export type RedeemCodeType = 'balance' | 'concurrency' | 'subscription'
|
export type RedeemCodeType = 'balance' | 'concurrency' | 'subscription'
|
||||||
|
|
||||||
// 消费类型: 0=钱包余额, 1=订阅套餐
|
|
||||||
export type BillingType = 0 | 1
|
|
||||||
|
|
||||||
export interface UsageLog {
|
export interface UsageLog {
|
||||||
id: number
|
id: number
|
||||||
user_id: number
|
user_id: number
|
||||||
@@ -588,7 +592,6 @@ export interface UsageLog {
|
|||||||
actual_cost: number
|
actual_cost: number
|
||||||
rate_multiplier: number
|
rate_multiplier: number
|
||||||
|
|
||||||
billing_type: BillingType
|
|
||||||
stream: boolean
|
stream: boolean
|
||||||
duration_ms: number
|
duration_ms: number
|
||||||
first_token_ms: number | null
|
first_token_ms: number | null
|
||||||
@@ -600,6 +603,9 @@ export interface UsageLog {
|
|||||||
// User-Agent
|
// User-Agent
|
||||||
user_agent: string | null
|
user_agent: string | null
|
||||||
|
|
||||||
|
// IP 地址(仅管理员可见)
|
||||||
|
ip_address: string | null
|
||||||
|
|
||||||
created_at: string
|
created_at: string
|
||||||
|
|
||||||
user?: User
|
user?: User
|
||||||
@@ -829,7 +835,6 @@ export interface UsageQueryParams {
|
|||||||
group_id?: number
|
group_id?: number
|
||||||
model?: string
|
model?: string
|
||||||
stream?: boolean
|
stream?: boolean
|
||||||
billing_type?: number
|
|
||||||
start_date?: string
|
start_date?: string
|
||||||
end_date?: string
|
end_date?: string
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,7 +7,7 @@
|
|||||||
v-model:searchQuery="params.search"
|
v-model:searchQuery="params.search"
|
||||||
:filters="params"
|
:filters="params"
|
||||||
@update:filters="(newFilters) => Object.assign(params, newFilters)"
|
@update:filters="(newFilters) => Object.assign(params, newFilters)"
|
||||||
@change="reload"
|
@change="debouncedReload"
|
||||||
@update:searchQuery="debouncedReload"
|
@update:searchQuery="debouncedReload"
|
||||||
/>
|
/>
|
||||||
<AccountTableActions
|
<AccountTableActions
|
||||||
@@ -19,7 +19,7 @@
|
|||||||
</div>
|
</div>
|
||||||
</template>
|
</template>
|
||||||
<template #table>
|
<template #table>
|
||||||
<AccountBulkActionsBar :selected-ids="selIds" @delete="handleBulkDelete" @edit="showBulkEdit = true" @clear="selIds = []" @select-page="selectPage" />
|
<AccountBulkActionsBar :selected-ids="selIds" @delete="handleBulkDelete" @edit="showBulkEdit = true" @clear="selIds = []" @select-page="selectPage" @toggle-schedulable="handleBulkToggleSchedulable" />
|
||||||
<DataTable :columns="cols" :data="accounts" :loading="loading">
|
<DataTable :columns="cols" :data="accounts" :loading="loading">
|
||||||
<template #cell-select="{ row }">
|
<template #cell-select="{ row }">
|
||||||
<input type="checkbox" :checked="selIds.includes(row.id)" @change="toggleSel(row.id)" class="rounded border-gray-300 text-primary-600 focus:ring-primary-500" />
|
<input type="checkbox" :checked="selIds.includes(row.id)" @change="toggleSel(row.id)" class="rounded border-gray-300 text-primary-600 focus:ring-primary-500" />
|
||||||
@@ -107,7 +107,7 @@
|
|||||||
</template>
|
</template>
|
||||||
</DataTable>
|
</DataTable>
|
||||||
</template>
|
</template>
|
||||||
<template #pagination><Pagination v-if="pagination.total > 0" :page="pagination.page" :total="pagination.total" :page-size="pagination.page_size" @update:page="handlePageChange" /></template>
|
<template #pagination><Pagination v-if="pagination.total > 0" :page="pagination.page" :total="pagination.total" :page-size="pagination.page_size" @update:page="handlePageChange" @update:pageSize="handlePageSizeChange" /></template>
|
||||||
</TablePageLayout>
|
</TablePageLayout>
|
||||||
<CreateAccountModal :show="showCreate" :proxies="proxies" :groups="groups" @close="showCreate = false" @created="reload" />
|
<CreateAccountModal :show="showCreate" :proxies="proxies" :groups="groups" @close="showCreate = false" @created="reload" />
|
||||||
<EditAccountModal :show="showEdit" :account="edAcc" :proxies="proxies" :groups="groups" @close="showEdit = false" @updated="load" />
|
<EditAccountModal :show="showEdit" :account="edAcc" :proxies="proxies" :groups="groups" @close="showEdit = false" @updated="load" />
|
||||||
@@ -175,7 +175,7 @@ const statsAcc = ref<Account | null>(null)
|
|||||||
const togglingSchedulable = ref<number | null>(null)
|
const togglingSchedulable = ref<number | null>(null)
|
||||||
const menu = reactive<{show:boolean, acc:Account|null, pos:{top:number, left:number}|null}>({ show: false, acc: null, pos: null })
|
const menu = reactive<{show:boolean, acc:Account|null, pos:{top:number, left:number}|null}>({ show: false, acc: null, pos: null })
|
||||||
|
|
||||||
const { items: accounts, loading, params, pagination, load, reload, debouncedReload, handlePageChange } = useTableLoader<Account, any>({
|
const { items: accounts, loading, params, pagination, load, reload, debouncedReload, handlePageChange, handlePageSizeChange } = useTableLoader<Account, any>({
|
||||||
fetchFn: adminAPI.accounts.list,
|
fetchFn: adminAPI.accounts.list,
|
||||||
initialParams: { platform: '', type: '', status: '', search: '' }
|
initialParams: { platform: '', type: '', status: '', search: '' }
|
||||||
})
|
})
|
||||||
@@ -209,6 +209,21 @@ const openMenu = (a: Account, e: MouseEvent) => { menu.acc = a; menu.pos = { top
|
|||||||
const toggleSel = (id: number) => { const i = selIds.value.indexOf(id); if(i === -1) selIds.value.push(id); else selIds.value.splice(i, 1) }
|
const toggleSel = (id: number) => { const i = selIds.value.indexOf(id); if(i === -1) selIds.value.push(id); else selIds.value.splice(i, 1) }
|
||||||
const selectPage = () => { selIds.value = [...new Set([...selIds.value, ...accounts.value.map(a => a.id)])] }
|
const selectPage = () => { selIds.value = [...new Set([...selIds.value, ...accounts.value.map(a => a.id)])] }
|
||||||
const handleBulkDelete = async () => { if(!confirm(t('common.confirm'))) return; try { await Promise.all(selIds.value.map(id => adminAPI.accounts.delete(id))); selIds.value = []; reload() } catch (error) { console.error('Failed to bulk delete accounts:', error) } }
|
const handleBulkDelete = async () => { if(!confirm(t('common.confirm'))) return; try { await Promise.all(selIds.value.map(id => adminAPI.accounts.delete(id))); selIds.value = []; reload() } catch (error) { console.error('Failed to bulk delete accounts:', error) } }
|
||||||
|
const handleBulkToggleSchedulable = async (schedulable: boolean) => {
|
||||||
|
const count = selIds.value.length
|
||||||
|
try {
|
||||||
|
const result = await adminAPI.accounts.bulkUpdate(selIds.value, { schedulable });
|
||||||
|
const message = schedulable
|
||||||
|
? t('admin.accounts.bulkSchedulableEnabled', { count: result.success || count })
|
||||||
|
: t('admin.accounts.bulkSchedulableDisabled', { count: result.success || count });
|
||||||
|
appStore.showSuccess(message);
|
||||||
|
selIds.value = [];
|
||||||
|
reload()
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Failed to bulk toggle schedulable:', error);
|
||||||
|
appStore.showError(t('common.error'))
|
||||||
|
}
|
||||||
|
}
|
||||||
const handleBulkUpdated = () => { showBulkEdit.value = false; selIds.value = []; reload() }
|
const handleBulkUpdated = () => { showBulkEdit.value = false; selIds.value = []; reload() }
|
||||||
const closeTestModal = () => { showTest.value = false; testingAcc.value = null }
|
const closeTestModal = () => { showTest.value = false; testingAcc.value = null }
|
||||||
const closeStatsModal = () => { showStats.value = false; statsAcc.value = null }
|
const closeStatsModal = () => { showStats.value = false; statsAcc.value = null }
|
||||||
|
|||||||
@@ -16,6 +16,7 @@
|
|||||||
type="text"
|
type="text"
|
||||||
:placeholder="t('admin.groups.searchGroups')"
|
:placeholder="t('admin.groups.searchGroups')"
|
||||||
class="input pl-10"
|
class="input pl-10"
|
||||||
|
@input="handleSearch"
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
<Select
|
<Select
|
||||||
@@ -64,7 +65,7 @@
|
|||||||
</template>
|
</template>
|
||||||
|
|
||||||
<template #table>
|
<template #table>
|
||||||
<DataTable :columns="columns" :data="displayedGroups" :loading="loading">
|
<DataTable :columns="columns" :data="groups" :loading="loading">
|
||||||
<template #cell-name="{ value }">
|
<template #cell-name="{ value }">
|
||||||
<span class="font-medium text-gray-900 dark:text-white">{{ value }}</span>
|
<span class="font-medium text-gray-900 dark:text-white">{{ value }}</span>
|
||||||
</template>
|
</template>
|
||||||
@@ -932,16 +933,6 @@ const pagination = reactive({
|
|||||||
|
|
||||||
let abortController: AbortController | null = null
|
let abortController: AbortController | null = null
|
||||||
|
|
||||||
const displayedGroups = computed(() => {
|
|
||||||
const q = searchQuery.value.trim().toLowerCase()
|
|
||||||
if (!q) return groups.value
|
|
||||||
return groups.value.filter((group) => {
|
|
||||||
const name = group.name?.toLowerCase?.() ?? ''
|
|
||||||
const description = group.description?.toLowerCase?.() ?? ''
|
|
||||||
return name.includes(q) || description.includes(q)
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
const showCreateModal = ref(false)
|
const showCreateModal = ref(false)
|
||||||
const showEditModal = ref(false)
|
const showEditModal = ref(false)
|
||||||
const showDeleteDialog = ref(false)
|
const showDeleteDialog = ref(false)
|
||||||
@@ -1011,7 +1002,8 @@ const loadGroups = async () => {
|
|||||||
const response = await adminAPI.groups.list(pagination.page, pagination.page_size, {
|
const response = await adminAPI.groups.list(pagination.page, pagination.page_size, {
|
||||||
platform: (filters.platform as GroupPlatform) || undefined,
|
platform: (filters.platform as GroupPlatform) || undefined,
|
||||||
status: filters.status as any,
|
status: filters.status as any,
|
||||||
is_exclusive: filters.is_exclusive ? filters.is_exclusive === 'true' : undefined
|
is_exclusive: filters.is_exclusive ? filters.is_exclusive === 'true' : undefined,
|
||||||
|
search: searchQuery.value.trim() || undefined
|
||||||
}, { signal })
|
}, { signal })
|
||||||
if (signal.aborted) return
|
if (signal.aborted) return
|
||||||
groups.value = response.items
|
groups.value = response.items
|
||||||
@@ -1030,6 +1022,15 @@ const loadGroups = async () => {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let searchTimeout: ReturnType<typeof setTimeout>
|
||||||
|
const handleSearch = () => {
|
||||||
|
clearTimeout(searchTimeout)
|
||||||
|
searchTimeout = setTimeout(() => {
|
||||||
|
pagination.page = 1
|
||||||
|
loadGroups()
|
||||||
|
}, 300)
|
||||||
|
}
|
||||||
|
|
||||||
const handlePageChange = (page: number) => {
|
const handlePageChange = (page: number) => {
|
||||||
pagination.page = page
|
pagination.page = page
|
||||||
loadGroups()
|
loadGroups()
|
||||||
|
|||||||
@@ -519,7 +519,7 @@
|
|||||||
</template>
|
</template>
|
||||||
|
|
||||||
<script setup lang="ts">
|
<script setup lang="ts">
|
||||||
import { ref, reactive, computed, onMounted } from 'vue'
|
import { ref, reactive, computed, onMounted, onUnmounted } from 'vue'
|
||||||
import { useI18n } from 'vue-i18n'
|
import { useI18n } from 'vue-i18n'
|
||||||
import { useAppStore } from '@/stores/app'
|
import { useAppStore } from '@/stores/app'
|
||||||
import { adminAPI } from '@/api/admin'
|
import { adminAPI } from '@/api/admin'
|
||||||
@@ -942,4 +942,9 @@ const confirmDelete = async () => {
|
|||||||
onMounted(() => {
|
onMounted(() => {
|
||||||
loadProxies()
|
loadProxies()
|
||||||
})
|
})
|
||||||
|
|
||||||
|
onUnmounted(() => {
|
||||||
|
clearTimeout(searchTimeout)
|
||||||
|
abortController?.abort()
|
||||||
|
})
|
||||||
</script>
|
</script>
|
||||||
|
|||||||
@@ -364,7 +364,7 @@
|
|||||||
</template>
|
</template>
|
||||||
|
|
||||||
<script setup lang="ts">
|
<script setup lang="ts">
|
||||||
import { ref, reactive, computed, onMounted } from 'vue'
|
import { ref, reactive, computed, onMounted, onUnmounted } from 'vue'
|
||||||
import { useI18n } from 'vue-i18n'
|
import { useI18n } from 'vue-i18n'
|
||||||
import { useAppStore } from '@/stores/app'
|
import { useAppStore } from '@/stores/app'
|
||||||
import { useClipboard } from '@/composables/useClipboard'
|
import { useClipboard } from '@/composables/useClipboard'
|
||||||
@@ -693,4 +693,9 @@ onMounted(() => {
|
|||||||
loadCodes()
|
loadCodes()
|
||||||
loadSubscriptionGroups()
|
loadSubscriptionGroups()
|
||||||
})
|
})
|
||||||
|
|
||||||
|
onUnmounted(() => {
|
||||||
|
clearTimeout(searchTimeout)
|
||||||
|
abortController?.abort()
|
||||||
|
})
|
||||||
</script>
|
</script>
|
||||||
|
|||||||
@@ -261,6 +261,106 @@
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<!-- LinuxDo Connect OAuth 登录 -->
|
||||||
|
<div class="card">
|
||||||
|
<div class="border-b border-gray-100 px-6 py-4 dark:border-dark-700">
|
||||||
|
<h2 class="text-lg font-semibold text-gray-900 dark:text-white">
|
||||||
|
{{ t('admin.settings.linuxdo.title') }}
|
||||||
|
</h2>
|
||||||
|
<p class="mt-1 text-sm text-gray-500 dark:text-gray-400">
|
||||||
|
{{ t('admin.settings.linuxdo.description') }}
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
<div class="space-y-5 p-6">
|
||||||
|
<div class="flex items-center justify-between">
|
||||||
|
<div>
|
||||||
|
<label class="font-medium text-gray-900 dark:text-white">{{
|
||||||
|
t('admin.settings.linuxdo.enable')
|
||||||
|
}}</label>
|
||||||
|
<p class="text-sm text-gray-500 dark:text-gray-400">
|
||||||
|
{{ t('admin.settings.linuxdo.enableHint') }}
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
<Toggle v-model="form.linuxdo_connect_enabled" />
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div
|
||||||
|
v-if="form.linuxdo_connect_enabled"
|
||||||
|
class="border-t border-gray-100 pt-4 dark:border-dark-700"
|
||||||
|
>
|
||||||
|
<div class="grid grid-cols-1 gap-6">
|
||||||
|
<div>
|
||||||
|
<label class="mb-2 block text-sm font-medium text-gray-700 dark:text-gray-300">
|
||||||
|
{{ t('admin.settings.linuxdo.clientId') }}
|
||||||
|
</label>
|
||||||
|
<input
|
||||||
|
v-model="form.linuxdo_connect_client_id"
|
||||||
|
type="text"
|
||||||
|
class="input font-mono text-sm"
|
||||||
|
:placeholder="t('admin.settings.linuxdo.clientIdPlaceholder')"
|
||||||
|
/>
|
||||||
|
<p class="mt-1.5 text-xs text-gray-500 dark:text-gray-400">
|
||||||
|
{{ t('admin.settings.linuxdo.clientIdHint') }}
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div>
|
||||||
|
<label class="mb-2 block text-sm font-medium text-gray-700 dark:text-gray-300">
|
||||||
|
{{ t('admin.settings.linuxdo.clientSecret') }}
|
||||||
|
</label>
|
||||||
|
<input
|
||||||
|
v-model="form.linuxdo_connect_client_secret"
|
||||||
|
type="password"
|
||||||
|
class="input font-mono text-sm"
|
||||||
|
:placeholder="
|
||||||
|
form.linuxdo_connect_client_secret_configured
|
||||||
|
? t('admin.settings.linuxdo.clientSecretConfiguredPlaceholder')
|
||||||
|
: t('admin.settings.linuxdo.clientSecretPlaceholder')
|
||||||
|
"
|
||||||
|
/>
|
||||||
|
<p class="mt-1.5 text-xs text-gray-500 dark:text-gray-400">
|
||||||
|
{{
|
||||||
|
form.linuxdo_connect_client_secret_configured
|
||||||
|
? t('admin.settings.linuxdo.clientSecretConfiguredHint')
|
||||||
|
: t('admin.settings.linuxdo.clientSecretHint')
|
||||||
|
}}
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div>
|
||||||
|
<label class="mb-2 block text-sm font-medium text-gray-700 dark:text-gray-300">
|
||||||
|
{{ t('admin.settings.linuxdo.redirectUrl') }}
|
||||||
|
</label>
|
||||||
|
<input
|
||||||
|
v-model="form.linuxdo_connect_redirect_url"
|
||||||
|
type="url"
|
||||||
|
class="input font-mono text-sm"
|
||||||
|
:placeholder="t('admin.settings.linuxdo.redirectUrlPlaceholder')"
|
||||||
|
/>
|
||||||
|
<div class="mt-2 flex flex-col gap-2 sm:flex-row sm:items-center sm:gap-3">
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
class="btn btn-secondary btn-sm w-fit"
|
||||||
|
@click="setAndCopyLinuxdoRedirectUrl"
|
||||||
|
>
|
||||||
|
{{ t('admin.settings.linuxdo.quickSetCopy') }}
|
||||||
|
</button>
|
||||||
|
<code
|
||||||
|
v-if="linuxdoRedirectUrlSuggestion"
|
||||||
|
class="select-all break-all rounded bg-gray-50 px-2 py-1 font-mono text-xs text-gray-600 dark:bg-dark-800 dark:text-gray-300"
|
||||||
|
>
|
||||||
|
{{ linuxdoRedirectUrlSuggestion }}
|
||||||
|
</code>
|
||||||
|
</div>
|
||||||
|
<p class="mt-1.5 text-xs text-gray-500 dark:text-gray-400">
|
||||||
|
{{ t('admin.settings.linuxdo.redirectUrlHint') }}
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
<!-- Default Settings -->
|
<!-- Default Settings -->
|
||||||
<div class="card">
|
<div class="card">
|
||||||
<div class="border-b border-gray-100 px-6 py-4 dark:border-dark-700">
|
<div class="border-b border-gray-100 px-6 py-4 dark:border-dark-700">
|
||||||
@@ -692,17 +792,19 @@
|
|||||||
</template>
|
</template>
|
||||||
|
|
||||||
<script setup lang="ts">
|
<script setup lang="ts">
|
||||||
import { ref, reactive, onMounted } from 'vue'
|
import { ref, reactive, computed, onMounted } from 'vue'
|
||||||
import { useI18n } from 'vue-i18n'
|
import { useI18n } from 'vue-i18n'
|
||||||
import { adminAPI } from '@/api'
|
import { adminAPI } from '@/api'
|
||||||
import type { SystemSettings, UpdateSettingsRequest } from '@/api/admin/settings'
|
import type { SystemSettings, UpdateSettingsRequest } from '@/api/admin/settings'
|
||||||
import AppLayout from '@/components/layout/AppLayout.vue'
|
import AppLayout from '@/components/layout/AppLayout.vue'
|
||||||
import Icon from '@/components/icons/Icon.vue'
|
import Icon from '@/components/icons/Icon.vue'
|
||||||
import Toggle from '@/components/common/Toggle.vue'
|
import Toggle from '@/components/common/Toggle.vue'
|
||||||
|
import { useClipboard } from '@/composables/useClipboard'
|
||||||
import { useAppStore } from '@/stores'
|
import { useAppStore } from '@/stores'
|
||||||
|
|
||||||
const { t } = useI18n()
|
const { t } = useI18n()
|
||||||
const appStore = useAppStore()
|
const appStore = useAppStore()
|
||||||
|
const { copyToClipboard } = useClipboard()
|
||||||
|
|
||||||
const loading = ref(true)
|
const loading = ref(true)
|
||||||
const saving = ref(false)
|
const saving = ref(false)
|
||||||
@@ -721,6 +823,7 @@ const newAdminApiKey = ref('')
|
|||||||
type SettingsForm = SystemSettings & {
|
type SettingsForm = SystemSettings & {
|
||||||
smtp_password: string
|
smtp_password: string
|
||||||
turnstile_secret_key: string
|
turnstile_secret_key: string
|
||||||
|
linuxdo_connect_client_secret: string
|
||||||
}
|
}
|
||||||
|
|
||||||
const form = reactive<SettingsForm>({
|
const form = reactive<SettingsForm>({
|
||||||
@@ -747,11 +850,32 @@ const form = reactive<SettingsForm>({
|
|||||||
turnstile_site_key: '',
|
turnstile_site_key: '',
|
||||||
turnstile_secret_key: '',
|
turnstile_secret_key: '',
|
||||||
turnstile_secret_key_configured: false,
|
turnstile_secret_key_configured: false,
|
||||||
|
// LinuxDo Connect OAuth(终端用户登录)
|
||||||
|
linuxdo_connect_enabled: false,
|
||||||
|
linuxdo_connect_client_id: '',
|
||||||
|
linuxdo_connect_client_secret: '',
|
||||||
|
linuxdo_connect_client_secret_configured: false,
|
||||||
|
linuxdo_connect_redirect_url: '',
|
||||||
// Identity patch (Claude -> Gemini)
|
// Identity patch (Claude -> Gemini)
|
||||||
enable_identity_patch: true,
|
enable_identity_patch: true,
|
||||||
identity_patch_prompt: ''
|
identity_patch_prompt: ''
|
||||||
})
|
})
|
||||||
|
|
||||||
|
const linuxdoRedirectUrlSuggestion = computed(() => {
|
||||||
|
if (typeof window === 'undefined') return ''
|
||||||
|
const origin =
|
||||||
|
window.location.origin || `${window.location.protocol}//${window.location.host}`
|
||||||
|
return `${origin}/api/v1/auth/oauth/linuxdo/callback`
|
||||||
|
})
|
||||||
|
|
||||||
|
async function setAndCopyLinuxdoRedirectUrl() {
|
||||||
|
const url = linuxdoRedirectUrlSuggestion.value
|
||||||
|
if (!url) return
|
||||||
|
|
||||||
|
form.linuxdo_connect_redirect_url = url
|
||||||
|
await copyToClipboard(url, t('admin.settings.linuxdo.redirectUrlSetAndCopied'))
|
||||||
|
}
|
||||||
|
|
||||||
function handleLogoUpload(event: Event) {
|
function handleLogoUpload(event: Event) {
|
||||||
const input = event.target as HTMLInputElement
|
const input = event.target as HTMLInputElement
|
||||||
const file = input.files?.[0]
|
const file = input.files?.[0]
|
||||||
@@ -797,6 +921,7 @@ async function loadSettings() {
|
|||||||
Object.assign(form, settings)
|
Object.assign(form, settings)
|
||||||
form.smtp_password = ''
|
form.smtp_password = ''
|
||||||
form.turnstile_secret_key = ''
|
form.turnstile_secret_key = ''
|
||||||
|
form.linuxdo_connect_client_secret = ''
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
appStore.showError(
|
appStore.showError(
|
||||||
t('admin.settings.failedToLoad') + ': ' + (error.message || t('common.unknownError'))
|
t('admin.settings.failedToLoad') + ': ' + (error.message || t('common.unknownError'))
|
||||||
@@ -829,12 +954,17 @@ async function saveSettings() {
|
|||||||
smtp_use_tls: form.smtp_use_tls,
|
smtp_use_tls: form.smtp_use_tls,
|
||||||
turnstile_enabled: form.turnstile_enabled,
|
turnstile_enabled: form.turnstile_enabled,
|
||||||
turnstile_site_key: form.turnstile_site_key,
|
turnstile_site_key: form.turnstile_site_key,
|
||||||
turnstile_secret_key: form.turnstile_secret_key || undefined
|
turnstile_secret_key: form.turnstile_secret_key || undefined,
|
||||||
|
linuxdo_connect_enabled: form.linuxdo_connect_enabled,
|
||||||
|
linuxdo_connect_client_id: form.linuxdo_connect_client_id,
|
||||||
|
linuxdo_connect_client_secret: form.linuxdo_connect_client_secret || undefined,
|
||||||
|
linuxdo_connect_redirect_url: form.linuxdo_connect_redirect_url
|
||||||
}
|
}
|
||||||
const updated = await adminAPI.settings.updateSettings(payload)
|
const updated = await adminAPI.settings.updateSettings(payload)
|
||||||
Object.assign(form, updated)
|
Object.assign(form, updated)
|
||||||
form.smtp_password = ''
|
form.smtp_password = ''
|
||||||
form.turnstile_secret_key = ''
|
form.turnstile_secret_key = ''
|
||||||
|
form.linuxdo_connect_client_secret = ''
|
||||||
// Refresh cached public settings so sidebar/header update immediately
|
// Refresh cached public settings so sidebar/header update immediately
|
||||||
await appStore.fetchPublicSettings(true)
|
await appStore.fetchPublicSettings(true)
|
||||||
appStore.showSuccess(t('admin.settings.settingsSaved'))
|
appStore.showSuccess(t('admin.settings.settingsSaved'))
|
||||||
|
|||||||
@@ -95,8 +95,8 @@ const exportToExcel = async () => {
|
|||||||
t('admin.usage.inputCost'), t('admin.usage.outputCost'),
|
t('admin.usage.inputCost'), t('admin.usage.outputCost'),
|
||||||
t('admin.usage.cacheReadCost'), t('admin.usage.cacheCreationCost'),
|
t('admin.usage.cacheReadCost'), t('admin.usage.cacheCreationCost'),
|
||||||
t('usage.rate'), t('usage.original'), t('usage.billed'),
|
t('usage.rate'), t('usage.original'), t('usage.billed'),
|
||||||
t('usage.billingType'), t('usage.firstToken'), t('usage.duration'),
|
t('usage.firstToken'), t('usage.duration'),
|
||||||
t('admin.usage.requestId'), t('usage.userAgent')
|
t('admin.usage.requestId'), t('usage.userAgent'), t('admin.usage.ipAddress')
|
||||||
]
|
]
|
||||||
const rows = all.map(log => [
|
const rows = all.map(log => [
|
||||||
log.created_at,
|
log.created_at,
|
||||||
@@ -117,11 +117,11 @@ const exportToExcel = async () => {
|
|||||||
log.rate_multiplier?.toFixed(2) || '1.00',
|
log.rate_multiplier?.toFixed(2) || '1.00',
|
||||||
log.total_cost?.toFixed(6) || '0.000000',
|
log.total_cost?.toFixed(6) || '0.000000',
|
||||||
log.actual_cost?.toFixed(6) || '0.000000',
|
log.actual_cost?.toFixed(6) || '0.000000',
|
||||||
log.billing_type === 1 ? t('usage.subscription') : t('usage.balance'),
|
|
||||||
log.first_token_ms ?? '',
|
log.first_token_ms ?? '',
|
||||||
log.duration_ms,
|
log.duration_ms,
|
||||||
log.request_id || '',
|
log.request_id || '',
|
||||||
log.user_agent || ''
|
log.user_agent || '',
|
||||||
|
log.ip_address || ''
|
||||||
])
|
])
|
||||||
const ws = XLSX.utils.aoa_to_sheet([headers, ...rows])
|
const ws = XLSX.utils.aoa_to_sheet([headers, ...rows])
|
||||||
const wb = XLSX.utils.book_new()
|
const wb = XLSX.utils.book_new()
|
||||||
|
|||||||
@@ -893,12 +893,13 @@ const loadUsers = async () => {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error: any) {
|
||||||
const errorInfo = error as { name?: string; code?: string }
|
const errorInfo = error as { name?: string; code?: string }
|
||||||
if (errorInfo?.name === 'AbortError' || errorInfo?.name === 'CanceledError' || errorInfo?.code === 'ERR_CANCELED') {
|
if (errorInfo?.name === 'AbortError' || errorInfo?.name === 'CanceledError' || errorInfo?.code === 'ERR_CANCELED') {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
appStore.showError(t('admin.users.failedToLoad'))
|
const message = error.response?.data?.detail || error.message || t('admin.users.failedToLoad')
|
||||||
|
appStore.showError(message)
|
||||||
console.error('Error loading users:', error)
|
console.error('Error loading users:', error)
|
||||||
} finally {
|
} finally {
|
||||||
if (abortController === currentAbortController) {
|
if (abortController === currentAbortController) {
|
||||||
@@ -917,7 +918,9 @@ const handleSearch = () => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const handlePageChange = (page: number) => {
|
const handlePageChange = (page: number) => {
|
||||||
pagination.page = page
|
// 确保页码在有效范围内
|
||||||
|
const validPage = Math.max(1, Math.min(page, pagination.pages || 1))
|
||||||
|
pagination.page = validPage
|
||||||
loadUsers()
|
loadUsers()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -943,6 +946,7 @@ const toggleBuiltInFilter = (key: string) => {
|
|||||||
visibleFilters.add(key)
|
visibleFilters.add(key)
|
||||||
}
|
}
|
||||||
saveFiltersToStorage()
|
saveFiltersToStorage()
|
||||||
|
pagination.page = 1
|
||||||
loadUsers()
|
loadUsers()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -957,6 +961,7 @@ const toggleAttributeFilter = (attr: UserAttributeDefinition) => {
|
|||||||
activeAttributeFilters[attr.id] = ''
|
activeAttributeFilters[attr.id] = ''
|
||||||
}
|
}
|
||||||
saveFiltersToStorage()
|
saveFiltersToStorage()
|
||||||
|
pagination.page = 1
|
||||||
loadUsers()
|
loadUsers()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1059,5 +1064,7 @@ onMounted(async () => {
|
|||||||
|
|
||||||
onUnmounted(() => {
|
onUnmounted(() => {
|
||||||
document.removeEventListener('click', handleClickOutside)
|
document.removeEventListener('click', handleClickOutside)
|
||||||
|
clearTimeout(searchTimeout)
|
||||||
|
abortController?.abort()
|
||||||
})
|
})
|
||||||
</script>
|
</script>
|
||||||
|
|||||||
119
frontend/src/views/auth/LinuxDoCallbackView.vue
Normal file
119
frontend/src/views/auth/LinuxDoCallbackView.vue
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
<template>
|
||||||
|
<AuthLayout>
|
||||||
|
<div class="space-y-6">
|
||||||
|
<div class="text-center">
|
||||||
|
<h2 class="text-2xl font-bold text-gray-900 dark:text-white">
|
||||||
|
{{ t('auth.linuxdo.callbackTitle') }}
|
||||||
|
</h2>
|
||||||
|
<p class="mt-2 text-sm text-gray-500 dark:text-dark-400">
|
||||||
|
{{ isProcessing ? t('auth.linuxdo.callbackProcessing') : t('auth.linuxdo.callbackHint') }}
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<transition name="fade">
|
||||||
|
<div
|
||||||
|
v-if="errorMessage"
|
||||||
|
class="rounded-xl border border-red-200 bg-red-50 p-4 dark:border-red-800/50 dark:bg-red-900/20"
|
||||||
|
>
|
||||||
|
<div class="flex items-start gap-3">
|
||||||
|
<div class="flex-shrink-0">
|
||||||
|
<Icon name="exclamationCircle" size="md" class="text-red-500" />
|
||||||
|
</div>
|
||||||
|
<div class="space-y-2">
|
||||||
|
<p class="text-sm text-red-700 dark:text-red-400">
|
||||||
|
{{ errorMessage }}
|
||||||
|
</p>
|
||||||
|
<router-link to="/login" class="btn btn-primary">
|
||||||
|
{{ t('auth.linuxdo.backToLogin') }}
|
||||||
|
</router-link>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</transition>
|
||||||
|
</div>
|
||||||
|
</AuthLayout>
|
||||||
|
</template>
|
||||||
|
|
||||||
|
<script setup lang="ts">
|
||||||
|
import { onMounted, ref } from 'vue'
|
||||||
|
import { useRoute, useRouter } from 'vue-router'
|
||||||
|
import { useI18n } from 'vue-i18n'
|
||||||
|
import { AuthLayout } from '@/components/layout'
|
||||||
|
import Icon from '@/components/icons/Icon.vue'
|
||||||
|
import { useAuthStore, useAppStore } from '@/stores'
|
||||||
|
|
||||||
|
const route = useRoute()
|
||||||
|
const router = useRouter()
|
||||||
|
const { t } = useI18n()
|
||||||
|
|
||||||
|
const authStore = useAuthStore()
|
||||||
|
const appStore = useAppStore()
|
||||||
|
|
||||||
|
const isProcessing = ref(true)
|
||||||
|
const errorMessage = ref('')
|
||||||
|
|
||||||
|
function parseFragmentParams(): URLSearchParams {
|
||||||
|
const raw = typeof window !== 'undefined' ? window.location.hash : ''
|
||||||
|
const hash = raw.startsWith('#') ? raw.slice(1) : raw
|
||||||
|
return new URLSearchParams(hash)
|
||||||
|
}
|
||||||
|
|
||||||
|
function sanitizeRedirectPath(path: string | null | undefined): string {
|
||||||
|
if (!path) return '/dashboard'
|
||||||
|
if (!path.startsWith('/')) return '/dashboard'
|
||||||
|
if (path.startsWith('//')) return '/dashboard'
|
||||||
|
if (path.includes('://')) return '/dashboard'
|
||||||
|
if (path.includes('\n') || path.includes('\r')) return '/dashboard'
|
||||||
|
return path
|
||||||
|
}
|
||||||
|
|
||||||
|
onMounted(async () => {
|
||||||
|
const params = parseFragmentParams()
|
||||||
|
|
||||||
|
const token = params.get('access_token') || ''
|
||||||
|
const redirect = sanitizeRedirectPath(
|
||||||
|
params.get('redirect') || (route.query.redirect as string | undefined) || '/dashboard'
|
||||||
|
)
|
||||||
|
const error = params.get('error')
|
||||||
|
const errorDesc = params.get('error_description') || params.get('error_message') || ''
|
||||||
|
|
||||||
|
if (error) {
|
||||||
|
errorMessage.value = errorDesc || error
|
||||||
|
appStore.showError(errorMessage.value)
|
||||||
|
isProcessing.value = false
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!token) {
|
||||||
|
errorMessage.value = t('auth.linuxdo.callbackMissingToken')
|
||||||
|
appStore.showError(errorMessage.value)
|
||||||
|
isProcessing.value = false
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
await authStore.setToken(token)
|
||||||
|
appStore.showSuccess(t('auth.loginSuccess'))
|
||||||
|
await router.replace(redirect)
|
||||||
|
} catch (e: unknown) {
|
||||||
|
const err = e as { message?: string; response?: { data?: { detail?: string } } }
|
||||||
|
errorMessage.value = err.response?.data?.detail || err.message || t('auth.loginFailed')
|
||||||
|
appStore.showError(errorMessage.value)
|
||||||
|
isProcessing.value = false
|
||||||
|
}
|
||||||
|
})
|
||||||
|
</script>
|
||||||
|
|
||||||
|
<style scoped>
|
||||||
|
.fade-enter-active,
|
||||||
|
.fade-leave-active {
|
||||||
|
transition: all 0.3s ease;
|
||||||
|
}
|
||||||
|
|
||||||
|
.fade-enter-from,
|
||||||
|
.fade-leave-to {
|
||||||
|
opacity: 0;
|
||||||
|
transform: translateY(-8px);
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
|
||||||
@@ -11,6 +11,9 @@
|
|||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<!-- LinuxDo Connect OAuth 登录 -->
|
||||||
|
<LinuxDoOAuthSection v-if="linuxdoOAuthEnabled" :disabled="isLoading" />
|
||||||
|
|
||||||
<!-- Login Form -->
|
<!-- Login Form -->
|
||||||
<form @submit.prevent="handleLogin" class="space-y-5">
|
<form @submit.prevent="handleLogin" class="space-y-5">
|
||||||
<!-- Email Input -->
|
<!-- Email Input -->
|
||||||
@@ -157,6 +160,7 @@ import { ref, reactive, onMounted } from 'vue'
|
|||||||
import { useRouter } from 'vue-router'
|
import { useRouter } from 'vue-router'
|
||||||
import { useI18n } from 'vue-i18n'
|
import { useI18n } from 'vue-i18n'
|
||||||
import { AuthLayout } from '@/components/layout'
|
import { AuthLayout } from '@/components/layout'
|
||||||
|
import LinuxDoOAuthSection from '@/components/auth/LinuxDoOAuthSection.vue'
|
||||||
import Icon from '@/components/icons/Icon.vue'
|
import Icon from '@/components/icons/Icon.vue'
|
||||||
import TurnstileWidget from '@/components/TurnstileWidget.vue'
|
import TurnstileWidget from '@/components/TurnstileWidget.vue'
|
||||||
import { useAuthStore, useAppStore } from '@/stores'
|
import { useAuthStore, useAppStore } from '@/stores'
|
||||||
@@ -179,6 +183,7 @@ const showPassword = ref<boolean>(false)
|
|||||||
// Public settings
|
// Public settings
|
||||||
const turnstileEnabled = ref<boolean>(false)
|
const turnstileEnabled = ref<boolean>(false)
|
||||||
const turnstileSiteKey = ref<string>('')
|
const turnstileSiteKey = ref<string>('')
|
||||||
|
const linuxdoOAuthEnabled = ref<boolean>(false)
|
||||||
|
|
||||||
// Turnstile
|
// Turnstile
|
||||||
const turnstileRef = ref<InstanceType<typeof TurnstileWidget> | null>(null)
|
const turnstileRef = ref<InstanceType<typeof TurnstileWidget> | null>(null)
|
||||||
@@ -210,6 +215,7 @@ onMounted(async () => {
|
|||||||
const settings = await getPublicSettings()
|
const settings = await getPublicSettings()
|
||||||
turnstileEnabled.value = settings.turnstile_enabled
|
turnstileEnabled.value = settings.turnstile_enabled
|
||||||
turnstileSiteKey.value = settings.turnstile_site_key || ''
|
turnstileSiteKey.value = settings.turnstile_site_key || ''
|
||||||
|
linuxdoOAuthEnabled.value = settings.linuxdo_oauth_enabled
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('Failed to load public settings:', error)
|
console.error('Failed to load public settings:', error)
|
||||||
}
|
}
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user