diff --git a/Linux DO Connect.md b/Linux DO Connect.md new file mode 100644 index 00000000..7ca1260f --- /dev/null +++ b/Linux DO Connect.md @@ -0,0 +1,368 @@ +# Linux DO Connect + +OAuth(Open Authorization)是一个开放的网络授权标准,目前最新版本为 OAuth 2.0。我们日常使用的第三方登录(如 Google 账号登录)就采用了该标准。OAuth 允许用户授权第三方应用访问存储在其他服务提供商(如 Google)上的信息,无需在不同平台上重复填写注册信息。用户授权后,平台可以直接访问用户的账户信息进行身份验证,而用户无需向第三方应用提供密码。 + +目前系统已实现完整的 OAuth2 授权码(code)方式鉴权,但界面等配套功能还在持续完善中。让我们一起打造一个更完善的共享方案。 + +## 基本介绍 + +这是一套标准的 OAuth2 鉴权系统,可以让开发者共享论坛的用户基本信息。 + +- 可获取字段: + +| 参数 | 说明 | +| ----------------- | ------------------------------- | +| `id` | 用户唯一标识(不可变) | +| `username` | 论坛用户名 | +| `name` | 论坛用户昵称(可变) | +| `avatar_template` | 用户头像模板URL(支持多种尺寸) | +| `active` | 账号活跃状态 | +| `trust_level` | 信任等级(0-4) | +| `silenced` | 禁言状态 | +| `external_ids` | 外部ID关联信息 | +| `api_key` | API访问密钥 | + +通过这些信息,公益网站/接口可以实现: + +1. 基于 `id` 的服务频率限制 +2. 基于 `trust_level` 的服务额度分配 +3. 基于用户信息的滥用举报机制 + +## 相关端点 + +- Authorize 端点: `https://connect.linux.do/oauth2/authorize` +- Token 端点:`https://connect.linux.do/oauth2/token` +- 用户信息 端点:`https://connect.linux.do/api/user` + +## 申请使用 + +- 访问 [Connect.Linux.Do](https://connect.linux.do/) 申请接入你的应用。 + +![linuxdoconnect_1](https://wiki.linux.do/_next/image?url=%2Flinuxdoconnect_1.png&w=1080&q=75) + +- 点击 **`我的应用接入`** - **`申请新接入`**,填写相关信息。其中 **`回调地址`** 是你的应用接收用户信息的地址。 + +![linuxdoconnect_2](https://wiki.linux.do/_next/image?url=%2Flinuxdoconnect_2.png&w=1080&q=75) + +- 申请成功后,你将获得 **`Client Id`** 和 **`Client Secret`**,这是你应用的唯一身份凭证。 + +![linuxdoconnect_3](https://wiki.linux.do/_next/image?url=%2Flinuxdoconnect_3.png&w=1080&q=75) + +## 接入 Linux Do + +JavaScript +```JavaScript +// 安装第三方请求库(或使用原生的 Fetch API),本例中使用 axios +// npm install axios + +// 通过 OAuth2 获取 Linux Do 用户信息的参考流程 +const axios = require('axios'); +const readline = require('readline'); + +// 配置信息(建议通过环境变量配置,避免使用硬编码) +const CLIENT_ID = '你的 Client ID'; +const CLIENT_SECRET = '你的 Client Secret'; +const REDIRECT_URI = '你的回调地址'; +const AUTH_URL = 'https://connect.linux.do/oauth2/authorize'; +const TOKEN_URL = 'https://connect.linux.do/oauth2/token'; +const USER_INFO_URL = 'https://connect.linux.do/api/user'; + +// 第一步:生成授权 URL +function getAuthUrl() { + const params = new URLSearchParams({ + client_id: CLIENT_ID, + redirect_uri: REDIRECT_URI, + response_type: 'code', + scope: 'user' + }); + + return `${AUTH_URL}?${params.toString()}`; +} + +// 第二步:获取 code 参数 +function getCode() { + return new Promise((resolve) => { + // 本例中使用终端输入来模拟流程,仅供本地测试 + // 请在实际应用中替换为真实的处理逻辑 + const rl = readline.createInterface({ input: process.stdin, output: process.stdout }); + rl.question('从回调 URL 中提取出 code,粘贴到此处并按回车:', (answer) => { + rl.close(); + resolve(answer.trim()); + }); + }); +} + +// 第三步:使用 code 参数获取访问令牌 +async function getAccessToken(code) { + try { + const form = new URLSearchParams({ + client_id: CLIENT_ID, + client_secret: CLIENT_SECRET, + code: code, + redirect_uri: REDIRECT_URI, + grant_type: 'authorization_code' + }).toString(); + + const response = await axios.post(TOKEN_URL, form, { + // 提醒:需正确配置请求头,否则无法正常获取访问令牌 + headers: { + 'Content-Type': 'application/x-www-form-urlencoded', + 'Accept': 'application/json' + } + }); + + return response.data; + } catch (error) { + console.error(`获取访问令牌失败:${error.response ? JSON.stringify(error.response.data) : error.message}`); + throw error; + } +} + +// 第四步:使用访问令牌获取用户信息 +async function getUserInfo(accessToken) { + try { + const response = await axios.get(USER_INFO_URL, { + headers: { + Authorization: `Bearer ${accessToken}` + } + }); + + return response.data; + } catch (error) { + console.error(`获取用户信息失败:${error.response ? JSON.stringify(error.response.data) : error.message}`); + throw error; + } +} + +// 主流程 +async function main() { + // 1. 生成授权 URL,前端引导用户访问授权页 + const authUrl = getAuthUrl(); + console.log(`请访问此 URL 授权:${authUrl} +`); + + // 2. 用户授权后,从回调 URL 获取 code 参数 + const code = await getCode(); + + try { + // 3. 使用 code 参数获取访问令牌 + const tokenData = await getAccessToken(code); + const accessToken = tokenData.access_token; + + // 4. 使用访问令牌获取用户信息 + if (accessToken) { + const userInfo = await getUserInfo(accessToken); + console.log(` +获取用户信息成功:${JSON.stringify(userInfo, null, 2)}`); + } else { + console.log(` +获取访问令牌失败:${JSON.stringify(tokenData)}`); + } + } catch (error) { + console.error('发生错误:', error); + } +} +``` +Python +```python +# 安装第三方请求库,本例中使用 requests +# pip install requests + +# 通过 OAuth2 获取 Linux Do 用户信息的参考流程 +import requests +import json + +# 配置信息(建议通过环境变量配置,避免使用硬编码) +CLIENT_ID = '你的 Client ID' +CLIENT_SECRET = '你的 Client Secret' +REDIRECT_URI = '你的回调地址' +AUTH_URL = 'https://connect.linux.do/oauth2/authorize' +TOKEN_URL = 'https://connect.linux.do/oauth2/token' +USER_INFO_URL = 'https://connect.linux.do/api/user' + +# 第一步:生成授权 URL +def get_auth_url(): + params = { + 'client_id': CLIENT_ID, + 'redirect_uri': REDIRECT_URI, + 'response_type': 'code', + 'scope': 'user' + } + auth_url = f"{AUTH_URL}?{'&'.join(f'{k}={v}' for k, v in params.items())}" + return auth_url + +# 第二步:获取 code 参数 +def get_code(): + # 本例中使用终端输入来模拟流程,仅供本地测试 + # 请在实际应用中替换为真实的处理逻辑 + return input('从回调 URL 中提取出 code,粘贴到此处并按回车:').strip() + +# 第三步:使用 code 参数获取访问令牌 +def get_access_token(code): + try: + data = { + 'client_id': CLIENT_ID, + 'client_secret': CLIENT_SECRET, + 'code': code, + 'redirect_uri': REDIRECT_URI, + 'grant_type': 'authorization_code' + } + # 提醒:需正确配置请求头,否则无法正常获取访问令牌 + headers = { + 'Content-Type': 'application/x-www-form-urlencoded', + 'Accept': 'application/json' + } + response = requests.post(TOKEN_URL, data=data, headers=headers) + response.raise_for_status() + return response.json() + except requests.exceptions.RequestException as e: + print(f"获取访问令牌失败:{e}") + return None + +# 第四步:使用访问令牌获取用户信息 +def get_user_info(access_token): + try: + headers = { + 'Authorization': f'Bearer {access_token}' + } + response = requests.get(USER_INFO_URL, headers=headers) + response.raise_for_status() + return response.json() + except requests.exceptions.RequestException as e: + print(f"获取用户信息失败:{e}") + return None + +# 主流程 +if __name__ == '__main__': + # 1. 生成授权 URL,前端引导用户访问授权页 + auth_url = get_auth_url() + print(f'请访问此 URL 授权:{auth_url} +') + + # 2. 用户授权后,从回调 URL 获取 code 参数 + code = get_code() + + # 3. 使用 code 参数获取访问令牌 + token_data = get_access_token(code) + if token_data: + access_token = token_data.get('access_token') + + # 4. 使用访问令牌获取用户信息 + if access_token: + user_info = get_user_info(access_token) + if user_info: + print(f" +获取用户信息成功:{json.dumps(user_info, indent=2)}") + else: + print(" +获取用户信息失败") + else: + print(f" +获取访问令牌失败:{json.dumps(token_data, indent=2)}") + else: + print(" +获取访问令牌失败") +``` +PHP +```php +// 通过 OAuth2 获取 Linux Do 用户信息的参考流程 + +// 配置信息 +$CLIENT_ID = '你的 Client ID'; +$CLIENT_SECRET = '你的 Client Secret'; +$REDIRECT_URI = '你的回调地址'; +$AUTH_URL = 'https://connect.linux.do/oauth2/authorize'; +$TOKEN_URL = 'https://connect.linux.do/oauth2/token'; +$USER_INFO_URL = 'https://connect.linux.do/api/user'; + +// 生成授权 URL +function getAuthUrl($clientId, $redirectUri) { + global $AUTH_URL; + return $AUTH_URL . '?' . http_build_query([ + 'client_id' => $clientId, + 'redirect_uri' => $redirectUri, + 'response_type' => 'code', + 'scope' => 'user' + ]); +} + +// 使用 code 参数获取用户信息(合并获取令牌和获取用户信息的步骤) +function getUserInfoWithCode($code, $clientId, $clientSecret, $redirectUri) { + global $TOKEN_URL, $USER_INFO_URL; + + // 1. 获取访问令牌 + $ch = curl_init($TOKEN_URL); + curl_setopt($ch, CURLOPT_RETURNTRANSFER, true); + curl_setopt($ch, CURLOPT_POST, true); + curl_setopt($ch, CURLOPT_POSTFIELDS, http_build_query([ + 'client_id' => $clientId, + 'client_secret' => $clientSecret, + 'code' => $code, + 'redirect_uri' => $redirectUri, + 'grant_type' => 'authorization_code' + ])); + curl_setopt($ch, CURLOPT_HTTPHEADER, [ + 'Content-Type: application/x-www-form-urlencoded', + 'Accept: application/json' + ]); + + $tokenResponse = curl_exec($ch); + curl_close($ch); + + $tokenData = json_decode($tokenResponse, true); + if (!isset($tokenData['access_token'])) { + return ['error' => '获取访问令牌失败', 'details' => $tokenData]; + } + + // 2. 获取用户信息 + $ch = curl_init($USER_INFO_URL); + curl_setopt($ch, CURLOPT_RETURNTRANSFER, true); + curl_setopt($ch, CURLOPT_HTTPHEADER, [ + 'Authorization: Bearer ' . $tokenData['access_token'] + ]); + + $userResponse = curl_exec($ch); + curl_close($ch); + + return json_decode($userResponse, true); +} + +// 主流程 +// 1. 生成授权 URL +$authUrl = getAuthUrl($CLIENT_ID, $REDIRECT_URI); +echo "使用 Linux Do 登录"; + +// 2. 处理回调并获取用户信息 +if (isset($_GET['code'])) { + $userInfo = getUserInfoWithCode( + $_GET['code'], + $CLIENT_ID, + $CLIENT_SECRET, + $REDIRECT_URI + ); + + if (isset($userInfo['error'])) { + echo '错误: ' . $userInfo['error']; + } else { + echo '欢迎, ' . $userInfo['name'] . '!'; + // 处理用户登录逻辑... + } +} +``` + +## 使用说明 + +### 授权流程 + +1. 用户点击应用中的’使用 Linux Do 登录’按钮 +2. 系统将用户重定向至 Linux Do 的授权页面 +3. 用户完成授权后,系统自动重定向回应用并携带授权码 +4. 应用使用授权码获取访问令牌 +5. 使用访问令牌获取用户信息 + +### 安全建议 + +- 切勿在前端代码中暴露 Client Secret +- 对所有用户输入数据进行严格验证 +- 确保使用 HTTPS 协议传输数据 +- 定期更新并妥善保管 Client Secret \ No newline at end of file diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION index 17e51c38..79e0dd8a 100644 --- a/backend/cmd/server/VERSION +++ b/backend/cmd/server/VERSION @@ -1 +1 @@ -0.1.1 +0.1.46 diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 31dc3682..85bed3f3 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -53,7 +53,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { emailQueueService := service.ProvideEmailQueueService(emailService) authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService) userService := service.NewUserService(userRepository) - authHandler := handler.NewAuthHandler(configConfig, authService, userService) + authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService) userHandler := handler.NewUserHandler(userService) apiKeyRepository := repository.NewAPIKeyRepository(client) groupRepository := repository.NewGroupRepository(client, db) diff --git a/backend/ent/apikey.go b/backend/ent/apikey.go index fe3ad0cf..95586017 100644 --- a/backend/ent/apikey.go +++ b/backend/ent/apikey.go @@ -3,6 +3,7 @@ package ent import ( + "encoding/json" "fmt" "strings" "time" @@ -35,6 +36,10 @@ type APIKey struct { GroupID *int64 `json:"group_id,omitempty"` // Status holds the value of the "status" field. Status string `json:"status,omitempty"` + // Allowed IPs/CIDRs, e.g. ["192.168.1.100", "10.0.0.0/8"] + IPWhitelist []string `json:"ip_whitelist,omitempty"` + // Blocked IPs/CIDRs + IPBlacklist []string `json:"ip_blacklist,omitempty"` // Edges holds the relations/edges for other nodes in the graph. // The values are being populated by the APIKeyQuery when eager-loading is set. Edges APIKeyEdges `json:"edges"` @@ -90,6 +95,8 @@ func (*APIKey) scanValues(columns []string) ([]any, error) { values := make([]any, len(columns)) for i := range columns { switch columns[i] { + case apikey.FieldIPWhitelist, apikey.FieldIPBlacklist: + values[i] = new([]byte) case apikey.FieldID, apikey.FieldUserID, apikey.FieldGroupID: values[i] = new(sql.NullInt64) case apikey.FieldKey, apikey.FieldName, apikey.FieldStatus: @@ -167,6 +174,22 @@ func (_m *APIKey) assignValues(columns []string, values []any) error { } else if value.Valid { _m.Status = value.String } + case apikey.FieldIPWhitelist: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field ip_whitelist", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &_m.IPWhitelist); err != nil { + return fmt.Errorf("unmarshal field ip_whitelist: %w", err) + } + } + case apikey.FieldIPBlacklist: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field ip_blacklist", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &_m.IPBlacklist); err != nil { + return fmt.Errorf("unmarshal field ip_blacklist: %w", err) + } + } default: _m.selectValues.Set(columns[i], values[i]) } @@ -245,6 +268,12 @@ func (_m *APIKey) String() string { builder.WriteString(", ") builder.WriteString("status=") builder.WriteString(_m.Status) + builder.WriteString(", ") + builder.WriteString("ip_whitelist=") + builder.WriteString(fmt.Sprintf("%v", _m.IPWhitelist)) + builder.WriteString(", ") + builder.WriteString("ip_blacklist=") + builder.WriteString(fmt.Sprintf("%v", _m.IPBlacklist)) builder.WriteByte(')') return builder.String() } diff --git a/backend/ent/apikey/apikey.go b/backend/ent/apikey/apikey.go index 91f7d620..564cddb1 100644 --- a/backend/ent/apikey/apikey.go +++ b/backend/ent/apikey/apikey.go @@ -31,6 +31,10 @@ const ( FieldGroupID = "group_id" // FieldStatus holds the string denoting the status field in the database. FieldStatus = "status" + // FieldIPWhitelist holds the string denoting the ip_whitelist field in the database. + FieldIPWhitelist = "ip_whitelist" + // FieldIPBlacklist holds the string denoting the ip_blacklist field in the database. + FieldIPBlacklist = "ip_blacklist" // EdgeUser holds the string denoting the user edge name in mutations. EdgeUser = "user" // EdgeGroup holds the string denoting the group edge name in mutations. @@ -73,6 +77,8 @@ var Columns = []string{ FieldName, FieldGroupID, FieldStatus, + FieldIPWhitelist, + FieldIPBlacklist, } // ValidColumn reports if the column name is valid (part of the table columns). diff --git a/backend/ent/apikey/where.go b/backend/ent/apikey/where.go index 5e739006..5152867f 100644 --- a/backend/ent/apikey/where.go +++ b/backend/ent/apikey/where.go @@ -470,6 +470,26 @@ func StatusContainsFold(v string) predicate.APIKey { return predicate.APIKey(sql.FieldContainsFold(FieldStatus, v)) } +// IPWhitelistIsNil applies the IsNil predicate on the "ip_whitelist" field. +func IPWhitelistIsNil() predicate.APIKey { + return predicate.APIKey(sql.FieldIsNull(FieldIPWhitelist)) +} + +// IPWhitelistNotNil applies the NotNil predicate on the "ip_whitelist" field. +func IPWhitelistNotNil() predicate.APIKey { + return predicate.APIKey(sql.FieldNotNull(FieldIPWhitelist)) +} + +// IPBlacklistIsNil applies the IsNil predicate on the "ip_blacklist" field. +func IPBlacklistIsNil() predicate.APIKey { + return predicate.APIKey(sql.FieldIsNull(FieldIPBlacklist)) +} + +// IPBlacklistNotNil applies the NotNil predicate on the "ip_blacklist" field. +func IPBlacklistNotNil() predicate.APIKey { + return predicate.APIKey(sql.FieldNotNull(FieldIPBlacklist)) +} + // HasUser applies the HasEdge predicate on the "user" edge. func HasUser() predicate.APIKey { return predicate.APIKey(func(s *sql.Selector) { diff --git a/backend/ent/apikey_create.go b/backend/ent/apikey_create.go index 2098872c..d5363be5 100644 --- a/backend/ent/apikey_create.go +++ b/backend/ent/apikey_create.go @@ -113,6 +113,18 @@ func (_c *APIKeyCreate) SetNillableStatus(v *string) *APIKeyCreate { return _c } +// SetIPWhitelist sets the "ip_whitelist" field. +func (_c *APIKeyCreate) SetIPWhitelist(v []string) *APIKeyCreate { + _c.mutation.SetIPWhitelist(v) + return _c +} + +// SetIPBlacklist sets the "ip_blacklist" field. +func (_c *APIKeyCreate) SetIPBlacklist(v []string) *APIKeyCreate { + _c.mutation.SetIPBlacklist(v) + return _c +} + // SetUser sets the "user" edge to the User entity. func (_c *APIKeyCreate) SetUser(v *User) *APIKeyCreate { return _c.SetUserID(v.ID) @@ -285,6 +297,14 @@ func (_c *APIKeyCreate) createSpec() (*APIKey, *sqlgraph.CreateSpec) { _spec.SetField(apikey.FieldStatus, field.TypeString, value) _node.Status = value } + if value, ok := _c.mutation.IPWhitelist(); ok { + _spec.SetField(apikey.FieldIPWhitelist, field.TypeJSON, value) + _node.IPWhitelist = value + } + if value, ok := _c.mutation.IPBlacklist(); ok { + _spec.SetField(apikey.FieldIPBlacklist, field.TypeJSON, value) + _node.IPBlacklist = value + } if nodes := _c.mutation.UserIDs(); len(nodes) > 0 { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.M2O, @@ -483,6 +503,42 @@ func (u *APIKeyUpsert) UpdateStatus() *APIKeyUpsert { return u } +// SetIPWhitelist sets the "ip_whitelist" field. +func (u *APIKeyUpsert) SetIPWhitelist(v []string) *APIKeyUpsert { + u.Set(apikey.FieldIPWhitelist, v) + return u +} + +// UpdateIPWhitelist sets the "ip_whitelist" field to the value that was provided on create. +func (u *APIKeyUpsert) UpdateIPWhitelist() *APIKeyUpsert { + u.SetExcluded(apikey.FieldIPWhitelist) + return u +} + +// ClearIPWhitelist clears the value of the "ip_whitelist" field. +func (u *APIKeyUpsert) ClearIPWhitelist() *APIKeyUpsert { + u.SetNull(apikey.FieldIPWhitelist) + return u +} + +// SetIPBlacklist sets the "ip_blacklist" field. +func (u *APIKeyUpsert) SetIPBlacklist(v []string) *APIKeyUpsert { + u.Set(apikey.FieldIPBlacklist, v) + return u +} + +// UpdateIPBlacklist sets the "ip_blacklist" field to the value that was provided on create. +func (u *APIKeyUpsert) UpdateIPBlacklist() *APIKeyUpsert { + u.SetExcluded(apikey.FieldIPBlacklist) + return u +} + +// ClearIPBlacklist clears the value of the "ip_blacklist" field. +func (u *APIKeyUpsert) ClearIPBlacklist() *APIKeyUpsert { + u.SetNull(apikey.FieldIPBlacklist) + return u +} + // UpdateNewValues updates the mutable fields using the new values that were set on create. // Using this option is equivalent to using: // @@ -640,6 +696,48 @@ func (u *APIKeyUpsertOne) UpdateStatus() *APIKeyUpsertOne { }) } +// SetIPWhitelist sets the "ip_whitelist" field. +func (u *APIKeyUpsertOne) SetIPWhitelist(v []string) *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.SetIPWhitelist(v) + }) +} + +// UpdateIPWhitelist sets the "ip_whitelist" field to the value that was provided on create. +func (u *APIKeyUpsertOne) UpdateIPWhitelist() *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateIPWhitelist() + }) +} + +// ClearIPWhitelist clears the value of the "ip_whitelist" field. +func (u *APIKeyUpsertOne) ClearIPWhitelist() *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.ClearIPWhitelist() + }) +} + +// SetIPBlacklist sets the "ip_blacklist" field. +func (u *APIKeyUpsertOne) SetIPBlacklist(v []string) *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.SetIPBlacklist(v) + }) +} + +// UpdateIPBlacklist sets the "ip_blacklist" field to the value that was provided on create. +func (u *APIKeyUpsertOne) UpdateIPBlacklist() *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateIPBlacklist() + }) +} + +// ClearIPBlacklist clears the value of the "ip_blacklist" field. +func (u *APIKeyUpsertOne) ClearIPBlacklist() *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.ClearIPBlacklist() + }) +} + // Exec executes the query. func (u *APIKeyUpsertOne) Exec(ctx context.Context) error { if len(u.create.conflict) == 0 { @@ -963,6 +1061,48 @@ func (u *APIKeyUpsertBulk) UpdateStatus() *APIKeyUpsertBulk { }) } +// SetIPWhitelist sets the "ip_whitelist" field. +func (u *APIKeyUpsertBulk) SetIPWhitelist(v []string) *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.SetIPWhitelist(v) + }) +} + +// UpdateIPWhitelist sets the "ip_whitelist" field to the value that was provided on create. +func (u *APIKeyUpsertBulk) UpdateIPWhitelist() *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateIPWhitelist() + }) +} + +// ClearIPWhitelist clears the value of the "ip_whitelist" field. +func (u *APIKeyUpsertBulk) ClearIPWhitelist() *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.ClearIPWhitelist() + }) +} + +// SetIPBlacklist sets the "ip_blacklist" field. +func (u *APIKeyUpsertBulk) SetIPBlacklist(v []string) *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.SetIPBlacklist(v) + }) +} + +// UpdateIPBlacklist sets the "ip_blacklist" field to the value that was provided on create. +func (u *APIKeyUpsertBulk) UpdateIPBlacklist() *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateIPBlacklist() + }) +} + +// ClearIPBlacklist clears the value of the "ip_blacklist" field. +func (u *APIKeyUpsertBulk) ClearIPBlacklist() *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.ClearIPBlacklist() + }) +} + // Exec executes the query. func (u *APIKeyUpsertBulk) Exec(ctx context.Context) error { if u.create.err != nil { diff --git a/backend/ent/apikey_update.go b/backend/ent/apikey_update.go index 4a16369b..9ae332a8 100644 --- a/backend/ent/apikey_update.go +++ b/backend/ent/apikey_update.go @@ -10,6 +10,7 @@ import ( "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/dialect/sql/sqljson" "entgo.io/ent/schema/field" "github.com/Wei-Shaw/sub2api/ent/apikey" "github.com/Wei-Shaw/sub2api/ent/group" @@ -133,6 +134,42 @@ func (_u *APIKeyUpdate) SetNillableStatus(v *string) *APIKeyUpdate { return _u } +// SetIPWhitelist sets the "ip_whitelist" field. +func (_u *APIKeyUpdate) SetIPWhitelist(v []string) *APIKeyUpdate { + _u.mutation.SetIPWhitelist(v) + return _u +} + +// AppendIPWhitelist appends value to the "ip_whitelist" field. +func (_u *APIKeyUpdate) AppendIPWhitelist(v []string) *APIKeyUpdate { + _u.mutation.AppendIPWhitelist(v) + return _u +} + +// ClearIPWhitelist clears the value of the "ip_whitelist" field. +func (_u *APIKeyUpdate) ClearIPWhitelist() *APIKeyUpdate { + _u.mutation.ClearIPWhitelist() + return _u +} + +// SetIPBlacklist sets the "ip_blacklist" field. +func (_u *APIKeyUpdate) SetIPBlacklist(v []string) *APIKeyUpdate { + _u.mutation.SetIPBlacklist(v) + return _u +} + +// AppendIPBlacklist appends value to the "ip_blacklist" field. +func (_u *APIKeyUpdate) AppendIPBlacklist(v []string) *APIKeyUpdate { + _u.mutation.AppendIPBlacklist(v) + return _u +} + +// ClearIPBlacklist clears the value of the "ip_blacklist" field. +func (_u *APIKeyUpdate) ClearIPBlacklist() *APIKeyUpdate { + _u.mutation.ClearIPBlacklist() + return _u +} + // SetUser sets the "user" edge to the User entity. func (_u *APIKeyUpdate) SetUser(v *User) *APIKeyUpdate { return _u.SetUserID(v.ID) @@ -291,6 +328,28 @@ func (_u *APIKeyUpdate) sqlSave(ctx context.Context) (_node int, err error) { if value, ok := _u.mutation.Status(); ok { _spec.SetField(apikey.FieldStatus, field.TypeString, value) } + if value, ok := _u.mutation.IPWhitelist(); ok { + _spec.SetField(apikey.FieldIPWhitelist, field.TypeJSON, value) + } + if value, ok := _u.mutation.AppendedIPWhitelist(); ok { + _spec.AddModifier(func(u *sql.UpdateBuilder) { + sqljson.Append(u, apikey.FieldIPWhitelist, value) + }) + } + if _u.mutation.IPWhitelistCleared() { + _spec.ClearField(apikey.FieldIPWhitelist, field.TypeJSON) + } + if value, ok := _u.mutation.IPBlacklist(); ok { + _spec.SetField(apikey.FieldIPBlacklist, field.TypeJSON, value) + } + if value, ok := _u.mutation.AppendedIPBlacklist(); ok { + _spec.AddModifier(func(u *sql.UpdateBuilder) { + sqljson.Append(u, apikey.FieldIPBlacklist, value) + }) + } + if _u.mutation.IPBlacklistCleared() { + _spec.ClearField(apikey.FieldIPBlacklist, field.TypeJSON) + } if _u.mutation.UserCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.M2O, @@ -516,6 +575,42 @@ func (_u *APIKeyUpdateOne) SetNillableStatus(v *string) *APIKeyUpdateOne { return _u } +// SetIPWhitelist sets the "ip_whitelist" field. +func (_u *APIKeyUpdateOne) SetIPWhitelist(v []string) *APIKeyUpdateOne { + _u.mutation.SetIPWhitelist(v) + return _u +} + +// AppendIPWhitelist appends value to the "ip_whitelist" field. +func (_u *APIKeyUpdateOne) AppendIPWhitelist(v []string) *APIKeyUpdateOne { + _u.mutation.AppendIPWhitelist(v) + return _u +} + +// ClearIPWhitelist clears the value of the "ip_whitelist" field. +func (_u *APIKeyUpdateOne) ClearIPWhitelist() *APIKeyUpdateOne { + _u.mutation.ClearIPWhitelist() + return _u +} + +// SetIPBlacklist sets the "ip_blacklist" field. +func (_u *APIKeyUpdateOne) SetIPBlacklist(v []string) *APIKeyUpdateOne { + _u.mutation.SetIPBlacklist(v) + return _u +} + +// AppendIPBlacklist appends value to the "ip_blacklist" field. +func (_u *APIKeyUpdateOne) AppendIPBlacklist(v []string) *APIKeyUpdateOne { + _u.mutation.AppendIPBlacklist(v) + return _u +} + +// ClearIPBlacklist clears the value of the "ip_blacklist" field. +func (_u *APIKeyUpdateOne) ClearIPBlacklist() *APIKeyUpdateOne { + _u.mutation.ClearIPBlacklist() + return _u +} + // SetUser sets the "user" edge to the User entity. func (_u *APIKeyUpdateOne) SetUser(v *User) *APIKeyUpdateOne { return _u.SetUserID(v.ID) @@ -704,6 +799,28 @@ func (_u *APIKeyUpdateOne) sqlSave(ctx context.Context) (_node *APIKey, err erro if value, ok := _u.mutation.Status(); ok { _spec.SetField(apikey.FieldStatus, field.TypeString, value) } + if value, ok := _u.mutation.IPWhitelist(); ok { + _spec.SetField(apikey.FieldIPWhitelist, field.TypeJSON, value) + } + if value, ok := _u.mutation.AppendedIPWhitelist(); ok { + _spec.AddModifier(func(u *sql.UpdateBuilder) { + sqljson.Append(u, apikey.FieldIPWhitelist, value) + }) + } + if _u.mutation.IPWhitelistCleared() { + _spec.ClearField(apikey.FieldIPWhitelist, field.TypeJSON) + } + if value, ok := _u.mutation.IPBlacklist(); ok { + _spec.SetField(apikey.FieldIPBlacklist, field.TypeJSON, value) + } + if value, ok := _u.mutation.AppendedIPBlacklist(); ok { + _spec.AddModifier(func(u *sql.UpdateBuilder) { + sqljson.Append(u, apikey.FieldIPBlacklist, value) + }) + } + if _u.mutation.IPBlacklistCleared() { + _spec.ClearField(apikey.FieldIPBlacklist, field.TypeJSON) + } if _u.mutation.UserCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.M2O, diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go index 13081e31..fdde0cd1 100644 --- a/backend/ent/migrate/schema.go +++ b/backend/ent/migrate/schema.go @@ -18,6 +18,8 @@ var ( {Name: "key", Type: field.TypeString, Unique: true, Size: 128}, {Name: "name", Type: field.TypeString, Size: 100}, {Name: "status", Type: field.TypeString, Size: 20, Default: "active"}, + {Name: "ip_whitelist", Type: field.TypeJSON, Nullable: true}, + {Name: "ip_blacklist", Type: field.TypeJSON, Nullable: true}, {Name: "group_id", Type: field.TypeInt64, Nullable: true}, {Name: "user_id", Type: field.TypeInt64}, } @@ -29,13 +31,13 @@ var ( ForeignKeys: []*schema.ForeignKey{ { Symbol: "api_keys_groups_api_keys", - Columns: []*schema.Column{APIKeysColumns[7]}, + Columns: []*schema.Column{APIKeysColumns[9]}, RefColumns: []*schema.Column{GroupsColumns[0]}, OnDelete: schema.SetNull, }, { Symbol: "api_keys_users_api_keys", - Columns: []*schema.Column{APIKeysColumns[8]}, + Columns: []*schema.Column{APIKeysColumns[10]}, RefColumns: []*schema.Column{UsersColumns[0]}, OnDelete: schema.NoAction, }, @@ -44,12 +46,12 @@ var ( { Name: "apikey_user_id", Unique: false, - Columns: []*schema.Column{APIKeysColumns[8]}, + Columns: []*schema.Column{APIKeysColumns[10]}, }, { Name: "apikey_group_id", Unique: false, - Columns: []*schema.Column{APIKeysColumns[7]}, + Columns: []*schema.Column{APIKeysColumns[9]}, }, { Name: "apikey_status", @@ -376,6 +378,7 @@ var ( {Name: "duration_ms", Type: field.TypeInt, Nullable: true}, {Name: "first_token_ms", Type: field.TypeInt, Nullable: true}, {Name: "user_agent", Type: field.TypeString, Nullable: true, Size: 512}, + {Name: "ip_address", Type: field.TypeString, Nullable: true, Size: 45}, {Name: "image_count", Type: field.TypeInt, Default: 0}, {Name: "image_size", Type: field.TypeString, Nullable: true, Size: 10}, {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, @@ -393,31 +396,31 @@ var ( ForeignKeys: []*schema.ForeignKey{ { Symbol: "usage_logs_api_keys_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[24]}, + Columns: []*schema.Column{UsageLogsColumns[25]}, RefColumns: []*schema.Column{APIKeysColumns[0]}, OnDelete: schema.NoAction, }, { Symbol: "usage_logs_accounts_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[25]}, + Columns: []*schema.Column{UsageLogsColumns[26]}, RefColumns: []*schema.Column{AccountsColumns[0]}, OnDelete: schema.NoAction, }, { Symbol: "usage_logs_groups_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[26]}, + Columns: []*schema.Column{UsageLogsColumns[27]}, RefColumns: []*schema.Column{GroupsColumns[0]}, OnDelete: schema.SetNull, }, { Symbol: "usage_logs_users_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[27]}, + Columns: []*schema.Column{UsageLogsColumns[28]}, RefColumns: []*schema.Column{UsersColumns[0]}, OnDelete: schema.NoAction, }, { Symbol: "usage_logs_user_subscriptions_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[28]}, + Columns: []*schema.Column{UsageLogsColumns[29]}, RefColumns: []*schema.Column{UserSubscriptionsColumns[0]}, OnDelete: schema.SetNull, }, @@ -426,32 +429,32 @@ var ( { Name: "usagelog_user_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[27]}, + Columns: []*schema.Column{UsageLogsColumns[28]}, }, { Name: "usagelog_api_key_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[24]}, + Columns: []*schema.Column{UsageLogsColumns[25]}, }, { Name: "usagelog_account_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[25]}, + Columns: []*schema.Column{UsageLogsColumns[26]}, }, { Name: "usagelog_group_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[26]}, + Columns: []*schema.Column{UsageLogsColumns[27]}, }, { Name: "usagelog_subscription_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[28]}, + Columns: []*schema.Column{UsageLogsColumns[29]}, }, { Name: "usagelog_created_at", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[23]}, + Columns: []*schema.Column{UsageLogsColumns[24]}, }, { Name: "usagelog_model", @@ -466,12 +469,12 @@ var ( { Name: "usagelog_user_id_created_at", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[27], UsageLogsColumns[23]}, + Columns: []*schema.Column{UsageLogsColumns[28], UsageLogsColumns[24]}, }, { Name: "usagelog_api_key_id_created_at", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[24], UsageLogsColumns[23]}, + Columns: []*schema.Column{UsageLogsColumns[25], UsageLogsColumns[24]}, }, }, } diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go index 4e01e12b..09801d4b 100644 --- a/backend/ent/mutation.go +++ b/backend/ent/mutation.go @@ -54,26 +54,30 @@ const ( // APIKeyMutation represents an operation that mutates the APIKey nodes in the graph. type APIKeyMutation struct { config - op Op - typ string - id *int64 - created_at *time.Time - updated_at *time.Time - deleted_at *time.Time - key *string - name *string - status *string - clearedFields map[string]struct{} - user *int64 - cleareduser bool - group *int64 - clearedgroup bool - usage_logs map[int64]struct{} - removedusage_logs map[int64]struct{} - clearedusage_logs bool - done bool - oldValue func(context.Context) (*APIKey, error) - predicates []predicate.APIKey + op Op + typ string + id *int64 + created_at *time.Time + updated_at *time.Time + deleted_at *time.Time + key *string + name *string + status *string + ip_whitelist *[]string + appendip_whitelist []string + ip_blacklist *[]string + appendip_blacklist []string + clearedFields map[string]struct{} + user *int64 + cleareduser bool + group *int64 + clearedgroup bool + usage_logs map[int64]struct{} + removedusage_logs map[int64]struct{} + clearedusage_logs bool + done bool + oldValue func(context.Context) (*APIKey, error) + predicates []predicate.APIKey } var _ ent.Mutation = (*APIKeyMutation)(nil) @@ -488,6 +492,136 @@ func (m *APIKeyMutation) ResetStatus() { m.status = nil } +// SetIPWhitelist sets the "ip_whitelist" field. +func (m *APIKeyMutation) SetIPWhitelist(s []string) { + m.ip_whitelist = &s + m.appendip_whitelist = nil +} + +// IPWhitelist returns the value of the "ip_whitelist" field in the mutation. +func (m *APIKeyMutation) IPWhitelist() (r []string, exists bool) { + v := m.ip_whitelist + if v == nil { + return + } + return *v, true +} + +// OldIPWhitelist returns the old "ip_whitelist" field's value of the APIKey entity. +// If the APIKey object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *APIKeyMutation) OldIPWhitelist(ctx context.Context) (v []string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldIPWhitelist is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldIPWhitelist requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldIPWhitelist: %w", err) + } + return oldValue.IPWhitelist, nil +} + +// AppendIPWhitelist adds s to the "ip_whitelist" field. +func (m *APIKeyMutation) AppendIPWhitelist(s []string) { + m.appendip_whitelist = append(m.appendip_whitelist, s...) +} + +// AppendedIPWhitelist returns the list of values that were appended to the "ip_whitelist" field in this mutation. +func (m *APIKeyMutation) AppendedIPWhitelist() ([]string, bool) { + if len(m.appendip_whitelist) == 0 { + return nil, false + } + return m.appendip_whitelist, true +} + +// ClearIPWhitelist clears the value of the "ip_whitelist" field. +func (m *APIKeyMutation) ClearIPWhitelist() { + m.ip_whitelist = nil + m.appendip_whitelist = nil + m.clearedFields[apikey.FieldIPWhitelist] = struct{}{} +} + +// IPWhitelistCleared returns if the "ip_whitelist" field was cleared in this mutation. +func (m *APIKeyMutation) IPWhitelistCleared() bool { + _, ok := m.clearedFields[apikey.FieldIPWhitelist] + return ok +} + +// ResetIPWhitelist resets all changes to the "ip_whitelist" field. +func (m *APIKeyMutation) ResetIPWhitelist() { + m.ip_whitelist = nil + m.appendip_whitelist = nil + delete(m.clearedFields, apikey.FieldIPWhitelist) +} + +// SetIPBlacklist sets the "ip_blacklist" field. +func (m *APIKeyMutation) SetIPBlacklist(s []string) { + m.ip_blacklist = &s + m.appendip_blacklist = nil +} + +// IPBlacklist returns the value of the "ip_blacklist" field in the mutation. +func (m *APIKeyMutation) IPBlacklist() (r []string, exists bool) { + v := m.ip_blacklist + if v == nil { + return + } + return *v, true +} + +// OldIPBlacklist returns the old "ip_blacklist" field's value of the APIKey entity. +// If the APIKey object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *APIKeyMutation) OldIPBlacklist(ctx context.Context) (v []string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldIPBlacklist is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldIPBlacklist requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldIPBlacklist: %w", err) + } + return oldValue.IPBlacklist, nil +} + +// AppendIPBlacklist adds s to the "ip_blacklist" field. +func (m *APIKeyMutation) AppendIPBlacklist(s []string) { + m.appendip_blacklist = append(m.appendip_blacklist, s...) +} + +// AppendedIPBlacklist returns the list of values that were appended to the "ip_blacklist" field in this mutation. +func (m *APIKeyMutation) AppendedIPBlacklist() ([]string, bool) { + if len(m.appendip_blacklist) == 0 { + return nil, false + } + return m.appendip_blacklist, true +} + +// ClearIPBlacklist clears the value of the "ip_blacklist" field. +func (m *APIKeyMutation) ClearIPBlacklist() { + m.ip_blacklist = nil + m.appendip_blacklist = nil + m.clearedFields[apikey.FieldIPBlacklist] = struct{}{} +} + +// IPBlacklistCleared returns if the "ip_blacklist" field was cleared in this mutation. +func (m *APIKeyMutation) IPBlacklistCleared() bool { + _, ok := m.clearedFields[apikey.FieldIPBlacklist] + return ok +} + +// ResetIPBlacklist resets all changes to the "ip_blacklist" field. +func (m *APIKeyMutation) ResetIPBlacklist() { + m.ip_blacklist = nil + m.appendip_blacklist = nil + delete(m.clearedFields, apikey.FieldIPBlacklist) +} + // ClearUser clears the "user" edge to the User entity. func (m *APIKeyMutation) ClearUser() { m.cleareduser = true @@ -630,7 +764,7 @@ func (m *APIKeyMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *APIKeyMutation) Fields() []string { - fields := make([]string, 0, 8) + fields := make([]string, 0, 10) if m.created_at != nil { fields = append(fields, apikey.FieldCreatedAt) } @@ -655,6 +789,12 @@ func (m *APIKeyMutation) Fields() []string { if m.status != nil { fields = append(fields, apikey.FieldStatus) } + if m.ip_whitelist != nil { + fields = append(fields, apikey.FieldIPWhitelist) + } + if m.ip_blacklist != nil { + fields = append(fields, apikey.FieldIPBlacklist) + } return fields } @@ -679,6 +819,10 @@ func (m *APIKeyMutation) Field(name string) (ent.Value, bool) { return m.GroupID() case apikey.FieldStatus: return m.Status() + case apikey.FieldIPWhitelist: + return m.IPWhitelist() + case apikey.FieldIPBlacklist: + return m.IPBlacklist() } return nil, false } @@ -704,6 +848,10 @@ func (m *APIKeyMutation) OldField(ctx context.Context, name string) (ent.Value, return m.OldGroupID(ctx) case apikey.FieldStatus: return m.OldStatus(ctx) + case apikey.FieldIPWhitelist: + return m.OldIPWhitelist(ctx) + case apikey.FieldIPBlacklist: + return m.OldIPBlacklist(ctx) } return nil, fmt.Errorf("unknown APIKey field %s", name) } @@ -769,6 +917,20 @@ func (m *APIKeyMutation) SetField(name string, value ent.Value) error { } m.SetStatus(v) return nil + case apikey.FieldIPWhitelist: + v, ok := value.([]string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetIPWhitelist(v) + return nil + case apikey.FieldIPBlacklist: + v, ok := value.([]string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetIPBlacklist(v) + return nil } return fmt.Errorf("unknown APIKey field %s", name) } @@ -808,6 +970,12 @@ func (m *APIKeyMutation) ClearedFields() []string { if m.FieldCleared(apikey.FieldGroupID) { fields = append(fields, apikey.FieldGroupID) } + if m.FieldCleared(apikey.FieldIPWhitelist) { + fields = append(fields, apikey.FieldIPWhitelist) + } + if m.FieldCleared(apikey.FieldIPBlacklist) { + fields = append(fields, apikey.FieldIPBlacklist) + } return fields } @@ -828,6 +996,12 @@ func (m *APIKeyMutation) ClearField(name string) error { case apikey.FieldGroupID: m.ClearGroupID() return nil + case apikey.FieldIPWhitelist: + m.ClearIPWhitelist() + return nil + case apikey.FieldIPBlacklist: + m.ClearIPBlacklist() + return nil } return fmt.Errorf("unknown APIKey nullable field %s", name) } @@ -860,6 +1034,12 @@ func (m *APIKeyMutation) ResetField(name string) error { case apikey.FieldStatus: m.ResetStatus() return nil + case apikey.FieldIPWhitelist: + m.ResetIPWhitelist() + return nil + case apikey.FieldIPBlacklist: + m.ResetIPBlacklist() + return nil } return fmt.Errorf("unknown APIKey field %s", name) } @@ -8396,6 +8576,7 @@ type UsageLogMutation struct { first_token_ms *int addfirst_token_ms *int user_agent *string + ip_address *string image_count *int addimage_count *int image_size *string @@ -9801,6 +9982,55 @@ func (m *UsageLogMutation) ResetUserAgent() { delete(m.clearedFields, usagelog.FieldUserAgent) } +// SetIPAddress sets the "ip_address" field. +func (m *UsageLogMutation) SetIPAddress(s string) { + m.ip_address = &s +} + +// IPAddress returns the value of the "ip_address" field in the mutation. +func (m *UsageLogMutation) IPAddress() (r string, exists bool) { + v := m.ip_address + if v == nil { + return + } + return *v, true +} + +// OldIPAddress returns the old "ip_address" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldIPAddress(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldIPAddress is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldIPAddress requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldIPAddress: %w", err) + } + return oldValue.IPAddress, nil +} + +// ClearIPAddress clears the value of the "ip_address" field. +func (m *UsageLogMutation) ClearIPAddress() { + m.ip_address = nil + m.clearedFields[usagelog.FieldIPAddress] = struct{}{} +} + +// IPAddressCleared returns if the "ip_address" field was cleared in this mutation. +func (m *UsageLogMutation) IPAddressCleared() bool { + _, ok := m.clearedFields[usagelog.FieldIPAddress] + return ok +} + +// ResetIPAddress resets all changes to the "ip_address" field. +func (m *UsageLogMutation) ResetIPAddress() { + m.ip_address = nil + delete(m.clearedFields, usagelog.FieldIPAddress) +} + // SetImageCount sets the "image_count" field. func (m *UsageLogMutation) SetImageCount(i int) { m.image_count = &i @@ -10111,7 +10341,7 @@ func (m *UsageLogMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *UsageLogMutation) Fields() []string { - fields := make([]string, 0, 28) + fields := make([]string, 0, 29) if m.user != nil { fields = append(fields, usagelog.FieldUserID) } @@ -10187,6 +10417,9 @@ func (m *UsageLogMutation) Fields() []string { if m.user_agent != nil { fields = append(fields, usagelog.FieldUserAgent) } + if m.ip_address != nil { + fields = append(fields, usagelog.FieldIPAddress) + } if m.image_count != nil { fields = append(fields, usagelog.FieldImageCount) } @@ -10254,6 +10487,8 @@ func (m *UsageLogMutation) Field(name string) (ent.Value, bool) { return m.FirstTokenMs() case usagelog.FieldUserAgent: return m.UserAgent() + case usagelog.FieldIPAddress: + return m.IPAddress() case usagelog.FieldImageCount: return m.ImageCount() case usagelog.FieldImageSize: @@ -10319,6 +10554,8 @@ func (m *UsageLogMutation) OldField(ctx context.Context, name string) (ent.Value return m.OldFirstTokenMs(ctx) case usagelog.FieldUserAgent: return m.OldUserAgent(ctx) + case usagelog.FieldIPAddress: + return m.OldIPAddress(ctx) case usagelog.FieldImageCount: return m.OldImageCount(ctx) case usagelog.FieldImageSize: @@ -10509,6 +10746,13 @@ func (m *UsageLogMutation) SetField(name string, value ent.Value) error { } m.SetUserAgent(v) return nil + case usagelog.FieldIPAddress: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetIPAddress(v) + return nil case usagelog.FieldImageCount: v, ok := value.(int) if !ok { @@ -10782,6 +11026,9 @@ func (m *UsageLogMutation) ClearedFields() []string { if m.FieldCleared(usagelog.FieldUserAgent) { fields = append(fields, usagelog.FieldUserAgent) } + if m.FieldCleared(usagelog.FieldIPAddress) { + fields = append(fields, usagelog.FieldIPAddress) + } if m.FieldCleared(usagelog.FieldImageSize) { fields = append(fields, usagelog.FieldImageSize) } @@ -10814,6 +11061,9 @@ func (m *UsageLogMutation) ClearField(name string) error { case usagelog.FieldUserAgent: m.ClearUserAgent() return nil + case usagelog.FieldIPAddress: + m.ClearIPAddress() + return nil case usagelog.FieldImageSize: m.ClearImageSize() return nil @@ -10900,6 +11150,9 @@ func (m *UsageLogMutation) ResetField(name string) error { case usagelog.FieldUserAgent: m.ResetUserAgent() return nil + case usagelog.FieldIPAddress: + m.ResetIPAddress() + return nil case usagelog.FieldImageCount: m.ResetImageCount() return nil diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go index fb1c948c..b82f2e6c 100644 --- a/backend/ent/runtime/runtime.go +++ b/backend/ent/runtime/runtime.go @@ -533,16 +533,20 @@ func init() { usagelogDescUserAgent := usagelogFields[24].Descriptor() // usagelog.UserAgentValidator is a validator for the "user_agent" field. It is called by the builders before save. usagelog.UserAgentValidator = usagelogDescUserAgent.Validators[0].(func(string) error) + // usagelogDescIPAddress is the schema descriptor for ip_address field. + usagelogDescIPAddress := usagelogFields[25].Descriptor() + // usagelog.IPAddressValidator is a validator for the "ip_address" field. It is called by the builders before save. + usagelog.IPAddressValidator = usagelogDescIPAddress.Validators[0].(func(string) error) // usagelogDescImageCount is the schema descriptor for image_count field. - usagelogDescImageCount := usagelogFields[25].Descriptor() + usagelogDescImageCount := usagelogFields[26].Descriptor() // usagelog.DefaultImageCount holds the default value on creation for the image_count field. usagelog.DefaultImageCount = usagelogDescImageCount.Default.(int) // usagelogDescImageSize is the schema descriptor for image_size field. - usagelogDescImageSize := usagelogFields[26].Descriptor() + usagelogDescImageSize := usagelogFields[27].Descriptor() // usagelog.ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save. usagelog.ImageSizeValidator = usagelogDescImageSize.Validators[0].(func(string) error) // usagelogDescCreatedAt is the schema descriptor for created_at field. - usagelogDescCreatedAt := usagelogFields[27].Descriptor() + usagelogDescCreatedAt := usagelogFields[28].Descriptor() // usagelog.DefaultCreatedAt holds the default value on creation for the created_at field. usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time) userMixin := schema.User{}.Mixin() diff --git a/backend/ent/schema/api_key.go b/backend/ent/schema/api_key.go index 94e572c5..1b206089 100644 --- a/backend/ent/schema/api_key.go +++ b/backend/ent/schema/api_key.go @@ -46,6 +46,12 @@ func (APIKey) Fields() []ent.Field { field.String("status"). MaxLen(20). Default(service.StatusActive), + field.JSON("ip_whitelist", []string{}). + Optional(). + Comment("Allowed IPs/CIDRs, e.g. [\"192.168.1.100\", \"10.0.0.0/8\"]"), + field.JSON("ip_blacklist", []string{}). + Optional(). + Comment("Blocked IPs/CIDRs"), } } diff --git a/backend/ent/schema/usage_log.go b/backend/ent/schema/usage_log.go index df955181..264a4087 100644 --- a/backend/ent/schema/usage_log.go +++ b/backend/ent/schema/usage_log.go @@ -100,6 +100,10 @@ func (UsageLog) Fields() []ent.Field { MaxLen(512). Optional(). Nillable(), + field.String("ip_address"). + MaxLen(45). // 支持 IPv6 + Optional(). + Nillable(), // 图片生成字段(仅 gemini-3-pro-image 等图片模型使用) field.Int("image_count"). diff --git a/backend/ent/usagelog.go b/backend/ent/usagelog.go index 798f3a9f..cd576466 100644 --- a/backend/ent/usagelog.go +++ b/backend/ent/usagelog.go @@ -72,6 +72,8 @@ type UsageLog struct { FirstTokenMs *int `json:"first_token_ms,omitempty"` // UserAgent holds the value of the "user_agent" field. UserAgent *string `json:"user_agent,omitempty"` + // IPAddress holds the value of the "ip_address" field. + IPAddress *string `json:"ip_address,omitempty"` // ImageCount holds the value of the "image_count" field. ImageCount int `json:"image_count,omitempty"` // ImageSize holds the value of the "image_size" field. @@ -167,7 +169,7 @@ func (*UsageLog) scanValues(columns []string) ([]any, error) { values[i] = new(sql.NullFloat64) case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount: values[i] = new(sql.NullInt64) - case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldUserAgent, usagelog.FieldImageSize: + case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize: values[i] = new(sql.NullString) case usagelog.FieldCreatedAt: values[i] = new(sql.NullTime) @@ -347,6 +349,13 @@ func (_m *UsageLog) assignValues(columns []string, values []any) error { _m.UserAgent = new(string) *_m.UserAgent = value.String } + case usagelog.FieldIPAddress: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field ip_address", values[i]) + } else if value.Valid { + _m.IPAddress = new(string) + *_m.IPAddress = value.String + } case usagelog.FieldImageCount: if value, ok := values[i].(*sql.NullInt64); !ok { return fmt.Errorf("unexpected type %T for field image_count", values[i]) @@ -512,6 +521,11 @@ func (_m *UsageLog) String() string { builder.WriteString(*v) } builder.WriteString(", ") + if v := _m.IPAddress; v != nil { + builder.WriteString("ip_address=") + builder.WriteString(*v) + } + builder.WriteString(", ") builder.WriteString("image_count=") builder.WriteString(fmt.Sprintf("%v", _m.ImageCount)) builder.WriteString(", ") diff --git a/backend/ent/usagelog/usagelog.go b/backend/ent/usagelog/usagelog.go index d3edfb4d..c06925c4 100644 --- a/backend/ent/usagelog/usagelog.go +++ b/backend/ent/usagelog/usagelog.go @@ -64,6 +64,8 @@ const ( FieldFirstTokenMs = "first_token_ms" // FieldUserAgent holds the string denoting the user_agent field in the database. FieldUserAgent = "user_agent" + // FieldIPAddress holds the string denoting the ip_address field in the database. + FieldIPAddress = "ip_address" // FieldImageCount holds the string denoting the image_count field in the database. FieldImageCount = "image_count" // FieldImageSize holds the string denoting the image_size field in the database. @@ -147,6 +149,7 @@ var Columns = []string{ FieldDurationMs, FieldFirstTokenMs, FieldUserAgent, + FieldIPAddress, FieldImageCount, FieldImageSize, FieldCreatedAt, @@ -199,6 +202,8 @@ var ( DefaultStream bool // UserAgentValidator is a validator for the "user_agent" field. It is called by the builders before save. UserAgentValidator func(string) error + // IPAddressValidator is a validator for the "ip_address" field. It is called by the builders before save. + IPAddressValidator func(string) error // DefaultImageCount holds the default value on creation for the "image_count" field. DefaultImageCount int // ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save. @@ -340,6 +345,11 @@ func ByUserAgent(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldUserAgent, opts...).ToFunc() } +// ByIPAddress orders the results by the ip_address field. +func ByIPAddress(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldIPAddress, opts...).ToFunc() +} + // ByImageCount orders the results by the image_count field. func ByImageCount(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldImageCount, opts...).ToFunc() diff --git a/backend/ent/usagelog/where.go b/backend/ent/usagelog/where.go index c7acd59d..96b7a19c 100644 --- a/backend/ent/usagelog/where.go +++ b/backend/ent/usagelog/where.go @@ -180,6 +180,11 @@ func UserAgent(v string) predicate.UsageLog { return predicate.UsageLog(sql.FieldEQ(FieldUserAgent, v)) } +// IPAddress applies equality check predicate on the "ip_address" field. It's identical to IPAddressEQ. +func IPAddress(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldIPAddress, v)) +} + // ImageCount applies equality check predicate on the "image_count" field. It's identical to ImageCountEQ. func ImageCount(v int) predicate.UsageLog { return predicate.UsageLog(sql.FieldEQ(FieldImageCount, v)) @@ -1190,6 +1195,81 @@ func UserAgentContainsFold(v string) predicate.UsageLog { return predicate.UsageLog(sql.FieldContainsFold(FieldUserAgent, v)) } +// IPAddressEQ applies the EQ predicate on the "ip_address" field. +func IPAddressEQ(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldIPAddress, v)) +} + +// IPAddressNEQ applies the NEQ predicate on the "ip_address" field. +func IPAddressNEQ(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldIPAddress, v)) +} + +// IPAddressIn applies the In predicate on the "ip_address" field. +func IPAddressIn(vs ...string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldIn(FieldIPAddress, vs...)) +} + +// IPAddressNotIn applies the NotIn predicate on the "ip_address" field. +func IPAddressNotIn(vs ...string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotIn(FieldIPAddress, vs...)) +} + +// IPAddressGT applies the GT predicate on the "ip_address" field. +func IPAddressGT(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGT(FieldIPAddress, v)) +} + +// IPAddressGTE applies the GTE predicate on the "ip_address" field. +func IPAddressGTE(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGTE(FieldIPAddress, v)) +} + +// IPAddressLT applies the LT predicate on the "ip_address" field. +func IPAddressLT(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLT(FieldIPAddress, v)) +} + +// IPAddressLTE applies the LTE predicate on the "ip_address" field. +func IPAddressLTE(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLTE(FieldIPAddress, v)) +} + +// IPAddressContains applies the Contains predicate on the "ip_address" field. +func IPAddressContains(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldContains(FieldIPAddress, v)) +} + +// IPAddressHasPrefix applies the HasPrefix predicate on the "ip_address" field. +func IPAddressHasPrefix(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldHasPrefix(FieldIPAddress, v)) +} + +// IPAddressHasSuffix applies the HasSuffix predicate on the "ip_address" field. +func IPAddressHasSuffix(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldHasSuffix(FieldIPAddress, v)) +} + +// IPAddressIsNil applies the IsNil predicate on the "ip_address" field. +func IPAddressIsNil() predicate.UsageLog { + return predicate.UsageLog(sql.FieldIsNull(FieldIPAddress)) +} + +// IPAddressNotNil applies the NotNil predicate on the "ip_address" field. +func IPAddressNotNil() predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotNull(FieldIPAddress)) +} + +// IPAddressEqualFold applies the EqualFold predicate on the "ip_address" field. +func IPAddressEqualFold(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEqualFold(FieldIPAddress, v)) +} + +// IPAddressContainsFold applies the ContainsFold predicate on the "ip_address" field. +func IPAddressContainsFold(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldContainsFold(FieldIPAddress, v)) +} + // ImageCountEQ applies the EQ predicate on the "image_count" field. func ImageCountEQ(v int) predicate.UsageLog { return predicate.UsageLog(sql.FieldEQ(FieldImageCount, v)) diff --git a/backend/ent/usagelog_create.go b/backend/ent/usagelog_create.go index f77650ab..e63fab05 100644 --- a/backend/ent/usagelog_create.go +++ b/backend/ent/usagelog_create.go @@ -337,6 +337,20 @@ func (_c *UsageLogCreate) SetNillableUserAgent(v *string) *UsageLogCreate { return _c } +// SetIPAddress sets the "ip_address" field. +func (_c *UsageLogCreate) SetIPAddress(v string) *UsageLogCreate { + _c.mutation.SetIPAddress(v) + return _c +} + +// SetNillableIPAddress sets the "ip_address" field if the given value is not nil. +func (_c *UsageLogCreate) SetNillableIPAddress(v *string) *UsageLogCreate { + if v != nil { + _c.SetIPAddress(*v) + } + return _c +} + // SetImageCount sets the "image_count" field. func (_c *UsageLogCreate) SetImageCount(v int) *UsageLogCreate { _c.mutation.SetImageCount(v) @@ -586,6 +600,11 @@ func (_c *UsageLogCreate) check() error { return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)} } } + if v, ok := _c.mutation.IPAddress(); ok { + if err := usagelog.IPAddressValidator(v); err != nil { + return &ValidationError{Name: "ip_address", err: fmt.Errorf(`ent: validator failed for field "UsageLog.ip_address": %w`, err)} + } + } if _, ok := _c.mutation.ImageCount(); !ok { return &ValidationError{Name: "image_count", err: errors.New(`ent: missing required field "UsageLog.image_count"`)} } @@ -713,6 +732,10 @@ func (_c *UsageLogCreate) createSpec() (*UsageLog, *sqlgraph.CreateSpec) { _spec.SetField(usagelog.FieldUserAgent, field.TypeString, value) _node.UserAgent = &value } + if value, ok := _c.mutation.IPAddress(); ok { + _spec.SetField(usagelog.FieldIPAddress, field.TypeString, value) + _node.IPAddress = &value + } if value, ok := _c.mutation.ImageCount(); ok { _spec.SetField(usagelog.FieldImageCount, field.TypeInt, value) _node.ImageCount = value @@ -1288,6 +1311,24 @@ func (u *UsageLogUpsert) ClearUserAgent() *UsageLogUpsert { return u } +// SetIPAddress sets the "ip_address" field. +func (u *UsageLogUpsert) SetIPAddress(v string) *UsageLogUpsert { + u.Set(usagelog.FieldIPAddress, v) + return u +} + +// UpdateIPAddress sets the "ip_address" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateIPAddress() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldIPAddress) + return u +} + +// ClearIPAddress clears the value of the "ip_address" field. +func (u *UsageLogUpsert) ClearIPAddress() *UsageLogUpsert { + u.SetNull(usagelog.FieldIPAddress) + return u +} + // SetImageCount sets the "image_count" field. func (u *UsageLogUpsert) SetImageCount(v int) *UsageLogUpsert { u.Set(usagelog.FieldImageCount, v) @@ -1866,6 +1907,27 @@ func (u *UsageLogUpsertOne) ClearUserAgent() *UsageLogUpsertOne { }) } +// SetIPAddress sets the "ip_address" field. +func (u *UsageLogUpsertOne) SetIPAddress(v string) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetIPAddress(v) + }) +} + +// UpdateIPAddress sets the "ip_address" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateIPAddress() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateIPAddress() + }) +} + +// ClearIPAddress clears the value of the "ip_address" field. +func (u *UsageLogUpsertOne) ClearIPAddress() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.ClearIPAddress() + }) +} + // SetImageCount sets the "image_count" field. func (u *UsageLogUpsertOne) SetImageCount(v int) *UsageLogUpsertOne { return u.Update(func(s *UsageLogUpsert) { @@ -2616,6 +2678,27 @@ func (u *UsageLogUpsertBulk) ClearUserAgent() *UsageLogUpsertBulk { }) } +// SetIPAddress sets the "ip_address" field. +func (u *UsageLogUpsertBulk) SetIPAddress(v string) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetIPAddress(v) + }) +} + +// UpdateIPAddress sets the "ip_address" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateIPAddress() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateIPAddress() + }) +} + +// ClearIPAddress clears the value of the "ip_address" field. +func (u *UsageLogUpsertBulk) ClearIPAddress() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.ClearIPAddress() + }) +} + // SetImageCount sets the "image_count" field. func (u *UsageLogUpsertBulk) SetImageCount(v int) *UsageLogUpsertBulk { return u.Update(func(s *UsageLogUpsert) { diff --git a/backend/ent/usagelog_update.go b/backend/ent/usagelog_update.go index 2e77eef7..ec2acbbb 100644 --- a/backend/ent/usagelog_update.go +++ b/backend/ent/usagelog_update.go @@ -524,6 +524,26 @@ func (_u *UsageLogUpdate) ClearUserAgent() *UsageLogUpdate { return _u } +// SetIPAddress sets the "ip_address" field. +func (_u *UsageLogUpdate) SetIPAddress(v string) *UsageLogUpdate { + _u.mutation.SetIPAddress(v) + return _u +} + +// SetNillableIPAddress sets the "ip_address" field if the given value is not nil. +func (_u *UsageLogUpdate) SetNillableIPAddress(v *string) *UsageLogUpdate { + if v != nil { + _u.SetIPAddress(*v) + } + return _u +} + +// ClearIPAddress clears the value of the "ip_address" field. +func (_u *UsageLogUpdate) ClearIPAddress() *UsageLogUpdate { + _u.mutation.ClearIPAddress() + return _u +} + // SetImageCount sets the "image_count" field. func (_u *UsageLogUpdate) SetImageCount(v int) *UsageLogUpdate { _u.mutation.ResetImageCount() @@ -669,6 +689,11 @@ func (_u *UsageLogUpdate) check() error { return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)} } } + if v, ok := _u.mutation.IPAddress(); ok { + if err := usagelog.IPAddressValidator(v); err != nil { + return &ValidationError{Name: "ip_address", err: fmt.Errorf(`ent: validator failed for field "UsageLog.ip_address": %w`, err)} + } + } if v, ok := _u.mutation.ImageSize(); ok { if err := usagelog.ImageSizeValidator(v); err != nil { return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)} @@ -815,6 +840,12 @@ func (_u *UsageLogUpdate) sqlSave(ctx context.Context) (_node int, err error) { if _u.mutation.UserAgentCleared() { _spec.ClearField(usagelog.FieldUserAgent, field.TypeString) } + if value, ok := _u.mutation.IPAddress(); ok { + _spec.SetField(usagelog.FieldIPAddress, field.TypeString, value) + } + if _u.mutation.IPAddressCleared() { + _spec.ClearField(usagelog.FieldIPAddress, field.TypeString) + } if value, ok := _u.mutation.ImageCount(); ok { _spec.SetField(usagelog.FieldImageCount, field.TypeInt, value) } @@ -1484,6 +1515,26 @@ func (_u *UsageLogUpdateOne) ClearUserAgent() *UsageLogUpdateOne { return _u } +// SetIPAddress sets the "ip_address" field. +func (_u *UsageLogUpdateOne) SetIPAddress(v string) *UsageLogUpdateOne { + _u.mutation.SetIPAddress(v) + return _u +} + +// SetNillableIPAddress sets the "ip_address" field if the given value is not nil. +func (_u *UsageLogUpdateOne) SetNillableIPAddress(v *string) *UsageLogUpdateOne { + if v != nil { + _u.SetIPAddress(*v) + } + return _u +} + +// ClearIPAddress clears the value of the "ip_address" field. +func (_u *UsageLogUpdateOne) ClearIPAddress() *UsageLogUpdateOne { + _u.mutation.ClearIPAddress() + return _u +} + // SetImageCount sets the "image_count" field. func (_u *UsageLogUpdateOne) SetImageCount(v int) *UsageLogUpdateOne { _u.mutation.ResetImageCount() @@ -1642,6 +1693,11 @@ func (_u *UsageLogUpdateOne) check() error { return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)} } } + if v, ok := _u.mutation.IPAddress(); ok { + if err := usagelog.IPAddressValidator(v); err != nil { + return &ValidationError{Name: "ip_address", err: fmt.Errorf(`ent: validator failed for field "UsageLog.ip_address": %w`, err)} + } + } if v, ok := _u.mutation.ImageSize(); ok { if err := usagelog.ImageSizeValidator(v); err != nil { return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)} @@ -1805,6 +1861,12 @@ func (_u *UsageLogUpdateOne) sqlSave(ctx context.Context) (_node *UsageLog, err if _u.mutation.UserAgentCleared() { _spec.ClearField(usagelog.FieldUserAgent, field.TypeString) } + if value, ok := _u.mutation.IPAddress(); ok { + _spec.SetField(usagelog.FieldIPAddress, field.TypeString, value) + } + if _u.mutation.IPAddressCleared() { + _spec.ClearField(usagelog.FieldIPAddress, field.TypeString) + } if value, ok := _u.mutation.ImageCount(); ok { _spec.SetField(usagelog.FieldImageCount, field.TypeInt, value) } diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index c1e15290..2cc11967 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -6,6 +6,7 @@ import ( "encoding/hex" "fmt" "log" + "net/url" "os" "strings" "time" @@ -35,24 +36,25 @@ const ( ) type Config struct { - Server ServerConfig `mapstructure:"server"` - CORS CORSConfig `mapstructure:"cors"` - Security SecurityConfig `mapstructure:"security"` - Billing BillingConfig `mapstructure:"billing"` - Turnstile TurnstileConfig `mapstructure:"turnstile"` - Database DatabaseConfig `mapstructure:"database"` - Redis RedisConfig `mapstructure:"redis"` - JWT JWTConfig `mapstructure:"jwt"` - Default DefaultConfig `mapstructure:"default"` - RateLimit RateLimitConfig `mapstructure:"rate_limit"` - Pricing PricingConfig `mapstructure:"pricing"` - Gateway GatewayConfig `mapstructure:"gateway"` - Concurrency ConcurrencyConfig `mapstructure:"concurrency"` - TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"` - RunMode string `mapstructure:"run_mode" yaml:"run_mode"` - Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC" - Gemini GeminiConfig `mapstructure:"gemini"` - Update UpdateConfig `mapstructure:"update"` + Server ServerConfig `mapstructure:"server"` + CORS CORSConfig `mapstructure:"cors"` + Security SecurityConfig `mapstructure:"security"` + Billing BillingConfig `mapstructure:"billing"` + Turnstile TurnstileConfig `mapstructure:"turnstile"` + Database DatabaseConfig `mapstructure:"database"` + Redis RedisConfig `mapstructure:"redis"` + JWT JWTConfig `mapstructure:"jwt"` + LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"` + Default DefaultConfig `mapstructure:"default"` + RateLimit RateLimitConfig `mapstructure:"rate_limit"` + Pricing PricingConfig `mapstructure:"pricing"` + Gateway GatewayConfig `mapstructure:"gateway"` + Concurrency ConcurrencyConfig `mapstructure:"concurrency"` + TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"` + RunMode string `mapstructure:"run_mode" yaml:"run_mode"` + Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC" + Gemini GeminiConfig `mapstructure:"gemini"` + Update UpdateConfig `mapstructure:"update"` } // UpdateConfig 在线更新相关配置 @@ -322,6 +324,30 @@ type TurnstileConfig struct { Required bool `mapstructure:"required"` } +// LinuxDoConnectConfig 用于 LinuxDo Connect OAuth 登录(终端用户 SSO)。 +// +// 注意:这与上游账号的 OAuth(例如 OpenAI/Gemini 账号接入)不是一回事。 +// 这里是用于登录 Sub2API 本身的用户体系。 +type LinuxDoConnectConfig struct { + Enabled bool `mapstructure:"enabled"` + ClientID string `mapstructure:"client_id"` + ClientSecret string `mapstructure:"client_secret"` + AuthorizeURL string `mapstructure:"authorize_url"` + TokenURL string `mapstructure:"token_url"` + UserInfoURL string `mapstructure:"userinfo_url"` + Scopes string `mapstructure:"scopes"` + RedirectURL string `mapstructure:"redirect_url"` // 后端回调地址(需在提供方后台登记) + FrontendRedirectURL string `mapstructure:"frontend_redirect_url"` // 前端接收 token 的路由(默认:/auth/linuxdo/callback) + TokenAuthMethod string `mapstructure:"token_auth_method"` // client_secret_post / client_secret_basic / none + UsePKCE bool `mapstructure:"use_pkce"` + + // 可选:用于从 userinfo JSON 中提取字段的 gjson 路径。 + // 为空时,服务端会尝试一组常见字段名。 + UserInfoEmailPath string `mapstructure:"userinfo_email_path"` + UserInfoIDPath string `mapstructure:"userinfo_id_path"` + UserInfoUsernamePath string `mapstructure:"userinfo_username_path"` +} + type DefaultConfig struct { AdminEmail string `mapstructure:"admin_email"` AdminPassword string `mapstructure:"admin_password"` @@ -388,6 +414,18 @@ func Load() (*Config, error) { cfg.Server.Mode = "debug" } cfg.JWT.Secret = strings.TrimSpace(cfg.JWT.Secret) + cfg.LinuxDo.ClientID = strings.TrimSpace(cfg.LinuxDo.ClientID) + cfg.LinuxDo.ClientSecret = strings.TrimSpace(cfg.LinuxDo.ClientSecret) + cfg.LinuxDo.AuthorizeURL = strings.TrimSpace(cfg.LinuxDo.AuthorizeURL) + cfg.LinuxDo.TokenURL = strings.TrimSpace(cfg.LinuxDo.TokenURL) + cfg.LinuxDo.UserInfoURL = strings.TrimSpace(cfg.LinuxDo.UserInfoURL) + cfg.LinuxDo.Scopes = strings.TrimSpace(cfg.LinuxDo.Scopes) + cfg.LinuxDo.RedirectURL = strings.TrimSpace(cfg.LinuxDo.RedirectURL) + cfg.LinuxDo.FrontendRedirectURL = strings.TrimSpace(cfg.LinuxDo.FrontendRedirectURL) + cfg.LinuxDo.TokenAuthMethod = strings.ToLower(strings.TrimSpace(cfg.LinuxDo.TokenAuthMethod)) + cfg.LinuxDo.UserInfoEmailPath = strings.TrimSpace(cfg.LinuxDo.UserInfoEmailPath) + cfg.LinuxDo.UserInfoIDPath = strings.TrimSpace(cfg.LinuxDo.UserInfoIDPath) + cfg.LinuxDo.UserInfoUsernamePath = strings.TrimSpace(cfg.LinuxDo.UserInfoUsernamePath) cfg.CORS.AllowedOrigins = normalizeStringSlice(cfg.CORS.AllowedOrigins) cfg.Security.ResponseHeaders.AdditionalAllowed = normalizeStringSlice(cfg.Security.ResponseHeaders.AdditionalAllowed) cfg.Security.ResponseHeaders.ForceRemove = normalizeStringSlice(cfg.Security.ResponseHeaders.ForceRemove) @@ -426,6 +464,81 @@ func Load() (*Config, error) { return &cfg, nil } +// ValidateAbsoluteHTTPURL 校验一个绝对 http(s) URL(禁止 fragment)。 +func ValidateAbsoluteHTTPURL(raw string) error { + raw = strings.TrimSpace(raw) + if raw == "" { + return fmt.Errorf("empty url") + } + u, err := url.Parse(raw) + if err != nil { + return err + } + if !u.IsAbs() { + return fmt.Errorf("must be absolute") + } + if !isHTTPScheme(u.Scheme) { + return fmt.Errorf("unsupported scheme: %s", u.Scheme) + } + if strings.TrimSpace(u.Host) == "" { + return fmt.Errorf("missing host") + } + if u.Fragment != "" { + return fmt.Errorf("must not include fragment") + } + return nil +} + +// ValidateFrontendRedirectURL 校验前端回调地址: +// - 允许同源相对路径(以 / 开头) +// - 或绝对 http(s) URL(禁止 fragment) +func ValidateFrontendRedirectURL(raw string) error { + raw = strings.TrimSpace(raw) + if raw == "" { + return fmt.Errorf("empty url") + } + if strings.ContainsAny(raw, "\r\n") { + return fmt.Errorf("contains invalid characters") + } + if strings.HasPrefix(raw, "/") { + if strings.HasPrefix(raw, "//") { + return fmt.Errorf("must not start with //") + } + return nil + } + u, err := url.Parse(raw) + if err != nil { + return err + } + if !u.IsAbs() { + return fmt.Errorf("must be absolute http(s) url or relative path") + } + if !isHTTPScheme(u.Scheme) { + return fmt.Errorf("unsupported scheme: %s", u.Scheme) + } + if strings.TrimSpace(u.Host) == "" { + return fmt.Errorf("missing host") + } + if u.Fragment != "" { + return fmt.Errorf("must not include fragment") + } + return nil +} + +func isHTTPScheme(scheme string) bool { + return strings.EqualFold(scheme, "http") || strings.EqualFold(scheme, "https") +} + +func warnIfInsecureURL(field, raw string) { + u, err := url.Parse(strings.TrimSpace(raw)) + if err != nil { + return + } + if strings.EqualFold(u.Scheme, "http") { + log.Printf("Warning: %s uses http scheme; use https in production to avoid token leakage.", field) + } +} + func setDefaults() { viper.SetDefault("run_mode", RunModeStandard) @@ -475,6 +588,22 @@ func setDefaults() { // Turnstile viper.SetDefault("turnstile.required", false) + // LinuxDo Connect OAuth 登录(终端用户 SSO) + viper.SetDefault("linuxdo_connect.enabled", false) + viper.SetDefault("linuxdo_connect.client_id", "") + viper.SetDefault("linuxdo_connect.client_secret", "") + viper.SetDefault("linuxdo_connect.authorize_url", "https://connect.linux.do/oauth2/authorize") + viper.SetDefault("linuxdo_connect.token_url", "https://connect.linux.do/oauth2/token") + viper.SetDefault("linuxdo_connect.userinfo_url", "https://connect.linux.do/api/user") + viper.SetDefault("linuxdo_connect.scopes", "user") + viper.SetDefault("linuxdo_connect.redirect_url", "") + viper.SetDefault("linuxdo_connect.frontend_redirect_url", "/auth/linuxdo/callback") + viper.SetDefault("linuxdo_connect.token_auth_method", "client_secret_post") + viper.SetDefault("linuxdo_connect.use_pkce", false) + viper.SetDefault("linuxdo_connect.userinfo_email_path", "") + viper.SetDefault("linuxdo_connect.userinfo_id_path", "") + viper.SetDefault("linuxdo_connect.userinfo_username_path", "") + // Database viper.SetDefault("database.host", "localhost") viper.SetDefault("database.port", 5432) @@ -544,7 +673,7 @@ func setDefaults() { viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 30) // 并发槽位过期时间(支持超长请求) viper.SetDefault("gateway.stream_data_interval_timeout", 180) viper.SetDefault("gateway.stream_keepalive_interval", 10) - viper.SetDefault("gateway.max_line_size", 10*1024*1024) + viper.SetDefault("gateway.max_line_size", 40*1024*1024) viper.SetDefault("gateway.scheduling.sticky_session_max_waiting", 3) viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 45*time.Second) viper.SetDefault("gateway.scheduling.fallback_wait_timeout", 30*time.Second) @@ -586,6 +715,60 @@ func (c *Config) Validate() error { if c.Security.CSP.Enabled && strings.TrimSpace(c.Security.CSP.Policy) == "" { return fmt.Errorf("security.csp.policy is required when CSP is enabled") } + if c.LinuxDo.Enabled { + if strings.TrimSpace(c.LinuxDo.ClientID) == "" { + return fmt.Errorf("linuxdo_connect.client_id is required when linuxdo_connect.enabled=true") + } + if strings.TrimSpace(c.LinuxDo.AuthorizeURL) == "" { + return fmt.Errorf("linuxdo_connect.authorize_url is required when linuxdo_connect.enabled=true") + } + if strings.TrimSpace(c.LinuxDo.TokenURL) == "" { + return fmt.Errorf("linuxdo_connect.token_url is required when linuxdo_connect.enabled=true") + } + if strings.TrimSpace(c.LinuxDo.UserInfoURL) == "" { + return fmt.Errorf("linuxdo_connect.userinfo_url is required when linuxdo_connect.enabled=true") + } + if strings.TrimSpace(c.LinuxDo.RedirectURL) == "" { + return fmt.Errorf("linuxdo_connect.redirect_url is required when linuxdo_connect.enabled=true") + } + method := strings.ToLower(strings.TrimSpace(c.LinuxDo.TokenAuthMethod)) + switch method { + case "", "client_secret_post", "client_secret_basic", "none": + default: + return fmt.Errorf("linuxdo_connect.token_auth_method must be one of: client_secret_post/client_secret_basic/none") + } + if method == "none" && !c.LinuxDo.UsePKCE { + return fmt.Errorf("linuxdo_connect.use_pkce must be true when linuxdo_connect.token_auth_method=none") + } + if (method == "" || method == "client_secret_post" || method == "client_secret_basic") && strings.TrimSpace(c.LinuxDo.ClientSecret) == "" { + return fmt.Errorf("linuxdo_connect.client_secret is required when linuxdo_connect.enabled=true and token_auth_method is client_secret_post/client_secret_basic") + } + if strings.TrimSpace(c.LinuxDo.FrontendRedirectURL) == "" { + return fmt.Errorf("linuxdo_connect.frontend_redirect_url is required when linuxdo_connect.enabled=true") + } + + if err := ValidateAbsoluteHTTPURL(c.LinuxDo.AuthorizeURL); err != nil { + return fmt.Errorf("linuxdo_connect.authorize_url invalid: %w", err) + } + if err := ValidateAbsoluteHTTPURL(c.LinuxDo.TokenURL); err != nil { + return fmt.Errorf("linuxdo_connect.token_url invalid: %w", err) + } + if err := ValidateAbsoluteHTTPURL(c.LinuxDo.UserInfoURL); err != nil { + return fmt.Errorf("linuxdo_connect.userinfo_url invalid: %w", err) + } + if err := ValidateAbsoluteHTTPURL(c.LinuxDo.RedirectURL); err != nil { + return fmt.Errorf("linuxdo_connect.redirect_url invalid: %w", err) + } + if err := ValidateFrontendRedirectURL(c.LinuxDo.FrontendRedirectURL); err != nil { + return fmt.Errorf("linuxdo_connect.frontend_redirect_url invalid: %w", err) + } + + warnIfInsecureURL("linuxdo_connect.authorize_url", c.LinuxDo.AuthorizeURL) + warnIfInsecureURL("linuxdo_connect.token_url", c.LinuxDo.TokenURL) + warnIfInsecureURL("linuxdo_connect.userinfo_url", c.LinuxDo.UserInfoURL) + warnIfInsecureURL("linuxdo_connect.redirect_url", c.LinuxDo.RedirectURL) + warnIfInsecureURL("linuxdo_connect.frontend_redirect_url", c.LinuxDo.FrontendRedirectURL) + } if c.Billing.CircuitBreaker.Enabled { if c.Billing.CircuitBreaker.FailureThreshold <= 0 { return fmt.Errorf("billing.circuit_breaker.failure_threshold must be positive") diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go index f28680c6..a39d41f9 100644 --- a/backend/internal/config/config_test.go +++ b/backend/internal/config/config_test.go @@ -1,6 +1,7 @@ package config import ( + "strings" "testing" "time" @@ -90,3 +91,53 @@ func TestLoadDefaultSecurityToggles(t *testing.T) { t.Fatalf("ResponseHeaders.Enabled = true, want false") } } + +func TestValidateLinuxDoFrontendRedirectURL(t *testing.T) { + viper.Reset() + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + cfg.LinuxDo.Enabled = true + cfg.LinuxDo.ClientID = "test-client" + cfg.LinuxDo.ClientSecret = "test-secret" + cfg.LinuxDo.RedirectURL = "https://example.com/api/v1/auth/oauth/linuxdo/callback" + cfg.LinuxDo.TokenAuthMethod = "client_secret_post" + cfg.LinuxDo.UsePKCE = false + + cfg.LinuxDo.FrontendRedirectURL = "javascript:alert(1)" + err = cfg.Validate() + if err == nil { + t.Fatalf("Validate() expected error for javascript scheme, got nil") + } + if !strings.Contains(err.Error(), "linuxdo_connect.frontend_redirect_url") { + t.Fatalf("Validate() expected frontend_redirect_url error, got: %v", err) + } +} + +func TestValidateLinuxDoPKCERequiredForPublicClient(t *testing.T) { + viper.Reset() + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + cfg.LinuxDo.Enabled = true + cfg.LinuxDo.ClientID = "test-client" + cfg.LinuxDo.ClientSecret = "" + cfg.LinuxDo.RedirectURL = "https://example.com/api/v1/auth/oauth/linuxdo/callback" + cfg.LinuxDo.FrontendRedirectURL = "/auth/linuxdo/callback" + cfg.LinuxDo.TokenAuthMethod = "none" + cfg.LinuxDo.UsePKCE = false + + err = cfg.Validate() + if err == nil { + t.Fatalf("Validate() expected error when token_auth_method=none and use_pkce=false, got nil") + } + if !strings.Contains(err.Error(), "linuxdo_connect.use_pkce") { + t.Fatalf("Validate() expected use_pkce error, got: %v", err) + } +} diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go index da9f6990..8a7270e5 100644 --- a/backend/internal/handler/admin/account_handler.go +++ b/backend/internal/handler/admin/account_handler.go @@ -116,6 +116,7 @@ type BulkUpdateAccountsRequest struct { Concurrency *int `json:"concurrency"` Priority *int `json:"priority"` Status string `json:"status" binding:"omitempty,oneof=active inactive error"` + Schedulable *bool `json:"schedulable"` GroupIDs *[]int64 `json:"group_ids"` Credentials map[string]any `json:"credentials"` Extra map[string]any `json:"extra"` @@ -136,6 +137,11 @@ func (h *AccountHandler) List(c *gin.Context) { accountType := c.Query("type") status := c.Query("status") search := c.Query("search") + // 标准化和验证 search 参数 + search = strings.TrimSpace(search) + if len(search) > 100 { + search = search[:100] + } accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search) if err != nil { @@ -655,6 +661,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) { req.Concurrency != nil || req.Priority != nil || req.Status != "" || + req.Schedulable != nil || req.GroupIDs != nil || len(req.Credentials) > 0 || len(req.Extra) > 0 @@ -671,6 +678,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) { Concurrency: req.Concurrency, Priority: req.Priority, Status: req.Status, + Schedulable: req.Schedulable, GroupIDs: req.GroupIDs, Credentials: req.Credentials, Extra: req.Extra, diff --git a/backend/internal/handler/admin/group_handler.go b/backend/internal/handler/admin/group_handler.go index acb9462c..a8bae35e 100644 --- a/backend/internal/handler/admin/group_handler.go +++ b/backend/internal/handler/admin/group_handler.go @@ -2,6 +2,7 @@ package admin import ( "strconv" + "strings" "github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/pkg/response" @@ -67,6 +68,12 @@ func (h *GroupHandler) List(c *gin.Context) { page, pageSize := response.ParsePagination(c) platform := c.Query("platform") status := c.Query("status") + search := c.Query("search") + // 标准化和验证 search 参数 + search = strings.TrimSpace(search) + if len(search) > 100 { + search = search[:100] + } isExclusiveStr := c.Query("is_exclusive") var isExclusive *bool @@ -75,7 +82,7 @@ func (h *GroupHandler) List(c *gin.Context) { isExclusive = &val } - groups, total, err := h.adminService.ListGroups(c.Request.Context(), page, pageSize, platform, status, isExclusive) + groups, total, err := h.adminService.ListGroups(c.Request.Context(), page, pageSize, platform, status, search, isExclusive) if err != nil { response.ErrorFrom(c, err) return diff --git a/backend/internal/handler/admin/proxy_handler.go b/backend/internal/handler/admin/proxy_handler.go index 4fabd8ec..437e9300 100644 --- a/backend/internal/handler/admin/proxy_handler.go +++ b/backend/internal/handler/admin/proxy_handler.go @@ -51,6 +51,11 @@ func (h *ProxyHandler) List(c *gin.Context) { protocol := c.Query("protocol") status := c.Query("status") search := c.Query("search") + // 标准化和验证 search 参数 + search = strings.TrimSpace(search) + if len(search) > 100 { + search = search[:100] + } proxies, total, err := h.adminService.ListProxiesWithAccountCount(c.Request.Context(), page, pageSize, protocol, status, search) if err != nil { diff --git a/backend/internal/handler/admin/redeem_handler.go b/backend/internal/handler/admin/redeem_handler.go index 45fae43a..5b3229b6 100644 --- a/backend/internal/handler/admin/redeem_handler.go +++ b/backend/internal/handler/admin/redeem_handler.go @@ -5,6 +5,7 @@ import ( "encoding/csv" "fmt" "strconv" + "strings" "github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/pkg/response" @@ -41,6 +42,11 @@ func (h *RedeemHandler) List(c *gin.Context) { codeType := c.Query("type") status := c.Query("status") search := c.Query("search") + // 标准化和验证 search 参数 + search = strings.TrimSpace(search) + if len(search) > 100 { + search = search[:100] + } codes, total, err := h.adminService.ListRedeemCodes(c.Request.Context(), page, pageSize, codeType, status, search) if err != nil { diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index 743c4268..d95a8980 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -2,8 +2,10 @@ package admin import ( "log" + "strings" "time" + "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/server/middleware" @@ -38,33 +40,37 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { } response.Success(c, dto.SystemSettings{ - RegistrationEnabled: settings.RegistrationEnabled, - EmailVerifyEnabled: settings.EmailVerifyEnabled, - SMTPHost: settings.SMTPHost, - SMTPPort: settings.SMTPPort, - SMTPUsername: settings.SMTPUsername, - SMTPPasswordConfigured: settings.SMTPPasswordConfigured, - SMTPFrom: settings.SMTPFrom, - SMTPFromName: settings.SMTPFromName, - SMTPUseTLS: settings.SMTPUseTLS, - TurnstileEnabled: settings.TurnstileEnabled, - TurnstileSiteKey: settings.TurnstileSiteKey, - TurnstileSecretKeyConfigured: settings.TurnstileSecretKeyConfigured, - SiteName: settings.SiteName, - SiteLogo: settings.SiteLogo, - SiteSubtitle: settings.SiteSubtitle, - APIBaseURL: settings.APIBaseURL, - ContactInfo: settings.ContactInfo, - DocURL: settings.DocURL, - DefaultConcurrency: settings.DefaultConcurrency, - DefaultBalance: settings.DefaultBalance, - EnableModelFallback: settings.EnableModelFallback, - FallbackModelAnthropic: settings.FallbackModelAnthropic, - FallbackModelOpenAI: settings.FallbackModelOpenAI, - FallbackModelGemini: settings.FallbackModelGemini, - FallbackModelAntigravity: settings.FallbackModelAntigravity, - EnableIdentityPatch: settings.EnableIdentityPatch, - IdentityPatchPrompt: settings.IdentityPatchPrompt, + RegistrationEnabled: settings.RegistrationEnabled, + EmailVerifyEnabled: settings.EmailVerifyEnabled, + SMTPHost: settings.SMTPHost, + SMTPPort: settings.SMTPPort, + SMTPUsername: settings.SMTPUsername, + SMTPPasswordConfigured: settings.SMTPPasswordConfigured, + SMTPFrom: settings.SMTPFrom, + SMTPFromName: settings.SMTPFromName, + SMTPUseTLS: settings.SMTPUseTLS, + TurnstileEnabled: settings.TurnstileEnabled, + TurnstileSiteKey: settings.TurnstileSiteKey, + TurnstileSecretKeyConfigured: settings.TurnstileSecretKeyConfigured, + LinuxDoConnectEnabled: settings.LinuxDoConnectEnabled, + LinuxDoConnectClientID: settings.LinuxDoConnectClientID, + LinuxDoConnectClientSecretConfigured: settings.LinuxDoConnectClientSecretConfigured, + LinuxDoConnectRedirectURL: settings.LinuxDoConnectRedirectURL, + SiteName: settings.SiteName, + SiteLogo: settings.SiteLogo, + SiteSubtitle: settings.SiteSubtitle, + APIBaseURL: settings.APIBaseURL, + ContactInfo: settings.ContactInfo, + DocURL: settings.DocURL, + DefaultConcurrency: settings.DefaultConcurrency, + DefaultBalance: settings.DefaultBalance, + EnableModelFallback: settings.EnableModelFallback, + FallbackModelAnthropic: settings.FallbackModelAnthropic, + FallbackModelOpenAI: settings.FallbackModelOpenAI, + FallbackModelGemini: settings.FallbackModelGemini, + FallbackModelAntigravity: settings.FallbackModelAntigravity, + EnableIdentityPatch: settings.EnableIdentityPatch, + IdentityPatchPrompt: settings.IdentityPatchPrompt, }) } @@ -88,6 +94,12 @@ type UpdateSettingsRequest struct { TurnstileSiteKey string `json:"turnstile_site_key"` TurnstileSecretKey string `json:"turnstile_secret_key"` + // LinuxDo Connect OAuth 登录(终端用户 SSO) + LinuxDoConnectEnabled bool `json:"linuxdo_connect_enabled"` + LinuxDoConnectClientID string `json:"linuxdo_connect_client_id"` + LinuxDoConnectClientSecret string `json:"linuxdo_connect_client_secret"` + LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"` + // OEM设置 SiteName string `json:"site_name"` SiteLogo string `json:"site_logo"` @@ -165,34 +177,67 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { } } + // LinuxDo Connect 参数验证 + if req.LinuxDoConnectEnabled { + req.LinuxDoConnectClientID = strings.TrimSpace(req.LinuxDoConnectClientID) + req.LinuxDoConnectClientSecret = strings.TrimSpace(req.LinuxDoConnectClientSecret) + req.LinuxDoConnectRedirectURL = strings.TrimSpace(req.LinuxDoConnectRedirectURL) + + if req.LinuxDoConnectClientID == "" { + response.BadRequest(c, "LinuxDo Client ID is required when enabled") + return + } + if req.LinuxDoConnectRedirectURL == "" { + response.BadRequest(c, "LinuxDo Redirect URL is required when enabled") + return + } + if err := config.ValidateAbsoluteHTTPURL(req.LinuxDoConnectRedirectURL); err != nil { + response.BadRequest(c, "LinuxDo Redirect URL must be an absolute http(s) URL") + return + } + + // 如果未提供 client_secret,则保留现有值(如有)。 + if req.LinuxDoConnectClientSecret == "" { + if previousSettings.LinuxDoConnectClientSecret == "" { + response.BadRequest(c, "LinuxDo Client Secret is required when enabled") + return + } + req.LinuxDoConnectClientSecret = previousSettings.LinuxDoConnectClientSecret + } + } + settings := &service.SystemSettings{ - RegistrationEnabled: req.RegistrationEnabled, - EmailVerifyEnabled: req.EmailVerifyEnabled, - SMTPHost: req.SMTPHost, - SMTPPort: req.SMTPPort, - SMTPUsername: req.SMTPUsername, - SMTPPassword: req.SMTPPassword, - SMTPFrom: req.SMTPFrom, - SMTPFromName: req.SMTPFromName, - SMTPUseTLS: req.SMTPUseTLS, - TurnstileEnabled: req.TurnstileEnabled, - TurnstileSiteKey: req.TurnstileSiteKey, - TurnstileSecretKey: req.TurnstileSecretKey, - SiteName: req.SiteName, - SiteLogo: req.SiteLogo, - SiteSubtitle: req.SiteSubtitle, - APIBaseURL: req.APIBaseURL, - ContactInfo: req.ContactInfo, - DocURL: req.DocURL, - DefaultConcurrency: req.DefaultConcurrency, - DefaultBalance: req.DefaultBalance, - EnableModelFallback: req.EnableModelFallback, - FallbackModelAnthropic: req.FallbackModelAnthropic, - FallbackModelOpenAI: req.FallbackModelOpenAI, - FallbackModelGemini: req.FallbackModelGemini, - FallbackModelAntigravity: req.FallbackModelAntigravity, - EnableIdentityPatch: req.EnableIdentityPatch, - IdentityPatchPrompt: req.IdentityPatchPrompt, + RegistrationEnabled: req.RegistrationEnabled, + EmailVerifyEnabled: req.EmailVerifyEnabled, + SMTPHost: req.SMTPHost, + SMTPPort: req.SMTPPort, + SMTPUsername: req.SMTPUsername, + SMTPPassword: req.SMTPPassword, + SMTPFrom: req.SMTPFrom, + SMTPFromName: req.SMTPFromName, + SMTPUseTLS: req.SMTPUseTLS, + TurnstileEnabled: req.TurnstileEnabled, + TurnstileSiteKey: req.TurnstileSiteKey, + TurnstileSecretKey: req.TurnstileSecretKey, + LinuxDoConnectEnabled: req.LinuxDoConnectEnabled, + LinuxDoConnectClientID: req.LinuxDoConnectClientID, + LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret, + LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL, + SiteName: req.SiteName, + SiteLogo: req.SiteLogo, + SiteSubtitle: req.SiteSubtitle, + APIBaseURL: req.APIBaseURL, + ContactInfo: req.ContactInfo, + DocURL: req.DocURL, + DefaultConcurrency: req.DefaultConcurrency, + DefaultBalance: req.DefaultBalance, + EnableModelFallback: req.EnableModelFallback, + FallbackModelAnthropic: req.FallbackModelAnthropic, + FallbackModelOpenAI: req.FallbackModelOpenAI, + FallbackModelGemini: req.FallbackModelGemini, + FallbackModelAntigravity: req.FallbackModelAntigravity, + EnableIdentityPatch: req.EnableIdentityPatch, + IdentityPatchPrompt: req.IdentityPatchPrompt, } if err := h.settingService.UpdateSettings(c.Request.Context(), settings); err != nil { @@ -210,33 +255,37 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { } response.Success(c, dto.SystemSettings{ - RegistrationEnabled: updatedSettings.RegistrationEnabled, - EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled, - SMTPHost: updatedSettings.SMTPHost, - SMTPPort: updatedSettings.SMTPPort, - SMTPUsername: updatedSettings.SMTPUsername, - SMTPPasswordConfigured: updatedSettings.SMTPPasswordConfigured, - SMTPFrom: updatedSettings.SMTPFrom, - SMTPFromName: updatedSettings.SMTPFromName, - SMTPUseTLS: updatedSettings.SMTPUseTLS, - TurnstileEnabled: updatedSettings.TurnstileEnabled, - TurnstileSiteKey: updatedSettings.TurnstileSiteKey, - TurnstileSecretKeyConfigured: updatedSettings.TurnstileSecretKeyConfigured, - SiteName: updatedSettings.SiteName, - SiteLogo: updatedSettings.SiteLogo, - SiteSubtitle: updatedSettings.SiteSubtitle, - APIBaseURL: updatedSettings.APIBaseURL, - ContactInfo: updatedSettings.ContactInfo, - DocURL: updatedSettings.DocURL, - DefaultConcurrency: updatedSettings.DefaultConcurrency, - DefaultBalance: updatedSettings.DefaultBalance, - EnableModelFallback: updatedSettings.EnableModelFallback, - FallbackModelAnthropic: updatedSettings.FallbackModelAnthropic, - FallbackModelOpenAI: updatedSettings.FallbackModelOpenAI, - FallbackModelGemini: updatedSettings.FallbackModelGemini, - FallbackModelAntigravity: updatedSettings.FallbackModelAntigravity, - EnableIdentityPatch: updatedSettings.EnableIdentityPatch, - IdentityPatchPrompt: updatedSettings.IdentityPatchPrompt, + RegistrationEnabled: updatedSettings.RegistrationEnabled, + EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled, + SMTPHost: updatedSettings.SMTPHost, + SMTPPort: updatedSettings.SMTPPort, + SMTPUsername: updatedSettings.SMTPUsername, + SMTPPasswordConfigured: updatedSettings.SMTPPasswordConfigured, + SMTPFrom: updatedSettings.SMTPFrom, + SMTPFromName: updatedSettings.SMTPFromName, + SMTPUseTLS: updatedSettings.SMTPUseTLS, + TurnstileEnabled: updatedSettings.TurnstileEnabled, + TurnstileSiteKey: updatedSettings.TurnstileSiteKey, + TurnstileSecretKeyConfigured: updatedSettings.TurnstileSecretKeyConfigured, + LinuxDoConnectEnabled: updatedSettings.LinuxDoConnectEnabled, + LinuxDoConnectClientID: updatedSettings.LinuxDoConnectClientID, + LinuxDoConnectClientSecretConfigured: updatedSettings.LinuxDoConnectClientSecretConfigured, + LinuxDoConnectRedirectURL: updatedSettings.LinuxDoConnectRedirectURL, + SiteName: updatedSettings.SiteName, + SiteLogo: updatedSettings.SiteLogo, + SiteSubtitle: updatedSettings.SiteSubtitle, + APIBaseURL: updatedSettings.APIBaseURL, + ContactInfo: updatedSettings.ContactInfo, + DocURL: updatedSettings.DocURL, + DefaultConcurrency: updatedSettings.DefaultConcurrency, + DefaultBalance: updatedSettings.DefaultBalance, + EnableModelFallback: updatedSettings.EnableModelFallback, + FallbackModelAnthropic: updatedSettings.FallbackModelAnthropic, + FallbackModelOpenAI: updatedSettings.FallbackModelOpenAI, + FallbackModelGemini: updatedSettings.FallbackModelGemini, + FallbackModelAntigravity: updatedSettings.FallbackModelAntigravity, + EnableIdentityPatch: updatedSettings.EnableIdentityPatch, + IdentityPatchPrompt: updatedSettings.IdentityPatchPrompt, }) } @@ -298,6 +347,18 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, if req.TurnstileSecretKey != "" { changed = append(changed, "turnstile_secret_key") } + if before.LinuxDoConnectEnabled != after.LinuxDoConnectEnabled { + changed = append(changed, "linuxdo_connect_enabled") + } + if before.LinuxDoConnectClientID != after.LinuxDoConnectClientID { + changed = append(changed, "linuxdo_connect_client_id") + } + if req.LinuxDoConnectClientSecret != "" { + changed = append(changed, "linuxdo_connect_client_secret") + } + if before.LinuxDoConnectRedirectURL != after.LinuxDoConnectRedirectURL { + changed = append(changed, "linuxdo_connect_redirect_url") + } if before.SiteName != after.SiteName { changed = append(changed, "site_name") } @@ -337,6 +398,12 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, if before.FallbackModelAntigravity != after.FallbackModelAntigravity { changed = append(changed, "fallback_model_antigravity") } + if before.EnableIdentityPatch != after.EnableIdentityPatch { + changed = append(changed, "enable_identity_patch") + } + if before.IdentityPatchPrompt != after.IdentityPatchPrompt { + changed = append(changed, "identity_patch_prompt") + } return changed } diff --git a/backend/internal/handler/admin/user_handler.go b/backend/internal/handler/admin/user_handler.go index f8cd1d5a..38cc8acd 100644 --- a/backend/internal/handler/admin/user_handler.go +++ b/backend/internal/handler/admin/user_handler.go @@ -2,6 +2,7 @@ package admin import ( "strconv" + "strings" "github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/pkg/response" @@ -63,10 +64,17 @@ type UpdateBalanceRequest struct { func (h *UserHandler) List(c *gin.Context) { page, pageSize := response.ParsePagination(c) + search := c.Query("search") + // 标准化和验证 search 参数 + search = strings.TrimSpace(search) + if len(search) > 100 { + search = search[:100] + } + filters := service.UserListFilters{ Status: c.Query("status"), Role: c.Query("role"), - Search: c.Query("search"), + Search: search, Attributes: parseAttributeFilters(c), } diff --git a/backend/internal/handler/api_key_handler.go b/backend/internal/handler/api_key_handler.go index 09772f22..52dc6911 100644 --- a/backend/internal/handler/api_key_handler.go +++ b/backend/internal/handler/api_key_handler.go @@ -27,16 +27,20 @@ func NewAPIKeyHandler(apiKeyService *service.APIKeyService) *APIKeyHandler { // CreateAPIKeyRequest represents the create API key request payload type CreateAPIKeyRequest struct { - Name string `json:"name" binding:"required"` - GroupID *int64 `json:"group_id"` // nullable - CustomKey *string `json:"custom_key"` // 可选的自定义key + Name string `json:"name" binding:"required"` + GroupID *int64 `json:"group_id"` // nullable + CustomKey *string `json:"custom_key"` // 可选的自定义key + IPWhitelist []string `json:"ip_whitelist"` // IP 白名单 + IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单 } // UpdateAPIKeyRequest represents the update API key request payload type UpdateAPIKeyRequest struct { - Name string `json:"name"` - GroupID *int64 `json:"group_id"` - Status string `json:"status" binding:"omitempty,oneof=active inactive"` + Name string `json:"name"` + GroupID *int64 `json:"group_id"` + Status string `json:"status" binding:"omitempty,oneof=active inactive"` + IPWhitelist []string `json:"ip_whitelist"` // IP 白名单 + IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单 } // List handles listing user's API keys with pagination @@ -110,9 +114,11 @@ func (h *APIKeyHandler) Create(c *gin.Context) { } svcReq := service.CreateAPIKeyRequest{ - Name: req.Name, - GroupID: req.GroupID, - CustomKey: req.CustomKey, + Name: req.Name, + GroupID: req.GroupID, + CustomKey: req.CustomKey, + IPWhitelist: req.IPWhitelist, + IPBlacklist: req.IPBlacklist, } key, err := h.apiKeyService.Create(c.Request.Context(), subject.UserID, svcReq) if err != nil { @@ -144,7 +150,10 @@ func (h *APIKeyHandler) Update(c *gin.Context) { return } - svcReq := service.UpdateAPIKeyRequest{} + svcReq := service.UpdateAPIKeyRequest{ + IPWhitelist: req.IPWhitelist, + IPBlacklist: req.IPBlacklist, + } if req.Name != "" { svcReq.Name = &req.Name } diff --git a/backend/internal/handler/auth_handler.go b/backend/internal/handler/auth_handler.go index 8466f131..8463367e 100644 --- a/backend/internal/handler/auth_handler.go +++ b/backend/internal/handler/auth_handler.go @@ -15,14 +15,16 @@ type AuthHandler struct { cfg *config.Config authService *service.AuthService userService *service.UserService + settingSvc *service.SettingService } // NewAuthHandler creates a new AuthHandler -func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService) *AuthHandler { +func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService, settingService *service.SettingService) *AuthHandler { return &AuthHandler{ cfg: cfg, authService: authService, userService: userService, + settingSvc: settingService, } } diff --git a/backend/internal/handler/auth_linuxdo_oauth.go b/backend/internal/handler/auth_linuxdo_oauth.go new file mode 100644 index 00000000..a16c4cc7 --- /dev/null +++ b/backend/internal/handler/auth_linuxdo_oauth.go @@ -0,0 +1,679 @@ +package handler + +import ( + "context" + "encoding/base64" + "errors" + "fmt" + "log" + "net/http" + "net/url" + "strconv" + "strings" + "time" + "unicode/utf8" + + "github.com/Wei-Shaw/sub2api/internal/config" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/oauth" + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/service" + + "github.com/gin-gonic/gin" + "github.com/imroc/req/v3" + "github.com/tidwall/gjson" +) + +const ( + linuxDoOAuthCookiePath = "/api/v1/auth/oauth/linuxdo" + linuxDoOAuthStateCookieName = "linuxdo_oauth_state" + linuxDoOAuthVerifierCookie = "linuxdo_oauth_verifier" + linuxDoOAuthRedirectCookie = "linuxdo_oauth_redirect" + linuxDoOAuthCookieMaxAgeSec = 10 * 60 // 10 minutes + linuxDoOAuthDefaultRedirectTo = "/dashboard" + linuxDoOAuthDefaultFrontendCB = "/auth/linuxdo/callback" + + linuxDoOAuthMaxRedirectLen = 2048 + linuxDoOAuthMaxFragmentValueLen = 512 + linuxDoOAuthMaxSubjectLen = 64 - len("linuxdo-") +) + +type linuxDoTokenResponse struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int64 `json:"expires_in"` + RefreshToken string `json:"refresh_token,omitempty"` + Scope string `json:"scope,omitempty"` +} + +type linuxDoTokenExchangeError struct { + StatusCode int + ProviderError string + ProviderDescription string + Body string +} + +func (e *linuxDoTokenExchangeError) Error() string { + if e == nil { + return "" + } + parts := []string{fmt.Sprintf("token exchange status=%d", e.StatusCode)} + if strings.TrimSpace(e.ProviderError) != "" { + parts = append(parts, "error="+strings.TrimSpace(e.ProviderError)) + } + if strings.TrimSpace(e.ProviderDescription) != "" { + parts = append(parts, "error_description="+strings.TrimSpace(e.ProviderDescription)) + } + return strings.Join(parts, " ") +} + +// LinuxDoOAuthStart 启动 LinuxDo Connect OAuth 登录流程。 +// GET /api/v1/auth/oauth/linuxdo/start?redirect=/dashboard +func (h *AuthHandler) LinuxDoOAuthStart(c *gin.Context) { + cfg, err := h.getLinuxDoOAuthConfig(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + + state, err := oauth.GenerateState() + if err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_STATE_GEN_FAILED", "failed to generate oauth state").WithCause(err)) + return + } + + redirectTo := sanitizeFrontendRedirectPath(c.Query("redirect")) + if redirectTo == "" { + redirectTo = linuxDoOAuthDefaultRedirectTo + } + + secureCookie := isRequestHTTPS(c) + setCookie(c, linuxDoOAuthStateCookieName, encodeCookieValue(state), linuxDoOAuthCookieMaxAgeSec, secureCookie) + setCookie(c, linuxDoOAuthRedirectCookie, encodeCookieValue(redirectTo), linuxDoOAuthCookieMaxAgeSec, secureCookie) + + codeChallenge := "" + if cfg.UsePKCE { + verifier, err := oauth.GenerateCodeVerifier() + if err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_PKCE_GEN_FAILED", "failed to generate pkce verifier").WithCause(err)) + return + } + codeChallenge = oauth.GenerateCodeChallenge(verifier) + setCookie(c, linuxDoOAuthVerifierCookie, encodeCookieValue(verifier), linuxDoOAuthCookieMaxAgeSec, secureCookie) + } + + redirectURI := strings.TrimSpace(cfg.RedirectURL) + if redirectURI == "" { + response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth redirect url not configured")) + return + } + + authURL, err := buildLinuxDoAuthorizeURL(cfg, state, codeChallenge, redirectURI) + if err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BUILD_URL_FAILED", "failed to build oauth authorization url").WithCause(err)) + return + } + + c.Redirect(http.StatusFound, authURL) +} + +// LinuxDoOAuthCallback 处理 OAuth 回调:创建/登录用户,然后重定向到前端。 +// GET /api/v1/auth/oauth/linuxdo/callback?code=...&state=... +func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) { + cfg, cfgErr := h.getLinuxDoOAuthConfig(c.Request.Context()) + if cfgErr != nil { + response.ErrorFrom(c, cfgErr) + return + } + + frontendCallback := strings.TrimSpace(cfg.FrontendRedirectURL) + if frontendCallback == "" { + frontendCallback = linuxDoOAuthDefaultFrontendCB + } + + if providerErr := strings.TrimSpace(c.Query("error")); providerErr != "" { + redirectOAuthError(c, frontendCallback, "provider_error", providerErr, c.Query("error_description")) + return + } + + code := strings.TrimSpace(c.Query("code")) + state := strings.TrimSpace(c.Query("state")) + if code == "" || state == "" { + redirectOAuthError(c, frontendCallback, "missing_params", "missing code/state", "") + return + } + + secureCookie := isRequestHTTPS(c) + defer func() { + clearCookie(c, linuxDoOAuthStateCookieName, secureCookie) + clearCookie(c, linuxDoOAuthVerifierCookie, secureCookie) + clearCookie(c, linuxDoOAuthRedirectCookie, secureCookie) + }() + + expectedState, err := readCookieDecoded(c, linuxDoOAuthStateCookieName) + if err != nil || expectedState == "" || state != expectedState { + redirectOAuthError(c, frontendCallback, "invalid_state", "invalid oauth state", "") + return + } + + redirectTo, _ := readCookieDecoded(c, linuxDoOAuthRedirectCookie) + redirectTo = sanitizeFrontendRedirectPath(redirectTo) + if redirectTo == "" { + redirectTo = linuxDoOAuthDefaultRedirectTo + } + + codeVerifier := "" + if cfg.UsePKCE { + codeVerifier, _ = readCookieDecoded(c, linuxDoOAuthVerifierCookie) + if codeVerifier == "" { + redirectOAuthError(c, frontendCallback, "missing_verifier", "missing pkce verifier", "") + return + } + } + + redirectURI := strings.TrimSpace(cfg.RedirectURL) + if redirectURI == "" { + redirectOAuthError(c, frontendCallback, "config_error", "oauth redirect url not configured", "") + return + } + + tokenResp, err := linuxDoExchangeCode(c.Request.Context(), cfg, code, redirectURI, codeVerifier) + if err != nil { + description := "" + var exchangeErr *linuxDoTokenExchangeError + if errors.As(err, &exchangeErr) && exchangeErr != nil { + log.Printf( + "[LinuxDo OAuth] token exchange failed: status=%d provider_error=%q provider_description=%q body=%s", + exchangeErr.StatusCode, + exchangeErr.ProviderError, + exchangeErr.ProviderDescription, + truncateLogValue(exchangeErr.Body, 2048), + ) + description = exchangeErr.Error() + } else { + log.Printf("[LinuxDo OAuth] token exchange failed: %v", err) + description = err.Error() + } + redirectOAuthError(c, frontendCallback, "token_exchange_failed", "failed to exchange oauth code", singleLine(description)) + return + } + + email, username, subject, err := linuxDoFetchUserInfo(c.Request.Context(), cfg, tokenResp) + if err != nil { + log.Printf("[LinuxDo OAuth] userinfo fetch failed: %v", err) + redirectOAuthError(c, frontendCallback, "userinfo_failed", "failed to fetch user info", "") + return + } + + // 安全考虑:不要把第三方返回的 email 直接映射到本地账号(可能与本地邮箱用户冲突导致账号被接管)。 + // 统一使用基于 subject 的稳定合成邮箱来做账号绑定。 + if subject != "" { + email = linuxDoSyntheticEmail(subject) + } + + jwtToken, _, err := h.authService.LoginOrRegisterOAuth(c.Request.Context(), email, username) + if err != nil { + // 避免把内部细节泄露给客户端;给前端保留结构化原因与提示信息即可。 + redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err)) + return + } + + fragment := url.Values{} + fragment.Set("access_token", jwtToken) + fragment.Set("token_type", "Bearer") + fragment.Set("redirect", redirectTo) + redirectWithFragment(c, frontendCallback, fragment) +} + +func (h *AuthHandler) getLinuxDoOAuthConfig(ctx context.Context) (config.LinuxDoConnectConfig, error) { + if h != nil && h.settingSvc != nil { + return h.settingSvc.GetLinuxDoConnectOAuthConfig(ctx) + } + if h == nil || h.cfg == nil { + return config.LinuxDoConnectConfig{}, infraerrors.ServiceUnavailable("CONFIG_NOT_READY", "config not loaded") + } + if !h.cfg.LinuxDo.Enabled { + return config.LinuxDoConnectConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "oauth login is disabled") + } + return h.cfg.LinuxDo, nil +} + +func linuxDoExchangeCode( + ctx context.Context, + cfg config.LinuxDoConnectConfig, + code string, + redirectURI string, + codeVerifier string, +) (*linuxDoTokenResponse, error) { + client := req.C().SetTimeout(30 * time.Second) + + form := url.Values{} + form.Set("grant_type", "authorization_code") + form.Set("client_id", cfg.ClientID) + form.Set("code", code) + form.Set("redirect_uri", redirectURI) + if cfg.UsePKCE { + form.Set("code_verifier", codeVerifier) + } + + r := client.R(). + SetContext(ctx). + SetHeader("Accept", "application/json") + + switch strings.ToLower(strings.TrimSpace(cfg.TokenAuthMethod)) { + case "", "client_secret_post": + form.Set("client_secret", cfg.ClientSecret) + case "client_secret_basic": + r.SetBasicAuth(cfg.ClientID, cfg.ClientSecret) + case "none": + default: + return nil, fmt.Errorf("unsupported token_auth_method: %s", cfg.TokenAuthMethod) + } + + resp, err := r.SetFormDataFromValues(form).Post(cfg.TokenURL) + if err != nil { + return nil, fmt.Errorf("request token: %w", err) + } + body := strings.TrimSpace(resp.String()) + if !resp.IsSuccessState() { + providerErr, providerDesc := parseOAuthProviderError(body) + return nil, &linuxDoTokenExchangeError{ + StatusCode: resp.StatusCode, + ProviderError: providerErr, + ProviderDescription: providerDesc, + Body: body, + } + } + + tokenResp, ok := parseLinuxDoTokenResponse(body) + if !ok || strings.TrimSpace(tokenResp.AccessToken) == "" { + return nil, &linuxDoTokenExchangeError{ + StatusCode: resp.StatusCode, + Body: body, + } + } + if strings.TrimSpace(tokenResp.TokenType) == "" { + tokenResp.TokenType = "Bearer" + } + return tokenResp, nil +} + +func linuxDoFetchUserInfo( + ctx context.Context, + cfg config.LinuxDoConnectConfig, + token *linuxDoTokenResponse, +) (email string, username string, subject string, err error) { + client := req.C().SetTimeout(30 * time.Second) + authorization, err := buildBearerAuthorization(token.TokenType, token.AccessToken) + if err != nil { + return "", "", "", fmt.Errorf("invalid token for userinfo request: %w", err) + } + + resp, err := client.R(). + SetContext(ctx). + SetHeader("Accept", "application/json"). + SetHeader("Authorization", authorization). + Get(cfg.UserInfoURL) + if err != nil { + return "", "", "", fmt.Errorf("request userinfo: %w", err) + } + if !resp.IsSuccessState() { + return "", "", "", fmt.Errorf("userinfo status=%d", resp.StatusCode) + } + + return linuxDoParseUserInfo(resp.String(), cfg) +} + +func linuxDoParseUserInfo(body string, cfg config.LinuxDoConnectConfig) (email string, username string, subject string, err error) { + email = firstNonEmpty( + getGJSON(body, cfg.UserInfoEmailPath), + getGJSON(body, "email"), + getGJSON(body, "user.email"), + getGJSON(body, "data.email"), + getGJSON(body, "attributes.email"), + ) + username = firstNonEmpty( + getGJSON(body, cfg.UserInfoUsernamePath), + getGJSON(body, "username"), + getGJSON(body, "preferred_username"), + getGJSON(body, "name"), + getGJSON(body, "user.username"), + getGJSON(body, "user.name"), + ) + subject = firstNonEmpty( + getGJSON(body, cfg.UserInfoIDPath), + getGJSON(body, "sub"), + getGJSON(body, "id"), + getGJSON(body, "user_id"), + getGJSON(body, "uid"), + getGJSON(body, "user.id"), + ) + + subject = strings.TrimSpace(subject) + if subject == "" { + return "", "", "", errors.New("userinfo missing id field") + } + if !isSafeLinuxDoSubject(subject) { + return "", "", "", errors.New("userinfo returned invalid id field") + } + + email = strings.TrimSpace(email) + if email == "" { + // LinuxDo Connect 的 userinfo 可能不提供 email。为兼容现有用户模型(email 必填且唯一),使用稳定的合成邮箱。 + email = linuxDoSyntheticEmail(subject) + } + + username = strings.TrimSpace(username) + if username == "" { + username = "linuxdo_" + subject + } + + return email, username, subject, nil +} + +func buildLinuxDoAuthorizeURL(cfg config.LinuxDoConnectConfig, state string, codeChallenge string, redirectURI string) (string, error) { + u, err := url.Parse(cfg.AuthorizeURL) + if err != nil { + return "", fmt.Errorf("parse authorize_url: %w", err) + } + + q := u.Query() + q.Set("response_type", "code") + q.Set("client_id", cfg.ClientID) + q.Set("redirect_uri", redirectURI) + if strings.TrimSpace(cfg.Scopes) != "" { + q.Set("scope", cfg.Scopes) + } + q.Set("state", state) + if cfg.UsePKCE { + q.Set("code_challenge", codeChallenge) + q.Set("code_challenge_method", "S256") + } + + u.RawQuery = q.Encode() + return u.String(), nil +} + +func redirectOAuthError(c *gin.Context, frontendCallback string, code string, message string, description string) { + fragment := url.Values{} + fragment.Set("error", truncateFragmentValue(code)) + if strings.TrimSpace(message) != "" { + fragment.Set("error_message", truncateFragmentValue(message)) + } + if strings.TrimSpace(description) != "" { + fragment.Set("error_description", truncateFragmentValue(description)) + } + redirectWithFragment(c, frontendCallback, fragment) +} + +func redirectWithFragment(c *gin.Context, frontendCallback string, fragment url.Values) { + u, err := url.Parse(frontendCallback) + if err != nil { + // 兜底:尽力跳转到默认页面,避免卡死在回调页。 + c.Redirect(http.StatusFound, linuxDoOAuthDefaultRedirectTo) + return + } + if u.Scheme != "" && !strings.EqualFold(u.Scheme, "http") && !strings.EqualFold(u.Scheme, "https") { + c.Redirect(http.StatusFound, linuxDoOAuthDefaultRedirectTo) + return + } + u.Fragment = fragment.Encode() + c.Header("Cache-Control", "no-store") + c.Header("Pragma", "no-cache") + c.Redirect(http.StatusFound, u.String()) +} + +func firstNonEmpty(values ...string) string { + for _, v := range values { + v = strings.TrimSpace(v) + if v != "" { + return v + } + } + return "" +} + +func parseOAuthProviderError(body string) (providerErr string, providerDesc string) { + body = strings.TrimSpace(body) + if body == "" { + return "", "" + } + + providerErr = firstNonEmpty( + getGJSON(body, "error"), + getGJSON(body, "code"), + getGJSON(body, "error.code"), + ) + providerDesc = firstNonEmpty( + getGJSON(body, "error_description"), + getGJSON(body, "error.message"), + getGJSON(body, "message"), + getGJSON(body, "detail"), + ) + + if providerErr != "" || providerDesc != "" { + return providerErr, providerDesc + } + + values, err := url.ParseQuery(body) + if err != nil { + return "", "" + } + providerErr = firstNonEmpty(values.Get("error"), values.Get("code")) + providerDesc = firstNonEmpty(values.Get("error_description"), values.Get("error_message"), values.Get("message")) + return providerErr, providerDesc +} + +func parseLinuxDoTokenResponse(body string) (*linuxDoTokenResponse, bool) { + body = strings.TrimSpace(body) + if body == "" { + return nil, false + } + + accessToken := strings.TrimSpace(getGJSON(body, "access_token")) + if accessToken != "" { + tokenType := strings.TrimSpace(getGJSON(body, "token_type")) + refreshToken := strings.TrimSpace(getGJSON(body, "refresh_token")) + scope := strings.TrimSpace(getGJSON(body, "scope")) + expiresIn := gjson.Get(body, "expires_in").Int() + return &linuxDoTokenResponse{ + AccessToken: accessToken, + TokenType: tokenType, + ExpiresIn: expiresIn, + RefreshToken: refreshToken, + Scope: scope, + }, true + } + + values, err := url.ParseQuery(body) + if err != nil { + return nil, false + } + accessToken = strings.TrimSpace(values.Get("access_token")) + if accessToken == "" { + return nil, false + } + expiresIn := int64(0) + if raw := strings.TrimSpace(values.Get("expires_in")); raw != "" { + if v, err := strconv.ParseInt(raw, 10, 64); err == nil { + expiresIn = v + } + } + return &linuxDoTokenResponse{ + AccessToken: accessToken, + TokenType: strings.TrimSpace(values.Get("token_type")), + ExpiresIn: expiresIn, + RefreshToken: strings.TrimSpace(values.Get("refresh_token")), + Scope: strings.TrimSpace(values.Get("scope")), + }, true +} + +func getGJSON(body string, path string) string { + path = strings.TrimSpace(path) + if path == "" { + return "" + } + res := gjson.Get(body, path) + if !res.Exists() { + return "" + } + return res.String() +} + +func truncateLogValue(value string, maxLen int) string { + value = strings.TrimSpace(value) + if value == "" || maxLen <= 0 { + return "" + } + if len(value) <= maxLen { + return value + } + value = value[:maxLen] + for !utf8.ValidString(value) { + value = value[:len(value)-1] + } + return value +} + +func singleLine(value string) string { + value = strings.TrimSpace(value) + if value == "" { + return "" + } + return strings.Join(strings.Fields(value), " ") +} + +func sanitizeFrontendRedirectPath(path string) string { + path = strings.TrimSpace(path) + if path == "" { + return "" + } + if len(path) > linuxDoOAuthMaxRedirectLen { + return "" + } + // 只允许同源相对路径(避免开放重定向)。 + if !strings.HasPrefix(path, "/") { + return "" + } + if strings.HasPrefix(path, "//") { + return "" + } + if strings.Contains(path, "://") { + return "" + } + if strings.ContainsAny(path, "\r\n") { + return "" + } + return path +} + +func isRequestHTTPS(c *gin.Context) bool { + if c.Request.TLS != nil { + return true + } + proto := strings.ToLower(strings.TrimSpace(c.GetHeader("X-Forwarded-Proto"))) + return proto == "https" +} + +func encodeCookieValue(value string) string { + return base64.RawURLEncoding.EncodeToString([]byte(value)) +} + +func decodeCookieValue(value string) (string, error) { + raw, err := base64.RawURLEncoding.DecodeString(value) + if err != nil { + return "", err + } + return string(raw), nil +} + +func readCookieDecoded(c *gin.Context, name string) (string, error) { + ck, err := c.Request.Cookie(name) + if err != nil { + return "", err + } + return decodeCookieValue(ck.Value) +} + +func setCookie(c *gin.Context, name string, value string, maxAgeSec int, secure bool) { + http.SetCookie(c.Writer, &http.Cookie{ + Name: name, + Value: value, + Path: linuxDoOAuthCookiePath, + MaxAge: maxAgeSec, + HttpOnly: true, + Secure: secure, + SameSite: http.SameSiteLaxMode, + }) +} + +func clearCookie(c *gin.Context, name string, secure bool) { + http.SetCookie(c.Writer, &http.Cookie{ + Name: name, + Value: "", + Path: linuxDoOAuthCookiePath, + MaxAge: -1, + HttpOnly: true, + Secure: secure, + SameSite: http.SameSiteLaxMode, + }) +} + +func truncateFragmentValue(value string) string { + value = strings.TrimSpace(value) + if value == "" { + return "" + } + if len(value) > linuxDoOAuthMaxFragmentValueLen { + value = value[:linuxDoOAuthMaxFragmentValueLen] + for !utf8.ValidString(value) { + value = value[:len(value)-1] + } + } + return value +} + +func buildBearerAuthorization(tokenType, accessToken string) (string, error) { + tokenType = strings.TrimSpace(tokenType) + if tokenType == "" { + tokenType = "Bearer" + } + if !strings.EqualFold(tokenType, "Bearer") { + return "", fmt.Errorf("unsupported token_type: %s", tokenType) + } + + accessToken = strings.TrimSpace(accessToken) + if accessToken == "" { + return "", errors.New("missing access_token") + } + if strings.ContainsAny(accessToken, " \t\r\n") { + return "", errors.New("access_token contains whitespace") + } + return "Bearer " + accessToken, nil +} + +func isSafeLinuxDoSubject(subject string) bool { + subject = strings.TrimSpace(subject) + if subject == "" || len(subject) > linuxDoOAuthMaxSubjectLen { + return false + } + for _, r := range subject { + switch { + case r >= '0' && r <= '9': + case r >= 'a' && r <= 'z': + case r >= 'A' && r <= 'Z': + case r == '_' || r == '-': + default: + return false + } + } + return true +} + +func linuxDoSyntheticEmail(subject string) string { + subject = strings.TrimSpace(subject) + if subject == "" { + return "" + } + return "linuxdo-" + subject + service.LinuxDoConnectSyntheticEmailDomain +} diff --git a/backend/internal/handler/auth_linuxdo_oauth_test.go b/backend/internal/handler/auth_linuxdo_oauth_test.go new file mode 100644 index 00000000..ff169c52 --- /dev/null +++ b/backend/internal/handler/auth_linuxdo_oauth_test.go @@ -0,0 +1,108 @@ +package handler + +import ( + "strings" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +func TestSanitizeFrontendRedirectPath(t *testing.T) { + require.Equal(t, "/dashboard", sanitizeFrontendRedirectPath("/dashboard")) + require.Equal(t, "/dashboard", sanitizeFrontendRedirectPath(" /dashboard ")) + require.Equal(t, "", sanitizeFrontendRedirectPath("dashboard")) + require.Equal(t, "", sanitizeFrontendRedirectPath("//evil.com")) + require.Equal(t, "", sanitizeFrontendRedirectPath("https://evil.com")) + require.Equal(t, "", sanitizeFrontendRedirectPath("/\nfoo")) + + long := "/" + strings.Repeat("a", linuxDoOAuthMaxRedirectLen) + require.Equal(t, "", sanitizeFrontendRedirectPath(long)) +} + +func TestBuildBearerAuthorization(t *testing.T) { + auth, err := buildBearerAuthorization("", "token123") + require.NoError(t, err) + require.Equal(t, "Bearer token123", auth) + + auth, err = buildBearerAuthorization("bearer", "token123") + require.NoError(t, err) + require.Equal(t, "Bearer token123", auth) + + _, err = buildBearerAuthorization("MAC", "token123") + require.Error(t, err) + + _, err = buildBearerAuthorization("Bearer", "token 123") + require.Error(t, err) +} + +func TestLinuxDoParseUserInfoParsesIDAndUsername(t *testing.T) { + cfg := config.LinuxDoConnectConfig{ + UserInfoURL: "https://connect.linux.do/api/user", + } + + email, username, subject, err := linuxDoParseUserInfo(`{"id":123,"username":"alice"}`, cfg) + require.NoError(t, err) + require.Equal(t, "123", subject) + require.Equal(t, "alice", username) + require.Equal(t, "linuxdo-123@linuxdo-connect.invalid", email) +} + +func TestLinuxDoParseUserInfoDefaultsUsername(t *testing.T) { + cfg := config.LinuxDoConnectConfig{ + UserInfoURL: "https://connect.linux.do/api/user", + } + + email, username, subject, err := linuxDoParseUserInfo(`{"id":"123"}`, cfg) + require.NoError(t, err) + require.Equal(t, "123", subject) + require.Equal(t, "linuxdo_123", username) + require.Equal(t, "linuxdo-123@linuxdo-connect.invalid", email) +} + +func TestLinuxDoParseUserInfoRejectsUnsafeSubject(t *testing.T) { + cfg := config.LinuxDoConnectConfig{ + UserInfoURL: "https://connect.linux.do/api/user", + } + + _, _, _, err := linuxDoParseUserInfo(`{"id":"123@456"}`, cfg) + require.Error(t, err) + + tooLong := strings.Repeat("a", linuxDoOAuthMaxSubjectLen+1) + _, _, _, err = linuxDoParseUserInfo(`{"id":"`+tooLong+`"}`, cfg) + require.Error(t, err) +} + +func TestParseOAuthProviderErrorJSON(t *testing.T) { + code, desc := parseOAuthProviderError(`{"error":"invalid_client","error_description":"bad secret"}`) + require.Equal(t, "invalid_client", code) + require.Equal(t, "bad secret", desc) +} + +func TestParseOAuthProviderErrorForm(t *testing.T) { + code, desc := parseOAuthProviderError("error=invalid_request&error_description=Missing+code_verifier") + require.Equal(t, "invalid_request", code) + require.Equal(t, "Missing code_verifier", desc) +} + +func TestParseLinuxDoTokenResponseJSON(t *testing.T) { + token, ok := parseLinuxDoTokenResponse(`{"access_token":"t1","token_type":"Bearer","expires_in":3600,"scope":"user"}`) + require.True(t, ok) + require.Equal(t, "t1", token.AccessToken) + require.Equal(t, "Bearer", token.TokenType) + require.Equal(t, int64(3600), token.ExpiresIn) + require.Equal(t, "user", token.Scope) +} + +func TestParseLinuxDoTokenResponseForm(t *testing.T) { + token, ok := parseLinuxDoTokenResponse("access_token=t2&token_type=bearer&expires_in=60") + require.True(t, ok) + require.Equal(t, "t2", token.AccessToken) + require.Equal(t, "bearer", token.TokenType) + require.Equal(t, int64(60), token.ExpiresIn) +} + +func TestSingleLineStripsWhitespace(t *testing.T) { + require.Equal(t, "hello world", singleLine("hello\r\nworld")) + require.Equal(t, "", singleLine("\n\t\r")) +} diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go index 9a672064..85dbe6f5 100644 --- a/backend/internal/handler/dto/mappers.go +++ b/backend/internal/handler/dto/mappers.go @@ -53,16 +53,18 @@ func APIKeyFromService(k *service.APIKey) *APIKey { return nil } return &APIKey{ - ID: k.ID, - UserID: k.UserID, - Key: k.Key, - Name: k.Name, - GroupID: k.GroupID, - Status: k.Status, - CreatedAt: k.CreatedAt, - UpdatedAt: k.UpdatedAt, - User: UserFromServiceShallow(k.User), - Group: GroupFromServiceShallow(k.Group), + ID: k.ID, + UserID: k.UserID, + Key: k.Key, + Name: k.Name, + GroupID: k.GroupID, + Status: k.Status, + IPWhitelist: k.IPWhitelist, + IPBlacklist: k.IPBlacklist, + CreatedAt: k.CreatedAt, + UpdatedAt: k.UpdatedAt, + User: UserFromServiceShallow(k.User), + Group: GroupFromServiceShallow(k.Group), } } @@ -250,11 +252,12 @@ func AccountSummaryFromService(a *service.Account) *AccountSummary { // usageLogFromServiceBase is a helper that converts service UsageLog to DTO. // The account parameter allows caller to control what Account info is included. -func usageLogFromServiceBase(l *service.UsageLog, account *AccountSummary) *UsageLog { +// The includeIPAddress parameter controls whether to include the IP address (admin-only). +func usageLogFromServiceBase(l *service.UsageLog, account *AccountSummary, includeIPAddress bool) *UsageLog { if l == nil { return nil } - return &UsageLog{ + result := &UsageLog{ ID: l.ID, UserID: l.UserID, APIKeyID: l.APIKeyID, @@ -290,21 +293,26 @@ func usageLogFromServiceBase(l *service.UsageLog, account *AccountSummary) *Usag Group: GroupFromServiceShallow(l.Group), Subscription: UserSubscriptionFromService(l.Subscription), } + // IP 地址仅对管理员可见 + if includeIPAddress { + result.IPAddress = l.IPAddress + } + return result } // UsageLogFromService converts a service UsageLog to DTO for regular users. -// It excludes Account details - users should not see account information. +// It excludes Account details and IP address - users should not see these. func UsageLogFromService(l *service.UsageLog) *UsageLog { - return usageLogFromServiceBase(l, nil) + return usageLogFromServiceBase(l, nil, false) } // UsageLogFromServiceAdmin converts a service UsageLog to DTO for admin users. -// It includes minimal Account info (ID, Name only). +// It includes minimal Account info (ID, Name only) and IP address. func UsageLogFromServiceAdmin(l *service.UsageLog) *UsageLog { if l == nil { return nil } - return usageLogFromServiceBase(l, AccountSummaryFromService(l.Account)) + return usageLogFromServiceBase(l, AccountSummaryFromService(l.Account), true) } func SettingFromService(s *service.Setting) *Setting { diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go index 4c50cedf..dab5eb75 100644 --- a/backend/internal/handler/dto/settings.go +++ b/backend/internal/handler/dto/settings.go @@ -17,6 +17,11 @@ type SystemSettings struct { TurnstileSiteKey string `json:"turnstile_site_key"` TurnstileSecretKeyConfigured bool `json:"turnstile_secret_key_configured"` + LinuxDoConnectEnabled bool `json:"linuxdo_connect_enabled"` + LinuxDoConnectClientID string `json:"linuxdo_connect_client_id"` + LinuxDoConnectClientSecretConfigured bool `json:"linuxdo_connect_client_secret_configured"` + LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"` + SiteName string `json:"site_name"` SiteLogo string `json:"site_logo"` SiteSubtitle string `json:"site_subtitle"` @@ -50,5 +55,6 @@ type PublicSettings struct { APIBaseURL string `json:"api_base_url"` ContactInfo string `json:"contact_info"` DocURL string `json:"doc_url"` + LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` Version string `json:"version"` } diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go index 03f7080b..ad583ad0 100644 --- a/backend/internal/handler/dto/types.go +++ b/backend/internal/handler/dto/types.go @@ -20,14 +20,16 @@ type User struct { } type APIKey struct { - ID int64 `json:"id"` - UserID int64 `json:"user_id"` - Key string `json:"key"` - Name string `json:"name"` - GroupID *int64 `json:"group_id"` - Status string `json:"status"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` + ID int64 `json:"id"` + UserID int64 `json:"user_id"` + Key string `json:"key"` + Name string `json:"name"` + GroupID *int64 `json:"group_id"` + Status string `json:"status"` + IPWhitelist []string `json:"ip_whitelist"` + IPBlacklist []string `json:"ip_blacklist"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` User *User `json:"user,omitempty"` Group *Group `json:"group,omitempty"` @@ -187,6 +189,9 @@ type UsageLog struct { // User-Agent UserAgent *string `json:"user_agent"` + // IP 地址(仅管理员可见) + IPAddress *string `json:"ip_address,omitempty"` + CreatedAt time.Time `json:"created_at"` User *User `json:"user,omitempty"` diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 48a827f3..0d38db17 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -15,6 +15,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" "github.com/Wei-Shaw/sub2api/internal/pkg/claude" pkgerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/ip" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" @@ -114,6 +115,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) { // 获取 User-Agent userAgent := c.Request.UserAgent() + // 获取客户端 IP + clientIP := ip.GetClientIP(c) + // 0. 检查wait队列是否已满 maxWait := service.CalculateMaxWait(subject.Concurrency) canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait) @@ -273,7 +277,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } // 异步记录使用量(subscription已在函数开头获取) - go func(result *service.ForwardResult, usedAccount *service.Account, ua string) { + go func(result *service.ForwardResult, usedAccount *service.Account, ua string, cip string) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ @@ -283,10 +287,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) { Account: usedAccount, Subscription: subscription, UserAgent: ua, + IPAddress: cip, }); err != nil { log.Printf("Record usage failed: %v", err) } - }(result, account, userAgent) + }(result, account, userAgent, clientIP) return } } @@ -401,7 +406,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } // 异步记录使用量(subscription已在函数开头获取) - go func(result *service.ForwardResult, usedAccount *service.Account, ua string) { + go func(result *service.ForwardResult, usedAccount *service.Account, ua string, cip string) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ @@ -411,10 +416,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) { Account: usedAccount, Subscription: subscription, UserAgent: ua, + IPAddress: cip, }); err != nil { log.Printf("Record usage failed: %v", err) } - }(result, account, userAgent) + }(result, account, userAgent, clientIP) return } } diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index 0cbe44f2..986b174b 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -12,6 +12,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" "github.com/Wei-Shaw/sub2api/internal/pkg/gemini" "github.com/Wei-Shaw/sub2api/internal/pkg/googleapi" + "github.com/Wei-Shaw/sub2api/internal/pkg/ip" "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" @@ -167,6 +168,9 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { // 获取 User-Agent userAgent := c.Request.UserAgent() + // 获取客户端 IP + clientIP := ip.GetClientIP(c) + // For Gemini native API, do not send Claude-style ping frames. geminiConcurrency := NewConcurrencyHelper(h.concurrencyHelper.concurrencyService, SSEPingFormatNone, 0) @@ -307,7 +311,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { } // 6) record usage async - go func(result *service.ForwardResult, usedAccount *service.Account, ua string) { + go func(result *service.ForwardResult, usedAccount *service.Account, ua string, cip string) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ @@ -317,10 +321,11 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { Account: usedAccount, Subscription: subscription, UserAgent: ua, + IPAddress: cip, }); err != nil { log.Printf("Record usage failed: %v", err) } - }(result, account, userAgent) + }(result, account, userAgent, clientIP) return } } diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index 70131417..068e80ea 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -11,6 +11,7 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/ip" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" @@ -94,6 +95,10 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { // For non-Codex CLI requests, set default instructions userAgent := c.GetHeader("User-Agent") + + // 获取客户端 IP + clientIP := ip.GetClientIP(c) + if !openai.IsCodexCLIRequest(userAgent) { reqBody["instructions"] = openai.DefaultInstructions // Re-serialize body @@ -242,7 +247,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { } // Async record usage - go func(result *service.OpenAIForwardResult, usedAccount *service.Account, ua string) { + go func(result *service.OpenAIForwardResult, usedAccount *service.Account, ua string, cip string) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{ @@ -252,10 +257,11 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { Account: usedAccount, Subscription: subscription, UserAgent: ua, + IPAddress: cip, }); err != nil { log.Printf("Record usage failed: %v", err) } - }(result, account, userAgent) + }(result, account, userAgent, clientIP) return } } diff --git a/backend/internal/handler/setting_handler.go b/backend/internal/handler/setting_handler.go index 3cae7a7f..e1b20c8c 100644 --- a/backend/internal/handler/setting_handler.go +++ b/backend/internal/handler/setting_handler.go @@ -42,6 +42,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) { APIBaseURL: settings.APIBaseURL, ContactInfo: settings.ContactInfo, DocURL: settings.DocURL, + LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, Version: h.version, }) } diff --git a/backend/internal/pkg/ip/ip.go b/backend/internal/pkg/ip/ip.go new file mode 100644 index 00000000..97109c0c --- /dev/null +++ b/backend/internal/pkg/ip/ip.go @@ -0,0 +1,168 @@ +// Package ip 提供客户端 IP 地址提取工具。 +package ip + +import ( + "net" + "strings" + + "github.com/gin-gonic/gin" +) + +// GetClientIP 从 Gin Context 中提取客户端真实 IP 地址。 +// 按以下优先级检查 Header: +// 1. CF-Connecting-IP (Cloudflare) +// 2. X-Real-IP (Nginx) +// 3. X-Forwarded-For (取第一个非私有 IP) +// 4. c.ClientIP() (Gin 内置方法) +func GetClientIP(c *gin.Context) string { + // 1. Cloudflare + if ip := c.GetHeader("CF-Connecting-IP"); ip != "" { + return normalizeIP(ip) + } + + // 2. Nginx X-Real-IP + if ip := c.GetHeader("X-Real-IP"); ip != "" { + return normalizeIP(ip) + } + + // 3. X-Forwarded-For (多个 IP 时取第一个公网 IP) + if xff := c.GetHeader("X-Forwarded-For"); xff != "" { + ips := strings.Split(xff, ",") + for _, ip := range ips { + ip = strings.TrimSpace(ip) + if ip != "" && !isPrivateIP(ip) { + return normalizeIP(ip) + } + } + // 如果都是私有 IP,返回第一个 + if len(ips) > 0 { + return normalizeIP(strings.TrimSpace(ips[0])) + } + } + + // 4. Gin 内置方法 + return normalizeIP(c.ClientIP()) +} + +// normalizeIP 规范化 IP 地址,去除端口号和空格。 +func normalizeIP(ip string) string { + ip = strings.TrimSpace(ip) + // 移除端口号(如 "192.168.1.1:8080" -> "192.168.1.1") + if host, _, err := net.SplitHostPort(ip); err == nil { + return host + } + return ip +} + +// isPrivateIP 检查 IP 是否为私有地址。 +func isPrivateIP(ipStr string) bool { + ip := net.ParseIP(ipStr) + if ip == nil { + return false + } + + // 私有 IP 范围 + privateBlocks := []string{ + "10.0.0.0/8", + "172.16.0.0/12", + "192.168.0.0/16", + "127.0.0.0/8", + "::1/128", + "fc00::/7", + } + + for _, block := range privateBlocks { + _, cidr, err := net.ParseCIDR(block) + if err != nil { + continue + } + if cidr.Contains(ip) { + return true + } + } + return false +} + +// MatchesPattern 检查 IP 是否匹配指定的模式(支持单个 IP 或 CIDR)。 +// pattern 可以是: +// - 单个 IP: "192.168.1.100" +// - CIDR 范围: "192.168.1.0/24" +func MatchesPattern(clientIP, pattern string) bool { + ip := net.ParseIP(clientIP) + if ip == nil { + return false + } + + // 尝试解析为 CIDR + if strings.Contains(pattern, "/") { + _, cidr, err := net.ParseCIDR(pattern) + if err != nil { + return false + } + return cidr.Contains(ip) + } + + // 作为单个 IP 处理 + patternIP := net.ParseIP(pattern) + if patternIP == nil { + return false + } + return ip.Equal(patternIP) +} + +// MatchesAnyPattern 检查 IP 是否匹配任意一个模式。 +func MatchesAnyPattern(clientIP string, patterns []string) bool { + for _, pattern := range patterns { + if MatchesPattern(clientIP, pattern) { + return true + } + } + return false +} + +// CheckIPRestriction 检查 IP 是否被 API Key 的 IP 限制允许。 +// 返回值:(是否允许, 拒绝原因) +// 逻辑: +// 1. 先检查黑名单,如果在黑名单中则直接拒绝 +// 2. 如果白名单不为空,IP 必须在白名单中 +// 3. 如果白名单为空,允许访问(除非被黑名单拒绝) +func CheckIPRestriction(clientIP string, whitelist, blacklist []string) (bool, string) { + // 规范化 IP + clientIP = normalizeIP(clientIP) + if clientIP == "" { + return false, "access denied" + } + + // 1. 检查黑名单 + if len(blacklist) > 0 && MatchesAnyPattern(clientIP, blacklist) { + return false, "access denied" + } + + // 2. 检查白名单(如果设置了白名单,IP 必须在其中) + if len(whitelist) > 0 && !MatchesAnyPattern(clientIP, whitelist) { + return false, "access denied" + } + + return true, "" +} + +// ValidateIPPattern 验证 IP 或 CIDR 格式是否有效。 +func ValidateIPPattern(pattern string) bool { + if strings.Contains(pattern, "/") { + _, _, err := net.ParseCIDR(pattern) + return err == nil + } + return net.ParseIP(pattern) != nil +} + +// ValidateIPPatterns 验证多个 IP 或 CIDR 格式。 +// 返回无效的模式列表。 +func ValidateIPPatterns(patterns []string) []string { + var invalid []string + for _, p := range patterns { + if !ValidateIPPattern(p) { + invalid = append(invalid, p) + } + } + return invalid +} diff --git a/backend/internal/repository/account_repo.go b/backend/internal/repository/account_repo.go index 83f02608..04ca7052 100644 --- a/backend/internal/repository/account_repo.go +++ b/backend/internal/repository/account_repo.go @@ -675,6 +675,40 @@ func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetA return err } +func (r *accountRepository) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope service.AntigravityQuotaScope, resetAt time.Time) error { + now := time.Now().UTC() + payload := map[string]string{ + "rate_limited_at": now.Format(time.RFC3339), + "rate_limit_reset_at": resetAt.UTC().Format(time.RFC3339), + } + raw, err := json.Marshal(payload) + if err != nil { + return err + } + + path := "{antigravity_quota_scopes," + string(scope) + "}" + client := clientFromContext(ctx, r.client) + result, err := client.ExecContext( + ctx, + "UPDATE accounts SET extra = jsonb_set(COALESCE(extra, '{}'::jsonb), $1::text[], $2::jsonb, true), updated_at = NOW() WHERE id = $3 AND deleted_at IS NULL", + path, + raw, + id, + ) + if err != nil { + return err + } + + affected, err := result.RowsAffected() + if err != nil { + return err + } + if affected == 0 { + return service.ErrAccountNotFound + } + return nil +} + func (r *accountRepository) SetOverloaded(ctx context.Context, id int64, until time.Time) error { _, err := r.client.Account.Update(). Where(dbaccount.IDEQ(id)). @@ -718,6 +752,27 @@ func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error return err } +func (r *accountRepository) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error { + client := clientFromContext(ctx, r.client) + result, err := client.ExecContext( + ctx, + "UPDATE accounts SET extra = COALESCE(extra, '{}'::jsonb) - 'antigravity_quota_scopes', updated_at = NOW() WHERE id = $1 AND deleted_at IS NULL", + id, + ) + if err != nil { + return err + } + + affected, err := result.RowsAffected() + if err != nil { + return err + } + if affected == 0 { + return service.ErrAccountNotFound + } + return nil +} + func (r *accountRepository) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error { builder := r.client.Account.Update(). Where(dbaccount.IDEQ(id)). @@ -831,6 +886,11 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates args = append(args, *updates.Status) idx++ } + if updates.Schedulable != nil { + setClauses = append(setClauses, "schedulable = $"+itoa(idx)) + args = append(args, *updates.Schedulable) + idx++ + } // JSONB 需要合并而非覆盖,使用 raw SQL 保持旧行为。 if len(updates.Credentials) > 0 { payload, err := json.Marshal(updates.Credentials) diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go index f3b07616..6da551da 100644 --- a/backend/internal/repository/api_key_repo.go +++ b/backend/internal/repository/api_key_repo.go @@ -26,13 +26,21 @@ func (r *apiKeyRepository) activeQuery() *dbent.APIKeyQuery { } func (r *apiKeyRepository) Create(ctx context.Context, key *service.APIKey) error { - created, err := r.client.APIKey.Create(). + builder := r.client.APIKey.Create(). SetUserID(key.UserID). SetKey(key.Key). SetName(key.Name). SetStatus(key.Status). - SetNillableGroupID(key.GroupID). - Save(ctx) + SetNillableGroupID(key.GroupID) + + if len(key.IPWhitelist) > 0 { + builder.SetIPWhitelist(key.IPWhitelist) + } + if len(key.IPBlacklist) > 0 { + builder.SetIPBlacklist(key.IPBlacklist) + } + + created, err := builder.Save(ctx) if err == nil { key.ID = created.ID key.CreatedAt = created.CreatedAt @@ -108,6 +116,18 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) erro builder.ClearGroupID() } + // IP 限制字段 + if len(key.IPWhitelist) > 0 { + builder.SetIPWhitelist(key.IPWhitelist) + } else { + builder.ClearIPWhitelist() + } + if len(key.IPBlacklist) > 0 { + builder.SetIPBlacklist(key.IPBlacklist) + } else { + builder.ClearIPBlacklist() + } + affected, err := builder.Save(ctx) if err != nil { return err @@ -268,14 +288,16 @@ func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey { return nil } out := &service.APIKey{ - ID: m.ID, - UserID: m.UserID, - Key: m.Key, - Name: m.Name, - Status: m.Status, - CreatedAt: m.CreatedAt, - UpdatedAt: m.UpdatedAt, - GroupID: m.GroupID, + ID: m.ID, + UserID: m.UserID, + Key: m.Key, + Name: m.Name, + Status: m.Status, + IPWhitelist: m.IPWhitelist, + IPBlacklist: m.IPBlacklist, + CreatedAt: m.CreatedAt, + UpdatedAt: m.UpdatedAt, + GroupID: m.GroupID, } if m.Edges.User != nil { out.User = userEntityToService(m.Edges.User) diff --git a/backend/internal/repository/group_repo.go b/backend/internal/repository/group_repo.go index 1fb4ae90..a54f3116 100644 --- a/backend/internal/repository/group_repo.go +++ b/backend/internal/repository/group_repo.go @@ -112,10 +112,10 @@ func (r *groupRepository) Delete(ctx context.Context, id int64) error { } func (r *groupRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Group, *pagination.PaginationResult, error) { - return r.ListWithFilters(ctx, params, "", "", nil) + return r.ListWithFilters(ctx, params, "", "", "", nil) } -func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]service.Group, *pagination.PaginationResult, error) { +func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]service.Group, *pagination.PaginationResult, error) { q := r.client.Group.Query() if platform != "" { @@ -124,6 +124,12 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination if status != "" { q = q.Where(group.StatusEQ(status)) } + if search != "" { + q = q.Where(group.Or( + group.NameContainsFold(search), + group.DescriptionContainsFold(search), + )) + } if isExclusive != nil { q = q.Where(group.IsExclusiveEQ(*isExclusive)) } diff --git a/backend/internal/repository/group_repo_integration_test.go b/backend/internal/repository/group_repo_integration_test.go index b9079d7a..660618a6 100644 --- a/backend/internal/repository/group_repo_integration_test.go +++ b/backend/internal/repository/group_repo_integration_test.go @@ -131,6 +131,7 @@ func (s *GroupRepoSuite) TestListWithFilters_Platform() { pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformOpenAI, "", + "", nil, ) s.Require().NoError(err, "ListWithFilters base") @@ -152,7 +153,7 @@ func (s *GroupRepoSuite) TestListWithFilters_Platform() { SubscriptionType: service.SubscriptionTypeStandard, })) - groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformOpenAI, "", nil) + groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformOpenAI, "", "", nil) s.Require().NoError(err) s.Require().Len(groups, len(baseGroups)+1) // Verify all groups are OpenAI platform @@ -179,7 +180,7 @@ func (s *GroupRepoSuite) TestListWithFilters_Status() { SubscriptionType: service.SubscriptionTypeStandard, })) - groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.StatusDisabled, nil) + groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.StatusDisabled, "", nil) s.Require().NoError(err) s.Require().Len(groups, 1) s.Require().Equal(service.StatusDisabled, groups[0].Status) @@ -204,12 +205,117 @@ func (s *GroupRepoSuite) TestListWithFilters_IsExclusive() { })) isExclusive := true - groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", &isExclusive) + groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "", &isExclusive) s.Require().NoError(err) s.Require().Len(groups, 1) s.Require().True(groups[0].IsExclusive) } +func (s *GroupRepoSuite) TestListWithFilters_Search() { + newRepo := func() (*groupRepository, context.Context) { + tx := testEntTx(s.T()) + return newGroupRepositoryWithSQL(tx.Client(), tx), context.Background() + } + + containsID := func(groups []service.Group, id int64) bool { + for i := range groups { + if groups[i].ID == id { + return true + } + } + return false + } + + mustCreate := func(repo *groupRepository, ctx context.Context, g *service.Group) *service.Group { + s.Require().NoError(repo.Create(ctx, g)) + s.Require().NotZero(g.ID) + return g + } + + newGroup := func(name string) *service.Group { + return &service.Group{ + Name: name, + Platform: service.PlatformAnthropic, + RateMultiplier: 1.0, + IsExclusive: false, + Status: service.StatusActive, + SubscriptionType: service.SubscriptionTypeStandard, + } + } + + s.Run("search_name_should_match", func() { + repo, ctx := newRepo() + + target := mustCreate(repo, ctx, newGroup("it-group-search-name-target")) + other := mustCreate(repo, ctx, newGroup("it-group-search-name-other")) + + groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "name-target", nil) + s.Require().NoError(err) + s.Require().True(containsID(groups, target.ID), "expected target group to match by name") + s.Require().False(containsID(groups, other.ID), "expected other group to be filtered out") + }) + + s.Run("search_description_should_match", func() { + repo, ctx := newRepo() + + target := newGroup("it-group-search-desc-target") + target.Description = "something about desc-needle in here" + target = mustCreate(repo, ctx, target) + + other := newGroup("it-group-search-desc-other") + other.Description = "nothing to see here" + other = mustCreate(repo, ctx, other) + + groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "desc-needle", nil) + s.Require().NoError(err) + s.Require().True(containsID(groups, target.ID), "expected target group to match by description") + s.Require().False(containsID(groups, other.ID), "expected other group to be filtered out") + }) + + s.Run("search_nonexistent_should_return_empty", func() { + repo, ctx := newRepo() + + _ = mustCreate(repo, ctx, newGroup("it-group-search-nonexistent-baseline")) + + search := s.T().Name() + "__no_such_group__" + groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", search, nil) + s.Require().NoError(err) + s.Require().Empty(groups) + }) + + s.Run("search_should_be_case_insensitive", func() { + repo, ctx := newRepo() + + target := mustCreate(repo, ctx, newGroup("MiXeDCaSe-Needle")) + other := mustCreate(repo, ctx, newGroup("it-group-search-case-other")) + + groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "mixedcase-needle", nil) + s.Require().NoError(err) + s.Require().True(containsID(groups, target.ID), "expected case-insensitive match") + s.Require().False(containsID(groups, other.ID), "expected other group to be filtered out") + }) + + s.Run("search_should_escape_like_wildcards", func() { + repo, ctx := newRepo() + + percentTarget := mustCreate(repo, ctx, newGroup("it-group-search-100%-target")) + percentOther := mustCreate(repo, ctx, newGroup("it-group-search-100X-other")) + + groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "100%", nil) + s.Require().NoError(err) + s.Require().True(containsID(groups, percentTarget.ID), "expected literal %% match") + s.Require().False(containsID(groups, percentOther.ID), "expected %% not to act as wildcard") + + underscoreTarget := mustCreate(repo, ctx, newGroup("it-group-search-ab_cd-target")) + underscoreOther := mustCreate(repo, ctx, newGroup("it-group-search-abXcd-other")) + + groups, _, err = repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "ab_cd", nil) + s.Require().NoError(err) + s.Require().True(containsID(groups, underscoreTarget.ID), "expected literal _ match") + s.Require().False(containsID(groups, underscoreOther.ID), "expected _ not to act as wildcard") + }) +} + func (s *GroupRepoSuite) TestListWithFilters_AccountCount() { g1 := &service.Group{ Name: "g1", @@ -244,7 +350,7 @@ func (s *GroupRepoSuite) TestListWithFilters_AccountCount() { s.Require().NoError(err) isExclusive := true - groups, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformAnthropic, service.StatusActive, &isExclusive) + groups, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformAnthropic, service.StatusActive, "", &isExclusive) s.Require().NoError(err, "ListWithFilters") s.Require().Equal(int64(1), page.Total) s.Require().Len(groups, 1) diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go index bd5c8b4f..6ed8910e 100644 --- a/backend/internal/repository/usage_log_repo.go +++ b/backend/internal/repository/usage_log_repo.go @@ -22,7 +22,7 @@ import ( "github.com/lib/pq" ) -const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, image_count, image_size, created_at" +const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, created_at" type usageLogRepository struct { client *dbent.Client @@ -110,6 +110,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) duration_ms, first_token_ms, user_agent, + ip_address, image_count, image_size, created_at @@ -119,7 +120,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, - $20, $21, $22, $23, $24, $25, $26, $27, $28 + $20, $21, $22, $23, $24, $25, $26, $27, $28, $29 ) ON CONFLICT (request_id, api_key_id) DO NOTHING RETURNING id, created_at @@ -130,6 +131,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) duration := nullInt(log.DurationMs) firstToken := nullInt(log.FirstTokenMs) userAgent := nullString(log.UserAgent) + ipAddress := nullString(log.IPAddress) imageSize := nullString(log.ImageSize) var requestIDArg any @@ -163,6 +165,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) duration, firstToken, userAgent, + ipAddress, log.ImageCount, imageSize, createdAt, @@ -1873,6 +1876,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e durationMs sql.NullInt64 firstTokenMs sql.NullInt64 userAgent sql.NullString + ipAddress sql.NullString imageCount int imageSize sql.NullString createdAt time.Time @@ -1905,6 +1909,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e &durationMs, &firstTokenMs, &userAgent, + &ipAddress, &imageCount, &imageSize, &createdAt, @@ -1959,6 +1964,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e if userAgent.Valid { log.UserAgent = &userAgent.String } + if ipAddress.Valid { + log.IPAddress = &ipAddress.String + } if imageSize.Valid { log.ImageSize = &imageSize.String } diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index 502d74b3..6e52c5bc 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -82,6 +82,8 @@ func TestAPIContracts(t *testing.T) { "name": "Key One", "group_id": null, "status": "active", + "ip_whitelist": null, + "ip_blacklist": null, "created_at": "2025-01-02T03:04:05Z", "updated_at": "2025-01-02T03:04:05Z" } @@ -116,6 +118,8 @@ func TestAPIContracts(t *testing.T) { "name": "Key One", "group_id": null, "status": "active", + "ip_whitelist": null, + "ip_blacklist": null, "created_at": "2025-01-02T03:04:05Z", "updated_at": "2025-01-02T03:04:05Z" } @@ -304,6 +308,10 @@ func TestAPIContracts(t *testing.T) { "turnstile_enabled": true, "turnstile_site_key": "site-key", "turnstile_secret_key_configured": true, + "linuxdo_connect_enabled": false, + "linuxdo_connect_client_id": "", + "linuxdo_connect_client_secret_configured": false, + "linuxdo_connect_redirect_url": "", "site_name": "Sub2API", "site_logo": "", "site_subtitle": "Subtitle", @@ -390,7 +398,7 @@ func newContractDeps(t *testing.T) *contractDeps { settingRepo := newStubSettingRepo() settingService := service.NewSettingService(settingRepo, cfg) - authHandler := handler.NewAuthHandler(cfg, nil, userService) + authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) usageHandler := handler.NewUsageHandler(usageService, apiKeyService) adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil) @@ -583,7 +591,7 @@ func (stubGroupRepo) List(ctx context.Context, params pagination.PaginationParam return nil, nil, errors.New("not implemented") } -func (stubGroupRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]service.Group, *pagination.PaginationResult, error) { +func (stubGroupRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]service.Group, *pagination.PaginationResult, error) { return nil, nil, errors.New("not implemented") } diff --git a/backend/internal/server/middleware/api_key_auth.go b/backend/internal/server/middleware/api_key_auth.go index 74ff8af3..d93724f2 100644 --- a/backend/internal/server/middleware/api_key_auth.go +++ b/backend/internal/server/middleware/api_key_auth.go @@ -6,6 +6,7 @@ import ( "strings" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/ip" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" @@ -71,6 +72,17 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti return } + // 检查 IP 限制(白名单/黑名单) + // 注意:错误信息故意模糊,避免暴露具体的 IP 限制机制 + if len(apiKey.IPWhitelist) > 0 || len(apiKey.IPBlacklist) > 0 { + clientIP := ip.GetClientIP(c) + allowed, _ := ip.CheckIPRestriction(clientIP, apiKey.IPWhitelist, apiKey.IPBlacklist) + if !allowed { + AbortWithError(c, 403, "ACCESS_DENIED", "Access denied") + return + } + } + // 检查关联的用户 if apiKey.User == nil { AbortWithError(c, 401, "USER_NOT_FOUND", "User associated with API key not found") diff --git a/backend/internal/server/routes/auth.go b/backend/internal/server/routes/auth.go index 196d8bdb..e61d3939 100644 --- a/backend/internal/server/routes/auth.go +++ b/backend/internal/server/routes/auth.go @@ -19,6 +19,8 @@ func RegisterAuthRoutes( auth.POST("/register", h.Auth.Register) auth.POST("/login", h.Auth.Login) auth.POST("/send-verify-code", h.Auth.SendVerifyCode) + auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart) + auth.GET("/oauth/linuxdo/callback", h.Auth.LinuxDoOAuthCallback) } // 公开设置(无需认证) diff --git a/backend/internal/service/account_service.go b/backend/internal/service/account_service.go index e1b93fcb..2f138b81 100644 --- a/backend/internal/service/account_service.go +++ b/backend/internal/service/account_service.go @@ -49,10 +49,12 @@ type AccountRepository interface { ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error + SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error SetOverloaded(ctx context.Context, id int64, until time.Time) error SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error ClearTempUnschedulable(ctx context.Context, id int64) error ClearRateLimit(ctx context.Context, id int64) error + ClearAntigravityQuotaScopes(ctx context.Context, id int64) error UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error UpdateExtra(ctx context.Context, id int64, updates map[string]any) error BulkUpdate(ctx context.Context, ids []int64, updates AccountBulkUpdate) (int64, error) @@ -66,6 +68,7 @@ type AccountBulkUpdate struct { Concurrency *int Priority *int Status *string + Schedulable *bool Credentials map[string]any Extra map[string]any } diff --git a/backend/internal/service/account_service_delete_test.go b/backend/internal/service/account_service_delete_test.go index edad8672..6923067d 100644 --- a/backend/internal/service/account_service_delete_test.go +++ b/backend/internal/service/account_service_delete_test.go @@ -139,6 +139,10 @@ func (s *accountRepoStub) SetRateLimited(ctx context.Context, id int64, resetAt panic("unexpected SetRateLimited call") } +func (s *accountRepoStub) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error { + panic("unexpected SetAntigravityQuotaScopeLimit call") +} + func (s *accountRepoStub) SetOverloaded(ctx context.Context, id int64, until time.Time) error { panic("unexpected SetOverloaded call") } @@ -155,6 +159,10 @@ func (s *accountRepoStub) ClearRateLimit(ctx context.Context, id int64) error { panic("unexpected ClearRateLimit call") } +func (s *accountRepoStub) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error { + panic("unexpected ClearAntigravityQuotaScopes call") +} + func (s *accountRepoStub) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error { panic("unexpected UpdateSessionWindow call") } diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index e29bbdb4..4288381c 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -24,7 +24,7 @@ type AdminService interface { GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error) // Group management - ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]Group, int64, error) + ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool) ([]Group, int64, error) GetAllGroups(ctx context.Context) ([]Group, error) GetAllGroupsByPlatform(ctx context.Context, platform string) ([]Group, error) GetGroup(ctx context.Context, id int64) (*Group, error) @@ -168,6 +168,7 @@ type BulkUpdateAccountsInput struct { Concurrency *int Priority *int Status string + Schedulable *bool GroupIDs *[]int64 Credentials map[string]any Extra map[string]any @@ -478,9 +479,9 @@ func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64, } // Group management implementations -func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]Group, int64, error) { +func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool) ([]Group, int64, error) { params := pagination.PaginationParams{Page: page, PageSize: pageSize} - groups, result, err := s.groupRepo.ListWithFilters(ctx, params, platform, status, isExclusive) + groups, result, err := s.groupRepo.ListWithFilters(ctx, params, platform, status, search, isExclusive) if err != nil { return nil, 0, err } @@ -910,6 +911,9 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp if input.Status != "" { repoUpdates.Status = &input.Status } + if input.Schedulable != nil { + repoUpdates.Schedulable = input.Schedulable + } // Run bulk update for column/jsonb fields first. if _, err := s.accountRepo.BulkUpdate(ctx, input.AccountIDs, repoUpdates); err != nil { diff --git a/backend/internal/service/admin_service_delete_test.go b/backend/internal/service/admin_service_delete_test.go index c1d2e4c9..351f64e8 100644 --- a/backend/internal/service/admin_service_delete_test.go +++ b/backend/internal/service/admin_service_delete_test.go @@ -124,7 +124,7 @@ func (s *groupRepoStub) List(ctx context.Context, params pagination.PaginationPa panic("unexpected List call") } -func (s *groupRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) { +func (s *groupRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) { panic("unexpected ListWithFilters call") } diff --git a/backend/internal/service/admin_service_group_test.go b/backend/internal/service/admin_service_group_test.go index 3171de11..26d6eedf 100644 --- a/backend/internal/service/admin_service_group_test.go +++ b/backend/internal/service/admin_service_group_test.go @@ -16,6 +16,16 @@ type groupRepoStubForAdmin struct { updated *Group // 记录 Update 调用的参数 getByID *Group // GetByID 返回值 getErr error // GetByID 返回的错误 + + listWithFiltersCalls int + listWithFiltersParams pagination.PaginationParams + listWithFiltersPlatform string + listWithFiltersStatus string + listWithFiltersSearch string + listWithFiltersIsExclusive *bool + listWithFiltersGroups []Group + listWithFiltersResult *pagination.PaginationResult + listWithFiltersErr error } func (s *groupRepoStubForAdmin) Create(_ context.Context, g *Group) error { @@ -47,8 +57,28 @@ func (s *groupRepoStubForAdmin) List(_ context.Context, _ pagination.PaginationP panic("unexpected List call") } -func (s *groupRepoStubForAdmin) ListWithFilters(_ context.Context, _ pagination.PaginationParams, _, _ string, _ *bool) ([]Group, *pagination.PaginationResult, error) { - panic("unexpected ListWithFilters call") +func (s *groupRepoStubForAdmin) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) { + s.listWithFiltersCalls++ + s.listWithFiltersParams = params + s.listWithFiltersPlatform = platform + s.listWithFiltersStatus = status + s.listWithFiltersSearch = search + s.listWithFiltersIsExclusive = isExclusive + + if s.listWithFiltersErr != nil { + return nil, nil, s.listWithFiltersErr + } + + result := s.listWithFiltersResult + if result == nil { + result = &pagination.PaginationResult{ + Total: int64(len(s.listWithFiltersGroups)), + Page: params.Page, + PageSize: params.PageSize, + } + } + + return s.listWithFiltersGroups, result, nil } func (s *groupRepoStubForAdmin) ListActive(_ context.Context) ([]Group, error) { @@ -195,3 +225,68 @@ func TestAdminService_UpdateGroup_PartialImagePricing(t *testing.T) { require.InDelta(t, 0.15, *repo.updated.ImagePrice2K, 0.0001) // 原值保持 require.Nil(t, repo.updated.ImagePrice4K) } + +func TestAdminService_ListGroups_WithSearch(t *testing.T) { + // 测试: + // 1. search 参数正常传递到 repository 层 + // 2. search 为空字符串时的行为 + // 3. search 与其他过滤条件组合使用 + + t.Run("search 参数正常传递到 repository 层", func(t *testing.T) { + repo := &groupRepoStubForAdmin{ + listWithFiltersGroups: []Group{{ID: 1, Name: "alpha"}}, + listWithFiltersResult: &pagination.PaginationResult{Total: 1}, + } + svc := &adminServiceImpl{groupRepo: repo} + + groups, total, err := svc.ListGroups(context.Background(), 1, 20, "", "", "alpha", nil) + require.NoError(t, err) + require.Equal(t, int64(1), total) + require.Equal(t, []Group{{ID: 1, Name: "alpha"}}, groups) + + require.Equal(t, 1, repo.listWithFiltersCalls) + require.Equal(t, pagination.PaginationParams{Page: 1, PageSize: 20}, repo.listWithFiltersParams) + require.Equal(t, "alpha", repo.listWithFiltersSearch) + require.Nil(t, repo.listWithFiltersIsExclusive) + }) + + t.Run("search 为空字符串时传递空字符串", func(t *testing.T) { + repo := &groupRepoStubForAdmin{ + listWithFiltersGroups: []Group{}, + listWithFiltersResult: &pagination.PaginationResult{Total: 0}, + } + svc := &adminServiceImpl{groupRepo: repo} + + groups, total, err := svc.ListGroups(context.Background(), 2, 10, "", "", "", nil) + require.NoError(t, err) + require.Empty(t, groups) + require.Equal(t, int64(0), total) + + require.Equal(t, 1, repo.listWithFiltersCalls) + require.Equal(t, pagination.PaginationParams{Page: 2, PageSize: 10}, repo.listWithFiltersParams) + require.Equal(t, "", repo.listWithFiltersSearch) + require.Nil(t, repo.listWithFiltersIsExclusive) + }) + + t.Run("search 与其他过滤条件组合使用", func(t *testing.T) { + isExclusive := true + repo := &groupRepoStubForAdmin{ + listWithFiltersGroups: []Group{{ID: 2, Name: "beta"}}, + listWithFiltersResult: &pagination.PaginationResult{Total: 42}, + } + svc := &adminServiceImpl{groupRepo: repo} + + groups, total, err := svc.ListGroups(context.Background(), 3, 50, PlatformAntigravity, StatusActive, "beta", &isExclusive) + require.NoError(t, err) + require.Equal(t, int64(42), total) + require.Equal(t, []Group{{ID: 2, Name: "beta"}}, groups) + + require.Equal(t, 1, repo.listWithFiltersCalls) + require.Equal(t, pagination.PaginationParams{Page: 3, PageSize: 50}, repo.listWithFiltersParams) + require.Equal(t, PlatformAntigravity, repo.listWithFiltersPlatform) + require.Equal(t, StatusActive, repo.listWithFiltersStatus) + require.Equal(t, "beta", repo.listWithFiltersSearch) + require.NotNil(t, repo.listWithFiltersIsExclusive) + require.True(t, *repo.listWithFiltersIsExclusive) + }) +} diff --git a/backend/internal/service/admin_service_search_test.go b/backend/internal/service/admin_service_search_test.go new file mode 100644 index 00000000..7506c6db --- /dev/null +++ b/backend/internal/service/admin_service_search_test.go @@ -0,0 +1,238 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/stretchr/testify/require" +) + +type accountRepoStubForAdminList struct { + accountRepoStub + + listWithFiltersCalls int + listWithFiltersParams pagination.PaginationParams + listWithFiltersPlatform string + listWithFiltersType string + listWithFiltersStatus string + listWithFiltersSearch string + listWithFiltersAccounts []Account + listWithFiltersResult *pagination.PaginationResult + listWithFiltersErr error +} + +func (s *accountRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) { + s.listWithFiltersCalls++ + s.listWithFiltersParams = params + s.listWithFiltersPlatform = platform + s.listWithFiltersType = accountType + s.listWithFiltersStatus = status + s.listWithFiltersSearch = search + + if s.listWithFiltersErr != nil { + return nil, nil, s.listWithFiltersErr + } + + result := s.listWithFiltersResult + if result == nil { + result = &pagination.PaginationResult{ + Total: int64(len(s.listWithFiltersAccounts)), + Page: params.Page, + PageSize: params.PageSize, + } + } + + return s.listWithFiltersAccounts, result, nil +} + +type proxyRepoStubForAdminList struct { + proxyRepoStub + + listWithFiltersCalls int + listWithFiltersParams pagination.PaginationParams + listWithFiltersProtocol string + listWithFiltersStatus string + listWithFiltersSearch string + listWithFiltersProxies []Proxy + listWithFiltersResult *pagination.PaginationResult + listWithFiltersErr error + + listWithFiltersAndAccountCountCalls int + listWithFiltersAndAccountCountParams pagination.PaginationParams + listWithFiltersAndAccountCountProtocol string + listWithFiltersAndAccountCountStatus string + listWithFiltersAndAccountCountSearch string + listWithFiltersAndAccountCountProxies []ProxyWithAccountCount + listWithFiltersAndAccountCountResult *pagination.PaginationResult + listWithFiltersAndAccountCountErr error +} + +func (s *proxyRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, protocol, status, search string) ([]Proxy, *pagination.PaginationResult, error) { + s.listWithFiltersCalls++ + s.listWithFiltersParams = params + s.listWithFiltersProtocol = protocol + s.listWithFiltersStatus = status + s.listWithFiltersSearch = search + + if s.listWithFiltersErr != nil { + return nil, nil, s.listWithFiltersErr + } + + result := s.listWithFiltersResult + if result == nil { + result = &pagination.PaginationResult{ + Total: int64(len(s.listWithFiltersProxies)), + Page: params.Page, + PageSize: params.PageSize, + } + } + + return s.listWithFiltersProxies, result, nil +} + +func (s *proxyRepoStubForAdminList) ListWithFiltersAndAccountCount(_ context.Context, params pagination.PaginationParams, protocol, status, search string) ([]ProxyWithAccountCount, *pagination.PaginationResult, error) { + s.listWithFiltersAndAccountCountCalls++ + s.listWithFiltersAndAccountCountParams = params + s.listWithFiltersAndAccountCountProtocol = protocol + s.listWithFiltersAndAccountCountStatus = status + s.listWithFiltersAndAccountCountSearch = search + + if s.listWithFiltersAndAccountCountErr != nil { + return nil, nil, s.listWithFiltersAndAccountCountErr + } + + result := s.listWithFiltersAndAccountCountResult + if result == nil { + result = &pagination.PaginationResult{ + Total: int64(len(s.listWithFiltersAndAccountCountProxies)), + Page: params.Page, + PageSize: params.PageSize, + } + } + + return s.listWithFiltersAndAccountCountProxies, result, nil +} + +type redeemRepoStubForAdminList struct { + redeemRepoStub + + listWithFiltersCalls int + listWithFiltersParams pagination.PaginationParams + listWithFiltersType string + listWithFiltersStatus string + listWithFiltersSearch string + listWithFiltersCodes []RedeemCode + listWithFiltersResult *pagination.PaginationResult + listWithFiltersErr error +} + +func (s *redeemRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, codeType, status, search string) ([]RedeemCode, *pagination.PaginationResult, error) { + s.listWithFiltersCalls++ + s.listWithFiltersParams = params + s.listWithFiltersType = codeType + s.listWithFiltersStatus = status + s.listWithFiltersSearch = search + + if s.listWithFiltersErr != nil { + return nil, nil, s.listWithFiltersErr + } + + result := s.listWithFiltersResult + if result == nil { + result = &pagination.PaginationResult{ + Total: int64(len(s.listWithFiltersCodes)), + Page: params.Page, + PageSize: params.PageSize, + } + } + + return s.listWithFiltersCodes, result, nil +} + +func TestAdminService_ListAccounts_WithSearch(t *testing.T) { + t.Run("search 参数正常传递到 repository 层", func(t *testing.T) { + repo := &accountRepoStubForAdminList{ + listWithFiltersAccounts: []Account{{ID: 1, Name: "acc"}}, + listWithFiltersResult: &pagination.PaginationResult{Total: 10}, + } + svc := &adminServiceImpl{accountRepo: repo} + + accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformGemini, AccountTypeOAuth, StatusActive, "acc") + require.NoError(t, err) + require.Equal(t, int64(10), total) + require.Equal(t, []Account{{ID: 1, Name: "acc"}}, accounts) + + require.Equal(t, 1, repo.listWithFiltersCalls) + require.Equal(t, pagination.PaginationParams{Page: 1, PageSize: 20}, repo.listWithFiltersParams) + require.Equal(t, PlatformGemini, repo.listWithFiltersPlatform) + require.Equal(t, AccountTypeOAuth, repo.listWithFiltersType) + require.Equal(t, StatusActive, repo.listWithFiltersStatus) + require.Equal(t, "acc", repo.listWithFiltersSearch) + }) +} + +func TestAdminService_ListProxies_WithSearch(t *testing.T) { + t.Run("search 参数正常传递到 repository 层", func(t *testing.T) { + repo := &proxyRepoStubForAdminList{ + listWithFiltersProxies: []Proxy{{ID: 2, Name: "p1"}}, + listWithFiltersResult: &pagination.PaginationResult{Total: 7}, + } + svc := &adminServiceImpl{proxyRepo: repo} + + proxies, total, err := svc.ListProxies(context.Background(), 3, 50, "http", StatusActive, "p1") + require.NoError(t, err) + require.Equal(t, int64(7), total) + require.Equal(t, []Proxy{{ID: 2, Name: "p1"}}, proxies) + + require.Equal(t, 1, repo.listWithFiltersCalls) + require.Equal(t, pagination.PaginationParams{Page: 3, PageSize: 50}, repo.listWithFiltersParams) + require.Equal(t, "http", repo.listWithFiltersProtocol) + require.Equal(t, StatusActive, repo.listWithFiltersStatus) + require.Equal(t, "p1", repo.listWithFiltersSearch) + }) +} + +func TestAdminService_ListProxiesWithAccountCount_WithSearch(t *testing.T) { + t.Run("search 参数正常传递到 repository 层", func(t *testing.T) { + repo := &proxyRepoStubForAdminList{ + listWithFiltersAndAccountCountProxies: []ProxyWithAccountCount{{Proxy: Proxy{ID: 3, Name: "p2"}, AccountCount: 5}}, + listWithFiltersAndAccountCountResult: &pagination.PaginationResult{Total: 9}, + } + svc := &adminServiceImpl{proxyRepo: repo} + + proxies, total, err := svc.ListProxiesWithAccountCount(context.Background(), 2, 10, "socks5", StatusDisabled, "p2") + require.NoError(t, err) + require.Equal(t, int64(9), total) + require.Equal(t, []ProxyWithAccountCount{{Proxy: Proxy{ID: 3, Name: "p2"}, AccountCount: 5}}, proxies) + + require.Equal(t, 1, repo.listWithFiltersAndAccountCountCalls) + require.Equal(t, pagination.PaginationParams{Page: 2, PageSize: 10}, repo.listWithFiltersAndAccountCountParams) + require.Equal(t, "socks5", repo.listWithFiltersAndAccountCountProtocol) + require.Equal(t, StatusDisabled, repo.listWithFiltersAndAccountCountStatus) + require.Equal(t, "p2", repo.listWithFiltersAndAccountCountSearch) + }) +} + +func TestAdminService_ListRedeemCodes_WithSearch(t *testing.T) { + t.Run("search 参数正常传递到 repository 层", func(t *testing.T) { + repo := &redeemRepoStubForAdminList{ + listWithFiltersCodes: []RedeemCode{{ID: 4, Code: "ABC"}}, + listWithFiltersResult: &pagination.PaginationResult{Total: 3}, + } + svc := &adminServiceImpl{redeemCodeRepo: repo} + + codes, total, err := svc.ListRedeemCodes(context.Background(), 1, 20, RedeemTypeBalance, StatusUnused, "ABC") + require.NoError(t, err) + require.Equal(t, int64(3), total) + require.Equal(t, []RedeemCode{{ID: 4, Code: "ABC"}}, codes) + + require.Equal(t, 1, repo.listWithFiltersCalls) + require.Equal(t, pagination.PaginationParams{Page: 1, PageSize: 20}, repo.listWithFiltersParams) + require.Equal(t, RedeemTypeBalance, repo.listWithFiltersType) + require.Equal(t, StatusUnused, repo.listWithFiltersStatus) + require.Equal(t, "ABC", repo.listWithFiltersSearch) + }) +} diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index 573017cd..4fd55757 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -93,6 +93,7 @@ var antigravityPrefixMapping = []struct { // 长前缀优先 {"gemini-2.5-flash-image", "gemini-3-pro-image"}, // gemini-2.5-flash-image → 3-pro-image {"gemini-3-pro-image", "gemini-3-pro-image"}, // gemini-3-pro-image-preview 等 + {"gemini-3-flash", "gemini-3-flash"}, // gemini-3-flash-preview 等 → gemini-3-flash {"claude-3-5-sonnet", "claude-sonnet-4-5"}, // 旧版 claude-3-5-sonnet-xxx {"claude-sonnet-4-5", "claude-sonnet-4-5"}, // claude-sonnet-4-5-xxx {"claude-haiku-4-5", "claude-sonnet-4-5"}, // claude-haiku-4-5-xxx → sonnet @@ -502,6 +503,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, originalModel := claudeReq.Model mappedModel := s.getMappedModel(account, claudeReq.Model) + quotaScope, _ := resolveAntigravityQuotaScope(originalModel) // 获取 access_token if s.tokenProvider == nil { @@ -603,7 +605,7 @@ urlFallbackLoop: } // 所有重试都失败,标记限流状态 if resp.StatusCode == 429 { - s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody) + s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope) } // 最后一次尝试也失败 resp = &http.Response{ @@ -696,7 +698,7 @@ urlFallbackLoop: // 处理错误响应(重试后仍失败或不触发重试) if resp.StatusCode >= 400 { - s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody) + s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope) if s.shouldFailoverUpstreamError(resp.StatusCode) { return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} @@ -1021,6 +1023,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co if len(body) == 0 { return nil, s.writeGoogleError(c, http.StatusBadRequest, "Request body is empty") } + quotaScope, _ := resolveAntigravityQuotaScope(originalModel) // 解析请求以获取 image_size(用于图片计费) imageSize := s.extractImageSize(body) @@ -1146,7 +1149,7 @@ urlFallbackLoop: } // 所有重试都失败,标记限流状态 if resp.StatusCode == 429 { - s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody) + s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope) } resp = &http.Response{ StatusCode: resp.StatusCode, @@ -1200,7 +1203,7 @@ urlFallbackLoop: goto handleSuccess } - s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody) + s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope) if s.shouldFailoverUpstreamError(resp.StatusCode) { return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} @@ -1314,7 +1317,7 @@ func sleepAntigravityBackoffWithContext(ctx context.Context, attempt int) bool { } } -func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte) { +func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope) { // 429 使用 Gemini 格式解析(从 body 解析重置时间) if statusCode == 429 { resetAt := ParseGeminiRateLimitResetTime(body) @@ -1325,13 +1328,23 @@ func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, pre defaultDur = 5 * time.Minute } ra := time.Now().Add(defaultDur) - log.Printf("%s status=429 rate_limited reset_in=%v (fallback)", prefix, defaultDur) - _ = s.accountRepo.SetRateLimited(ctx, account.ID, ra) + log.Printf("%s status=429 rate_limited scope=%s reset_in=%v (fallback)", prefix, quotaScope, defaultDur) + if quotaScope == "" { + return + } + if err := s.accountRepo.SetAntigravityQuotaScopeLimit(ctx, account.ID, quotaScope, ra); err != nil { + log.Printf("%s status=429 rate_limit_set_failed scope=%s error=%v", prefix, quotaScope, err) + } return } resetTime := time.Unix(*resetAt, 0) - log.Printf("%s status=429 rate_limited reset_at=%v reset_in=%v", prefix, resetTime.Format("15:04:05"), time.Until(resetTime).Truncate(time.Second)) - _ = s.accountRepo.SetRateLimited(ctx, account.ID, resetTime) + log.Printf("%s status=429 rate_limited scope=%s reset_at=%v reset_in=%v", prefix, quotaScope, resetTime.Format("15:04:05"), time.Until(resetTime).Truncate(time.Second)) + if quotaScope == "" { + return + } + if err := s.accountRepo.SetAntigravityQuotaScopeLimit(ctx, account.ID, quotaScope, resetTime); err != nil { + log.Printf("%s status=429 rate_limit_set_failed scope=%s error=%v", prefix, quotaScope, err) + } return } // 其他错误码继续使用 rateLimitService diff --git a/backend/internal/service/antigravity_quota_scope.go b/backend/internal/service/antigravity_quota_scope.go new file mode 100644 index 00000000..e9f7184b --- /dev/null +++ b/backend/internal/service/antigravity_quota_scope.go @@ -0,0 +1,88 @@ +package service + +import ( + "strings" + "time" +) + +const antigravityQuotaScopesKey = "antigravity_quota_scopes" + +// AntigravityQuotaScope 表示 Antigravity 的配额域 +type AntigravityQuotaScope string + +const ( + AntigravityQuotaScopeClaude AntigravityQuotaScope = "claude" + AntigravityQuotaScopeGeminiText AntigravityQuotaScope = "gemini_text" + AntigravityQuotaScopeGeminiImage AntigravityQuotaScope = "gemini_image" +) + +// resolveAntigravityQuotaScope 根据模型名称解析配额域 +func resolveAntigravityQuotaScope(requestedModel string) (AntigravityQuotaScope, bool) { + model := normalizeAntigravityModelName(requestedModel) + if model == "" { + return "", false + } + switch { + case strings.HasPrefix(model, "claude-"): + return AntigravityQuotaScopeClaude, true + case strings.HasPrefix(model, "gemini-"): + if isImageGenerationModel(model) { + return AntigravityQuotaScopeGeminiImage, true + } + return AntigravityQuotaScopeGeminiText, true + default: + return "", false + } +} + +func normalizeAntigravityModelName(model string) string { + normalized := strings.ToLower(strings.TrimSpace(model)) + normalized = strings.TrimPrefix(normalized, "models/") + return normalized +} + +// IsSchedulableForModel 结合 Antigravity 配额域限流判断是否可调度 +func (a *Account) IsSchedulableForModel(requestedModel string) bool { + if a == nil { + return false + } + if !a.IsSchedulable() { + return false + } + if a.Platform != PlatformAntigravity { + return true + } + scope, ok := resolveAntigravityQuotaScope(requestedModel) + if !ok { + return true + } + resetAt := a.antigravityQuotaScopeResetAt(scope) + if resetAt == nil { + return true + } + now := time.Now() + return !now.Before(*resetAt) +} + +func (a *Account) antigravityQuotaScopeResetAt(scope AntigravityQuotaScope) *time.Time { + if a == nil || a.Extra == nil || scope == "" { + return nil + } + rawScopes, ok := a.Extra[antigravityQuotaScopesKey].(map[string]any) + if !ok { + return nil + } + rawScope, ok := rawScopes[string(scope)].(map[string]any) + if !ok { + return nil + } + resetAtRaw, ok := rawScope["rate_limit_reset_at"].(string) + if !ok || strings.TrimSpace(resetAtRaw) == "" { + return nil + } + resetAt, err := time.Parse(time.RFC3339, resetAtRaw) + if err != nil { + return nil + } + return &resetAt +} diff --git a/backend/internal/service/api_key.go b/backend/internal/service/api_key.go index 0cf0f4f9..8c692d09 100644 --- a/backend/internal/service/api_key.go +++ b/backend/internal/service/api_key.go @@ -3,16 +3,18 @@ package service import "time" type APIKey struct { - ID int64 - UserID int64 - Key string - Name string - GroupID *int64 - Status string - CreatedAt time.Time - UpdatedAt time.Time - User *User - Group *Group + ID int64 + UserID int64 + Key string + Name string + GroupID *int64 + Status string + IPWhitelist []string + IPBlacklist []string + CreatedAt time.Time + UpdatedAt time.Time + User *User + Group *Group } func (k *APIKey) IsActive() bool { diff --git a/backend/internal/service/api_key_service.go b/backend/internal/service/api_key_service.go index 0ffe8821..578afc1a 100644 --- a/backend/internal/service/api_key_service.go +++ b/backend/internal/service/api_key_service.go @@ -9,6 +9,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/config" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/ip" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" ) @@ -20,6 +21,7 @@ var ( ErrAPIKeyTooShort = infraerrors.BadRequest("API_KEY_TOO_SHORT", "api key must be at least 16 characters") ErrAPIKeyInvalidChars = infraerrors.BadRequest("API_KEY_INVALID_CHARS", "api key can only contain letters, numbers, underscores, and hyphens") ErrAPIKeyRateLimited = infraerrors.TooManyRequests("API_KEY_RATE_LIMITED", "too many failed attempts, please try again later") + ErrInvalidIPPattern = infraerrors.BadRequest("INVALID_IP_PATTERN", "invalid IP or CIDR pattern") ) const ( @@ -57,16 +59,20 @@ type APIKeyCache interface { // CreateAPIKeyRequest 创建API Key请求 type CreateAPIKeyRequest struct { - Name string `json:"name"` - GroupID *int64 `json:"group_id"` - CustomKey *string `json:"custom_key"` // 可选的自定义key + Name string `json:"name"` + GroupID *int64 `json:"group_id"` + CustomKey *string `json:"custom_key"` // 可选的自定义key + IPWhitelist []string `json:"ip_whitelist"` // IP 白名单 + IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单 } // UpdateAPIKeyRequest 更新API Key请求 type UpdateAPIKeyRequest struct { - Name *string `json:"name"` - GroupID *int64 `json:"group_id"` - Status *string `json:"status"` + Name *string `json:"name"` + GroupID *int64 `json:"group_id"` + Status *string `json:"status"` + IPWhitelist []string `json:"ip_whitelist"` // IP 白名单(空数组清空) + IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单(空数组清空) } // APIKeyService API Key服务 @@ -186,6 +192,20 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK return nil, fmt.Errorf("get user: %w", err) } + // 验证 IP 白名单格式 + if len(req.IPWhitelist) > 0 { + if invalid := ip.ValidateIPPatterns(req.IPWhitelist); len(invalid) > 0 { + return nil, fmt.Errorf("%w: %v", ErrInvalidIPPattern, invalid) + } + } + + // 验证 IP 黑名单格式 + if len(req.IPBlacklist) > 0 { + if invalid := ip.ValidateIPPatterns(req.IPBlacklist); len(invalid) > 0 { + return nil, fmt.Errorf("%w: %v", ErrInvalidIPPattern, invalid) + } + } + // 验证分组权限(如果指定了分组) if req.GroupID != nil { group, err := s.groupRepo.GetByID(ctx, *req.GroupID) @@ -236,11 +256,13 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK // 创建API Key记录 apiKey := &APIKey{ - UserID: userID, - Key: key, - Name: req.Name, - GroupID: req.GroupID, - Status: StatusActive, + UserID: userID, + Key: key, + Name: req.Name, + GroupID: req.GroupID, + Status: StatusActive, + IPWhitelist: req.IPWhitelist, + IPBlacklist: req.IPBlacklist, } if err := s.apiKeyRepo.Create(ctx, apiKey); err != nil { @@ -312,6 +334,20 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req return nil, ErrInsufficientPerms } + // 验证 IP 白名单格式 + if len(req.IPWhitelist) > 0 { + if invalid := ip.ValidateIPPatterns(req.IPWhitelist); len(invalid) > 0 { + return nil, fmt.Errorf("%w: %v", ErrInvalidIPPattern, invalid) + } + } + + // 验证 IP 黑名单格式 + if len(req.IPBlacklist) > 0 { + if invalid := ip.ValidateIPPatterns(req.IPBlacklist); len(invalid) > 0 { + return nil, fmt.Errorf("%w: %v", ErrInvalidIPPattern, invalid) + } + } + // 更新字段 if req.Name != nil { apiKey.Name = *req.Name @@ -344,6 +380,10 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req } } + // 更新 IP 限制(空数组会清空设置) + apiKey.IPWhitelist = req.IPWhitelist + apiKey.IPBlacklist = req.IPBlacklist + if err := s.apiKeyRepo.Update(ctx, apiKey); err != nil { return nil, fmt.Errorf("update api key: %w", err) } diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go index 6e685869..e232deb3 100644 --- a/backend/internal/service/auth_service.go +++ b/backend/internal/service/auth_service.go @@ -2,9 +2,13 @@ package service import ( "context" + "crypto/rand" + "encoding/hex" "errors" "fmt" "log" + "net/mail" + "strings" "time" "github.com/Wei-Shaw/sub2api/internal/config" @@ -18,6 +22,7 @@ var ( ErrInvalidCredentials = infraerrors.Unauthorized("INVALID_CREDENTIALS", "invalid email or password") ErrUserNotActive = infraerrors.Forbidden("USER_NOT_ACTIVE", "user is not active") ErrEmailExists = infraerrors.Conflict("EMAIL_EXISTS", "email already exists") + ErrEmailReserved = infraerrors.BadRequest("EMAIL_RESERVED", "email is reserved") ErrInvalidToken = infraerrors.Unauthorized("INVALID_TOKEN", "invalid token") ErrTokenExpired = infraerrors.Unauthorized("TOKEN_EXPIRED", "token has expired") ErrTokenTooLarge = infraerrors.BadRequest("TOKEN_TOO_LARGE", "token too large") @@ -80,6 +85,11 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw return "", nil, ErrRegDisabled } + // 防止用户注册 LinuxDo OAuth 合成邮箱,避免第三方登录与本地账号发生碰撞。 + if isReservedEmail(email) { + return "", nil, ErrEmailReserved + } + // 检查是否需要邮件验证 if s.settingService != nil && s.settingService.IsEmailVerifyEnabled(ctx) { // 如果邮件验证已开启但邮件服务未配置,拒绝注册 @@ -161,6 +171,10 @@ func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error { return ErrRegDisabled } + if isReservedEmail(email) { + return ErrEmailReserved + } + // 检查邮箱是否已存在 existsEmail, err := s.userRepo.ExistsByEmail(ctx, email) if err != nil { @@ -195,6 +209,10 @@ func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*S return nil, ErrRegDisabled } + if isReservedEmail(email) { + return nil, ErrEmailReserved + } + // 检查邮箱是否已存在 existsEmail, err := s.userRepo.ExistsByEmail(ctx, email) if err != nil { @@ -319,6 +337,102 @@ func (s *AuthService) Login(ctx context.Context, email, password string) (string return token, user, nil } +// LoginOrRegisterOAuth 用于第三方 OAuth/SSO 登录: +// - 如果邮箱已存在:直接登录(不需要本地密码) +// - 如果邮箱不存在:创建新用户并登录 +// +// 注意:该函数用于“终端用户登录 Sub2API 本身”的场景(不同于上游账号的 OAuth,例如 OpenAI/Gemini)。 +// 为了满足现有数据库约束(需要密码哈希),新用户会生成随机密码并进行哈希保存。 +func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username string) (string, *User, error) { + email = strings.TrimSpace(email) + if email == "" || len(email) > 255 { + return "", nil, infraerrors.BadRequest("INVALID_EMAIL", "invalid email") + } + if _, err := mail.ParseAddress(email); err != nil { + return "", nil, infraerrors.BadRequest("INVALID_EMAIL", "invalid email") + } + + username = strings.TrimSpace(username) + if len([]rune(username)) > 100 { + username = string([]rune(username)[:100]) + } + + user, err := s.userRepo.GetByEmail(ctx, email) + if err != nil { + if errors.Is(err, ErrUserNotFound) { + // OAuth 首次登录视为注册。 + if s.settingService != nil && !s.settingService.IsRegistrationEnabled(ctx) { + return "", nil, ErrRegDisabled + } + + randomPassword, err := randomHexString(32) + if err != nil { + log.Printf("[Auth] Failed to generate random password for oauth signup: %v", err) + return "", nil, ErrServiceUnavailable + } + hashedPassword, err := s.HashPassword(randomPassword) + if err != nil { + return "", nil, fmt.Errorf("hash password: %w", err) + } + + // 新用户默认值。 + defaultBalance := s.cfg.Default.UserBalance + defaultConcurrency := s.cfg.Default.UserConcurrency + if s.settingService != nil { + defaultBalance = s.settingService.GetDefaultBalance(ctx) + defaultConcurrency = s.settingService.GetDefaultConcurrency(ctx) + } + + newUser := &User{ + Email: email, + Username: username, + PasswordHash: hashedPassword, + Role: RoleUser, + Balance: defaultBalance, + Concurrency: defaultConcurrency, + Status: StatusActive, + } + + if err := s.userRepo.Create(ctx, newUser); err != nil { + if errors.Is(err, ErrEmailExists) { + // 并发场景:GetByEmail 与 Create 之间用户被创建。 + user, err = s.userRepo.GetByEmail(ctx, email) + if err != nil { + log.Printf("[Auth] Database error getting user after conflict: %v", err) + return "", nil, ErrServiceUnavailable + } + } else { + log.Printf("[Auth] Database error creating oauth user: %v", err) + return "", nil, ErrServiceUnavailable + } + } else { + user = newUser + } + } else { + log.Printf("[Auth] Database error during oauth login: %v", err) + return "", nil, ErrServiceUnavailable + } + } + + if !user.IsActive() { + return "", nil, ErrUserNotActive + } + + // 尽力补全:当用户名为空时,使用第三方返回的用户名回填。 + if user.Username == "" && username != "" { + user.Username = username + if err := s.userRepo.Update(ctx, user); err != nil { + log.Printf("[Auth] Failed to update username after oauth login: %v", err) + } + } + + token, err := s.GenerateToken(user) + if err != nil { + return "", nil, fmt.Errorf("generate token: %w", err) + } + return token, user, nil +} + // ValidateToken 验证JWT token并返回用户声明 func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) { // 先做长度校验,尽早拒绝异常超长 token,降低 DoS 风险。 @@ -361,6 +475,22 @@ func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) { return nil, ErrInvalidToken } +func randomHexString(byteLength int) (string, error) { + if byteLength <= 0 { + byteLength = 16 + } + buf := make([]byte, byteLength) + if _, err := rand.Read(buf); err != nil { + return "", err + } + return hex.EncodeToString(buf), nil +} + +func isReservedEmail(email string) bool { + normalized := strings.ToLower(strings.TrimSpace(email)) + return strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain) +} + // GenerateToken 生成JWT token func (s *AuthService) GenerateToken(user *User) (string, error) { now := time.Now() diff --git a/backend/internal/service/auth_service_register_test.go b/backend/internal/service/auth_service_register_test.go index bfd504a3..ab1f20a0 100644 --- a/backend/internal/service/auth_service_register_test.go +++ b/backend/internal/service/auth_service_register_test.go @@ -182,6 +182,16 @@ func TestAuthService_Register_CheckEmailError(t *testing.T) { require.ErrorIs(t, err, ErrServiceUnavailable) } +func TestAuthService_Register_ReservedEmail(t *testing.T) { + repo := &userRepoStub{} + service := newAuthService(repo, map[string]string{ + SettingKeyRegistrationEnabled: "true", + }, nil) + + _, _, err := service.Register(context.Background(), "linuxdo-123@linuxdo-connect.invalid", "password") + require.ErrorIs(t, err, ErrEmailReserved) +} + func TestAuthService_Register_CreateError(t *testing.T) { repo := &userRepoStub{createErr: errors.New("create failed")} service := newAuthService(repo, map[string]string{ diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go index 9c61ea2e..df34e167 100644 --- a/backend/internal/service/domain_constants.go +++ b/backend/internal/service/domain_constants.go @@ -105,7 +105,17 @@ const ( // Request identity patch (Claude -> Gemini systemInstruction injection) SettingKeyEnableIdentityPatch = "enable_identity_patch" SettingKeyIdentityPatchPrompt = "identity_patch_prompt" + + // LinuxDo Connect OAuth 登录(终端用户 SSO) + SettingKeyLinuxDoConnectEnabled = "linuxdo_connect_enabled" + SettingKeyLinuxDoConnectClientID = "linuxdo_connect_client_id" + SettingKeyLinuxDoConnectClientSecret = "linuxdo_connect_client_secret" + SettingKeyLinuxDoConnectRedirectURL = "linuxdo_connect_redirect_url" ) +// LinuxDoConnectSyntheticEmailDomain 是 LinuxDo Connect 用户的合成邮箱后缀(RFC 保留域名)。 +// 目的:避免第三方登录返回的用户标识与本地真实邮箱发生碰撞,进而造成账号被接管的风险。 +const LinuxDoConnectSyntheticEmailDomain = "@linuxdo-connect.invalid" + // AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys). const AdminAPIKeyPrefix = "admin-" diff --git a/backend/internal/service/gateway_multiplatform_test.go b/backend/internal/service/gateway_multiplatform_test.go index 66c40e25..da7c311c 100644 --- a/backend/internal/service/gateway_multiplatform_test.go +++ b/backend/internal/service/gateway_multiplatform_test.go @@ -136,6 +136,9 @@ func (m *mockAccountRepoForPlatform) ListSchedulableByGroupIDAndPlatforms(ctx co func (m *mockAccountRepoForPlatform) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { return nil } +func (m *mockAccountRepoForPlatform) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error { + return nil +} func (m *mockAccountRepoForPlatform) SetOverloaded(ctx context.Context, id int64, until time.Time) error { return nil } @@ -148,6 +151,9 @@ func (m *mockAccountRepoForPlatform) ClearTempUnschedulable(ctx context.Context, func (m *mockAccountRepoForPlatform) ClearRateLimit(ctx context.Context, id int64) error { return nil } +func (m *mockAccountRepoForPlatform) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error { + return nil +} func (m *mockAccountRepoForPlatform) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error { return nil } diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index e73e9406..7623d025 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -33,7 +33,7 @@ const ( claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true" claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true" stickySessionTTL = time.Hour // 粘性会话TTL - defaultMaxLineSize = 10 * 1024 * 1024 + defaultMaxLineSize = 40 * 1024 * 1024 claudeCodeSystemPrompt = "You are Claude Code, Anthropic's official CLI for Claude." maxCacheControlBlocks = 4 // Anthropic API 允许的最大 cache_control 块数量 ) @@ -481,7 +481,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro account, err := s.accountRepo.GetByID(ctx, accountID) if err == nil && s.isAccountInGroup(account, groupID) && s.isAccountAllowedForPlatform(account, platform, useMixed) && - account.IsSchedulable() && + account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) { result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) if err == nil && result.Acquired { @@ -519,6 +519,9 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if !s.isAccountAllowedForPlatform(acc, platform, useMixed) { continue } + if !acc.IsSchedulableForModel(requestedModel) { + continue + } if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) { continue } @@ -812,7 +815,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, if _, excluded := excludedIDs[accountID]; !excluded { account, err := s.accountRepo.GetByID(ctx, accountID) // 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台) - if err == nil && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulable() && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) { + if err == nil && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) { if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil { log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err) } @@ -844,6 +847,9 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, if _, excluded := excludedIDs[acc.ID]; excluded { continue } + if !acc.IsSchedulableForModel(requestedModel) { + continue + } if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) { continue } @@ -901,7 +907,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g if _, excluded := excludedIDs[accountID]; !excluded { account, err := s.accountRepo.GetByID(ctx, accountID) // 检查账号分组归属和有效性:原生平台直接匹配,antigravity 需要启用混合调度 - if err == nil && s.isAccountInGroup(account, groupID) && account.IsSchedulable() && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) { + if err == nil && s.isAccountInGroup(account, groupID) && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) { if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) { if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil { log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err) @@ -936,6 +942,9 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() { continue } + if !acc.IsSchedulableForModel(requestedModel) { + continue + } if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) { continue } @@ -2247,6 +2256,7 @@ type RecordUsageInput struct { Account *Account Subscription *UserSubscription // 可选:订阅信息 UserAgent string // 请求的 User-Agent + IPAddress string // 请求的客户端 IP 地址 } // RecordUsage 记录使用量并扣费(或更新订阅用量) @@ -2337,6 +2347,11 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu usageLog.UserAgent = &input.UserAgent } + // 添加 IPAddress + if input.IPAddress != "" { + usageLog.IPAddress = &input.IPAddress + } + // 添加分组和订阅关联 if apiKey.GroupID != nil { usageLog.GroupID = apiKey.GroupID diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go index f2b5bafd..2b500072 100644 --- a/backend/internal/service/gemini_messages_compat_service.go +++ b/backend/internal/service/gemini_messages_compat_service.go @@ -114,7 +114,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co if _, excluded := excludedIDs[accountID]; !excluded { account, err := s.accountRepo.GetByID(ctx, accountID) // 检查账号是否有效:原生平台直接匹配,antigravity 需要启用混合调度 - if err == nil && account.IsSchedulable() && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) { + if err == nil && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) { valid := false if account.Platform == platform { valid = true @@ -172,6 +172,9 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co if useMixedScheduling && acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() { continue } + if !acc.IsSchedulableForModel(requestedModel) { + continue + } if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) { continue } diff --git a/backend/internal/service/gemini_multiplatform_test.go b/backend/internal/service/gemini_multiplatform_test.go index 6007bce8..d9df5f4c 100644 --- a/backend/internal/service/gemini_multiplatform_test.go +++ b/backend/internal/service/gemini_multiplatform_test.go @@ -121,6 +121,9 @@ func (m *mockAccountRepoForGemini) ListSchedulableByGroupIDAndPlatforms(ctx cont func (m *mockAccountRepoForGemini) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { return nil } +func (m *mockAccountRepoForGemini) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error { + return nil +} func (m *mockAccountRepoForGemini) SetOverloaded(ctx context.Context, id int64, until time.Time) error { return nil } @@ -131,6 +134,9 @@ func (m *mockAccountRepoForGemini) ClearTempUnschedulable(ctx context.Context, i return nil } func (m *mockAccountRepoForGemini) ClearRateLimit(ctx context.Context, id int64) error { return nil } +func (m *mockAccountRepoForGemini) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error { + return nil +} func (m *mockAccountRepoForGemini) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error { return nil } @@ -166,7 +172,7 @@ func (m *mockGroupRepoForGemini) DeleteCascade(ctx context.Context, id int64) ([ func (m *mockGroupRepoForGemini) List(ctx context.Context, params pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) { return nil, nil, nil } -func (m *mockGroupRepoForGemini) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) { +func (m *mockGroupRepoForGemini) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) { return nil, nil, nil } func (m *mockGroupRepoForGemini) ListActive(ctx context.Context) ([]Group, error) { return nil, nil } diff --git a/backend/internal/service/group_service.go b/backend/internal/service/group_service.go index 403636e8..a444556f 100644 --- a/backend/internal/service/group_service.go +++ b/backend/internal/service/group_service.go @@ -21,7 +21,7 @@ type GroupRepository interface { DeleteCascade(ctx context.Context, id int64) ([]int64, error) List(ctx context.Context, params pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) - ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) + ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) ListActive(ctx context.Context) ([]Group, error) ListActiveByPlatform(ctx context.Context, platform string) ([]Group, error) diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 42e98585..5bb7574a 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -1197,6 +1197,7 @@ type OpenAIRecordUsageInput struct { Account *Account Subscription *UserSubscription UserAgent string // 请求的 User-Agent + IPAddress string // 请求的客户端 IP 地址 } // RecordUsage records usage and deducts balance @@ -1271,6 +1272,11 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec usageLog.UserAgent = &input.UserAgent } + // 添加 IPAddress + if input.IPAddress != "" { + usageLog.IPAddress = &input.IPAddress + } + if apiKey.GroupID != nil { usageLog.GroupID = apiKey.GroupID } diff --git a/backend/internal/service/ratelimit_service.go b/backend/internal/service/ratelimit_service.go index 196f1643..f1362646 100644 --- a/backend/internal/service/ratelimit_service.go +++ b/backend/internal/service/ratelimit_service.go @@ -345,7 +345,7 @@ func (s *RateLimitService) UpdateSessionWindow(ctx context.Context, account *Acc // 如果状态为allowed且之前有限流,说明窗口已重置,清除限流状态 if status == "allowed" && account.IsRateLimited() { - if err := s.accountRepo.ClearRateLimit(ctx, account.ID); err != nil { + if err := s.ClearRateLimit(ctx, account.ID); err != nil { log.Printf("ClearRateLimit failed for account %d: %v", account.ID, err) } } @@ -353,7 +353,10 @@ func (s *RateLimitService) UpdateSessionWindow(ctx context.Context, account *Acc // ClearRateLimit 清除账号的限流状态 func (s *RateLimitService) ClearRateLimit(ctx context.Context, accountID int64) error { - return s.accountRepo.ClearRateLimit(ctx, accountID) + if err := s.accountRepo.ClearRateLimit(ctx, accountID); err != nil { + return err + } + return s.accountRepo.ClearAntigravityQuotaScopes(ctx, accountID) } func (s *RateLimitService) ClearTempUnschedulable(ctx context.Context, accountID int64) error { diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index 965253cf..d25698de 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "strconv" + "strings" "github.com/Wei-Shaw/sub2api/internal/config" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" @@ -64,6 +65,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings SettingKeyAPIBaseURL, SettingKeyContactInfo, SettingKeyDocURL, + SettingKeyLinuxDoConnectEnabled, } settings, err := s.settingRepo.GetMultiple(ctx, keys) @@ -71,6 +73,13 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings return nil, fmt.Errorf("get public settings: %w", err) } + linuxDoEnabled := false + if raw, ok := settings[SettingKeyLinuxDoConnectEnabled]; ok { + linuxDoEnabled = raw == "true" + } else { + linuxDoEnabled = s.cfg != nil && s.cfg.LinuxDo.Enabled + } + return &PublicSettings{ RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true", EmailVerifyEnabled: settings[SettingKeyEmailVerifyEnabled] == "true", @@ -82,6 +91,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings APIBaseURL: settings[SettingKeyAPIBaseURL], ContactInfo: settings[SettingKeyContactInfo], DocURL: settings[SettingKeyDocURL], + LinuxDoOAuthEnabled: linuxDoEnabled, }, nil } @@ -111,6 +121,14 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet updates[SettingKeyTurnstileSecretKey] = settings.TurnstileSecretKey } + // LinuxDo Connect OAuth 登录(终端用户 SSO) + updates[SettingKeyLinuxDoConnectEnabled] = strconv.FormatBool(settings.LinuxDoConnectEnabled) + updates[SettingKeyLinuxDoConnectClientID] = settings.LinuxDoConnectClientID + updates[SettingKeyLinuxDoConnectRedirectURL] = settings.LinuxDoConnectRedirectURL + if settings.LinuxDoConnectClientSecret != "" { + updates[SettingKeyLinuxDoConnectClientSecret] = settings.LinuxDoConnectClientSecret + } + // OEM设置 updates[SettingKeySiteName] = settings.SiteName updates[SettingKeySiteLogo] = settings.SiteLogo @@ -271,6 +289,38 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin result.SMTPPassword = settings[SettingKeySMTPPassword] result.TurnstileSecretKey = settings[SettingKeyTurnstileSecretKey] + // LinuxDo Connect 设置: + // - 兼容 config.yaml/env(避免老部署因为未迁移到数据库设置而被意外关闭) + // - 支持在后台“系统设置”中覆盖并持久化(存储于 DB) + linuxDoBase := config.LinuxDoConnectConfig{} + if s.cfg != nil { + linuxDoBase = s.cfg.LinuxDo + } + + if raw, ok := settings[SettingKeyLinuxDoConnectEnabled]; ok { + result.LinuxDoConnectEnabled = raw == "true" + } else { + result.LinuxDoConnectEnabled = linuxDoBase.Enabled + } + + if v, ok := settings[SettingKeyLinuxDoConnectClientID]; ok && strings.TrimSpace(v) != "" { + result.LinuxDoConnectClientID = strings.TrimSpace(v) + } else { + result.LinuxDoConnectClientID = linuxDoBase.ClientID + } + + if v, ok := settings[SettingKeyLinuxDoConnectRedirectURL]; ok && strings.TrimSpace(v) != "" { + result.LinuxDoConnectRedirectURL = strings.TrimSpace(v) + } else { + result.LinuxDoConnectRedirectURL = linuxDoBase.RedirectURL + } + + result.LinuxDoConnectClientSecret = strings.TrimSpace(settings[SettingKeyLinuxDoConnectClientSecret]) + if result.LinuxDoConnectClientSecret == "" { + result.LinuxDoConnectClientSecret = strings.TrimSpace(linuxDoBase.ClientSecret) + } + result.LinuxDoConnectClientSecretConfigured = result.LinuxDoConnectClientSecret != "" + // Model fallback settings result.EnableModelFallback = settings[SettingKeyEnableModelFallback] == "true" result.FallbackModelAnthropic = s.getStringOrDefault(settings, SettingKeyFallbackModelAnthropic, "claude-3-5-sonnet-20241022") @@ -289,6 +339,99 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin return result } +// GetLinuxDoConnectOAuthConfig 返回用于登录的“最终生效” LinuxDo Connect 配置。 +// +// 优先级: +// - 若对应系统设置键存在,则覆盖 config.yaml/env 的值 +// - 否则回退到 config.yaml/env 的值 +func (s *SettingService) GetLinuxDoConnectOAuthConfig(ctx context.Context) (config.LinuxDoConnectConfig, error) { + if s == nil || s.cfg == nil { + return config.LinuxDoConnectConfig{}, infraerrors.ServiceUnavailable("CONFIG_NOT_READY", "config not loaded") + } + + effective := s.cfg.LinuxDo + + keys := []string{ + SettingKeyLinuxDoConnectEnabled, + SettingKeyLinuxDoConnectClientID, + SettingKeyLinuxDoConnectClientSecret, + SettingKeyLinuxDoConnectRedirectURL, + } + settings, err := s.settingRepo.GetMultiple(ctx, keys) + if err != nil { + return config.LinuxDoConnectConfig{}, fmt.Errorf("get linuxdo connect settings: %w", err) + } + + if raw, ok := settings[SettingKeyLinuxDoConnectEnabled]; ok { + effective.Enabled = raw == "true" + } + if v, ok := settings[SettingKeyLinuxDoConnectClientID]; ok && strings.TrimSpace(v) != "" { + effective.ClientID = strings.TrimSpace(v) + } + if v, ok := settings[SettingKeyLinuxDoConnectClientSecret]; ok && strings.TrimSpace(v) != "" { + effective.ClientSecret = strings.TrimSpace(v) + } + if v, ok := settings[SettingKeyLinuxDoConnectRedirectURL]; ok && strings.TrimSpace(v) != "" { + effective.RedirectURL = strings.TrimSpace(v) + } + + if !effective.Enabled { + return config.LinuxDoConnectConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "oauth login is disabled") + } + + // 基础健壮性校验(避免把用户重定向到一个必然失败或不安全的 OAuth 流程里)。 + if strings.TrimSpace(effective.ClientID) == "" { + return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth client id not configured") + } + if strings.TrimSpace(effective.AuthorizeURL) == "" { + return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth authorize url not configured") + } + if strings.TrimSpace(effective.TokenURL) == "" { + return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth token url not configured") + } + if strings.TrimSpace(effective.UserInfoURL) == "" { + return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth userinfo url not configured") + } + if strings.TrimSpace(effective.RedirectURL) == "" { + return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth redirect url not configured") + } + if strings.TrimSpace(effective.FrontendRedirectURL) == "" { + return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth frontend redirect url not configured") + } + + if err := config.ValidateAbsoluteHTTPURL(effective.AuthorizeURL); err != nil { + return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth authorize url invalid") + } + if err := config.ValidateAbsoluteHTTPURL(effective.TokenURL); err != nil { + return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth token url invalid") + } + if err := config.ValidateAbsoluteHTTPURL(effective.UserInfoURL); err != nil { + return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth userinfo url invalid") + } + if err := config.ValidateAbsoluteHTTPURL(effective.RedirectURL); err != nil { + return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth redirect url invalid") + } + if err := config.ValidateFrontendRedirectURL(effective.FrontendRedirectURL); err != nil { + return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth frontend redirect url invalid") + } + + method := strings.ToLower(strings.TrimSpace(effective.TokenAuthMethod)) + switch method { + case "", "client_secret_post", "client_secret_basic": + if strings.TrimSpace(effective.ClientSecret) == "" { + return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth client secret not configured") + } + case "none": + if !effective.UsePKCE { + return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth pkce must be enabled when token_auth_method=none") + } + default: + return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth token_auth_method invalid") + } + + return effective, nil +} + // getStringOrDefault 获取字符串值或默认值 func (s *SettingService) getStringOrDefault(settings map[string]string, key, defaultValue string) string { if value, ok := settings[key]; ok && value != "" { diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go index de0331f7..26051418 100644 --- a/backend/internal/service/settings_view.go +++ b/backend/internal/service/settings_view.go @@ -18,6 +18,13 @@ type SystemSettings struct { TurnstileSecretKey string TurnstileSecretKeyConfigured bool + // LinuxDo Connect OAuth 登录(终端用户 SSO) + LinuxDoConnectEnabled bool + LinuxDoConnectClientID string + LinuxDoConnectClientSecret string + LinuxDoConnectClientSecretConfigured bool + LinuxDoConnectRedirectURL string + SiteName string SiteLogo string SiteSubtitle string @@ -51,5 +58,6 @@ type PublicSettings struct { APIBaseURL string ContactInfo string DocURL string + LinuxDoOAuthEnabled bool Version string } diff --git a/backend/internal/service/usage_log.go b/backend/internal/service/usage_log.go index 9ecb7098..62d7fae0 100644 --- a/backend/internal/service/usage_log.go +++ b/backend/internal/service/usage_log.go @@ -39,6 +39,7 @@ type UsageLog struct { DurationMs *int FirstTokenMs *int UserAgent *string + IPAddress *string // 图片生成字段 ImageCount int diff --git a/backend/migrations/031_add_ip_address.sql b/backend/migrations/031_add_ip_address.sql new file mode 100644 index 00000000..7f557830 --- /dev/null +++ b/backend/migrations/031_add_ip_address.sql @@ -0,0 +1,5 @@ +-- Add IP address field to usage_logs table for request tracking (admin-only visibility) +ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS ip_address VARCHAR(45); + +-- Create index for IP address queries +CREATE INDEX IF NOT EXISTS idx_usage_logs_ip_address ON usage_logs(ip_address); diff --git a/backend/migrations/032_add_api_key_ip_restriction.sql b/backend/migrations/032_add_api_key_ip_restriction.sql new file mode 100644 index 00000000..2dfe2c92 --- /dev/null +++ b/backend/migrations/032_add_api_key_ip_restriction.sql @@ -0,0 +1,9 @@ +-- Add IP restriction fields to api_keys table +-- ip_whitelist: JSON array of allowed IPs/CIDRs (if set, only these IPs can use the key) +-- ip_blacklist: JSON array of blocked IPs/CIDRs (these IPs are always blocked) + +ALTER TABLE api_keys ADD COLUMN IF NOT EXISTS ip_whitelist JSONB DEFAULT NULL; +ALTER TABLE api_keys ADD COLUMN IF NOT EXISTS ip_blacklist JSONB DEFAULT NULL; + +COMMENT ON COLUMN api_keys.ip_whitelist IS 'JSON array of allowed IPs/CIDRs, e.g. ["192.168.1.100", "10.0.0.0/8"]'; +COMMENT ON COLUMN api_keys.ip_blacklist IS 'JSON array of blocked IPs/CIDRs, e.g. ["1.2.3.4", "5.6.0.0/16"]'; diff --git a/backend/repository.test b/backend/repository.test new file mode 100755 index 00000000..9ecc014c Binary files /dev/null and b/backend/repository.test differ diff --git a/config.yaml b/config.yaml index f43c9c19..54b591f3 100644 --- a/config.yaml +++ b/config.yaml @@ -154,9 +154,9 @@ gateway: # Stream keepalive interval (seconds), 0=disable # 流式 keepalive 间隔(秒),0=禁用 stream_keepalive_interval: 10 - # SSE max line size in bytes (default: 10MB) - # SSE 单行最大字节数(默认 10MB) - max_line_size: 10485760 + # SSE max line size in bytes (default: 40MB) + # SSE 单行最大字节数(默认 40MB) + max_line_size: 41943040 # Log upstream error response body summary (safe/truncated; does not log request content) # 记录上游错误响应体摘要(安全/截断;不记录请求内容) log_upstream_error_body: false diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index 49bf0afa..87ff3148 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -154,9 +154,9 @@ gateway: # Stream keepalive interval (seconds), 0=disable # 流式 keepalive 间隔(秒),0=禁用 stream_keepalive_interval: 10 - # SSE max line size in bytes (default: 10MB) - # SSE 单行最大字节数(默认 10MB) - max_line_size: 10485760 + # SSE max line size in bytes (default: 40MB) + # SSE 单行最大字节数(默认 40MB) + max_line_size: 41943040 # Log upstream error response body summary (safe/truncated; does not log request content) # 记录上游错误响应体摘要(安全/截断;不记录请求内容) log_upstream_error_body: false @@ -234,6 +234,31 @@ jwt: # 令牌过期时间(小时,最大 24) expire_hour: 24 +# ============================================================================= +# LinuxDo Connect OAuth Login (SSO) +# LinuxDo Connect OAuth 登录(用于 Sub2API 用户登录) +# ============================================================================= +linuxdo_connect: + enabled: false + client_id: "" + client_secret: "" + authorize_url: "https://connect.linux.do/oauth2/authorize" + token_url: "https://connect.linux.do/oauth2/token" + userinfo_url: "https://connect.linux.do/api/user" + scopes: "user" + # 示例: "https://your-domain.com/api/v1/auth/oauth/linuxdo/callback" + redirect_url: "" + # 安全提示: + # - 建议使用同源相对路径(以 / 开头),避免把 token 重定向到意外的第三方域名 + # - 该地址不应包含 #fragment(本实现使用 URL fragment 传递 access_token) + frontend_redirect_url: "/auth/linuxdo/callback" + token_auth_method: "client_secret_post" # client_secret_post | client_secret_basic | none + # 注意:当 token_auth_method=none(public client)时,必须启用 PKCE + use_pkce: false + userinfo_email_path: "" + userinfo_id_path: "" + userinfo_username_path: "" + # ============================================================================= # Default Settings # 默认设置 diff --git a/deploy/docker-compose.standalone.yml b/deploy/docker-compose.standalone.yml new file mode 100644 index 00000000..1bf247c7 --- /dev/null +++ b/deploy/docker-compose.standalone.yml @@ -0,0 +1,93 @@ +# ============================================================================= +# Sub2API Docker Compose - Standalone Configuration +# ============================================================================= +# This configuration runs only the Sub2API application. +# PostgreSQL and Redis must be provided externally. +# +# Usage: +# 1. Copy .env.example to .env and configure database/redis connection +# 2. docker-compose -f docker-compose.standalone.yml up -d +# 3. Access: http://localhost:8080 +# ============================================================================= + +services: + sub2api: + image: weishaw/sub2api:latest + container_name: sub2api + restart: unless-stopped + ulimits: + nofile: + soft: 100000 + hard: 100000 + ports: + - "${BIND_HOST:-0.0.0.0}:${SERVER_PORT:-8080}:8080" + volumes: + - sub2api_data:/app/data + extra_hosts: + - "host.docker.internal:host-gateway" + environment: + # ======================================================================= + # Auto Setup + # ======================================================================= + - AUTO_SETUP=true + + # ======================================================================= + # Server Configuration + # ======================================================================= + - SERVER_HOST=0.0.0.0 + - SERVER_PORT=8080 + - SERVER_MODE=${SERVER_MODE:-release} + - RUN_MODE=${RUN_MODE:-standard} + + # ======================================================================= + # Database Configuration (PostgreSQL) - Required + # ======================================================================= + - DATABASE_HOST=${DATABASE_HOST:?DATABASE_HOST is required} + - DATABASE_PORT=${DATABASE_PORT:-5432} + - DATABASE_USER=${DATABASE_USER:-sub2api} + - DATABASE_PASSWORD=${DATABASE_PASSWORD:?DATABASE_PASSWORD is required} + - DATABASE_DBNAME=${DATABASE_DBNAME:-sub2api} + - DATABASE_SSLMODE=${DATABASE_SSLMODE:-disable} + + # ======================================================================= + # Redis Configuration - Required + # ======================================================================= + - REDIS_HOST=${REDIS_HOST:?REDIS_HOST is required} + - REDIS_PORT=${REDIS_PORT:-6379} + - REDIS_PASSWORD=${REDIS_PASSWORD:-} + - REDIS_DB=${REDIS_DB:-0} + + # ======================================================================= + # Admin Account (auto-created on first run) + # ======================================================================= + - ADMIN_EMAIL=${ADMIN_EMAIL:-admin@sub2api.local} + - ADMIN_PASSWORD=${ADMIN_PASSWORD:-} + + # ======================================================================= + # JWT Configuration + # ======================================================================= + - JWT_SECRET=${JWT_SECRET:-} + - JWT_EXPIRE_HOUR=${JWT_EXPIRE_HOUR:-24} + + # ======================================================================= + # Timezone Configuration + # ======================================================================= + - TZ=${TZ:-Asia/Shanghai} + + # ======================================================================= + # Gemini OAuth Configuration (optional) + # ======================================================================= + - GEMINI_OAUTH_CLIENT_ID=${GEMINI_OAUTH_CLIENT_ID:-} + - GEMINI_OAUTH_CLIENT_SECRET=${GEMINI_OAUTH_CLIENT_SECRET:-} + - GEMINI_OAUTH_SCOPES=${GEMINI_OAUTH_SCOPES:-} + - GEMINI_QUOTA_POLICY=${GEMINI_QUOTA_POLICY:-} + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:8080/health"] + interval: 30s + timeout: 10s + retries: 3 + start_period: 30s + +volumes: + sub2api_data: + driver: local diff --git a/deploy/docker-compose.yml b/deploy/docker-compose.yml index 6a370e9a..484df3a8 100644 --- a/deploy/docker-compose.yml +++ b/deploy/docker-compose.yml @@ -173,11 +173,12 @@ services: volumes: - redis_data:/data command: > - redis-server - --save 60 1 - --appendonly yes - --appendfsync everysec - ${REDIS_PASSWORD:+--requirepass ${REDIS_PASSWORD}} + sh -c ' + redis-server + --save 60 1 + --appendonly yes + --appendfsync everysec + ${REDIS_PASSWORD:+--requirepass "$REDIS_PASSWORD"}' environment: - TZ=${TZ:-Asia/Shanghai} # REDISCLI_AUTH is used by redis-cli for authentication (safer than -a flag) diff --git a/frontend/src/api/admin/groups.ts b/frontend/src/api/admin/groups.ts index 23db9104..44eebc99 100644 --- a/frontend/src/api/admin/groups.ts +++ b/frontend/src/api/admin/groups.ts @@ -16,7 +16,7 @@ import type { * List all groups with pagination * @param page - Page number (default: 1) * @param pageSize - Items per page (default: 20) - * @param filters - Optional filters (platform, status, is_exclusive) + * @param filters - Optional filters (platform, status, is_exclusive, search) * @returns Paginated list of groups */ export async function list( @@ -26,6 +26,7 @@ export async function list( platform?: GroupPlatform status?: 'active' | 'inactive' is_exclusive?: boolean + search?: string }, options?: { signal?: AbortSignal diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts index 6b46de7d..2f6991e7 100644 --- a/frontend/src/api/admin/settings.ts +++ b/frontend/src/api/admin/settings.ts @@ -34,6 +34,11 @@ export interface SystemSettings { turnstile_enabled: boolean turnstile_site_key: string turnstile_secret_key_configured: boolean + // LinuxDo Connect OAuth 登录(终端用户 SSO) + linuxdo_connect_enabled: boolean + linuxdo_connect_client_id: string + linuxdo_connect_client_secret_configured: boolean + linuxdo_connect_redirect_url: string // Identity patch configuration (Claude -> Gemini) enable_identity_patch: boolean identity_patch_prompt: string @@ -60,6 +65,10 @@ export interface UpdateSettingsRequest { turnstile_enabled?: boolean turnstile_site_key?: string turnstile_secret_key?: string + linuxdo_connect_enabled?: boolean + linuxdo_connect_client_id?: string + linuxdo_connect_client_secret?: string + linuxdo_connect_redirect_url?: string enable_identity_patch?: boolean identity_patch_prompt?: string } diff --git a/frontend/src/api/admin/usage.ts b/frontend/src/api/admin/usage.ts index 4712dafd..ca76234b 100644 --- a/frontend/src/api/admin/usage.ts +++ b/frontend/src/api/admin/usage.ts @@ -64,7 +64,6 @@ export async function getStats(params: { group_id?: number model?: string stream?: boolean - billing_type?: number period?: string start_date?: string end_date?: string diff --git a/frontend/src/api/keys.ts b/frontend/src/api/keys.ts index caa339e4..cdae1359 100644 --- a/frontend/src/api/keys.ts +++ b/frontend/src/api/keys.ts @@ -42,12 +42,16 @@ export async function getById(id: number): Promise { * @param name - Key name * @param groupId - Optional group ID * @param customKey - Optional custom key value + * @param ipWhitelist - Optional IP whitelist + * @param ipBlacklist - Optional IP blacklist * @returns Created API key */ export async function create( name: string, groupId?: number | null, - customKey?: string + customKey?: string, + ipWhitelist?: string[], + ipBlacklist?: string[] ): Promise { const payload: CreateApiKeyRequest = { name } if (groupId !== undefined) { @@ -56,6 +60,12 @@ export async function create( if (customKey) { payload.custom_key = customKey } + if (ipWhitelist && ipWhitelist.length > 0) { + payload.ip_whitelist = ipWhitelist + } + if (ipBlacklist && ipBlacklist.length > 0) { + payload.ip_blacklist = ipBlacklist + } const { data } = await apiClient.post('/keys', payload) return data diff --git a/frontend/src/components/admin/account/AccountBulkActionsBar.vue b/frontend/src/components/admin/account/AccountBulkActionsBar.vue index 17bd634d..41111484 100644 --- a/frontend/src/components/admin/account/AccountBulkActionsBar.vue +++ b/frontend/src/components/admin/account/AccountBulkActionsBar.vue @@ -1,8 +1,27 @@ - - + + @@ -249,11 +248,11 @@ const cols = computed(() => [ { key: 'stream', label: t('usage.type'), sortable: false }, { key: 'tokens', label: t('usage.tokens'), sortable: false }, { key: 'cost', label: t('usage.cost'), sortable: false }, - { key: 'billing_type', label: t('usage.billingType'), sortable: false }, { key: 'first_token', label: t('usage.firstToken'), sortable: false }, { key: 'duration', label: t('usage.duration'), sortable: false }, { key: 'created_at', label: t('usage.time'), sortable: true }, - { key: 'user_agent', label: t('usage.userAgent'), sortable: false } + { key: 'user_agent', label: t('usage.userAgent'), sortable: false }, + { key: 'ip_address', label: t('admin.usage.ipAddress'), sortable: false } ]) const formatCacheTokens = (tokens: number): string => { diff --git a/frontend/src/components/auth/LinuxDoOAuthSection.vue b/frontend/src/components/auth/LinuxDoOAuthSection.vue new file mode 100644 index 00000000..8012b101 --- /dev/null +++ b/frontend/src/components/auth/LinuxDoOAuthSection.vue @@ -0,0 +1,61 @@ + + + + diff --git a/frontend/src/composables/useTableLoader.ts b/frontend/src/composables/useTableLoader.ts index 01703ee1..5fb6c5e0 100644 --- a/frontend/src/composables/useTableLoader.ts +++ b/frontend/src/composables/useTableLoader.ts @@ -43,7 +43,8 @@ export function useTableLoader>(options: TableL if (abortController) { abortController.abort() } - abortController = new AbortController() + const currentController = new AbortController() + abortController = currentController loading.value = true try { @@ -51,9 +52,9 @@ export function useTableLoader>(options: TableL pagination.page, pagination.page_size, toRaw(params) as P, - { signal: abortController.signal } + { signal: currentController.signal } ) - + items.value = response.items || [] pagination.total = response.total || 0 pagination.pages = response.pages || 0 @@ -63,7 +64,7 @@ export function useTableLoader>(options: TableL throw error } } finally { - if (abortController && !abortController.signal.aborted) { + if (abortController === currentController) { loading.value = false } } @@ -77,7 +78,9 @@ export function useTableLoader>(options: TableL const debouncedReload = useDebounceFn(reload, debounceMs) const handlePageChange = (page: number) => { - pagination.page = page + // 确保页码在有效范围内 + const validPage = Math.max(1, Math.min(page, pagination.pages || 1)) + pagination.page = validPage load() } diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index c4cf6cc6..e7d3a28d 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -229,6 +229,15 @@ export default { sendingCode: 'Sending...', clickToResend: 'Click to resend code', resendCode: 'Resend verification code', + linuxdo: { + signIn: 'Continue with Linux.do', + orContinue: 'or continue with email', + callbackTitle: 'Signing you in', + callbackProcessing: 'Completing login, please wait...', + callbackHint: 'If you are not redirected automatically, go back to the login page and try again.', + callbackMissingToken: 'Missing login token, please try again.', + backToLogin: 'Back to Login' + }, oauth: { code: 'Code', state: 'State', @@ -361,6 +370,14 @@ export default { customKeyTooShort: 'Custom key must be at least 16 characters', customKeyInvalidChars: 'Custom key can only contain letters, numbers, underscores, and hyphens', customKeyRequired: 'Please enter a custom key', + ipRestriction: 'IP Restriction', + ipWhitelist: 'IP Whitelist', + ipWhitelistPlaceholder: '192.168.1.100\n10.0.0.0/8', + ipWhitelistHint: 'One IP or CIDR per line. Only these IPs can use this key when set.', + ipBlacklist: 'IP Blacklist', + ipBlacklistPlaceholder: '1.2.3.4\n5.6.0.0/16', + ipBlacklistHint: 'One IP or CIDR per line. These IPs will be blocked from using this key.', + ipRestrictionEnabled: 'IP restriction enabled', ccSwitchNotInstalled: 'CC-Switch is not installed or the protocol handler is not registered. Please install CC-Switch first or manually copy the API key.', ccsClientSelect: { title: 'Select Client', @@ -421,9 +438,6 @@ export default { exportFailed: 'Failed to export usage data', exportExcelSuccess: 'Usage data exported successfully (Excel format)', exportExcelFailed: 'Failed to export usage data', - billingType: 'Billing', - balance: 'Balance', - subscription: 'Subscription', imageUnit: ' images', userAgent: 'User-Agent' }, @@ -1076,12 +1090,16 @@ export default { tokenRefreshed: 'Token refreshed successfully', accountDeleted: 'Account deleted successfully', rateLimitCleared: 'Rate limit cleared successfully', + bulkSchedulableEnabled: 'Successfully enabled scheduling for {count} account(s)', + bulkSchedulableDisabled: 'Successfully disabled scheduling for {count} account(s)', bulkActions: { selected: '{count} account(s) selected', selectCurrentPage: 'Select this page', clear: 'Clear selection', edit: 'Bulk Edit', - delete: 'Bulk Delete' + delete: 'Bulk Delete', + enableScheduling: 'Enable Scheduling', + disableScheduling: 'Disable Scheduling' }, bulkEdit: { title: 'Bulk Edit Accounts', @@ -1486,6 +1504,7 @@ export default { testing: 'Testing...', retry: 'Retry', copyOutput: 'Copy output', + outputCopied: 'Output copied', startingTestForAccount: 'Starting test for account: {name}', testAccountTypeLabel: 'Account type: {type}', selectTestModel: 'Select Test Model', @@ -1721,7 +1740,6 @@ export default { allAccounts: 'All Accounts', allGroups: 'All Groups', allTypes: 'All Types', - allBillingTypes: 'All Billing', inputCost: 'Input Cost', outputCost: 'Output Cost', cacheCreationCost: 'Cache Creation Cost', @@ -1730,7 +1748,8 @@ export default { outputTokens: 'Output Tokens', cacheCreationTokens: 'Cache Creation Tokens', cacheReadTokens: 'Cache Read Tokens', - failedToLoad: 'Failed to load usage records' + failedToLoad: 'Failed to load usage records', + ipAddress: 'IP' }, // Settings @@ -1756,6 +1775,26 @@ export default { cloudflareDashboard: 'Cloudflare Dashboard', secretKeyHint: 'Server-side verification key (keep this secret)', secretKeyConfiguredHint: 'Secret key configured. Leave empty to keep the current value.' }, + linuxdo: { + title: 'LinuxDo Connect Login', + description: 'Configure LinuxDo Connect OAuth for Sub2API end-user login', + enable: 'Enable LinuxDo Login', + enableHint: 'Show LinuxDo login on the login/register pages', + clientId: 'Client ID', + clientIdPlaceholder: 'e.g., hprJ5pC3...', + clientIdHint: 'Get this from Connect.Linux.Do', + clientSecret: 'Client Secret', + clientSecretPlaceholder: '********', + clientSecretHint: 'Used by backend to exchange tokens (keep it secret)', + clientSecretConfiguredPlaceholder: '********', + clientSecretConfiguredHint: 'Secret configured. Leave empty to keep the current value.', + redirectUrl: 'Redirect URL', + redirectUrlPlaceholder: 'https://your-domain.com/api/v1/auth/oauth/linuxdo/callback', + redirectUrlHint: + 'Must match the redirect URL configured in Connect.Linux.Do (must be an absolute http(s) URL)', + quickSetCopy: 'Generate & Copy (current site)', + redirectUrlSetAndCopied: 'Redirect URL generated and copied to clipboard' + }, defaults: { title: 'Default User Settings', description: 'Default values for new users', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index 79ddf6cc..fc1e6fff 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -227,6 +227,15 @@ export default { sendingCode: '发送中...', clickToResend: '点击重新发送验证码', resendCode: '重新发送验证码', + linuxdo: { + signIn: '使用 Linux.do 登录', + orContinue: '或使用邮箱密码继续', + callbackTitle: '正在完成登录', + callbackProcessing: '正在验证登录信息,请稍候...', + callbackHint: '如果页面未自动跳转,请返回登录页重试。', + callbackMissingToken: '登录信息缺失,请返回重试。', + backToLogin: '返回登录' + }, oauth: { code: '授权码', state: '状态', @@ -358,6 +367,14 @@ export default { customKeyTooShort: '自定义密钥至少需要16个字符', customKeyInvalidChars: '自定义密钥只能包含字母、数字、下划线和连字符', customKeyRequired: '请输入自定义密钥', + ipRestriction: 'IP 限制', + ipWhitelist: 'IP 白名单', + ipWhitelistPlaceholder: '192.168.1.100\n10.0.0.0/8', + ipWhitelistHint: '每行一个 IP 或 CIDR,设置后仅允许这些 IP 使用此密钥', + ipBlacklist: 'IP 黑名单', + ipBlacklistPlaceholder: '1.2.3.4\n5.6.0.0/16', + ipBlacklistHint: '每行一个 IP 或 CIDR,这些 IP 将被禁止使用此密钥', + ipRestrictionEnabled: '已配置 IP 限制', ccSwitchNotInstalled: 'CC-Switch 未安装或协议处理程序未注册。请先安装 CC-Switch 或手动复制 API 密钥。', ccsClientSelect: { title: '选择客户端', @@ -418,9 +435,6 @@ export default { exportFailed: '使用数据导出失败', exportExcelSuccess: '使用数据导出成功(Excel格式)', exportExcelFailed: '使用数据导出失败', - billingType: '消费类型', - balance: '余额', - subscription: '订阅', imageUnit: '张', userAgent: 'User-Agent' }, @@ -1212,12 +1226,16 @@ export default { accountCreatedSuccess: '账号添加成功', accountUpdatedSuccess: '账号更新成功', accountDeletedSuccess: '账号删除成功', + bulkSchedulableEnabled: '成功启用 {count} 个账号的调度', + bulkSchedulableDisabled: '成功停止 {count} 个账号的调度', bulkActions: { selected: '已选择 {count} 个账号', selectCurrentPage: '本页全选', clear: '清除选择', edit: '批量编辑账号', - delete: '批量删除' + delete: '批量删除', + enableScheduling: '批量启用调度', + disableScheduling: '批量停止调度' }, bulkEdit: { title: '批量编辑账号', @@ -1601,6 +1619,7 @@ export default { startTest: '开始测试', retry: '重试', copyOutput: '复制输出', + outputCopied: '输出已复制', startingTestForAccount: '开始测试账号:{name}', testAccountTypeLabel: '账号类型:{type}', selectTestModel: '选择测试模型', @@ -1866,7 +1885,6 @@ export default { allAccounts: '全部账户', allGroups: '全部分组', allTypes: '全部类型', - allBillingTypes: '全部计费', inputCost: '输入成本', outputCost: '输出成本', cacheCreationCost: '缓存创建成本', @@ -1875,7 +1893,8 @@ export default { outputTokens: '输出 Token', cacheCreationTokens: '缓存创建 Token', cacheReadTokens: '缓存读取 Token', - failedToLoad: '加载使用记录失败' + failedToLoad: '加载使用记录失败', + ipAddress: 'IP' }, // Settings @@ -1901,6 +1920,25 @@ export default { cloudflareDashboard: 'Cloudflare Dashboard', secretKeyHint: '服务端验证密钥(请保密)', secretKeyConfiguredHint: '密钥已配置,留空以保留当前值。' }, + linuxdo: { + title: 'LinuxDo Connect 登录', + description: '配置 LinuxDo Connect OAuth,用于 Sub2API 用户登录', + enable: '启用 LinuxDo 登录', + enableHint: '在登录/注册页面显示 LinuxDo 登录入口', + clientId: 'Client ID', + clientIdPlaceholder: '例如:hprJ5pC3...', + clientIdHint: '从 Connect.Linux.Do 后台获取', + clientSecret: 'Client Secret', + clientSecretPlaceholder: '********', + clientSecretHint: '用于后端交换 token(请保密)', + clientSecretConfiguredPlaceholder: '********', + clientSecretConfiguredHint: '密钥已配置,留空以保留当前值。', + redirectUrl: '回调地址(Redirect URL)', + redirectUrlPlaceholder: 'https://your-domain.com/api/v1/auth/oauth/linuxdo/callback', + redirectUrlHint: '需与 Connect.Linux.Do 中配置的回调地址一致(必须是 http(s) 完整 URL)', + quickSetCopy: '使用当前站点生成并复制', + redirectUrlSetAndCopied: '已使用当前站点生成回调地址并复制到剪贴板' + }, defaults: { title: '用户默认设置', description: '新用户的默认值', diff --git a/frontend/src/router/index.ts b/frontend/src/router/index.ts index 48a6f0fd..238982ef 100644 --- a/frontend/src/router/index.ts +++ b/frontend/src/router/index.ts @@ -67,6 +67,15 @@ const routes: RouteRecordRaw[] = [ title: 'OAuth Callback' } }, + { + path: '/auth/linuxdo/callback', + name: 'LinuxDoOAuthCallback', + component: () => import('@/views/auth/LinuxDoCallbackView.vue'), + meta: { + requiresAuth: false, + title: 'LinuxDo OAuth Callback' + } + }, // ==================== User Routes ==================== { diff --git a/frontend/src/stores/app.ts b/frontend/src/stores/app.ts index cfc9d677..ce7081e1 100644 --- a/frontend/src/stores/app.ts +++ b/frontend/src/stores/app.ts @@ -30,6 +30,7 @@ export const useAppStore = defineStore('app', () => { const contactInfo = ref('') const apiBaseUrl = ref('') const docUrl = ref('') + const cachedPublicSettings = ref(null) // Version cache state const versionLoaded = ref(false) @@ -285,6 +286,9 @@ export const useAppStore = defineStore('app', () => { async function fetchPublicSettings(force = false): Promise { // Return cached data if available and not forcing refresh if (publicSettingsLoaded.value && !force) { + if (cachedPublicSettings.value) { + return { ...cachedPublicSettings.value } + } return { registration_enabled: false, email_verify_enabled: false, @@ -296,6 +300,7 @@ export const useAppStore = defineStore('app', () => { api_base_url: apiBaseUrl.value, contact_info: contactInfo.value, doc_url: docUrl.value, + linuxdo_oauth_enabled: false, version: siteVersion.value } } @@ -308,6 +313,7 @@ export const useAppStore = defineStore('app', () => { publicSettingsLoading.value = true try { const data = await fetchPublicSettingsAPI() + cachedPublicSettings.value = data siteName.value = data.site_name || 'Sub2API' siteLogo.value = data.site_logo || '' siteVersion.value = data.version || '' @@ -329,6 +335,7 @@ export const useAppStore = defineStore('app', () => { */ function clearPublicSettingsCache(): void { publicSettingsLoaded.value = false + cachedPublicSettings.value = null } // ==================== Return Store API ==================== diff --git a/frontend/src/stores/auth.ts b/frontend/src/stores/auth.ts index 27faaf4b..4076e154 100644 --- a/frontend/src/stores/auth.ts +++ b/frontend/src/stores/auth.ts @@ -159,6 +159,27 @@ export const useAuthStore = defineStore('auth', () => { } } + /** + * 直接设置 token(用于 OAuth/SSO 回调),并加载当前用户信息。 + * @param newToken - 后端签发的 JWT access token + */ + async function setToken(newToken: string): Promise { + // Clear any previous state first (avoid mixing sessions) + clearAuth() + + token.value = newToken + localStorage.setItem(AUTH_TOKEN_KEY, newToken) + + try { + const userData = await refreshUser() + startAutoRefresh() + return userData + } catch (error) { + clearAuth() + throw error + } + } + /** * User logout * Clears all authentication state and persisted data @@ -233,6 +254,7 @@ export const useAuthStore = defineStore('auth', () => { // Actions login, register, + setToken, logout, checkAuth, refreshUser diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index eaea24be..bc858c6a 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -73,6 +73,7 @@ export interface PublicSettings { api_base_url: string contact_info: string doc_url: string + linuxdo_oauth_enabled: boolean version: string } @@ -278,6 +279,8 @@ export interface ApiKey { name: string group_id: number | null status: 'active' | 'inactive' + ip_whitelist: string[] + ip_blacklist: string[] created_at: string updated_at: string group?: Group @@ -287,12 +290,16 @@ export interface CreateApiKeyRequest { name: string group_id?: number | null custom_key?: string // Optional custom API Key + ip_whitelist?: string[] + ip_blacklist?: string[] } export interface UpdateApiKeyRequest { name?: string group_id?: number | null status?: 'active' | 'inactive' + ip_whitelist?: string[] + ip_blacklist?: string[] } export interface CreateGroupRequest { @@ -559,9 +566,6 @@ export interface UpdateProxyRequest { export type RedeemCodeType = 'balance' | 'concurrency' | 'subscription' -// 消费类型: 0=钱包余额, 1=订阅套餐 -export type BillingType = 0 | 1 - export interface UsageLog { id: number user_id: number @@ -588,7 +592,6 @@ export interface UsageLog { actual_cost: number rate_multiplier: number - billing_type: BillingType stream: boolean duration_ms: number first_token_ms: number | null @@ -600,6 +603,9 @@ export interface UsageLog { // User-Agent user_agent: string | null + // IP 地址(仅管理员可见) + ip_address: string | null + created_at: string user?: User @@ -829,7 +835,6 @@ export interface UsageQueryParams { group_id?: number model?: string stream?: boolean - billing_type?: number start_date?: string end_date?: string } diff --git a/frontend/src/views/admin/AccountsView.vue b/frontend/src/views/admin/AccountsView.vue index 0ca22a76..79c6072c 100644 --- a/frontend/src/views/admin/AccountsView.vue +++ b/frontend/src/views/admin/AccountsView.vue @@ -7,7 +7,7 @@ v-model:searchQuery="params.search" :filters="params" @update:filters="(newFilters) => Object.assign(params, newFilters)" - @change="reload" + @change="debouncedReload" @update:searchQuery="debouncedReload" /> - + @@ -175,7 +175,7 @@ const statsAcc = ref(null) const togglingSchedulable = ref(null) const menu = reactive<{show:boolean, acc:Account|null, pos:{top:number, left:number}|null}>({ show: false, acc: null, pos: null }) -const { items: accounts, loading, params, pagination, load, reload, debouncedReload, handlePageChange } = useTableLoader({ +const { items: accounts, loading, params, pagination, load, reload, debouncedReload, handlePageChange, handlePageSizeChange } = useTableLoader({ fetchFn: adminAPI.accounts.list, initialParams: { platform: '', type: '', status: '', search: '' } }) @@ -209,6 +209,21 @@ const openMenu = (a: Account, e: MouseEvent) => { menu.acc = a; menu.pos = { top const toggleSel = (id: number) => { const i = selIds.value.indexOf(id); if(i === -1) selIds.value.push(id); else selIds.value.splice(i, 1) } const selectPage = () => { selIds.value = [...new Set([...selIds.value, ...accounts.value.map(a => a.id)])] } const handleBulkDelete = async () => { if(!confirm(t('common.confirm'))) return; try { await Promise.all(selIds.value.map(id => adminAPI.accounts.delete(id))); selIds.value = []; reload() } catch (error) { console.error('Failed to bulk delete accounts:', error) } } +const handleBulkToggleSchedulable = async (schedulable: boolean) => { + const count = selIds.value.length + try { + const result = await adminAPI.accounts.bulkUpdate(selIds.value, { schedulable }); + const message = schedulable + ? t('admin.accounts.bulkSchedulableEnabled', { count: result.success || count }) + : t('admin.accounts.bulkSchedulableDisabled', { count: result.success || count }); + appStore.showSuccess(message); + selIds.value = []; + reload() + } catch (error) { + console.error('Failed to bulk toggle schedulable:', error); + appStore.showError(t('common.error')) + } +} const handleBulkUpdated = () => { showBulkEdit.value = false; selIds.value = []; reload() } const closeTestModal = () => { showTest.value = false; testingAcc.value = null } const closeStatsModal = () => { showStats.value = false; statsAcc.value = null } diff --git a/frontend/src/views/admin/GroupsView.vue b/frontend/src/views/admin/GroupsView.vue index f7ef2339..d8322154 100644 --- a/frontend/src/views/admin/GroupsView.vue +++ b/frontend/src/views/admin/GroupsView.vue @@ -16,6 +16,7 @@ type="text" :placeholder="t('admin.groups.searchGroups')" class="input pl-10" + @input="handleSearch" /> +

+ {{ t('admin.settings.linuxdo.clientIdHint') }} +

+ + +
+ + +

+ {{ + form.linuxdo_connect_client_secret_configured + ? t('admin.settings.linuxdo.clientSecretConfiguredHint') + : t('admin.settings.linuxdo.clientSecretHint') + }} +

+
+ +
+ + +
+ + + {{ linuxdoRedirectUrlSuggestion }} + +
+

+ {{ t('admin.settings.linuxdo.redirectUrlHint') }} +

+
+ + + + +
@@ -692,17 +792,19 @@ diff --git a/frontend/src/views/auth/LinuxDoCallbackView.vue b/frontend/src/views/auth/LinuxDoCallbackView.vue new file mode 100644 index 00000000..c6f93e6b --- /dev/null +++ b/frontend/src/views/auth/LinuxDoCallbackView.vue @@ -0,0 +1,119 @@ + + + + + + diff --git a/frontend/src/views/auth/LoginView.vue b/frontend/src/views/auth/LoginView.vue index 903db100..6e6cee27 100644 --- a/frontend/src/views/auth/LoginView.vue +++ b/frontend/src/views/auth/LoginView.vue @@ -11,6 +11,9 @@

+ + +
@@ -157,6 +160,7 @@ import { ref, reactive, onMounted } from 'vue' import { useRouter } from 'vue-router' import { useI18n } from 'vue-i18n' import { AuthLayout } from '@/components/layout' +import LinuxDoOAuthSection from '@/components/auth/LinuxDoOAuthSection.vue' import Icon from '@/components/icons/Icon.vue' import TurnstileWidget from '@/components/TurnstileWidget.vue' import { useAuthStore, useAppStore } from '@/stores' @@ -179,6 +183,7 @@ const showPassword = ref(false) // Public settings const turnstileEnabled = ref(false) const turnstileSiteKey = ref('') +const linuxdoOAuthEnabled = ref(false) // Turnstile const turnstileRef = ref | null>(null) @@ -210,6 +215,7 @@ onMounted(async () => { const settings = await getPublicSettings() turnstileEnabled.value = settings.turnstile_enabled turnstileSiteKey.value = settings.turnstile_site_key || '' + linuxdoOAuthEnabled.value = settings.linuxdo_oauth_enabled } catch (error) { console.error('Failed to load public settings:', error) } diff --git a/frontend/src/views/auth/RegisterView.vue b/frontend/src/views/auth/RegisterView.vue index 9f3555d4..e48120b5 100644 --- a/frontend/src/views/auth/RegisterView.vue +++ b/frontend/src/views/auth/RegisterView.vue @@ -11,6 +11,9 @@

+ + +
(false) const turnstileEnabled = ref(false) const turnstileSiteKey = ref('') const siteName = ref('Sub2API') +const linuxdoOAuthEnabled = ref(false) // Turnstile const turnstileRef = ref | null>(null) @@ -233,6 +238,7 @@ onMounted(async () => { turnstileEnabled.value = settings.turnstile_enabled turnstileSiteKey.value = settings.turnstile_site_key || '' siteName.value = settings.site_name || 'Sub2API' + linuxdoOAuthEnabled.value = settings.linuxdo_oauth_enabled } catch (error) { console.error('Failed to load public settings:', error) } finally { diff --git a/frontend/src/views/user/KeysView.vue b/frontend/src/views/user/KeysView.vue index 6d4e3c96..0787c467 100644 --- a/frontend/src/views/user/KeysView.vue +++ b/frontend/src/views/user/KeysView.vue @@ -46,8 +46,17 @@
-