mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-10 18:14:48 +08:00
Compare commits
24 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3cc407bc0e | ||
|
|
00a0a12138 | ||
|
|
b08767a4f9 | ||
|
|
ac6bde7a98 | ||
|
|
d2d41d68dd | ||
|
|
944b7f7617 | ||
|
|
53825eb073 | ||
|
|
1a7f49513f | ||
|
|
885a2ce7ef | ||
|
|
14ba80a0af | ||
|
|
5fa22fdf82 | ||
|
|
bcaae2eb91 | ||
|
|
767a41e263 | ||
|
|
252d6c5301 | ||
|
|
7a4e65ad4b | ||
|
|
a582aa89a9 | ||
|
|
acefa1da12 | ||
|
|
a88698f3fc | ||
|
|
ebc6755b33 | ||
|
|
c8eff34388 | ||
|
|
f19b03825b | ||
|
|
b43ee62947 | ||
|
|
106b20cdbf | ||
|
|
c069b3b1e8 |
28
README.md
28
README.md
@@ -150,14 +150,14 @@ mkdir -p sub2api-deploy && cd sub2api-deploy
|
|||||||
curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/docker-deploy.sh | bash
|
curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/docker-deploy.sh | bash
|
||||||
|
|
||||||
# Start services
|
# Start services
|
||||||
docker-compose -f docker-compose.local.yml up -d
|
docker-compose up -d
|
||||||
|
|
||||||
# View logs
|
# View logs
|
||||||
docker-compose -f docker-compose.local.yml logs -f sub2api
|
docker-compose logs -f sub2api
|
||||||
```
|
```
|
||||||
|
|
||||||
**What the script does:**
|
**What the script does:**
|
||||||
- Downloads `docker-compose.local.yml` and `.env.example`
|
- Downloads `docker-compose.local.yml` (saved as `docker-compose.yml`) and `.env.example`
|
||||||
- Generates secure credentials (JWT_SECRET, TOTP_ENCRYPTION_KEY, POSTGRES_PASSWORD)
|
- Generates secure credentials (JWT_SECRET, TOTP_ENCRYPTION_KEY, POSTGRES_PASSWORD)
|
||||||
- Creates `.env` file with auto-generated secrets
|
- Creates `.env` file with auto-generated secrets
|
||||||
- Creates data directories (uses local directories for easy backup/migration)
|
- Creates data directories (uses local directories for easy backup/migration)
|
||||||
@@ -522,6 +522,28 @@ sub2api/
|
|||||||
└── install.sh # One-click installation script
|
└── install.sh # One-click installation script
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Disclaimer
|
||||||
|
|
||||||
|
> **Please read carefully before using this project:**
|
||||||
|
>
|
||||||
|
> :rotating_light: **Terms of Service Risk**: Using this project may violate Anthropic's Terms of Service. Please read Anthropic's user agreement carefully before use. All risks arising from the use of this project are borne solely by the user.
|
||||||
|
>
|
||||||
|
> :book: **Disclaimer**: This project is for technical learning and research purposes only. The author assumes no responsibility for account suspension, service interruption, or any other losses caused by the use of this project.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Star History
|
||||||
|
|
||||||
|
<a href="https://star-history.com/#Wei-Shaw/sub2api&Date">
|
||||||
|
<picture>
|
||||||
|
<source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=Wei-Shaw/sub2api&type=Date&theme=dark" />
|
||||||
|
<source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=Wei-Shaw/sub2api&type=Date" />
|
||||||
|
<img alt="Star History Chart" src="https://api.star-history.com/svg?repos=Wei-Shaw/sub2api&type=Date" />
|
||||||
|
</picture>
|
||||||
|
</a>
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
## License
|
## License
|
||||||
|
|
||||||
MIT License
|
MIT License
|
||||||
|
|||||||
28
README_CN.md
28
README_CN.md
@@ -154,14 +154,14 @@ mkdir -p sub2api-deploy && cd sub2api-deploy
|
|||||||
curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/docker-deploy.sh | bash
|
curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/docker-deploy.sh | bash
|
||||||
|
|
||||||
# 启动服务
|
# 启动服务
|
||||||
docker-compose -f docker-compose.local.yml up -d
|
docker-compose up -d
|
||||||
|
|
||||||
# 查看日志
|
# 查看日志
|
||||||
docker-compose -f docker-compose.local.yml logs -f sub2api
|
docker-compose logs -f sub2api
|
||||||
```
|
```
|
||||||
|
|
||||||
**脚本功能:**
|
**脚本功能:**
|
||||||
- 下载 `docker-compose.local.yml` 和 `.env.example`
|
- 下载 `docker-compose.local.yml`(本地保存为 `docker-compose.yml`)和 `.env.example`
|
||||||
- 自动生成安全凭证(JWT_SECRET、TOTP_ENCRYPTION_KEY、POSTGRES_PASSWORD)
|
- 自动生成安全凭证(JWT_SECRET、TOTP_ENCRYPTION_KEY、POSTGRES_PASSWORD)
|
||||||
- 创建 `.env` 文件并填充自动生成的密钥
|
- 创建 `.env` 文件并填充自动生成的密钥
|
||||||
- 创建数据目录(使用本地目录,便于备份和迁移)
|
- 创建数据目录(使用本地目录,便于备份和迁移)
|
||||||
@@ -588,6 +588,28 @@ sub2api/
|
|||||||
└── install.sh # 一键安装脚本
|
└── install.sh # 一键安装脚本
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## 免责声明
|
||||||
|
|
||||||
|
> **使用本项目前请仔细阅读:**
|
||||||
|
>
|
||||||
|
> :rotating_light: **服务条款风险**: 使用本项目可能违反 Anthropic 的服务条款。请在使用前仔细阅读 Anthropic 的用户协议,使用本项目的一切风险由用户自行承担。
|
||||||
|
>
|
||||||
|
> :book: **免责声明**: 本项目仅供技术学习和研究使用,作者不对因使用本项目导致的账户封禁、服务中断或其他损失承担任何责任。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Star History
|
||||||
|
|
||||||
|
<a href="https://star-history.com/#Wei-Shaw/sub2api&Date">
|
||||||
|
<picture>
|
||||||
|
<source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=Wei-Shaw/sub2api&type=Date&theme=dark" />
|
||||||
|
<source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=Wei-Shaw/sub2api&type=Date" />
|
||||||
|
<img alt="Star History Chart" src="https://api.star-history.com/svg?repos=Wei-Shaw/sub2api&type=Date" />
|
||||||
|
</picture>
|
||||||
|
</a>
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
## 许可证
|
## 许可证
|
||||||
|
|
||||||
MIT License
|
MIT License
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ func main() {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
userRepo := repository.NewUserRepository(client, sqlDB)
|
userRepo := repository.NewUserRepository(client, sqlDB)
|
||||||
authService := service.NewAuthService(userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
|
authService := service.NewAuthService(client, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|||||||
@@ -67,7 +67,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService)
|
apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService)
|
||||||
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator)
|
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator)
|
||||||
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig)
|
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig)
|
||||||
authService := service.NewAuthService(userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService)
|
authService := service.NewAuthService(client, userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService)
|
||||||
userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator, billingCache)
|
userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator, billingCache)
|
||||||
redeemCache := repository.NewRedeemCache(redisClient)
|
redeemCache := repository.NewRedeemCache(redisClient)
|
||||||
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator)
|
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator)
|
||||||
@@ -104,7 +104,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
proxyRepository := repository.NewProxyRepository(client, db)
|
proxyRepository := repository.NewProxyRepository(client, db)
|
||||||
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
|
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
|
||||||
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
|
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
|
||||||
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService)
|
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService, userSubscriptionRepository)
|
||||||
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
|
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
|
||||||
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
|
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
|
||||||
adminUserHandler := admin.NewUserHandler(adminService, concurrencyService)
|
adminUserHandler := admin.NewUserHandler(adminService, concurrencyService)
|
||||||
|
|||||||
@@ -8,6 +8,9 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"log/slog"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -292,6 +295,8 @@ func (h *AccountHandler) importData(ctx context.Context, req DataImportRequest)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
enrichCredentialsFromIDToken(&item)
|
||||||
|
|
||||||
accountInput := &service.CreateAccountInput{
|
accountInput := &service.CreateAccountInput{
|
||||||
Name: item.Name,
|
Name: item.Name,
|
||||||
Notes: item.Notes,
|
Notes: item.Notes,
|
||||||
@@ -535,6 +540,57 @@ func defaultProxyName(name string) string {
|
|||||||
return name
|
return name
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// enrichCredentialsFromIDToken performs best-effort extraction of user info fields
|
||||||
|
// (email, plan_type, chatgpt_account_id, etc.) from id_token in credentials.
|
||||||
|
// Only applies to OpenAI/Sora OAuth accounts. Skips expired token errors silently.
|
||||||
|
// Existing credential values are never overwritten — only missing fields are filled.
|
||||||
|
func enrichCredentialsFromIDToken(item *DataAccount) {
|
||||||
|
if item.Credentials == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Only enrich OpenAI/Sora OAuth accounts
|
||||||
|
platform := strings.ToLower(strings.TrimSpace(item.Platform))
|
||||||
|
if platform != service.PlatformOpenAI && platform != service.PlatformSora {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if strings.ToLower(strings.TrimSpace(item.Type)) != service.AccountTypeOAuth {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
idToken, _ := item.Credentials["id_token"].(string)
|
||||||
|
if strings.TrimSpace(idToken) == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecodeIDToken skips expiry validation — safe for imported data
|
||||||
|
claims, err := openai.DecodeIDToken(idToken)
|
||||||
|
if err != nil {
|
||||||
|
slog.Debug("import_enrich_id_token_decode_failed", "account", item.Name, "error", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
userInfo := claims.GetUserInfo()
|
||||||
|
if userInfo == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fill missing fields only (never overwrite existing values)
|
||||||
|
setIfMissing := func(key, value string) {
|
||||||
|
if value == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if existing, _ := item.Credentials[key].(string); existing == "" {
|
||||||
|
item.Credentials[key] = value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
setIfMissing("email", userInfo.Email)
|
||||||
|
setIfMissing("plan_type", userInfo.PlanType)
|
||||||
|
setIfMissing("chatgpt_account_id", userInfo.ChatGPTAccountID)
|
||||||
|
setIfMissing("chatgpt_user_id", userInfo.ChatGPTUserID)
|
||||||
|
setIfMissing("organization_id", userInfo.OrganizationID)
|
||||||
|
}
|
||||||
|
|
||||||
func normalizeProxyStatus(status string) string {
|
func normalizeProxyStatus(status string) string {
|
||||||
normalized := strings.TrimSpace(strings.ToLower(status))
|
normalized := strings.TrimSpace(strings.ToLower(status))
|
||||||
switch normalized {
|
switch normalized {
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -18,6 +19,7 @@ import (
|
|||||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||||
|
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||||
@@ -751,52 +753,31 @@ func (h *AccountHandler) PreviewFromCRS(c *gin.Context) {
|
|||||||
response.Success(c, result)
|
response.Success(c, result)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Refresh handles refreshing account credentials
|
// refreshSingleAccount refreshes credentials for a single OAuth account.
|
||||||
// POST /api/v1/admin/accounts/:id/refresh
|
// Returns (updatedAccount, warning, error) where warning is used for Antigravity ProjectIDMissing scenario.
|
||||||
func (h *AccountHandler) Refresh(c *gin.Context) {
|
func (h *AccountHandler) refreshSingleAccount(ctx context.Context, account *service.Account) (*service.Account, string, error) {
|
||||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
|
||||||
if err != nil {
|
|
||||||
response.BadRequest(c, "Invalid account ID")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get account
|
|
||||||
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
|
|
||||||
if err != nil {
|
|
||||||
response.NotFound(c, "Account not found")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Only refresh OAuth-based accounts (oauth and setup-token)
|
|
||||||
if !account.IsOAuth() {
|
if !account.IsOAuth() {
|
||||||
response.BadRequest(c, "Cannot refresh non-OAuth account credentials")
|
return nil, "", infraerrors.BadRequest("NOT_OAUTH", "cannot refresh non-OAuth account")
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var newCredentials map[string]any
|
var newCredentials map[string]any
|
||||||
|
|
||||||
if account.IsOpenAI() {
|
if account.IsOpenAI() {
|
||||||
// Use OpenAI OAuth service to refresh token
|
tokenInfo, err := h.openaiOAuthService.RefreshAccountToken(ctx, account)
|
||||||
tokenInfo, err := h.openaiOAuthService.RefreshAccountToken(c.Request.Context(), account)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
return nil, "", err
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build new credentials from token info
|
|
||||||
newCredentials = h.openaiOAuthService.BuildAccountCredentials(tokenInfo)
|
newCredentials = h.openaiOAuthService.BuildAccountCredentials(tokenInfo)
|
||||||
|
|
||||||
// Preserve non-token settings from existing credentials
|
|
||||||
for k, v := range account.Credentials {
|
for k, v := range account.Credentials {
|
||||||
if _, exists := newCredentials[k]; !exists {
|
if _, exists := newCredentials[k]; !exists {
|
||||||
newCredentials[k] = v
|
newCredentials[k] = v
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if account.Platform == service.PlatformGemini {
|
} else if account.Platform == service.PlatformGemini {
|
||||||
tokenInfo, err := h.geminiOAuthService.RefreshAccountToken(c.Request.Context(), account)
|
tokenInfo, err := h.geminiOAuthService.RefreshAccountToken(ctx, account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.InternalError(c, "Failed to refresh credentials: "+err.Error())
|
return nil, "", fmt.Errorf("failed to refresh credentials: %w", err)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
newCredentials = h.geminiOAuthService.BuildAccountCredentials(tokenInfo)
|
newCredentials = h.geminiOAuthService.BuildAccountCredentials(tokenInfo)
|
||||||
@@ -806,10 +787,9 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if account.Platform == service.PlatformAntigravity {
|
} else if account.Platform == service.PlatformAntigravity {
|
||||||
tokenInfo, err := h.antigravityOAuthService.RefreshAccountToken(c.Request.Context(), account)
|
tokenInfo, err := h.antigravityOAuthService.RefreshAccountToken(ctx, account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
return nil, "", err
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
newCredentials = h.antigravityOAuthService.BuildAccountCredentials(tokenInfo)
|
newCredentials = h.antigravityOAuthService.BuildAccountCredentials(tokenInfo)
|
||||||
@@ -828,37 +808,27 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 如果 project_id 获取失败,更新凭证但不标记为 error
|
// 如果 project_id 获取失败,更新凭证但不标记为 error
|
||||||
// LoadCodeAssist 失败可能是临时网络问题,给它机会在下次自动刷新时重试
|
|
||||||
if tokenInfo.ProjectIDMissing {
|
if tokenInfo.ProjectIDMissing {
|
||||||
// 先更新凭证(token 本身刷新成功了)
|
updatedAccount, updateErr := h.adminService.UpdateAccount(ctx, account.ID, &service.UpdateAccountInput{
|
||||||
_, updateErr := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
|
|
||||||
Credentials: newCredentials,
|
Credentials: newCredentials,
|
||||||
})
|
})
|
||||||
if updateErr != nil {
|
if updateErr != nil {
|
||||||
response.InternalError(c, "Failed to update credentials: "+updateErr.Error())
|
return nil, "", fmt.Errorf("failed to update credentials: %w", updateErr)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
// 不标记为 error,只返回警告信息
|
return updatedAccount, "missing_project_id_temporary", nil
|
||||||
response.Success(c, gin.H{
|
|
||||||
"message": "Token refreshed successfully, but project_id could not be retrieved (will retry automatically)",
|
|
||||||
"warning": "missing_project_id_temporary",
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 成功获取到 project_id,如果之前是 missing_project_id 错误则清除
|
// 成功获取到 project_id,如果之前是 missing_project_id 错误则清除
|
||||||
if account.Status == service.StatusError && strings.Contains(account.ErrorMessage, "missing_project_id:") {
|
if account.Status == service.StatusError && strings.Contains(account.ErrorMessage, "missing_project_id:") {
|
||||||
if _, clearErr := h.adminService.ClearAccountError(c.Request.Context(), accountID); clearErr != nil {
|
if _, clearErr := h.adminService.ClearAccountError(ctx, account.ID); clearErr != nil {
|
||||||
response.InternalError(c, "Failed to clear account error: "+clearErr.Error())
|
return nil, "", fmt.Errorf("failed to clear account error: %w", clearErr)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Use Anthropic/Claude OAuth service to refresh token
|
// Use Anthropic/Claude OAuth service to refresh token
|
||||||
tokenInfo, err := h.oauthService.RefreshAccountToken(c.Request.Context(), account)
|
tokenInfo, err := h.oauthService.RefreshAccountToken(ctx, account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
return nil, "", err
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Copy existing credentials to preserve non-token settings (e.g., intercept_warmup_requests)
|
// Copy existing credentials to preserve non-token settings (e.g., intercept_warmup_requests)
|
||||||
@@ -880,20 +850,51 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
updatedAccount, err := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
|
updatedAccount, err := h.adminService.UpdateAccount(ctx, account.ID, &service.UpdateAccountInput{
|
||||||
Credentials: newCredentials,
|
Credentials: newCredentials,
|
||||||
})
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 刷新成功后,清除 token 缓存,确保下次请求使用新 token
|
||||||
|
if h.tokenCacheInvalidator != nil {
|
||||||
|
if invalidateErr := h.tokenCacheInvalidator.InvalidateToken(ctx, updatedAccount); invalidateErr != nil {
|
||||||
|
log.Printf("[WARN] Failed to invalidate token cache for account %d: %v", updatedAccount.ID, invalidateErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return updatedAccount, "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Refresh handles refreshing account credentials
|
||||||
|
// POST /api/v1/admin/accounts/:id/refresh
|
||||||
|
func (h *AccountHandler) Refresh(c *gin.Context) {
|
||||||
|
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, "Invalid account ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get account
|
||||||
|
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
|
||||||
|
if err != nil {
|
||||||
|
response.NotFound(c, "Account not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
updatedAccount, warning, err := h.refreshSingleAccount(c.Request.Context(), account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 刷新成功后,清除 token 缓存,确保下次请求使用新 token
|
if warning == "missing_project_id_temporary" {
|
||||||
if h.tokenCacheInvalidator != nil {
|
response.Success(c, gin.H{
|
||||||
if invalidateErr := h.tokenCacheInvalidator.InvalidateToken(c.Request.Context(), updatedAccount); invalidateErr != nil {
|
"message": "Token refreshed successfully, but project_id could not be retrieved (will retry automatically)",
|
||||||
// 缓存失效失败只记录日志,不影响主流程
|
"warning": "missing_project_id_temporary",
|
||||||
_ = c.Error(invalidateErr)
|
})
|
||||||
}
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), updatedAccount))
|
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), updatedAccount))
|
||||||
@@ -949,14 +950,175 @@ func (h *AccountHandler) ClearError(c *gin.Context) {
|
|||||||
// 这解决了管理员重置账号状态后,旧的失效 token 仍在缓存中导致立即再次 401 的问题
|
// 这解决了管理员重置账号状态后,旧的失效 token 仍在缓存中导致立即再次 401 的问题
|
||||||
if h.tokenCacheInvalidator != nil && account.IsOAuth() {
|
if h.tokenCacheInvalidator != nil && account.IsOAuth() {
|
||||||
if invalidateErr := h.tokenCacheInvalidator.InvalidateToken(c.Request.Context(), account); invalidateErr != nil {
|
if invalidateErr := h.tokenCacheInvalidator.InvalidateToken(c.Request.Context(), account); invalidateErr != nil {
|
||||||
// 缓存失效失败只记录日志,不影响主流程
|
log.Printf("[WARN] Failed to invalidate token cache for account %d: %v", accountID, invalidateErr)
|
||||||
_ = c.Error(invalidateErr)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
|
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BatchClearError handles batch clearing account errors
|
||||||
|
// POST /api/v1/admin/accounts/batch-clear-error
|
||||||
|
func (h *AccountHandler) BatchClearError(c *gin.Context) {
|
||||||
|
var req struct {
|
||||||
|
AccountIDs []int64 `json:"account_ids"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(req.AccountIDs) == 0 {
|
||||||
|
response.BadRequest(c, "account_ids is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := c.Request.Context()
|
||||||
|
|
||||||
|
const maxConcurrency = 10
|
||||||
|
g, gctx := errgroup.WithContext(ctx)
|
||||||
|
g.SetLimit(maxConcurrency)
|
||||||
|
|
||||||
|
var mu sync.Mutex
|
||||||
|
var successCount, failedCount int
|
||||||
|
var errors []gin.H
|
||||||
|
|
||||||
|
// 注意:所有 goroutine 必须 return nil,避免 errgroup cancel 其他并发任务
|
||||||
|
for _, id := range req.AccountIDs {
|
||||||
|
accountID := id // 闭包捕获
|
||||||
|
g.Go(func() error {
|
||||||
|
account, err := h.adminService.ClearAccountError(gctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
mu.Lock()
|
||||||
|
failedCount++
|
||||||
|
errors = append(errors, gin.H{
|
||||||
|
"account_id": accountID,
|
||||||
|
"error": err.Error(),
|
||||||
|
})
|
||||||
|
mu.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 清除错误后,同时清除 token 缓存
|
||||||
|
if h.tokenCacheInvalidator != nil && account.IsOAuth() {
|
||||||
|
if invalidateErr := h.tokenCacheInvalidator.InvalidateToken(gctx, account); invalidateErr != nil {
|
||||||
|
log.Printf("[WARN] Failed to invalidate token cache for account %d: %v", accountID, invalidateErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mu.Lock()
|
||||||
|
successCount++
|
||||||
|
mu.Unlock()
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := g.Wait(); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, gin.H{
|
||||||
|
"total": len(req.AccountIDs),
|
||||||
|
"success": successCount,
|
||||||
|
"failed": failedCount,
|
||||||
|
"errors": errors,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// BatchRefresh handles batch refreshing account credentials
|
||||||
|
// POST /api/v1/admin/accounts/batch-refresh
|
||||||
|
func (h *AccountHandler) BatchRefresh(c *gin.Context) {
|
||||||
|
var req struct {
|
||||||
|
AccountIDs []int64 `json:"account_ids"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(req.AccountIDs) == 0 {
|
||||||
|
response.BadRequest(c, "account_ids is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := c.Request.Context()
|
||||||
|
|
||||||
|
accounts, err := h.adminService.GetAccountsByIDs(ctx, req.AccountIDs)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 建立已获取账号的 ID 集合,检测缺失的 ID
|
||||||
|
foundIDs := make(map[int64]bool, len(accounts))
|
||||||
|
for _, acc := range accounts {
|
||||||
|
if acc != nil {
|
||||||
|
foundIDs[acc.ID] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const maxConcurrency = 10
|
||||||
|
g, gctx := errgroup.WithContext(ctx)
|
||||||
|
g.SetLimit(maxConcurrency)
|
||||||
|
|
||||||
|
var mu sync.Mutex
|
||||||
|
var successCount, failedCount int
|
||||||
|
var errors []gin.H
|
||||||
|
var warnings []gin.H
|
||||||
|
|
||||||
|
// 将不存在的账号 ID 标记为失败
|
||||||
|
for _, id := range req.AccountIDs {
|
||||||
|
if !foundIDs[id] {
|
||||||
|
failedCount++
|
||||||
|
errors = append(errors, gin.H{
|
||||||
|
"account_id": id,
|
||||||
|
"error": "account not found",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 注意:所有 goroutine 必须 return nil,避免 errgroup cancel 其他并发任务
|
||||||
|
for _, account := range accounts {
|
||||||
|
acc := account // 闭包捕获
|
||||||
|
if acc == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
g.Go(func() error {
|
||||||
|
_, warning, err := h.refreshSingleAccount(gctx, acc)
|
||||||
|
mu.Lock()
|
||||||
|
if err != nil {
|
||||||
|
failedCount++
|
||||||
|
errors = append(errors, gin.H{
|
||||||
|
"account_id": acc.ID,
|
||||||
|
"error": err.Error(),
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
successCount++
|
||||||
|
if warning != "" {
|
||||||
|
warnings = append(warnings, gin.H{
|
||||||
|
"account_id": acc.ID,
|
||||||
|
"warning": warning,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mu.Unlock()
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := g.Wait(); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, gin.H{
|
||||||
|
"total": len(req.AccountIDs),
|
||||||
|
"success": successCount,
|
||||||
|
"failed": failedCount,
|
||||||
|
"errors": errors,
|
||||||
|
"warnings": warnings,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// BatchCreate handles batch creating accounts
|
// BatchCreate handles batch creating accounts
|
||||||
// POST /api/v1/admin/accounts/batch
|
// POST /api/v1/admin/accounts/batch
|
||||||
func (h *AccountHandler) BatchCreate(c *gin.Context) {
|
func (h *AccountHandler) BatchCreate(c *gin.Context) {
|
||||||
|
|||||||
@@ -1405,6 +1405,61 @@ func (h *SettingHandler) UpdateRectifierSettings(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetBetaPolicySettings 获取 Beta 策略配置
|
||||||
|
// GET /api/v1/admin/settings/beta-policy
|
||||||
|
func (h *SettingHandler) GetBetaPolicySettings(c *gin.Context) {
|
||||||
|
settings, err := h.settingService.GetBetaPolicySettings(c.Request.Context())
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
rules := make([]dto.BetaPolicyRule, len(settings.Rules))
|
||||||
|
for i, r := range settings.Rules {
|
||||||
|
rules[i] = dto.BetaPolicyRule(r)
|
||||||
|
}
|
||||||
|
response.Success(c, dto.BetaPolicySettings{Rules: rules})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateBetaPolicySettingsRequest 更新 Beta 策略配置请求
|
||||||
|
type UpdateBetaPolicySettingsRequest struct {
|
||||||
|
Rules []dto.BetaPolicyRule `json:"rules"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateBetaPolicySettings 更新 Beta 策略配置
|
||||||
|
// PUT /api/v1/admin/settings/beta-policy
|
||||||
|
func (h *SettingHandler) UpdateBetaPolicySettings(c *gin.Context) {
|
||||||
|
var req UpdateBetaPolicySettingsRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
rules := make([]service.BetaPolicyRule, len(req.Rules))
|
||||||
|
for i, r := range req.Rules {
|
||||||
|
rules[i] = service.BetaPolicyRule(r)
|
||||||
|
}
|
||||||
|
|
||||||
|
settings := &service.BetaPolicySettings{Rules: rules}
|
||||||
|
if err := h.settingService.SetBetaPolicySettings(c.Request.Context(), settings); err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Re-fetch to return updated settings
|
||||||
|
updated, err := h.settingService.GetBetaPolicySettings(c.Request.Context())
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
outRules := make([]dto.BetaPolicyRule, len(updated.Rules))
|
||||||
|
for i, r := range updated.Rules {
|
||||||
|
outRules[i] = dto.BetaPolicyRule(r)
|
||||||
|
}
|
||||||
|
response.Success(c, dto.BetaPolicySettings{Rules: outRules})
|
||||||
|
}
|
||||||
|
|
||||||
// UpdateStreamTimeoutSettingsRequest 更新流超时配置请求
|
// UpdateStreamTimeoutSettingsRequest 更新流超时配置请求
|
||||||
type UpdateStreamTimeoutSettingsRequest struct {
|
type UpdateStreamTimeoutSettingsRequest struct {
|
||||||
Enabled bool `json:"enabled"`
|
Enabled bool `json:"enabled"`
|
||||||
|
|||||||
@@ -211,8 +211,22 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
|
|||||||
email = linuxDoSyntheticEmail(subject)
|
email = linuxDoSyntheticEmail(subject)
|
||||||
}
|
}
|
||||||
|
|
||||||
tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username)
|
// 传入空邀请码;如果需要邀请码,服务层返回 ErrOAuthInvitationRequired
|
||||||
|
tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if errors.Is(err, service.ErrOAuthInvitationRequired) {
|
||||||
|
pendingToken, tokenErr := h.authService.CreatePendingOAuthToken(email, username)
|
||||||
|
if tokenErr != nil {
|
||||||
|
redirectOAuthError(c, frontendCallback, "login_failed", "service_error", "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fragment := url.Values{}
|
||||||
|
fragment.Set("error", "invitation_required")
|
||||||
|
fragment.Set("pending_oauth_token", pendingToken)
|
||||||
|
fragment.Set("redirect", redirectTo)
|
||||||
|
redirectWithFragment(c, frontendCallback, fragment)
|
||||||
|
return
|
||||||
|
}
|
||||||
// 避免把内部细节泄露给客户端;给前端保留结构化原因与提示信息即可。
|
// 避免把内部细节泄露给客户端;给前端保留结构化原因与提示信息即可。
|
||||||
redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err))
|
redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err))
|
||||||
return
|
return
|
||||||
@@ -227,6 +241,41 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
|
|||||||
redirectWithFragment(c, frontendCallback, fragment)
|
redirectWithFragment(c, frontendCallback, fragment)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type completeLinuxDoOAuthRequest struct {
|
||||||
|
PendingOAuthToken string `json:"pending_oauth_token" binding:"required"`
|
||||||
|
InvitationCode string `json:"invitation_code" binding:"required"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// CompleteLinuxDoOAuthRegistration completes a pending OAuth registration by validating
|
||||||
|
// the invitation code and creating the user account.
|
||||||
|
// POST /api/v1/auth/oauth/linuxdo/complete-registration
|
||||||
|
func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) {
|
||||||
|
var req completeLinuxDoOAuthRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "INVALID_REQUEST", "message": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
email, username, err := h.authService.VerifyPendingOAuthToken(req.PendingOAuthToken)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "INVALID_TOKEN", "message": "invalid or expired registration token"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"access_token": tokenPair.AccessToken,
|
||||||
|
"refresh_token": tokenPair.RefreshToken,
|
||||||
|
"expires_in": tokenPair.ExpiresIn,
|
||||||
|
"token_type": "Bearer",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func (h *AuthHandler) getLinuxDoOAuthConfig(ctx context.Context) (config.LinuxDoConnectConfig, error) {
|
func (h *AuthHandler) getLinuxDoOAuthConfig(ctx context.Context) (config.LinuxDoConnectConfig, error) {
|
||||||
if h != nil && h.settingSvc != nil {
|
if h != nil && h.settingSvc != nil {
|
||||||
return h.settingSvc.GetLinuxDoConnectOAuthConfig(ctx)
|
return h.settingSvc.GetLinuxDoConnectOAuthConfig(ctx)
|
||||||
|
|||||||
@@ -168,6 +168,19 @@ type RectifierSettings struct {
|
|||||||
ThinkingBudgetEnabled bool `json:"thinking_budget_enabled"`
|
ThinkingBudgetEnabled bool `json:"thinking_budget_enabled"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BetaPolicyRule Beta 策略规则 DTO
|
||||||
|
type BetaPolicyRule struct {
|
||||||
|
BetaToken string `json:"beta_token"`
|
||||||
|
Action string `json:"action"`
|
||||||
|
Scope string `json:"scope"`
|
||||||
|
ErrorMessage string `json:"error_message,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// BetaPolicySettings Beta 策略配置 DTO
|
||||||
|
type BetaPolicySettings struct {
|
||||||
|
Rules []BetaPolicyRule `json:"rules"`
|
||||||
|
}
|
||||||
|
|
||||||
// ParseCustomMenuItems parses a JSON string into a slice of CustomMenuItem.
|
// ParseCustomMenuItems parses a JSON string into a slice of CustomMenuItem.
|
||||||
// Returns empty slice on empty/invalid input.
|
// Returns empty slice on empty/invalid input.
|
||||||
func ParseCustomMenuItems(raw string) []CustomMenuItem {
|
func ParseCustomMenuItems(raw string) []CustomMenuItem {
|
||||||
|
|||||||
@@ -652,6 +652,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
accountReleaseFunc()
|
accountReleaseFunc()
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
// Beta policy block: return 400 immediately, no failover
|
||||||
|
var betaBlockedErr *service.BetaBlockedError
|
||||||
|
if errors.As(err, &betaBlockedErr) {
|
||||||
|
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", betaBlockedErr.Message)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
var promptTooLongErr *service.PromptTooLongError
|
var promptTooLongErr *service.PromptTooLongError
|
||||||
if errors.As(err, &promptTooLongErr) {
|
if errors.As(err, &promptTooLongErr) {
|
||||||
reqLog.Warn("gateway.prompt_too_long_from_antigravity",
|
reqLog.Warn("gateway.prompt_too_long_from_antigravity",
|
||||||
|
|||||||
@@ -127,6 +127,7 @@ func (f *fakeConcurrencyCache) GetAccountConcurrencyBatch(_ context.Context, acc
|
|||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
func (f *fakeConcurrencyCache) CleanupExpiredAccountSlots(context.Context, int64) error { return nil }
|
func (f *fakeConcurrencyCache) CleanupExpiredAccountSlots(context.Context, int64) error { return nil }
|
||||||
|
func (f *fakeConcurrencyCache) CleanupStaleProcessSlots(context.Context, string) error { return nil }
|
||||||
|
|
||||||
func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*service.Account) (*GatewayHandler, func()) {
|
func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*service.Account) (*GatewayHandler, func()) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|||||||
@@ -89,6 +89,10 @@ func (m *concurrencyCacheMock) CleanupExpiredAccountSlots(ctx context.Context, a
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *concurrencyCacheMock) CleanupStaleProcessSlots(ctx context.Context, activeRequestPrefix string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func TestConcurrencyHelper_TryAcquireUserSlot(t *testing.T) {
|
func TestConcurrencyHelper_TryAcquireUserSlot(t *testing.T) {
|
||||||
cache := &concurrencyCacheMock{
|
cache := &concurrencyCacheMock{
|
||||||
acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
|
acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||||
|
|||||||
@@ -120,6 +120,10 @@ func (s *helperConcurrencyCacheStub) CleanupExpiredAccountSlots(ctx context.Cont
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *helperConcurrencyCacheStub) CleanupStaleProcessSlots(ctx context.Context, activeRequestPrefix string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func newHelperTestContext(method, path string) (*gin.Context, *httptest.ResponseRecorder) {
|
func newHelperTestContext(method, path string) (*gin.Context, *httptest.ResponseRecorder) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ const (
|
|||||||
|
|
||||||
// DroppedBetas 是转发时需要从 anthropic-beta header 中移除的 beta token 列表。
|
// DroppedBetas 是转发时需要从 anthropic-beta header 中移除的 beta token 列表。
|
||||||
// 这些 token 是客户端特有的,不应透传给上游 API。
|
// 这些 token 是客户端特有的,不应透传给上游 API。
|
||||||
var DroppedBetas = []string{BetaFastMode}
|
var DroppedBetas = []string{}
|
||||||
|
|
||||||
// DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header
|
// DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header
|
||||||
const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
|
const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
|
||||||
|
|||||||
@@ -268,6 +268,7 @@ type IDTokenClaims struct {
|
|||||||
type OpenAIAuthClaims struct {
|
type OpenAIAuthClaims struct {
|
||||||
ChatGPTAccountID string `json:"chatgpt_account_id"`
|
ChatGPTAccountID string `json:"chatgpt_account_id"`
|
||||||
ChatGPTUserID string `json:"chatgpt_user_id"`
|
ChatGPTUserID string `json:"chatgpt_user_id"`
|
||||||
|
ChatGPTPlanType string `json:"chatgpt_plan_type"`
|
||||||
UserID string `json:"user_id"`
|
UserID string `json:"user_id"`
|
||||||
Organizations []OrganizationClaim `json:"organizations"`
|
Organizations []OrganizationClaim `json:"organizations"`
|
||||||
}
|
}
|
||||||
@@ -325,12 +326,9 @@ func (r *RefreshTokenRequest) ToFormData() string {
|
|||||||
return params.Encode()
|
return params.Encode()
|
||||||
}
|
}
|
||||||
|
|
||||||
// ParseIDToken parses the ID Token JWT and extracts claims.
|
// DecodeIDToken decodes the ID Token JWT payload without validating expiration.
|
||||||
// 注意:当前仅解码 payload 并校验 exp,未验证 JWT 签名。
|
// Use this for best-effort extraction (e.g., during data import) where the token may be expired.
|
||||||
// 生产环境如需用 ID Token 做授权决策,应通过 OpenAI 的 JWKS 端点验证签名:
|
func DecodeIDToken(idToken string) (*IDTokenClaims, error) {
|
||||||
//
|
|
||||||
// https://auth.openai.com/.well-known/jwks.json
|
|
||||||
func ParseIDToken(idToken string) (*IDTokenClaims, error) {
|
|
||||||
parts := strings.Split(idToken, ".")
|
parts := strings.Split(idToken, ".")
|
||||||
if len(parts) != 3 {
|
if len(parts) != 3 {
|
||||||
return nil, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts))
|
return nil, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts))
|
||||||
@@ -360,6 +358,20 @@ func ParseIDToken(idToken string) (*IDTokenClaims, error) {
|
|||||||
return nil, fmt.Errorf("failed to parse JWT claims: %w", err)
|
return nil, fmt.Errorf("failed to parse JWT claims: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return &claims, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseIDToken parses the ID Token JWT and extracts claims.
|
||||||
|
// 注意:当前仅解码 payload 并校验 exp,未验证 JWT 签名。
|
||||||
|
// 生产环境如需用 ID Token 做授权决策,应通过 OpenAI 的 JWKS 端点验证签名:
|
||||||
|
//
|
||||||
|
// https://auth.openai.com/.well-known/jwks.json
|
||||||
|
func ParseIDToken(idToken string) (*IDTokenClaims, error) {
|
||||||
|
claims, err := DecodeIDToken(idToken)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
// 校验 ID Token 是否已过期(允许 2 分钟时钟偏差,防止因服务器时钟略有差异误判刚颁发的令牌)
|
// 校验 ID Token 是否已过期(允许 2 分钟时钟偏差,防止因服务器时钟略有差异误判刚颁发的令牌)
|
||||||
const clockSkewTolerance = 120 // 秒
|
const clockSkewTolerance = 120 // 秒
|
||||||
now := time.Now().Unix()
|
now := time.Now().Unix()
|
||||||
@@ -367,7 +379,7 @@ func ParseIDToken(idToken string) (*IDTokenClaims, error) {
|
|||||||
return nil, fmt.Errorf("id_token has expired (exp: %d, now: %d, skew_tolerance: %ds)", claims.Exp, now, clockSkewTolerance)
|
return nil, fmt.Errorf("id_token has expired (exp: %d, now: %d, skew_tolerance: %ds)", claims.Exp, now, clockSkewTolerance)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &claims, nil
|
return claims, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// UserInfo represents user information extracted from ID Token claims.
|
// UserInfo represents user information extracted from ID Token claims.
|
||||||
@@ -375,6 +387,7 @@ type UserInfo struct {
|
|||||||
Email string
|
Email string
|
||||||
ChatGPTAccountID string
|
ChatGPTAccountID string
|
||||||
ChatGPTUserID string
|
ChatGPTUserID string
|
||||||
|
PlanType string
|
||||||
UserID string
|
UserID string
|
||||||
OrganizationID string
|
OrganizationID string
|
||||||
Organizations []OrganizationClaim
|
Organizations []OrganizationClaim
|
||||||
@@ -389,6 +402,7 @@ func (c *IDTokenClaims) GetUserInfo() *UserInfo {
|
|||||||
if c.OpenAIAuth != nil {
|
if c.OpenAIAuth != nil {
|
||||||
info.ChatGPTAccountID = c.OpenAIAuth.ChatGPTAccountID
|
info.ChatGPTAccountID = c.OpenAIAuth.ChatGPTAccountID
|
||||||
info.ChatGPTUserID = c.OpenAIAuth.ChatGPTUserID
|
info.ChatGPTUserID = c.OpenAIAuth.ChatGPTUserID
|
||||||
|
info.PlanType = c.OpenAIAuth.ChatGPTPlanType
|
||||||
info.UserID = c.OpenAIAuth.UserID
|
info.UserID = c.OpenAIAuth.UserID
|
||||||
info.Organizations = c.OpenAIAuth.Organizations
|
info.Organizations = c.OpenAIAuth.Organizations
|
||||||
|
|
||||||
|
|||||||
@@ -147,17 +147,47 @@ var (
|
|||||||
return 1
|
return 1
|
||||||
`)
|
`)
|
||||||
|
|
||||||
// cleanupExpiredSlotsScript - remove expired slots
|
// cleanupExpiredSlotsScript 清理单个账号/用户有序集合中过期槽位
|
||||||
// KEYS[1] = concurrency:account:{accountID}
|
// KEYS[1] = 有序集合键
|
||||||
// ARGV[1] = TTL (seconds)
|
// ARGV[1] = TTL(秒)
|
||||||
cleanupExpiredSlotsScript = redis.NewScript(`
|
cleanupExpiredSlotsScript = redis.NewScript(`
|
||||||
local key = KEYS[1]
|
local key = KEYS[1]
|
||||||
local ttl = tonumber(ARGV[1])
|
local ttl = tonumber(ARGV[1])
|
||||||
local timeResult = redis.call('TIME')
|
local timeResult = redis.call('TIME')
|
||||||
local now = tonumber(timeResult[1])
|
local now = tonumber(timeResult[1])
|
||||||
local expireBefore = now - ttl
|
local expireBefore = now - ttl
|
||||||
return redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
|
redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
|
||||||
`)
|
if redis.call('ZCARD', key) == 0 then
|
||||||
|
redis.call('DEL', key)
|
||||||
|
else
|
||||||
|
redis.call('EXPIRE', key, ttl)
|
||||||
|
end
|
||||||
|
return 1
|
||||||
|
`)
|
||||||
|
|
||||||
|
// startupCleanupScript 清理非当前进程前缀的槽位成员。
|
||||||
|
// KEYS 是有序集合键列表,ARGV[1] 是当前进程前缀,ARGV[2] 是槽位 TTL。
|
||||||
|
// 遍历每个 KEYS[i],移除前缀不匹配的成员,清空后删 key,否则刷新 EXPIRE。
|
||||||
|
startupCleanupScript = redis.NewScript(`
|
||||||
|
local activePrefix = ARGV[1]
|
||||||
|
local slotTTL = tonumber(ARGV[2])
|
||||||
|
local removed = 0
|
||||||
|
for i = 1, #KEYS do
|
||||||
|
local key = KEYS[i]
|
||||||
|
local members = redis.call('ZRANGE', key, 0, -1)
|
||||||
|
for _, member in ipairs(members) do
|
||||||
|
if string.sub(member, 1, string.len(activePrefix)) ~= activePrefix then
|
||||||
|
removed = removed + redis.call('ZREM', key, member)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
if redis.call('ZCARD', key) == 0 then
|
||||||
|
redis.call('DEL', key)
|
||||||
|
else
|
||||||
|
redis.call('EXPIRE', key, slotTTL)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
return removed
|
||||||
|
`)
|
||||||
)
|
)
|
||||||
|
|
||||||
type concurrencyCache struct {
|
type concurrencyCache struct {
|
||||||
@@ -463,3 +493,72 @@ func (c *concurrencyCache) CleanupExpiredAccountSlots(ctx context.Context, accou
|
|||||||
_, err := cleanupExpiredSlotsScript.Run(ctx, c.rdb, []string{key}, c.slotTTLSeconds).Result()
|
_, err := cleanupExpiredSlotsScript.Run(ctx, c.rdb, []string{key}, c.slotTTLSeconds).Result()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *concurrencyCache) CleanupStaleProcessSlots(ctx context.Context, activeRequestPrefix string) error {
|
||||||
|
if activeRequestPrefix == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 1. 清理有序集合中非当前进程前缀的成员
|
||||||
|
slotPatterns := []string{accountSlotKeyPrefix + "*", userSlotKeyPrefix + "*"}
|
||||||
|
for _, pattern := range slotPatterns {
|
||||||
|
if err := c.cleanupSlotsByPattern(ctx, pattern, activeRequestPrefix); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. 删除所有等待队列计数器(重启后计数器失效)
|
||||||
|
waitPatterns := []string{accountWaitKeyPrefix + "*", waitQueueKeyPrefix + "*"}
|
||||||
|
for _, pattern := range waitPatterns {
|
||||||
|
if err := c.deleteKeysByPattern(ctx, pattern); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanupSlotsByPattern 扫描匹配 pattern 的有序集合键,批量调用 Lua 脚本清理非当前进程成员。
|
||||||
|
func (c *concurrencyCache) cleanupSlotsByPattern(ctx context.Context, pattern, activePrefix string) error {
|
||||||
|
const scanCount = 200
|
||||||
|
var cursor uint64
|
||||||
|
for {
|
||||||
|
keys, nextCursor, err := c.rdb.Scan(ctx, cursor, pattern, scanCount).Result()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("scan %s: %w", pattern, err)
|
||||||
|
}
|
||||||
|
if len(keys) > 0 {
|
||||||
|
_, err := startupCleanupScript.Run(ctx, c.rdb, keys, activePrefix, c.slotTTLSeconds).Result()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("cleanup slots %s: %w", pattern, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cursor = nextCursor
|
||||||
|
if cursor == 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// deleteKeysByPattern 扫描匹配 pattern 的键并删除。
|
||||||
|
func (c *concurrencyCache) deleteKeysByPattern(ctx context.Context, pattern string) error {
|
||||||
|
const scanCount = 200
|
||||||
|
var cursor uint64
|
||||||
|
for {
|
||||||
|
keys, nextCursor, err := c.rdb.Scan(ctx, cursor, pattern, scanCount).Result()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("scan %s: %w", pattern, err)
|
||||||
|
}
|
||||||
|
if len(keys) > 0 {
|
||||||
|
if err := c.rdb.Del(ctx, keys...).Err(); err != nil {
|
||||||
|
return fmt.Errorf("del %s: %w", pattern, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cursor = nextCursor
|
||||||
|
if cursor == 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -25,6 +25,10 @@ type ConcurrencyCacheSuite struct {
|
|||||||
cache service.ConcurrencyCache
|
cache service.ConcurrencyCache
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConcurrencyCacheSuite(t *testing.T) {
|
||||||
|
suite.Run(t, new(ConcurrencyCacheSuite))
|
||||||
|
}
|
||||||
|
|
||||||
func (s *ConcurrencyCacheSuite) SetupTest() {
|
func (s *ConcurrencyCacheSuite) SetupTest() {
|
||||||
s.IntegrationRedisSuite.SetupTest()
|
s.IntegrationRedisSuite.SetupTest()
|
||||||
s.cache = NewConcurrencyCache(s.rdb, testSlotTTLMinutes, int(testSlotTTL.Seconds()))
|
s.cache = NewConcurrencyCache(s.rdb, testSlotTTLMinutes, int(testSlotTTL.Seconds()))
|
||||||
@@ -247,17 +251,41 @@ func (s *ConcurrencyCacheSuite) TestAccountWaitQueue_IncrementAndDecrement() {
|
|||||||
require.Equal(s.T(), 1, val, "expected account wait count 1")
|
require.Equal(s.T(), 1, val, "expected account wait count 1")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ConcurrencyCacheSuite) TestAccountWaitQueue_DecrementNoNegative() {
|
func (s *ConcurrencyCacheSuite) TestCleanupStaleProcessSlots() {
|
||||||
accountID := int64(301)
|
accountID := int64(901)
|
||||||
waitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
|
userID := int64(902)
|
||||||
|
accountKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
|
||||||
|
userKey := fmt.Sprintf("%s%d", userSlotKeyPrefix, userID)
|
||||||
|
userWaitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
|
||||||
|
accountWaitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
|
||||||
|
|
||||||
require.NoError(s.T(), s.cache.DecrementAccountWaitCount(s.ctx, accountID), "DecrementAccountWaitCount on non-existent key")
|
now := time.Now().Unix()
|
||||||
|
require.NoError(s.T(), s.rdb.ZAdd(s.ctx, accountKey,
|
||||||
|
redis.Z{Score: float64(now), Member: "oldproc-1"},
|
||||||
|
redis.Z{Score: float64(now), Member: "keep-1"},
|
||||||
|
).Err())
|
||||||
|
require.NoError(s.T(), s.rdb.ZAdd(s.ctx, userKey,
|
||||||
|
redis.Z{Score: float64(now), Member: "oldproc-2"},
|
||||||
|
redis.Z{Score: float64(now), Member: "keep-2"},
|
||||||
|
).Err())
|
||||||
|
require.NoError(s.T(), s.rdb.Set(s.ctx, userWaitKey, 3, time.Minute).Err())
|
||||||
|
require.NoError(s.T(), s.rdb.Set(s.ctx, accountWaitKey, 2, time.Minute).Err())
|
||||||
|
|
||||||
val, err := s.rdb.Get(s.ctx, waitKey).Int()
|
require.NoError(s.T(), s.cache.CleanupStaleProcessSlots(s.ctx, "keep-"))
|
||||||
if !errors.Is(err, redis.Nil) {
|
|
||||||
require.NoError(s.T(), err, "Get waitKey")
|
accountMembers, err := s.rdb.ZRange(s.ctx, accountKey, 0, -1).Result()
|
||||||
}
|
require.NoError(s.T(), err)
|
||||||
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative account wait count after decrement on empty")
|
require.Equal(s.T(), []string{"keep-1"}, accountMembers)
|
||||||
|
|
||||||
|
userMembers, err := s.rdb.ZRange(s.ctx, userKey, 0, -1).Result()
|
||||||
|
require.NoError(s.T(), err)
|
||||||
|
require.Equal(s.T(), []string{"keep-2"}, userMembers)
|
||||||
|
|
||||||
|
_, err = s.rdb.Get(s.ctx, userWaitKey).Result()
|
||||||
|
require.True(s.T(), errors.Is(err, redis.Nil))
|
||||||
|
|
||||||
|
_, err = s.rdb.Get(s.ctx, accountWaitKey).Result()
|
||||||
|
require.True(s.T(), errors.Is(err, redis.Nil))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ConcurrencyCacheSuite) TestGetAccountConcurrency_Missing() {
|
func (s *ConcurrencyCacheSuite) TestGetAccountConcurrency_Missing() {
|
||||||
@@ -407,6 +435,53 @@ func (s *ConcurrencyCacheSuite) TestCleanupExpiredAccountSlots_NoExpired() {
|
|||||||
require.Equal(s.T(), 2, cur)
|
require.Equal(s.T(), 2, cur)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConcurrencyCacheSuite(t *testing.T) {
|
func (s *ConcurrencyCacheSuite) TestCleanupStaleProcessSlots_RemovesOldPrefixesAndWaitCounters() {
|
||||||
suite.Run(t, new(ConcurrencyCacheSuite))
|
accountID := int64(901)
|
||||||
|
userID := int64(902)
|
||||||
|
accountSlotKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
|
||||||
|
userSlotKey := fmt.Sprintf("%s%d", userSlotKeyPrefix, userID)
|
||||||
|
userWaitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
|
||||||
|
accountWaitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
|
||||||
|
|
||||||
|
now := float64(time.Now().Unix())
|
||||||
|
require.NoError(s.T(), s.rdb.ZAdd(s.ctx, accountSlotKey,
|
||||||
|
redis.Z{Score: now, Member: "oldproc-1"},
|
||||||
|
redis.Z{Score: now, Member: "activeproc-1"},
|
||||||
|
).Err())
|
||||||
|
require.NoError(s.T(), s.rdb.Expire(s.ctx, accountSlotKey, testSlotTTL).Err())
|
||||||
|
require.NoError(s.T(), s.rdb.ZAdd(s.ctx, userSlotKey,
|
||||||
|
redis.Z{Score: now, Member: "oldproc-2"},
|
||||||
|
redis.Z{Score: now, Member: "activeproc-2"},
|
||||||
|
).Err())
|
||||||
|
require.NoError(s.T(), s.rdb.Expire(s.ctx, userSlotKey, testSlotTTL).Err())
|
||||||
|
require.NoError(s.T(), s.rdb.Set(s.ctx, userWaitKey, 3, testSlotTTL).Err())
|
||||||
|
require.NoError(s.T(), s.rdb.Set(s.ctx, accountWaitKey, 2, testSlotTTL).Err())
|
||||||
|
|
||||||
|
require.NoError(s.T(), s.cache.CleanupStaleProcessSlots(s.ctx, "activeproc-"))
|
||||||
|
|
||||||
|
accountMembers, err := s.rdb.ZRange(s.ctx, accountSlotKey, 0, -1).Result()
|
||||||
|
require.NoError(s.T(), err)
|
||||||
|
require.Equal(s.T(), []string{"activeproc-1"}, accountMembers)
|
||||||
|
|
||||||
|
userMembers, err := s.rdb.ZRange(s.ctx, userSlotKey, 0, -1).Result()
|
||||||
|
require.NoError(s.T(), err)
|
||||||
|
require.Equal(s.T(), []string{"activeproc-2"}, userMembers)
|
||||||
|
|
||||||
|
_, err = s.rdb.Get(s.ctx, userWaitKey).Result()
|
||||||
|
require.ErrorIs(s.T(), err, redis.Nil)
|
||||||
|
_, err = s.rdb.Get(s.ctx, accountWaitKey).Result()
|
||||||
|
require.ErrorIs(s.T(), err, redis.Nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ConcurrencyCacheSuite) TestCleanupStaleProcessSlots_DeletesEmptySlotKeys() {
|
||||||
|
accountID := int64(903)
|
||||||
|
accountSlotKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
|
||||||
|
require.NoError(s.T(), s.rdb.ZAdd(s.ctx, accountSlotKey, redis.Z{Score: float64(time.Now().Unix()), Member: "oldproc-1"}).Err())
|
||||||
|
require.NoError(s.T(), s.rdb.Expire(s.ctx, accountSlotKey, testSlotTTL).Err())
|
||||||
|
|
||||||
|
require.NoError(s.T(), s.cache.CleanupStaleProcessSlots(s.ctx, "activeproc-"))
|
||||||
|
|
||||||
|
exists, err := s.rdb.Exists(s.ctx, accountSlotKey).Result()
|
||||||
|
require.NoError(s.T(), err)
|
||||||
|
require.EqualValues(s.T(), 0, exists)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -645,7 +645,7 @@ func newContractDeps(t *testing.T) *contractDeps {
|
|||||||
settingRepo := newStubSettingRepo()
|
settingRepo := newStubSettingRepo()
|
||||||
settingService := service.NewSettingService(settingRepo, cfg)
|
settingService := service.NewSettingService(settingRepo, cfg)
|
||||||
|
|
||||||
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil)
|
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||||
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil)
|
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil)
|
||||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||||
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ func TestAdminAuthJWTValidatesTokenVersion(t *testing.T) {
|
|||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
cfg := &config.Config{JWT: config.JWTConfig{Secret: "test-secret", ExpireHour: 1}}
|
cfg := &config.Config{JWT: config.JWTConfig{Secret: "test-secret", ExpireHour: 1}}
|
||||||
authService := service.NewAuthService(nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
|
authService := service.NewAuthService(nil, nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
|
||||||
|
|
||||||
admin := &service.User{
|
admin := &service.User{
|
||||||
ID: 1,
|
ID: 1,
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ func newJWTTestEnv(users map[int64]*service.User) (*gin.Engine, *service.AuthSer
|
|||||||
cfg.JWT.AccessTokenExpireMinutes = 60
|
cfg.JWT.AccessTokenExpireMinutes = 60
|
||||||
|
|
||||||
userRepo := &stubJWTUserRepo{users: users}
|
userRepo := &stubJWTUserRepo{users: users}
|
||||||
authSvc := service.NewAuthService(userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
|
authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
|
||||||
userSvc := service.NewUserService(userRepo, nil, nil)
|
userSvc := service.NewUserService(userRepo, nil, nil)
|
||||||
mw := NewJWTAuthMiddleware(authSvc, userSvc)
|
mw := NewJWTAuthMiddleware(authSvc, userSvc)
|
||||||
|
|
||||||
|
|||||||
@@ -264,6 +264,8 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
|||||||
accounts.POST("/batch-update-credentials", h.Admin.Account.BatchUpdateCredentials)
|
accounts.POST("/batch-update-credentials", h.Admin.Account.BatchUpdateCredentials)
|
||||||
accounts.POST("/batch-refresh-tier", h.Admin.Account.BatchRefreshTier)
|
accounts.POST("/batch-refresh-tier", h.Admin.Account.BatchRefreshTier)
|
||||||
accounts.POST("/bulk-update", h.Admin.Account.BulkUpdate)
|
accounts.POST("/bulk-update", h.Admin.Account.BulkUpdate)
|
||||||
|
accounts.POST("/batch-clear-error", h.Admin.Account.BatchClearError)
|
||||||
|
accounts.POST("/batch-refresh", h.Admin.Account.BatchRefresh)
|
||||||
|
|
||||||
// Antigravity 默认模型映射
|
// Antigravity 默认模型映射
|
||||||
accounts.GET("/antigravity/default-model-mapping", h.Admin.Account.GetAntigravityDefaultModelMapping)
|
accounts.GET("/antigravity/default-model-mapping", h.Admin.Account.GetAntigravityDefaultModelMapping)
|
||||||
@@ -396,6 +398,9 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
|||||||
// 请求整流器配置
|
// 请求整流器配置
|
||||||
adminSettings.GET("/rectifier", h.Admin.Setting.GetRectifierSettings)
|
adminSettings.GET("/rectifier", h.Admin.Setting.GetRectifierSettings)
|
||||||
adminSettings.PUT("/rectifier", h.Admin.Setting.UpdateRectifierSettings)
|
adminSettings.PUT("/rectifier", h.Admin.Setting.UpdateRectifierSettings)
|
||||||
|
// Beta 策略配置
|
||||||
|
adminSettings.GET("/beta-policy", h.Admin.Setting.GetBetaPolicySettings)
|
||||||
|
adminSettings.PUT("/beta-policy", h.Admin.Setting.UpdateBetaPolicySettings)
|
||||||
// Sora S3 存储配置
|
// Sora S3 存储配置
|
||||||
adminSettings.GET("/sora-s3", h.Admin.Setting.GetSoraS3Settings)
|
adminSettings.GET("/sora-s3", h.Admin.Setting.GetSoraS3Settings)
|
||||||
adminSettings.PUT("/sora-s3", h.Admin.Setting.UpdateSoraS3Settings)
|
adminSettings.PUT("/sora-s3", h.Admin.Setting.UpdateSoraS3Settings)
|
||||||
|
|||||||
@@ -61,6 +61,12 @@ func RegisterAuthRoutes(
|
|||||||
}), h.Auth.ResetPassword)
|
}), h.Auth.ResetPassword)
|
||||||
auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart)
|
auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart)
|
||||||
auth.GET("/oauth/linuxdo/callback", h.Auth.LinuxDoOAuthCallback)
|
auth.GET("/oauth/linuxdo/callback", h.Auth.LinuxDoOAuthCallback)
|
||||||
|
auth.POST("/oauth/linuxdo/complete-registration",
|
||||||
|
rateLimiter.LimitWithOptions("oauth-linuxdo-complete", 10, time.Minute, middleware.RateLimitOptions{
|
||||||
|
FailureMode: middleware.RateLimitFailClose,
|
||||||
|
}),
|
||||||
|
h.Auth.CompleteLinuxDoOAuthRegistration,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 公开设置(无需认证)
|
// 公开设置(无需认证)
|
||||||
|
|||||||
@@ -432,6 +432,7 @@ type adminServiceImpl struct {
|
|||||||
entClient *dbent.Client // 用于开启数据库事务
|
entClient *dbent.Client // 用于开启数据库事务
|
||||||
settingService *SettingService
|
settingService *SettingService
|
||||||
defaultSubAssigner DefaultSubscriptionAssigner
|
defaultSubAssigner DefaultSubscriptionAssigner
|
||||||
|
userSubRepo UserSubscriptionRepository
|
||||||
}
|
}
|
||||||
|
|
||||||
type userGroupRateBatchReader interface {
|
type userGroupRateBatchReader interface {
|
||||||
@@ -459,6 +460,7 @@ func NewAdminService(
|
|||||||
entClient *dbent.Client,
|
entClient *dbent.Client,
|
||||||
settingService *SettingService,
|
settingService *SettingService,
|
||||||
defaultSubAssigner DefaultSubscriptionAssigner,
|
defaultSubAssigner DefaultSubscriptionAssigner,
|
||||||
|
userSubRepo UserSubscriptionRepository,
|
||||||
) AdminService {
|
) AdminService {
|
||||||
return &adminServiceImpl{
|
return &adminServiceImpl{
|
||||||
userRepo: userRepo,
|
userRepo: userRepo,
|
||||||
@@ -476,6 +478,7 @@ func NewAdminService(
|
|||||||
entClient: entClient,
|
entClient: entClient,
|
||||||
settingService: settingService,
|
settingService: settingService,
|
||||||
defaultSubAssigner: defaultSubAssigner,
|
defaultSubAssigner: defaultSubAssigner,
|
||||||
|
userSubRepo: userSubRepo,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1277,9 +1280,17 @@ func (s *adminServiceImpl) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID i
|
|||||||
if group.Status != StatusActive {
|
if group.Status != StatusActive {
|
||||||
return nil, infraerrors.BadRequest("GROUP_NOT_ACTIVE", "target group is not active")
|
return nil, infraerrors.BadRequest("GROUP_NOT_ACTIVE", "target group is not active")
|
||||||
}
|
}
|
||||||
// 订阅类型分组:不允许通过此 API 直接绑定,需通过订阅管理流程
|
// 订阅类型分组:用户须持有该分组的有效订阅才可绑定
|
||||||
if group.IsSubscriptionType() {
|
if group.IsSubscriptionType() {
|
||||||
return nil, infraerrors.BadRequest("SUBSCRIPTION_GROUP_NOT_ALLOWED", "subscription groups must be managed through the subscription workflow")
|
if s.userSubRepo == nil {
|
||||||
|
return nil, infraerrors.InternalServer("SUBSCRIPTION_REPOSITORY_UNAVAILABLE", "subscription repository is not configured")
|
||||||
|
}
|
||||||
|
if _, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, apiKey.UserID, *groupID); err != nil {
|
||||||
|
if errors.Is(err, ErrSubscriptionNotFound) {
|
||||||
|
return nil, infraerrors.BadRequest("SUBSCRIPTION_REQUIRED", "user does not have an active subscription for this group")
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
gid := *groupID
|
gid := *groupID
|
||||||
@@ -1287,7 +1298,7 @@ func (s *adminServiceImpl) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID i
|
|||||||
apiKey.Group = group
|
apiKey.Group = group
|
||||||
|
|
||||||
// 专属标准分组:使用事务保证「添加分组权限」与「更新 API Key」的原子性
|
// 专属标准分组:使用事务保证「添加分组权限」与「更新 API Key」的原子性
|
||||||
if group.IsExclusive {
|
if group.IsExclusive && !group.IsSubscriptionType() {
|
||||||
opCtx := ctx
|
opCtx := ctx
|
||||||
var tx *dbent.Tx
|
var tx *dbent.Tx
|
||||||
if s.entClient == nil {
|
if s.entClient == nil {
|
||||||
|
|||||||
@@ -32,28 +32,44 @@ func (s *userRepoStubForGroupUpdate) AddGroupToAllowedGroups(_ context.Context,
|
|||||||
return s.addGroupErr
|
return s.addGroupErr
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *userRepoStubForGroupUpdate) Create(context.Context, *User) error { panic("unexpected") }
|
func (s *userRepoStubForGroupUpdate) Create(context.Context, *User) error { panic("unexpected") }
|
||||||
func (s *userRepoStubForGroupUpdate) GetByID(context.Context, int64) (*User, error) { panic("unexpected") }
|
func (s *userRepoStubForGroupUpdate) GetByID(context.Context, int64) (*User, error) {
|
||||||
func (s *userRepoStubForGroupUpdate) GetByEmail(context.Context, string) (*User, error) { panic("unexpected") }
|
panic("unexpected")
|
||||||
func (s *userRepoStubForGroupUpdate) GetFirstAdmin(context.Context) (*User, error) { panic("unexpected") }
|
}
|
||||||
func (s *userRepoStubForGroupUpdate) Update(context.Context, *User) error { panic("unexpected") }
|
func (s *userRepoStubForGroupUpdate) GetByEmail(context.Context, string) (*User, error) {
|
||||||
func (s *userRepoStubForGroupUpdate) Delete(context.Context, int64) error { panic("unexpected") }
|
panic("unexpected")
|
||||||
|
}
|
||||||
|
func (s *userRepoStubForGroupUpdate) GetFirstAdmin(context.Context) (*User, error) {
|
||||||
|
panic("unexpected")
|
||||||
|
}
|
||||||
|
func (s *userRepoStubForGroupUpdate) Update(context.Context, *User) error { panic("unexpected") }
|
||||||
|
func (s *userRepoStubForGroupUpdate) Delete(context.Context, int64) error { panic("unexpected") }
|
||||||
func (s *userRepoStubForGroupUpdate) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
|
func (s *userRepoStubForGroupUpdate) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
|
||||||
panic("unexpected")
|
panic("unexpected")
|
||||||
}
|
}
|
||||||
func (s *userRepoStubForGroupUpdate) ListWithFilters(context.Context, pagination.PaginationParams, UserListFilters) ([]User, *pagination.PaginationResult, error) {
|
func (s *userRepoStubForGroupUpdate) ListWithFilters(context.Context, pagination.PaginationParams, UserListFilters) ([]User, *pagination.PaginationResult, error) {
|
||||||
panic("unexpected")
|
panic("unexpected")
|
||||||
}
|
}
|
||||||
func (s *userRepoStubForGroupUpdate) UpdateBalance(context.Context, int64, float64) error { panic("unexpected") }
|
func (s *userRepoStubForGroupUpdate) UpdateBalance(context.Context, int64, float64) error {
|
||||||
func (s *userRepoStubForGroupUpdate) DeductBalance(context.Context, int64, float64) error { panic("unexpected") }
|
panic("unexpected")
|
||||||
func (s *userRepoStubForGroupUpdate) UpdateConcurrency(context.Context, int64, int) error { panic("unexpected") }
|
}
|
||||||
func (s *userRepoStubForGroupUpdate) ExistsByEmail(context.Context, string) (bool, error) { panic("unexpected") }
|
func (s *userRepoStubForGroupUpdate) DeductBalance(context.Context, int64, float64) error {
|
||||||
|
panic("unexpected")
|
||||||
|
}
|
||||||
|
func (s *userRepoStubForGroupUpdate) UpdateConcurrency(context.Context, int64, int) error {
|
||||||
|
panic("unexpected")
|
||||||
|
}
|
||||||
|
func (s *userRepoStubForGroupUpdate) ExistsByEmail(context.Context, string) (bool, error) {
|
||||||
|
panic("unexpected")
|
||||||
|
}
|
||||||
func (s *userRepoStubForGroupUpdate) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
|
func (s *userRepoStubForGroupUpdate) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
|
||||||
panic("unexpected")
|
panic("unexpected")
|
||||||
}
|
}
|
||||||
func (s *userRepoStubForGroupUpdate) UpdateTotpSecret(context.Context, int64, *string) error { panic("unexpected") }
|
func (s *userRepoStubForGroupUpdate) UpdateTotpSecret(context.Context, int64, *string) error {
|
||||||
func (s *userRepoStubForGroupUpdate) EnableTotp(context.Context, int64) error { panic("unexpected") }
|
panic("unexpected")
|
||||||
func (s *userRepoStubForGroupUpdate) DisableTotp(context.Context, int64) error { panic("unexpected") }
|
}
|
||||||
|
func (s *userRepoStubForGroupUpdate) EnableTotp(context.Context, int64) error { panic("unexpected") }
|
||||||
|
func (s *userRepoStubForGroupUpdate) DisableTotp(context.Context, int64) error { panic("unexpected") }
|
||||||
|
|
||||||
// apiKeyRepoStubForGroupUpdate implements APIKeyRepository for AdminUpdateAPIKeyGroupID tests.
|
// apiKeyRepoStubForGroupUpdate implements APIKeyRepository for AdminUpdateAPIKeyGroupID tests.
|
||||||
type apiKeyRepoStubForGroupUpdate struct {
|
type apiKeyRepoStubForGroupUpdate struct {
|
||||||
@@ -194,6 +210,29 @@ func (s *groupRepoStubForGroupUpdate) UpdateSortOrders(context.Context, []GroupS
|
|||||||
panic("unexpected")
|
panic("unexpected")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type userSubRepoStubForGroupUpdate struct {
|
||||||
|
userSubRepoNoop
|
||||||
|
getActiveSub *UserSubscription
|
||||||
|
getActiveErr error
|
||||||
|
called bool
|
||||||
|
calledUserID int64
|
||||||
|
calledGroupID int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *userSubRepoStubForGroupUpdate) GetActiveByUserIDAndGroupID(_ context.Context, userID, groupID int64) (*UserSubscription, error) {
|
||||||
|
s.called = true
|
||||||
|
s.calledUserID = userID
|
||||||
|
s.calledGroupID = groupID
|
||||||
|
if s.getActiveErr != nil {
|
||||||
|
return nil, s.getActiveErr
|
||||||
|
}
|
||||||
|
if s.getActiveSub == nil {
|
||||||
|
return nil, ErrSubscriptionNotFound
|
||||||
|
}
|
||||||
|
clone := *s.getActiveSub
|
||||||
|
return &clone, nil
|
||||||
|
}
|
||||||
|
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
// Tests
|
// Tests
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
@@ -386,14 +425,49 @@ func TestAdminService_AdminUpdateAPIKeyGroupID_NonExclusiveGroup_NoAllowedGroupU
|
|||||||
func TestAdminService_AdminUpdateAPIKeyGroupID_SubscriptionGroup_Blocked(t *testing.T) {
|
func TestAdminService_AdminUpdateAPIKeyGroupID_SubscriptionGroup_Blocked(t *testing.T) {
|
||||||
existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil}
|
existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil}
|
||||||
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
|
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
|
||||||
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Sub", Status: StatusActive, IsExclusive: true, SubscriptionType: SubscriptionTypeSubscription}}
|
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Sub", Status: StatusActive, IsExclusive: false, SubscriptionType: SubscriptionTypeSubscription}}
|
||||||
|
userRepo := &userRepoStubForGroupUpdate{}
|
||||||
|
userSubRepo := &userSubRepoStubForGroupUpdate{getActiveErr: ErrSubscriptionNotFound}
|
||||||
|
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo, userSubRepo: userSubRepo}
|
||||||
|
|
||||||
|
// 无有效订阅时应拒绝绑定
|
||||||
|
_, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10))
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Equal(t, "SUBSCRIPTION_REQUIRED", infraerrors.Reason(err))
|
||||||
|
require.True(t, userSubRepo.called)
|
||||||
|
require.Equal(t, int64(42), userSubRepo.calledUserID)
|
||||||
|
require.Equal(t, int64(10), userSubRepo.calledGroupID)
|
||||||
|
require.False(t, userRepo.addGroupCalled)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminService_AdminUpdateAPIKeyGroupID_SubscriptionGroup_RequiresRepo(t *testing.T) {
|
||||||
|
existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil}
|
||||||
|
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
|
||||||
|
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Sub", Status: StatusActive, IsExclusive: false, SubscriptionType: SubscriptionTypeSubscription}}
|
||||||
userRepo := &userRepoStubForGroupUpdate{}
|
userRepo := &userRepoStubForGroupUpdate{}
|
||||||
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo}
|
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo}
|
||||||
|
|
||||||
// 订阅类型分组应被阻止绑定
|
|
||||||
_, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10))
|
_, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10))
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Equal(t, "SUBSCRIPTION_GROUP_NOT_ALLOWED", infraerrors.Reason(err))
|
require.Equal(t, "SUBSCRIPTION_REPOSITORY_UNAVAILABLE", infraerrors.Reason(err))
|
||||||
|
require.False(t, userRepo.addGroupCalled)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminService_AdminUpdateAPIKeyGroupID_SubscriptionGroup_AllowsActiveSubscription(t *testing.T) {
|
||||||
|
existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil}
|
||||||
|
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
|
||||||
|
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Sub", Status: StatusActive, IsExclusive: true, SubscriptionType: SubscriptionTypeSubscription}}
|
||||||
|
userRepo := &userRepoStubForGroupUpdate{}
|
||||||
|
userSubRepo := &userSubRepoStubForGroupUpdate{
|
||||||
|
getActiveSub: &UserSubscription{ID: 99, UserID: 42, GroupID: 10},
|
||||||
|
}
|
||||||
|
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo, userSubRepo: userSubRepo}
|
||||||
|
|
||||||
|
got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.True(t, userSubRepo.called)
|
||||||
|
require.NotNil(t, got.APIKey.GroupID)
|
||||||
|
require.Equal(t, int64(10), *got.APIKey.GroupID)
|
||||||
require.False(t, userRepo.addGroupCalled)
|
require.False(t, userRepo.addGroupCalled)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||||
@@ -21,24 +22,25 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrInvalidCredentials = infraerrors.Unauthorized("INVALID_CREDENTIALS", "invalid email or password")
|
ErrInvalidCredentials = infraerrors.Unauthorized("INVALID_CREDENTIALS", "invalid email or password")
|
||||||
ErrUserNotActive = infraerrors.Forbidden("USER_NOT_ACTIVE", "user is not active")
|
ErrUserNotActive = infraerrors.Forbidden("USER_NOT_ACTIVE", "user is not active")
|
||||||
ErrEmailExists = infraerrors.Conflict("EMAIL_EXISTS", "email already exists")
|
ErrEmailExists = infraerrors.Conflict("EMAIL_EXISTS", "email already exists")
|
||||||
ErrEmailReserved = infraerrors.BadRequest("EMAIL_RESERVED", "email is reserved")
|
ErrEmailReserved = infraerrors.BadRequest("EMAIL_RESERVED", "email is reserved")
|
||||||
ErrInvalidToken = infraerrors.Unauthorized("INVALID_TOKEN", "invalid token")
|
ErrInvalidToken = infraerrors.Unauthorized("INVALID_TOKEN", "invalid token")
|
||||||
ErrTokenExpired = infraerrors.Unauthorized("TOKEN_EXPIRED", "token has expired")
|
ErrTokenExpired = infraerrors.Unauthorized("TOKEN_EXPIRED", "token has expired")
|
||||||
ErrAccessTokenExpired = infraerrors.Unauthorized("ACCESS_TOKEN_EXPIRED", "access token has expired")
|
ErrAccessTokenExpired = infraerrors.Unauthorized("ACCESS_TOKEN_EXPIRED", "access token has expired")
|
||||||
ErrTokenTooLarge = infraerrors.BadRequest("TOKEN_TOO_LARGE", "token too large")
|
ErrTokenTooLarge = infraerrors.BadRequest("TOKEN_TOO_LARGE", "token too large")
|
||||||
ErrTokenRevoked = infraerrors.Unauthorized("TOKEN_REVOKED", "token has been revoked")
|
ErrTokenRevoked = infraerrors.Unauthorized("TOKEN_REVOKED", "token has been revoked")
|
||||||
ErrRefreshTokenInvalid = infraerrors.Unauthorized("REFRESH_TOKEN_INVALID", "invalid refresh token")
|
ErrRefreshTokenInvalid = infraerrors.Unauthorized("REFRESH_TOKEN_INVALID", "invalid refresh token")
|
||||||
ErrRefreshTokenExpired = infraerrors.Unauthorized("REFRESH_TOKEN_EXPIRED", "refresh token has expired")
|
ErrRefreshTokenExpired = infraerrors.Unauthorized("REFRESH_TOKEN_EXPIRED", "refresh token has expired")
|
||||||
ErrRefreshTokenReused = infraerrors.Unauthorized("REFRESH_TOKEN_REUSED", "refresh token has been reused")
|
ErrRefreshTokenReused = infraerrors.Unauthorized("REFRESH_TOKEN_REUSED", "refresh token has been reused")
|
||||||
ErrEmailVerifyRequired = infraerrors.BadRequest("EMAIL_VERIFY_REQUIRED", "email verification is required")
|
ErrEmailVerifyRequired = infraerrors.BadRequest("EMAIL_VERIFY_REQUIRED", "email verification is required")
|
||||||
ErrEmailSuffixNotAllowed = infraerrors.BadRequest("EMAIL_SUFFIX_NOT_ALLOWED", "email suffix is not allowed")
|
ErrEmailSuffixNotAllowed = infraerrors.BadRequest("EMAIL_SUFFIX_NOT_ALLOWED", "email suffix is not allowed")
|
||||||
ErrRegDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
|
ErrRegDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
|
||||||
ErrServiceUnavailable = infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable")
|
ErrServiceUnavailable = infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable")
|
||||||
ErrInvitationCodeRequired = infraerrors.BadRequest("INVITATION_CODE_REQUIRED", "invitation code is required")
|
ErrInvitationCodeRequired = infraerrors.BadRequest("INVITATION_CODE_REQUIRED", "invitation code is required")
|
||||||
ErrInvitationCodeInvalid = infraerrors.BadRequest("INVITATION_CODE_INVALID", "invalid or used invitation code")
|
ErrInvitationCodeInvalid = infraerrors.BadRequest("INVITATION_CODE_INVALID", "invalid or used invitation code")
|
||||||
|
ErrOAuthInvitationRequired = infraerrors.Forbidden("OAUTH_INVITATION_REQUIRED", "invitation code required to complete oauth registration")
|
||||||
)
|
)
|
||||||
|
|
||||||
// maxTokenLength 限制 token 大小,避免超长 header 触发解析时的异常内存分配。
|
// maxTokenLength 限制 token 大小,避免超长 header 触发解析时的异常内存分配。
|
||||||
@@ -58,6 +60,7 @@ type JWTClaims struct {
|
|||||||
|
|
||||||
// AuthService 认证服务
|
// AuthService 认证服务
|
||||||
type AuthService struct {
|
type AuthService struct {
|
||||||
|
entClient *dbent.Client
|
||||||
userRepo UserRepository
|
userRepo UserRepository
|
||||||
redeemRepo RedeemCodeRepository
|
redeemRepo RedeemCodeRepository
|
||||||
refreshTokenCache RefreshTokenCache
|
refreshTokenCache RefreshTokenCache
|
||||||
@@ -76,6 +79,7 @@ type DefaultSubscriptionAssigner interface {
|
|||||||
|
|
||||||
// NewAuthService 创建认证服务实例
|
// NewAuthService 创建认证服务实例
|
||||||
func NewAuthService(
|
func NewAuthService(
|
||||||
|
entClient *dbent.Client,
|
||||||
userRepo UserRepository,
|
userRepo UserRepository,
|
||||||
redeemRepo RedeemCodeRepository,
|
redeemRepo RedeemCodeRepository,
|
||||||
refreshTokenCache RefreshTokenCache,
|
refreshTokenCache RefreshTokenCache,
|
||||||
@@ -88,6 +92,7 @@ func NewAuthService(
|
|||||||
defaultSubAssigner DefaultSubscriptionAssigner,
|
defaultSubAssigner DefaultSubscriptionAssigner,
|
||||||
) *AuthService {
|
) *AuthService {
|
||||||
return &AuthService{
|
return &AuthService{
|
||||||
|
entClient: entClient,
|
||||||
userRepo: userRepo,
|
userRepo: userRepo,
|
||||||
redeemRepo: redeemRepo,
|
redeemRepo: redeemRepo,
|
||||||
refreshTokenCache: refreshTokenCache,
|
refreshTokenCache: refreshTokenCache,
|
||||||
@@ -523,9 +528,10 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
|
|||||||
return token, user, nil
|
return token, user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoginOrRegisterOAuthWithTokenPair 用于第三方 OAuth/SSO 登录,返回完整的 TokenPair
|
// LoginOrRegisterOAuthWithTokenPair 用于第三方 OAuth/SSO 登录,返回完整的 TokenPair。
|
||||||
// 与 LoginOrRegisterOAuth 功能相同,但返回 TokenPair 而非单个 token
|
// 与 LoginOrRegisterOAuth 功能相同,但返回 TokenPair 而非单个 token。
|
||||||
func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, email, username string) (*TokenPair, *User, error) {
|
// invitationCode 仅在邀请码注册模式下新用户注册时使用;已有账号登录时忽略。
|
||||||
|
func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, email, username, invitationCode string) (*TokenPair, *User, error) {
|
||||||
// 检查 refreshTokenCache 是否可用
|
// 检查 refreshTokenCache 是否可用
|
||||||
if s.refreshTokenCache == nil {
|
if s.refreshTokenCache == nil {
|
||||||
return nil, nil, errors.New("refresh token cache not configured")
|
return nil, nil, errors.New("refresh token cache not configured")
|
||||||
@@ -552,6 +558,22 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
|
|||||||
return nil, nil, ErrRegDisabled
|
return nil, nil, ErrRegDisabled
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 检查是否需要邀请码
|
||||||
|
var invitationRedeemCode *RedeemCode
|
||||||
|
if s.settingService != nil && s.settingService.IsInvitationCodeEnabled(ctx) {
|
||||||
|
if invitationCode == "" {
|
||||||
|
return nil, nil, ErrOAuthInvitationRequired
|
||||||
|
}
|
||||||
|
redeemCode, err := s.redeemRepo.GetByCode(ctx, invitationCode)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, ErrInvitationCodeInvalid
|
||||||
|
}
|
||||||
|
if redeemCode.Type != RedeemTypeInvitation || redeemCode.Status != StatusUnused {
|
||||||
|
return nil, nil, ErrInvitationCodeInvalid
|
||||||
|
}
|
||||||
|
invitationRedeemCode = redeemCode
|
||||||
|
}
|
||||||
|
|
||||||
randomPassword, err := randomHexString(32)
|
randomPassword, err := randomHexString(32)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to generate random password for oauth signup: %v", err)
|
logger.LegacyPrintf("service.auth", "[Auth] Failed to generate random password for oauth signup: %v", err)
|
||||||
@@ -579,20 +601,58 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
|
|||||||
Status: StatusActive,
|
Status: StatusActive,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.userRepo.Create(ctx, newUser); err != nil {
|
if s.entClient != nil && invitationRedeemCode != nil {
|
||||||
if errors.Is(err, ErrEmailExists) {
|
tx, err := s.entClient.Tx(ctx)
|
||||||
user, err = s.userRepo.GetByEmail(ctx, email)
|
if err != nil {
|
||||||
if err != nil {
|
logger.LegacyPrintf("service.auth", "[Auth] Failed to begin transaction for oauth registration: %v", err)
|
||||||
logger.LegacyPrintf("service.auth", "[Auth] Database error getting user after conflict: %v", err)
|
return nil, nil, ErrServiceUnavailable
|
||||||
|
}
|
||||||
|
defer func() { _ = tx.Rollback() }()
|
||||||
|
txCtx := dbent.NewTxContext(ctx, tx)
|
||||||
|
|
||||||
|
if err := s.userRepo.Create(txCtx, newUser); err != nil {
|
||||||
|
if errors.Is(err, ErrEmailExists) {
|
||||||
|
user, err = s.userRepo.GetByEmail(ctx, email)
|
||||||
|
if err != nil {
|
||||||
|
logger.LegacyPrintf("service.auth", "[Auth] Database error getting user after conflict: %v", err)
|
||||||
|
return nil, nil, ErrServiceUnavailable
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
logger.LegacyPrintf("service.auth", "[Auth] Database error creating oauth user: %v", err)
|
||||||
return nil, nil, ErrServiceUnavailable
|
return nil, nil, ErrServiceUnavailable
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
logger.LegacyPrintf("service.auth", "[Auth] Database error creating oauth user: %v", err)
|
if err := s.redeemRepo.Use(txCtx, invitationRedeemCode.ID, newUser.ID); err != nil {
|
||||||
return nil, nil, ErrServiceUnavailable
|
return nil, nil, ErrInvitationCodeInvalid
|
||||||
|
}
|
||||||
|
if err := tx.Commit(); err != nil {
|
||||||
|
logger.LegacyPrintf("service.auth", "[Auth] Failed to commit oauth registration transaction: %v", err)
|
||||||
|
return nil, nil, ErrServiceUnavailable
|
||||||
|
}
|
||||||
|
user = newUser
|
||||||
|
s.assignDefaultSubscriptions(ctx, user.ID)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
user = newUser
|
if err := s.userRepo.Create(ctx, newUser); err != nil {
|
||||||
s.assignDefaultSubscriptions(ctx, user.ID)
|
if errors.Is(err, ErrEmailExists) {
|
||||||
|
user, err = s.userRepo.GetByEmail(ctx, email)
|
||||||
|
if err != nil {
|
||||||
|
logger.LegacyPrintf("service.auth", "[Auth] Database error getting user after conflict: %v", err)
|
||||||
|
return nil, nil, ErrServiceUnavailable
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
logger.LegacyPrintf("service.auth", "[Auth] Database error creating oauth user: %v", err)
|
||||||
|
return nil, nil, ErrServiceUnavailable
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
user = newUser
|
||||||
|
s.assignDefaultSubscriptions(ctx, user.ID)
|
||||||
|
if invitationRedeemCode != nil {
|
||||||
|
if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil {
|
||||||
|
return nil, nil, ErrInvitationCodeInvalid
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
logger.LegacyPrintf("service.auth", "[Auth] Database error during oauth login: %v", err)
|
logger.LegacyPrintf("service.auth", "[Auth] Database error during oauth login: %v", err)
|
||||||
@@ -618,6 +678,63 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
|
|||||||
return tokenPair, user, nil
|
return tokenPair, user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// pendingOAuthTokenTTL is the validity period for pending OAuth tokens.
|
||||||
|
const pendingOAuthTokenTTL = 10 * time.Minute
|
||||||
|
|
||||||
|
// pendingOAuthPurpose is the purpose claim value for pending OAuth registration tokens.
|
||||||
|
const pendingOAuthPurpose = "pending_oauth_registration"
|
||||||
|
|
||||||
|
type pendingOAuthClaims struct {
|
||||||
|
Email string `json:"email"`
|
||||||
|
Username string `json:"username"`
|
||||||
|
Purpose string `json:"purpose"`
|
||||||
|
jwt.RegisteredClaims
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreatePendingOAuthToken generates a short-lived JWT that carries the OAuth identity
|
||||||
|
// while waiting for the user to supply an invitation code.
|
||||||
|
func (s *AuthService) CreatePendingOAuthToken(email, username string) (string, error) {
|
||||||
|
now := time.Now()
|
||||||
|
claims := &pendingOAuthClaims{
|
||||||
|
Email: email,
|
||||||
|
Username: username,
|
||||||
|
Purpose: pendingOAuthPurpose,
|
||||||
|
RegisteredClaims: jwt.RegisteredClaims{
|
||||||
|
ExpiresAt: jwt.NewNumericDate(now.Add(pendingOAuthTokenTTL)),
|
||||||
|
IssuedAt: jwt.NewNumericDate(now),
|
||||||
|
NotBefore: jwt.NewNumericDate(now),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||||
|
return token.SignedString([]byte(s.cfg.JWT.Secret))
|
||||||
|
}
|
||||||
|
|
||||||
|
// VerifyPendingOAuthToken validates a pending OAuth token and returns the embedded identity.
|
||||||
|
// Returns ErrInvalidToken when the token is invalid or expired.
|
||||||
|
func (s *AuthService) VerifyPendingOAuthToken(tokenStr string) (email, username string, err error) {
|
||||||
|
if len(tokenStr) > maxTokenLength {
|
||||||
|
return "", "", ErrInvalidToken
|
||||||
|
}
|
||||||
|
parser := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name}))
|
||||||
|
token, parseErr := parser.ParseWithClaims(tokenStr, &pendingOAuthClaims{}, func(t *jwt.Token) (any, error) {
|
||||||
|
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||||
|
return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
|
||||||
|
}
|
||||||
|
return []byte(s.cfg.JWT.Secret), nil
|
||||||
|
})
|
||||||
|
if parseErr != nil {
|
||||||
|
return "", "", ErrInvalidToken
|
||||||
|
}
|
||||||
|
claims, ok := token.Claims.(*pendingOAuthClaims)
|
||||||
|
if !ok || !token.Valid {
|
||||||
|
return "", "", ErrInvalidToken
|
||||||
|
}
|
||||||
|
if claims.Purpose != pendingOAuthPurpose {
|
||||||
|
return "", "", ErrInvalidToken
|
||||||
|
}
|
||||||
|
return claims.Email, claims.Username, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *AuthService) assignDefaultSubscriptions(ctx context.Context, userID int64) {
|
func (s *AuthService) assignDefaultSubscriptions(ctx context.Context, userID int64) {
|
||||||
if s.settingService == nil || s.defaultSubAssigner == nil || userID <= 0 {
|
if s.settingService == nil || s.defaultSubAssigner == nil || userID <= 0 {
|
||||||
return
|
return
|
||||||
|
|||||||
146
backend/internal/service/auth_service_pending_oauth_test.go
Normal file
146
backend/internal/service/auth_service_pending_oauth_test.go
Normal file
@@ -0,0 +1,146 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/golang-jwt/jwt/v5"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newAuthServiceForPendingOAuthTest() *AuthService {
|
||||||
|
cfg := &config.Config{
|
||||||
|
JWT: config.JWTConfig{
|
||||||
|
Secret: "test-secret-pending-oauth",
|
||||||
|
ExpireHour: 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return NewAuthService(nil, nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestVerifyPendingOAuthToken_ValidToken 验证正常签发的 pending token 可以被成功解析。
|
||||||
|
func TestVerifyPendingOAuthToken_ValidToken(t *testing.T) {
|
||||||
|
svc := newAuthServiceForPendingOAuthTest()
|
||||||
|
|
||||||
|
token, err := svc.CreatePendingOAuthToken("user@example.com", "alice")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotEmpty(t, token)
|
||||||
|
|
||||||
|
email, username, err := svc.VerifyPendingOAuthToken(token)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "user@example.com", email)
|
||||||
|
require.Equal(t, "alice", username)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestVerifyPendingOAuthToken_RegularJWTRejected 用普通 access token 尝试验证,应返回 ErrInvalidToken。
|
||||||
|
func TestVerifyPendingOAuthToken_RegularJWTRejected(t *testing.T) {
|
||||||
|
svc := newAuthServiceForPendingOAuthTest()
|
||||||
|
|
||||||
|
// 签发一个普通 access token(JWTClaims,无 Purpose 字段)
|
||||||
|
accessToken, err := svc.GenerateToken(&User{
|
||||||
|
ID: 1,
|
||||||
|
Email: "user@example.com",
|
||||||
|
Role: RoleUser,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, _, err = svc.VerifyPendingOAuthToken(accessToken)
|
||||||
|
require.ErrorIs(t, err, ErrInvalidToken)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestVerifyPendingOAuthToken_WrongPurpose 手动构造 purpose 字段不匹配的 JWT,应返回 ErrInvalidToken。
|
||||||
|
func TestVerifyPendingOAuthToken_WrongPurpose(t *testing.T) {
|
||||||
|
svc := newAuthServiceForPendingOAuthTest()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
claims := &pendingOAuthClaims{
|
||||||
|
Email: "user@example.com",
|
||||||
|
Username: "alice",
|
||||||
|
Purpose: "some_other_purpose",
|
||||||
|
RegisteredClaims: jwt.RegisteredClaims{
|
||||||
|
ExpiresAt: jwt.NewNumericDate(now.Add(10 * time.Minute)),
|
||||||
|
IssuedAt: jwt.NewNumericDate(now),
|
||||||
|
NotBefore: jwt.NewNumericDate(now),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||||
|
tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, _, err = svc.VerifyPendingOAuthToken(tokenStr)
|
||||||
|
require.ErrorIs(t, err, ErrInvalidToken)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestVerifyPendingOAuthToken_MissingPurpose 手动构造无 purpose 字段的 JWT(模拟旧 token),应返回 ErrInvalidToken。
|
||||||
|
func TestVerifyPendingOAuthToken_MissingPurpose(t *testing.T) {
|
||||||
|
svc := newAuthServiceForPendingOAuthTest()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
claims := &pendingOAuthClaims{
|
||||||
|
Email: "user@example.com",
|
||||||
|
Username: "alice",
|
||||||
|
Purpose: "", // 旧 token 无此字段,反序列化后为零值
|
||||||
|
RegisteredClaims: jwt.RegisteredClaims{
|
||||||
|
ExpiresAt: jwt.NewNumericDate(now.Add(10 * time.Minute)),
|
||||||
|
IssuedAt: jwt.NewNumericDate(now),
|
||||||
|
NotBefore: jwt.NewNumericDate(now),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||||
|
tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, _, err = svc.VerifyPendingOAuthToken(tokenStr)
|
||||||
|
require.ErrorIs(t, err, ErrInvalidToken)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestVerifyPendingOAuthToken_ExpiredToken 过期 token 应返回 ErrInvalidToken。
|
||||||
|
func TestVerifyPendingOAuthToken_ExpiredToken(t *testing.T) {
|
||||||
|
svc := newAuthServiceForPendingOAuthTest()
|
||||||
|
|
||||||
|
past := time.Now().Add(-1 * time.Hour)
|
||||||
|
claims := &pendingOAuthClaims{
|
||||||
|
Email: "user@example.com",
|
||||||
|
Username: "alice",
|
||||||
|
Purpose: pendingOAuthPurpose,
|
||||||
|
RegisteredClaims: jwt.RegisteredClaims{
|
||||||
|
ExpiresAt: jwt.NewNumericDate(past),
|
||||||
|
IssuedAt: jwt.NewNumericDate(past.Add(-10 * time.Minute)),
|
||||||
|
NotBefore: jwt.NewNumericDate(past.Add(-10 * time.Minute)),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||||
|
tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, _, err = svc.VerifyPendingOAuthToken(tokenStr)
|
||||||
|
require.ErrorIs(t, err, ErrInvalidToken)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestVerifyPendingOAuthToken_WrongSecret 不同密钥签发的 token 应返回 ErrInvalidToken。
|
||||||
|
func TestVerifyPendingOAuthToken_WrongSecret(t *testing.T) {
|
||||||
|
other := NewAuthService(nil, nil, nil, nil, &config.Config{
|
||||||
|
JWT: config.JWTConfig{Secret: "other-secret"},
|
||||||
|
}, nil, nil, nil, nil, nil, nil)
|
||||||
|
|
||||||
|
token, err := other.CreatePendingOAuthToken("user@example.com", "alice")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
svc := newAuthServiceForPendingOAuthTest()
|
||||||
|
_, _, err = svc.VerifyPendingOAuthToken(token)
|
||||||
|
require.ErrorIs(t, err, ErrInvalidToken)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestVerifyPendingOAuthToken_TooLong 超长 token 应返回 ErrInvalidToken。
|
||||||
|
func TestVerifyPendingOAuthToken_TooLong(t *testing.T) {
|
||||||
|
svc := newAuthServiceForPendingOAuthTest()
|
||||||
|
giant := make([]byte, maxTokenLength+1)
|
||||||
|
for i := range giant {
|
||||||
|
giant[i] = 'a'
|
||||||
|
}
|
||||||
|
_, _, err := svc.VerifyPendingOAuthToken(string(giant))
|
||||||
|
require.ErrorIs(t, err, ErrInvalidToken)
|
||||||
|
}
|
||||||
@@ -130,6 +130,7 @@ func newAuthService(repo *userRepoStub, settings map[string]string, emailCache E
|
|||||||
}
|
}
|
||||||
|
|
||||||
return NewAuthService(
|
return NewAuthService(
|
||||||
|
nil, // entClient
|
||||||
repo,
|
repo,
|
||||||
nil, // redeemRepo
|
nil, // redeemRepo
|
||||||
nil, // refreshTokenCache
|
nil, // refreshTokenCache
|
||||||
|
|||||||
@@ -43,6 +43,7 @@ func newAuthServiceForRegisterTurnstileTest(settings map[string]string, verifier
|
|||||||
turnstileService := NewTurnstileService(settingService, verifier)
|
turnstileService := NewTurnstileService(settingService, verifier)
|
||||||
|
|
||||||
return NewAuthService(
|
return NewAuthService(
|
||||||
|
nil, // entClient
|
||||||
&userRepoStub{},
|
&userRepoStub{},
|
||||||
nil, // redeemRepo
|
nil, // redeemRepo
|
||||||
nil, // refreshTokenCache
|
nil, // refreshTokenCache
|
||||||
|
|||||||
@@ -43,6 +43,9 @@ type ConcurrencyCache interface {
|
|||||||
|
|
||||||
// 清理过期槽位(后台任务)
|
// 清理过期槽位(后台任务)
|
||||||
CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error
|
CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error
|
||||||
|
|
||||||
|
// 启动时清理旧进程遗留槽位与等待计数
|
||||||
|
CleanupStaleProcessSlots(ctx context.Context, activeRequestPrefix string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -59,13 +62,22 @@ func initRequestIDPrefix() string {
|
|||||||
return "r" + strconv.FormatUint(fallback, 36)
|
return "r" + strconv.FormatUint(fallback, 36)
|
||||||
}
|
}
|
||||||
|
|
||||||
// generateRequestID generates a unique request ID for concurrency slot tracking.
|
func RequestIDPrefix() string {
|
||||||
// Format: {process_random_prefix}-{base36_counter}
|
return requestIDPrefix
|
||||||
|
}
|
||||||
|
|
||||||
func generateRequestID() string {
|
func generateRequestID() string {
|
||||||
seq := requestIDCounter.Add(1)
|
seq := requestIDCounter.Add(1)
|
||||||
return requestIDPrefix + "-" + strconv.FormatUint(seq, 36)
|
return requestIDPrefix + "-" + strconv.FormatUint(seq, 36)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *ConcurrencyService) CleanupStaleProcessSlots(ctx context.Context) error {
|
||||||
|
if s == nil || s.cache == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return s.cache.CleanupStaleProcessSlots(ctx, RequestIDPrefix())
|
||||||
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// Default extra wait slots beyond concurrency limit
|
// Default extra wait slots beyond concurrency limit
|
||||||
defaultExtraWaitSlots = 20
|
defaultExtraWaitSlots = 20
|
||||||
|
|||||||
@@ -91,6 +91,32 @@ func (c *stubConcurrencyCacheForTest) CleanupExpiredAccountSlots(_ context.Conte
|
|||||||
return c.cleanupErr
|
return c.cleanupErr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *stubConcurrencyCacheForTest) CleanupStaleProcessSlots(_ context.Context, _ string) error {
|
||||||
|
return c.cleanupErr
|
||||||
|
}
|
||||||
|
|
||||||
|
type trackingConcurrencyCache struct {
|
||||||
|
stubConcurrencyCacheForTest
|
||||||
|
cleanupPrefix string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *trackingConcurrencyCache) CleanupStaleProcessSlots(_ context.Context, prefix string) error {
|
||||||
|
c.cleanupPrefix = prefix
|
||||||
|
return c.cleanupErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCleanupStaleProcessSlots_NilCache(t *testing.T) {
|
||||||
|
svc := &ConcurrencyService{cache: nil}
|
||||||
|
require.NoError(t, svc.CleanupStaleProcessSlots(context.Background()))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCleanupStaleProcessSlots_DelegatesPrefix(t *testing.T) {
|
||||||
|
cache := &trackingConcurrencyCache{}
|
||||||
|
svc := NewConcurrencyService(cache)
|
||||||
|
require.NoError(t, svc.CleanupStaleProcessSlots(context.Background()))
|
||||||
|
require.Equal(t, RequestIDPrefix(), cache.cleanupPrefix)
|
||||||
|
}
|
||||||
|
|
||||||
func TestAcquireAccountSlot_Success(t *testing.T) {
|
func TestAcquireAccountSlot_Success(t *testing.T) {
|
||||||
cache := &stubConcurrencyCacheForTest{acquireResult: true}
|
cache := &stubConcurrencyCacheForTest{acquireResult: true}
|
||||||
svc := NewConcurrencyService(cache)
|
svc := NewConcurrencyService(cache)
|
||||||
|
|||||||
@@ -182,6 +182,13 @@ const (
|
|||||||
// SettingKeyRectifierSettings stores JSON config for rectifier settings (thinking signature + budget).
|
// SettingKeyRectifierSettings stores JSON config for rectifier settings (thinking signature + budget).
|
||||||
SettingKeyRectifierSettings = "rectifier_settings"
|
SettingKeyRectifierSettings = "rectifier_settings"
|
||||||
|
|
||||||
|
// =========================
|
||||||
|
// Beta Policy Settings
|
||||||
|
// =========================
|
||||||
|
|
||||||
|
// SettingKeyBetaPolicySettings stores JSON config for beta policy rules.
|
||||||
|
SettingKeyBetaPolicySettings = "beta_policy_settings"
|
||||||
|
|
||||||
// =========================
|
// =========================
|
||||||
// Sora S3 存储配置
|
// Sora S3 存储配置
|
||||||
// =========================
|
// =========================
|
||||||
|
|||||||
@@ -86,10 +86,10 @@ func TestStripBetaTokens(t *testing.T) {
|
|||||||
want: "oauth-2025-04-20,interleaved-thinking-2025-05-14",
|
want: "oauth-2025-04-20,interleaved-thinking-2025-05-14",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "DroppedBetas removes fast-mode only",
|
name: "DroppedBetas is empty (filtering moved to configurable beta policy)",
|
||||||
header: "oauth-2025-04-20,context-1m-2025-08-07,fast-mode-2026-02-01,interleaved-thinking-2025-05-14",
|
header: "oauth-2025-04-20,context-1m-2025-08-07,fast-mode-2026-02-01,interleaved-thinking-2025-05-14",
|
||||||
tokens: claude.DroppedBetas,
|
tokens: claude.DroppedBetas,
|
||||||
want: "oauth-2025-04-20,context-1m-2025-08-07,interleaved-thinking-2025-05-14",
|
want: "oauth-2025-04-20,context-1m-2025-08-07,fast-mode-2026-02-01,interleaved-thinking-2025-05-14",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -114,25 +114,23 @@ func TestMergeAnthropicBetaDropping_Context1M(t *testing.T) {
|
|||||||
func TestMergeAnthropicBetaDropping_DroppedBetas(t *testing.T) {
|
func TestMergeAnthropicBetaDropping_DroppedBetas(t *testing.T) {
|
||||||
required := []string{"oauth-2025-04-20", "interleaved-thinking-2025-05-14"}
|
required := []string{"oauth-2025-04-20", "interleaved-thinking-2025-05-14"}
|
||||||
incoming := "context-1m-2025-08-07,fast-mode-2026-02-01,foo-beta,oauth-2025-04-20"
|
incoming := "context-1m-2025-08-07,fast-mode-2026-02-01,foo-beta,oauth-2025-04-20"
|
||||||
|
// DroppedBetas is now empty — filtering moved to configurable beta policy.
|
||||||
|
// Without a policy filter set, nothing gets dropped from the static set.
|
||||||
drop := droppedBetaSet()
|
drop := droppedBetaSet()
|
||||||
|
|
||||||
got := mergeAnthropicBetaDropping(required, incoming, drop)
|
got := mergeAnthropicBetaDropping(required, incoming, drop)
|
||||||
require.Equal(t, "oauth-2025-04-20,interleaved-thinking-2025-05-14,context-1m-2025-08-07,foo-beta", got)
|
require.Equal(t, "oauth-2025-04-20,interleaved-thinking-2025-05-14,context-1m-2025-08-07,fast-mode-2026-02-01,foo-beta", got)
|
||||||
require.Contains(t, got, "context-1m-2025-08-07")
|
require.Contains(t, got, "context-1m-2025-08-07")
|
||||||
require.NotContains(t, got, "fast-mode-2026-02-01")
|
require.Contains(t, got, "fast-mode-2026-02-01")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDroppedBetaSet(t *testing.T) {
|
func TestDroppedBetaSet(t *testing.T) {
|
||||||
// Base set contains DroppedBetas
|
// Base set contains DroppedBetas (now empty — filtering moved to configurable beta policy)
|
||||||
base := droppedBetaSet()
|
base := droppedBetaSet()
|
||||||
require.NotContains(t, base, claude.BetaContext1M)
|
|
||||||
require.Contains(t, base, claude.BetaFastMode)
|
|
||||||
require.Len(t, base, len(claude.DroppedBetas))
|
require.Len(t, base, len(claude.DroppedBetas))
|
||||||
|
|
||||||
// With extra tokens
|
// With extra tokens
|
||||||
extended := droppedBetaSet(claude.BetaClaudeCode)
|
extended := droppedBetaSet(claude.BetaClaudeCode)
|
||||||
require.NotContains(t, extended, claude.BetaContext1M)
|
|
||||||
require.Contains(t, extended, claude.BetaFastMode)
|
|
||||||
require.Contains(t, extended, claude.BetaClaudeCode)
|
require.Contains(t, extended, claude.BetaClaudeCode)
|
||||||
require.Len(t, extended, len(claude.DroppedBetas)+1)
|
require.Len(t, extended, len(claude.DroppedBetas)+1)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1986,6 +1986,10 @@ func (m *mockConcurrencyCache) CleanupExpiredAccountSlots(ctx context.Context, a
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *mockConcurrencyCache) CleanupStaleProcessSlots(ctx context.Context, activeRequestPrefix string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (m *mockConcurrencyCache) GetUsersLoadBatch(ctx context.Context, users []UserWithConcurrency) (map[int64]*UserLoadInfo, error) {
|
func (m *mockConcurrencyCache) GetUsersLoadBatch(ctx context.Context, users []UserWithConcurrency) (map[int64]*UserLoadInfo, error) {
|
||||||
result := make(map[int64]*UserLoadInfo, len(users))
|
result := make(map[int64]*UserLoadInfo, len(users))
|
||||||
for _, user := range users {
|
for _, user := range users {
|
||||||
|
|||||||
@@ -3948,6 +3948,20 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
return s.forwardAnthropicAPIKeyPassthrough(ctx, c, account, passthroughBody, passthroughModel, parsed.Stream, startTime)
|
return s.forwardAnthropicAPIKeyPassthrough(ctx, c, account, passthroughBody, passthroughModel, parsed.Stream, startTime)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Beta policy: evaluate once; block check + cache filter set for buildUpstreamRequest.
|
||||||
|
// Always overwrite the cache to prevent stale values from a previous retry with a different account.
|
||||||
|
if account.Platform == PlatformAnthropic && c != nil {
|
||||||
|
policy := s.evaluateBetaPolicy(ctx, c.GetHeader("anthropic-beta"), account)
|
||||||
|
if policy.blockErr != nil {
|
||||||
|
return nil, policy.blockErr
|
||||||
|
}
|
||||||
|
filterSet := policy.filterSet
|
||||||
|
if filterSet == nil {
|
||||||
|
filterSet = map[string]struct{}{}
|
||||||
|
}
|
||||||
|
c.Set(betaPolicyFilterSetKey, filterSet)
|
||||||
|
}
|
||||||
|
|
||||||
body := parsed.Body
|
body := parsed.Body
|
||||||
reqModel := parsed.Model
|
reqModel := parsed.Model
|
||||||
reqStream := parsed.Stream
|
reqStream := parsed.Stream
|
||||||
@@ -5133,6 +5147,11 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
|||||||
applyClaudeOAuthHeaderDefaults(req, reqStream)
|
applyClaudeOAuthHeaderDefaults(req, reqStream)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Build effective drop set: merge static defaults with dynamic beta policy filter rules
|
||||||
|
policyFilterSet := s.getBetaPolicyFilterSet(ctx, c, account)
|
||||||
|
effectiveDropSet := mergeDropSets(policyFilterSet)
|
||||||
|
effectiveDropWithClaudeCodeSet := mergeDropSets(policyFilterSet, claude.BetaClaudeCode)
|
||||||
|
|
||||||
// 处理 anthropic-beta header(OAuth 账号需要包含 oauth beta)
|
// 处理 anthropic-beta header(OAuth 账号需要包含 oauth beta)
|
||||||
if tokenType == "oauth" {
|
if tokenType == "oauth" {
|
||||||
if mimicClaudeCode {
|
if mimicClaudeCode {
|
||||||
@@ -5146,17 +5165,22 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
|||||||
// messages requests typically use only oauth + interleaved-thinking.
|
// messages requests typically use only oauth + interleaved-thinking.
|
||||||
// Also drop claude-code beta if a downstream client added it.
|
// Also drop claude-code beta if a downstream client added it.
|
||||||
requiredBetas := []string{claude.BetaOAuth, claude.BetaInterleavedThinking}
|
requiredBetas := []string{claude.BetaOAuth, claude.BetaInterleavedThinking}
|
||||||
req.Header.Set("anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, droppedBetasWithClaudeCodeSet))
|
req.Header.Set("anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, effectiveDropWithClaudeCodeSet))
|
||||||
} else {
|
} else {
|
||||||
// Claude Code 客户端:尽量透传原始 header,仅补齐 oauth beta
|
// Claude Code 客户端:尽量透传原始 header,仅补齐 oauth beta
|
||||||
clientBetaHeader := req.Header.Get("anthropic-beta")
|
clientBetaHeader := req.Header.Get("anthropic-beta")
|
||||||
req.Header.Set("anthropic-beta", stripBetaTokensWithSet(s.getBetaHeader(modelID, clientBetaHeader), defaultDroppedBetasSet))
|
req.Header.Set("anthropic-beta", stripBetaTokensWithSet(s.getBetaHeader(modelID, clientBetaHeader), effectiveDropSet))
|
||||||
}
|
}
|
||||||
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
|
} else {
|
||||||
// API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭)
|
// API-key accounts: apply beta policy filter to strip controlled tokens
|
||||||
if requestNeedsBetaFeatures(body) {
|
if existingBeta := req.Header.Get("anthropic-beta"); existingBeta != "" {
|
||||||
if beta := defaultAPIKeyBetaHeader(body); beta != "" {
|
req.Header.Set("anthropic-beta", stripBetaTokensWithSet(existingBeta, effectiveDropSet))
|
||||||
req.Header.Set("anthropic-beta", beta)
|
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey {
|
||||||
|
// API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭)
|
||||||
|
if requestNeedsBetaFeatures(body) {
|
||||||
|
if beta := defaultAPIKeyBetaHeader(body); beta != "" {
|
||||||
|
req.Header.Set("anthropic-beta", beta)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -5334,6 +5358,104 @@ func stripBetaTokensWithSet(header string, drop map[string]struct{}) string {
|
|||||||
return strings.Join(out, ",")
|
return strings.Join(out, ",")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BetaBlockedError indicates a request was blocked by a beta policy rule.
|
||||||
|
type BetaBlockedError struct {
|
||||||
|
Message string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *BetaBlockedError) Error() string { return e.Message }
|
||||||
|
|
||||||
|
// betaPolicyResult holds the evaluated result of beta policy rules for a single request.
|
||||||
|
type betaPolicyResult struct {
|
||||||
|
blockErr *BetaBlockedError // non-nil if a block rule matched
|
||||||
|
filterSet map[string]struct{} // tokens to filter (may be nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// evaluateBetaPolicy loads settings once and evaluates all rules against the given request.
|
||||||
|
func (s *GatewayService) evaluateBetaPolicy(ctx context.Context, betaHeader string, account *Account) betaPolicyResult {
|
||||||
|
if s.settingService == nil {
|
||||||
|
return betaPolicyResult{}
|
||||||
|
}
|
||||||
|
settings, err := s.settingService.GetBetaPolicySettings(ctx)
|
||||||
|
if err != nil || settings == nil {
|
||||||
|
return betaPolicyResult{}
|
||||||
|
}
|
||||||
|
isOAuth := account.IsOAuth()
|
||||||
|
var result betaPolicyResult
|
||||||
|
for _, rule := range settings.Rules {
|
||||||
|
if !betaPolicyScopeMatches(rule.Scope, isOAuth) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
switch rule.Action {
|
||||||
|
case BetaPolicyActionBlock:
|
||||||
|
if result.blockErr == nil && betaHeader != "" && containsBetaToken(betaHeader, rule.BetaToken) {
|
||||||
|
msg := rule.ErrorMessage
|
||||||
|
if msg == "" {
|
||||||
|
msg = "beta feature " + rule.BetaToken + " is not allowed"
|
||||||
|
}
|
||||||
|
result.blockErr = &BetaBlockedError{Message: msg}
|
||||||
|
}
|
||||||
|
case BetaPolicyActionFilter:
|
||||||
|
if result.filterSet == nil {
|
||||||
|
result.filterSet = make(map[string]struct{})
|
||||||
|
}
|
||||||
|
result.filterSet[rule.BetaToken] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// mergeDropSets merges the static defaultDroppedBetasSet with dynamic policy filter tokens.
|
||||||
|
// Returns defaultDroppedBetasSet directly when policySet is empty (zero allocation).
|
||||||
|
func mergeDropSets(policySet map[string]struct{}, extra ...string) map[string]struct{} {
|
||||||
|
if len(policySet) == 0 && len(extra) == 0 {
|
||||||
|
return defaultDroppedBetasSet
|
||||||
|
}
|
||||||
|
m := make(map[string]struct{}, len(defaultDroppedBetasSet)+len(policySet)+len(extra))
|
||||||
|
for t := range defaultDroppedBetasSet {
|
||||||
|
m[t] = struct{}{}
|
||||||
|
}
|
||||||
|
for t := range policySet {
|
||||||
|
m[t] = struct{}{}
|
||||||
|
}
|
||||||
|
for _, t := range extra {
|
||||||
|
m[t] = struct{}{}
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
// betaPolicyFilterSetKey is the gin.Context key for caching the policy filter set within a request.
|
||||||
|
const betaPolicyFilterSetKey = "betaPolicyFilterSet"
|
||||||
|
|
||||||
|
// getBetaPolicyFilterSet returns the beta policy filter set, using the gin context cache if available.
|
||||||
|
// In the /v1/messages path, Forward() evaluates the policy first and caches the result;
|
||||||
|
// buildUpstreamRequest reuses it (zero extra DB calls). In the count_tokens path, this
|
||||||
|
// evaluates on demand (one DB call).
|
||||||
|
func (s *GatewayService) getBetaPolicyFilterSet(ctx context.Context, c *gin.Context, account *Account) map[string]struct{} {
|
||||||
|
if c != nil {
|
||||||
|
if v, ok := c.Get(betaPolicyFilterSetKey); ok {
|
||||||
|
if fs, ok := v.(map[string]struct{}); ok {
|
||||||
|
return fs
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return s.evaluateBetaPolicy(ctx, "", account).filterSet
|
||||||
|
}
|
||||||
|
|
||||||
|
// betaPolicyScopeMatches checks whether a rule's scope matches the current account type.
|
||||||
|
func betaPolicyScopeMatches(scope string, isOAuth bool) bool {
|
||||||
|
switch scope {
|
||||||
|
case BetaPolicyScopeAll:
|
||||||
|
return true
|
||||||
|
case BetaPolicyScopeOAuth:
|
||||||
|
return isOAuth
|
||||||
|
case BetaPolicyScopeAPIKey:
|
||||||
|
return !isOAuth
|
||||||
|
default:
|
||||||
|
return true // unknown scope → match all (fail-open)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// droppedBetaSet returns claude.DroppedBetas as a set, with optional extra tokens.
|
// droppedBetaSet returns claude.DroppedBetas as a set, with optional extra tokens.
|
||||||
func droppedBetaSet(extra ...string) map[string]struct{} {
|
func droppedBetaSet(extra ...string) map[string]struct{} {
|
||||||
m := make(map[string]struct{}, len(defaultDroppedBetasSet)+len(extra))
|
m := make(map[string]struct{}, len(defaultDroppedBetasSet)+len(extra))
|
||||||
@@ -5370,10 +5492,7 @@ func buildBetaTokenSet(tokens []string) map[string]struct{} {
|
|||||||
return m
|
return m
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var defaultDroppedBetasSet = buildBetaTokenSet(claude.DroppedBetas)
|
||||||
defaultDroppedBetasSet = buildBetaTokenSet(claude.DroppedBetas)
|
|
||||||
droppedBetasWithClaudeCodeSet = droppedBetaSet(claude.BetaClaudeCode)
|
|
||||||
)
|
|
||||||
|
|
||||||
// applyClaudeCodeMimicHeaders forces "Claude Code-like" request headers.
|
// applyClaudeCodeMimicHeaders forces "Claude Code-like" request headers.
|
||||||
// This mirrors opencode-anthropic-auth behavior: do not trust downstream
|
// This mirrors opencode-anthropic-auth behavior: do not trust downstream
|
||||||
@@ -7311,6 +7430,9 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
|
|||||||
applyClaudeOAuthHeaderDefaults(req, false)
|
applyClaudeOAuthHeaderDefaults(req, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Build effective drop set for count_tokens: merge static defaults with dynamic beta policy filter rules
|
||||||
|
ctEffectiveDropSet := mergeDropSets(s.getBetaPolicyFilterSet(ctx, c, account))
|
||||||
|
|
||||||
// OAuth 账号:处理 anthropic-beta header
|
// OAuth 账号:处理 anthropic-beta header
|
||||||
if tokenType == "oauth" {
|
if tokenType == "oauth" {
|
||||||
if mimicClaudeCode {
|
if mimicClaudeCode {
|
||||||
@@ -7318,8 +7440,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
|
|||||||
|
|
||||||
incomingBeta := req.Header.Get("anthropic-beta")
|
incomingBeta := req.Header.Get("anthropic-beta")
|
||||||
requiredBetas := []string{claude.BetaClaudeCode, claude.BetaOAuth, claude.BetaInterleavedThinking, claude.BetaTokenCounting}
|
requiredBetas := []string{claude.BetaClaudeCode, claude.BetaOAuth, claude.BetaInterleavedThinking, claude.BetaTokenCounting}
|
||||||
drop := droppedBetaSet()
|
req.Header.Set("anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, ctEffectiveDropSet))
|
||||||
req.Header.Set("anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, drop))
|
|
||||||
} else {
|
} else {
|
||||||
clientBetaHeader := req.Header.Get("anthropic-beta")
|
clientBetaHeader := req.Header.Get("anthropic-beta")
|
||||||
if clientBetaHeader == "" {
|
if clientBetaHeader == "" {
|
||||||
@@ -7329,14 +7450,19 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
|
|||||||
if !strings.Contains(beta, claude.BetaTokenCounting) {
|
if !strings.Contains(beta, claude.BetaTokenCounting) {
|
||||||
beta = beta + "," + claude.BetaTokenCounting
|
beta = beta + "," + claude.BetaTokenCounting
|
||||||
}
|
}
|
||||||
req.Header.Set("anthropic-beta", stripBetaTokensWithSet(beta, defaultDroppedBetasSet))
|
req.Header.Set("anthropic-beta", stripBetaTokensWithSet(beta, ctEffectiveDropSet))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
|
} else {
|
||||||
// API-key:与 messages 同步的按需 beta 注入(默认关闭)
|
// API-key accounts: apply beta policy filter to strip controlled tokens
|
||||||
if requestNeedsBetaFeatures(body) {
|
if existingBeta := req.Header.Get("anthropic-beta"); existingBeta != "" {
|
||||||
if beta := defaultAPIKeyBetaHeader(body); beta != "" {
|
req.Header.Set("anthropic-beta", stripBetaTokensWithSet(existingBeta, ctEffectiveDropSet))
|
||||||
req.Header.Set("anthropic-beta", beta)
|
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey {
|
||||||
|
// API-key:与 messages 同步的按需 beta 注入(默认关闭)
|
||||||
|
if requestNeedsBetaFeatures(body) {
|
||||||
|
if beta := defaultAPIKeyBetaHeader(body); beta != "" {
|
||||||
|
req.Header.Set("anthropic-beta", beta)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package service
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
@@ -140,12 +141,13 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
|
|||||||
|
|
||||||
// 8. Handle error response with failover
|
// 8. Handle error response with failover
|
||||||
if resp.StatusCode >= 400 {
|
if resp.StatusCode >= 400 {
|
||||||
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
_ = resp.Body.Close()
|
||||||
_ = resp.Body.Close()
|
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||||||
|
|
||||||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
|
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
|
||||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||||
|
if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) {
|
||||||
upstreamDetail := ""
|
upstreamDetail := ""
|
||||||
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||||
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||||
@@ -167,7 +169,11 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
|
|||||||
if s.rateLimitService != nil {
|
if s.rateLimitService != nil {
|
||||||
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||||
}
|
}
|
||||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
|
return nil, &UpstreamFailoverError{
|
||||||
|
StatusCode: resp.StatusCode,
|
||||||
|
ResponseBody: respBody,
|
||||||
|
RetryableOnSameAccount: account.IsPoolMode() && isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
// Non-failover error: return Anthropic-formatted error to client
|
// Non-failover error: return Anthropic-formatted error to client
|
||||||
return s.handleAnthropicErrorResponse(resp, c, account)
|
return s.handleAnthropicErrorResponse(resp, c, account)
|
||||||
@@ -279,7 +285,11 @@ func (s *OpenAIGatewayService) handleAnthropicBufferedStreamingResponse(
|
|||||||
requestID := resp.Header.Get("x-request-id")
|
requestID := resp.Header.Get("x-request-id")
|
||||||
|
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
|
maxLineSize := defaultMaxLineSize
|
||||||
|
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
|
||||||
|
maxLineSize = s.cfg.Gateway.MaxLineSize
|
||||||
|
}
|
||||||
|
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
|
||||||
|
|
||||||
var finalResponse *apicompat.ResponsesResponse
|
var finalResponse *apicompat.ResponsesResponse
|
||||||
var usage OpenAIUsage
|
var usage OpenAIUsage
|
||||||
@@ -378,7 +388,11 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
|
|||||||
firstChunk := true
|
firstChunk := true
|
||||||
|
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
|
maxLineSize := defaultMaxLineSize
|
||||||
|
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
|
||||||
|
maxLineSize = s.cfg.Gateway.MaxLineSize
|
||||||
|
}
|
||||||
|
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
|
||||||
|
|
||||||
// resultWithUsage builds the final result snapshot.
|
// resultWithUsage builds the final result snapshot.
|
||||||
resultWithUsage := func() *OpenAIForwardResult {
|
resultWithUsage := func() *OpenAIForwardResult {
|
||||||
|
|||||||
@@ -911,6 +911,36 @@ func isOpenAIInstructionsRequiredError(upstreamStatusCode int, upstreamMsg strin
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isOpenAITransientProcessingError(upstreamStatusCode int, upstreamMsg string, upstreamBody []byte) bool {
|
||||||
|
if upstreamStatusCode != http.StatusBadRequest {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
match := func(text string) bool {
|
||||||
|
lower := strings.ToLower(strings.TrimSpace(text))
|
||||||
|
if lower == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if strings.Contains(lower, "an error occurred while processing your request") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return strings.Contains(lower, "you can retry your request") &&
|
||||||
|
strings.Contains(lower, "help.openai.com") &&
|
||||||
|
strings.Contains(lower, "request id")
|
||||||
|
}
|
||||||
|
|
||||||
|
if match(upstreamMsg) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if len(upstreamBody) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if match(gjson.GetBytes(upstreamBody, "error.message").String()) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return match(string(upstreamBody))
|
||||||
|
}
|
||||||
|
|
||||||
// ExtractSessionID extracts the raw session ID from headers or body without hashing.
|
// ExtractSessionID extracts the raw session ID from headers or body without hashing.
|
||||||
// Used by ForwardAsAnthropic to pass as prompt_cache_key for upstream cache.
|
// Used by ForwardAsAnthropic to pass as prompt_cache_key for upstream cache.
|
||||||
func (s *OpenAIGatewayService) ExtractSessionID(c *gin.Context, body []byte) string {
|
func (s *OpenAIGatewayService) ExtractSessionID(c *gin.Context, body []byte) string {
|
||||||
@@ -1518,6 +1548,13 @@ func (s *OpenAIGatewayService) shouldFailoverUpstreamError(statusCode int) bool
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIGatewayService) shouldFailoverOpenAIUpstreamResponse(statusCode int, upstreamMsg string, upstreamBody []byte) bool {
|
||||||
|
if s.shouldFailoverUpstreamError(statusCode) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return isOpenAITransientProcessingError(statusCode, upstreamMsg, upstreamBody)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *OpenAIGatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account) {
|
func (s *OpenAIGatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account) {
|
||||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||||
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
||||||
@@ -2016,13 +2053,13 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
|||||||
|
|
||||||
// Handle error response
|
// Handle error response
|
||||||
if resp.StatusCode >= 400 {
|
if resp.StatusCode >= 400 {
|
||||||
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
_ = resp.Body.Close()
|
||||||
_ = resp.Body.Close()
|
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||||||
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
|
||||||
|
|
||||||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
|
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
|
||||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||||
|
if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) {
|
||||||
upstreamDetail := ""
|
upstreamDetail := ""
|
||||||
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||||
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||||
@@ -2046,7 +2083,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
|||||||
return nil, &UpstreamFailoverError{
|
return nil, &UpstreamFailoverError{
|
||||||
StatusCode: resp.StatusCode,
|
StatusCode: resp.StatusCode,
|
||||||
ResponseBody: respBody,
|
ResponseBody: respBody,
|
||||||
RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode),
|
RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return s.handleErrorResponse(ctx, resp, c, account, body)
|
return s.handleErrorResponse(ctx, resp, c, account, body)
|
||||||
|
|||||||
@@ -211,6 +211,26 @@ func TestLogOpenAIInstructionsRequiredDebug_NonTargetErrorSkipped(t *testing.T)
|
|||||||
require.False(t, logSink.ContainsMessage("OpenAI 上游返回 Instructions are required,已记录请求详情用于排查"))
|
require.False(t, logSink.ContainsMessage("OpenAI 上游返回 Instructions are required,已记录请求详情用于排查"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestIsOpenAITransientProcessingError(t *testing.T) {
|
||||||
|
require.True(t, isOpenAITransientProcessingError(
|
||||||
|
http.StatusBadRequest,
|
||||||
|
"An error occurred while processing your request.",
|
||||||
|
nil,
|
||||||
|
))
|
||||||
|
|
||||||
|
require.True(t, isOpenAITransientProcessingError(
|
||||||
|
http.StatusBadRequest,
|
||||||
|
"",
|
||||||
|
[]byte(`{"error":{"message":"An error occurred while processing your request. You can retry your request, or contact us through our help center at help.openai.com if the error persists. Please include the request ID req_123 in your message."}}`),
|
||||||
|
))
|
||||||
|
|
||||||
|
require.False(t, isOpenAITransientProcessingError(
|
||||||
|
http.StatusBadRequest,
|
||||||
|
"Missing required parameter: 'instructions'",
|
||||||
|
[]byte(`{"error":{"message":"Missing required parameter: 'instructions'"}}`),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
func TestOpenAIGatewayService_Forward_LogsInstructionsRequiredDetails(t *testing.T) {
|
func TestOpenAIGatewayService_Forward_LogsInstructionsRequiredDetails(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
logSink, restore := captureStructuredLog(t)
|
logSink, restore := captureStructuredLog(t)
|
||||||
@@ -264,3 +284,51 @@ func TestOpenAIGatewayService_Forward_LogsInstructionsRequiredDetails(t *testing
|
|||||||
require.True(t, logSink.ContainsField("request_body_size"))
|
require.True(t, logSink.ContainsField("request_body_size"))
|
||||||
require.False(t, logSink.ContainsField("request_body_preview"))
|
require.False(t, logSink.ContainsField("request_body_preview"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayService_Forward_TransientProcessingErrorTriggersFailover(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil))
|
||||||
|
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0")
|
||||||
|
c.Request.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
upstream := &httpUpstreamRecorder{
|
||||||
|
resp: &http.Response{
|
||||||
|
StatusCode: http.StatusBadRequest,
|
||||||
|
Header: http.Header{
|
||||||
|
"Content-Type": []string{"application/json"},
|
||||||
|
"x-request-id": []string{"rid-processing-400"},
|
||||||
|
},
|
||||||
|
Body: io.NopCloser(strings.NewReader(`{"error":{"message":"An error occurred while processing your request. You can retry your request, or contact us through our help center at help.openai.com if the error persists. Please include the request ID req_123 in your message.","type":"invalid_request_error"}}`)),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := &OpenAIGatewayService{
|
||||||
|
cfg: &config.Config{
|
||||||
|
Gateway: config.GatewayConfig{ForceCodexCLI: false},
|
||||||
|
},
|
||||||
|
httpUpstream: upstream,
|
||||||
|
}
|
||||||
|
account := &Account{
|
||||||
|
ID: 1001,
|
||||||
|
Name: "codex max套餐",
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{"api_key": "sk-test"},
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
RateMultiplier: f64p(1),
|
||||||
|
}
|
||||||
|
body := []byte(`{"model":"gpt-5.1-codex","stream":false,"input":[{"type":"text","text":"hello"}]}`)
|
||||||
|
|
||||||
|
_, err := svc.Forward(context.Background(), c, account, body)
|
||||||
|
require.Error(t, err)
|
||||||
|
|
||||||
|
var failoverErr *UpstreamFailoverError
|
||||||
|
require.ErrorAs(t, err, &failoverErr)
|
||||||
|
require.Equal(t, http.StatusBadRequest, failoverErr.StatusCode)
|
||||||
|
require.Contains(t, string(failoverErr.ResponseBody), "An error occurred while processing your request")
|
||||||
|
require.False(t, c.Writer.Written(), "service 层应返回 failover 错误给上层换号,而不是直接向客户端写响应")
|
||||||
|
}
|
||||||
|
|||||||
@@ -130,6 +130,7 @@ type OpenAITokenInfo struct {
|
|||||||
ChatGPTAccountID string `json:"chatgpt_account_id,omitempty"`
|
ChatGPTAccountID string `json:"chatgpt_account_id,omitempty"`
|
||||||
ChatGPTUserID string `json:"chatgpt_user_id,omitempty"`
|
ChatGPTUserID string `json:"chatgpt_user_id,omitempty"`
|
||||||
OrganizationID string `json:"organization_id,omitempty"`
|
OrganizationID string `json:"organization_id,omitempty"`
|
||||||
|
PlanType string `json:"plan_type,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ExchangeCode exchanges authorization code for tokens
|
// ExchangeCode exchanges authorization code for tokens
|
||||||
@@ -202,6 +203,7 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch
|
|||||||
tokenInfo.ChatGPTAccountID = userInfo.ChatGPTAccountID
|
tokenInfo.ChatGPTAccountID = userInfo.ChatGPTAccountID
|
||||||
tokenInfo.ChatGPTUserID = userInfo.ChatGPTUserID
|
tokenInfo.ChatGPTUserID = userInfo.ChatGPTUserID
|
||||||
tokenInfo.OrganizationID = userInfo.OrganizationID
|
tokenInfo.OrganizationID = userInfo.OrganizationID
|
||||||
|
tokenInfo.PlanType = userInfo.PlanType
|
||||||
}
|
}
|
||||||
|
|
||||||
return tokenInfo, nil
|
return tokenInfo, nil
|
||||||
@@ -246,6 +248,7 @@ func (s *OpenAIOAuthService) RefreshTokenWithClientID(ctx context.Context, refre
|
|||||||
tokenInfo.ChatGPTAccountID = userInfo.ChatGPTAccountID
|
tokenInfo.ChatGPTAccountID = userInfo.ChatGPTAccountID
|
||||||
tokenInfo.ChatGPTUserID = userInfo.ChatGPTUserID
|
tokenInfo.ChatGPTUserID = userInfo.ChatGPTUserID
|
||||||
tokenInfo.OrganizationID = userInfo.OrganizationID
|
tokenInfo.OrganizationID = userInfo.OrganizationID
|
||||||
|
tokenInfo.PlanType = userInfo.PlanType
|
||||||
}
|
}
|
||||||
|
|
||||||
return tokenInfo, nil
|
return tokenInfo, nil
|
||||||
@@ -510,6 +513,9 @@ func (s *OpenAIOAuthService) BuildAccountCredentials(tokenInfo *OpenAITokenInfo)
|
|||||||
if tokenInfo.OrganizationID != "" {
|
if tokenInfo.OrganizationID != "" {
|
||||||
creds["organization_id"] = tokenInfo.OrganizationID
|
creds["organization_id"] = tokenInfo.OrganizationID
|
||||||
}
|
}
|
||||||
|
if tokenInfo.PlanType != "" {
|
||||||
|
creds["plan_type"] = tokenInfo.PlanType
|
||||||
|
}
|
||||||
if strings.TrimSpace(tokenInfo.ClientID) != "" {
|
if strings.TrimSpace(tokenInfo.ClientID) != "" {
|
||||||
creds["client_id"] = strings.TrimSpace(tokenInfo.ClientID)
|
creds["client_id"] = strings.TrimSpace(tokenInfo.ClientID)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1247,6 +1247,60 @@ func (s *SettingService) IsBudgetRectifierEnabled(ctx context.Context) bool {
|
|||||||
return settings.Enabled && settings.ThinkingBudgetEnabled
|
return settings.Enabled && settings.ThinkingBudgetEnabled
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetBetaPolicySettings 获取 Beta 策略配置
|
||||||
|
func (s *SettingService) GetBetaPolicySettings(ctx context.Context) (*BetaPolicySettings, error) {
|
||||||
|
value, err := s.settingRepo.GetValue(ctx, SettingKeyBetaPolicySettings)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, ErrSettingNotFound) {
|
||||||
|
return DefaultBetaPolicySettings(), nil
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("get beta policy settings: %w", err)
|
||||||
|
}
|
||||||
|
if value == "" {
|
||||||
|
return DefaultBetaPolicySettings(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var settings BetaPolicySettings
|
||||||
|
if err := json.Unmarshal([]byte(value), &settings); err != nil {
|
||||||
|
return DefaultBetaPolicySettings(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return &settings, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetBetaPolicySettings 设置 Beta 策略配置
|
||||||
|
func (s *SettingService) SetBetaPolicySettings(ctx context.Context, settings *BetaPolicySettings) error {
|
||||||
|
if settings == nil {
|
||||||
|
return fmt.Errorf("settings cannot be nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
validActions := map[string]bool{
|
||||||
|
BetaPolicyActionPass: true, BetaPolicyActionFilter: true, BetaPolicyActionBlock: true,
|
||||||
|
}
|
||||||
|
validScopes := map[string]bool{
|
||||||
|
BetaPolicyScopeAll: true, BetaPolicyScopeOAuth: true, BetaPolicyScopeAPIKey: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, rule := range settings.Rules {
|
||||||
|
if rule.BetaToken == "" {
|
||||||
|
return fmt.Errorf("rule[%d]: beta_token cannot be empty", i)
|
||||||
|
}
|
||||||
|
if !validActions[rule.Action] {
|
||||||
|
return fmt.Errorf("rule[%d]: invalid action %q", i, rule.Action)
|
||||||
|
}
|
||||||
|
if !validScopes[rule.Scope] {
|
||||||
|
return fmt.Errorf("rule[%d]: invalid scope %q", i, rule.Scope)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := json.Marshal(settings)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("marshal beta policy settings: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.settingRepo.Set(ctx, SettingKeyBetaPolicySettings, string(data))
|
||||||
|
}
|
||||||
|
|
||||||
// SetStreamTimeoutSettings 设置流超时处理配置
|
// SetStreamTimeoutSettings 设置流超时处理配置
|
||||||
func (s *SettingService) SetStreamTimeoutSettings(ctx context.Context, settings *StreamTimeoutSettings) error {
|
func (s *SettingService) SetStreamTimeoutSettings(ctx context.Context, settings *StreamTimeoutSettings) error {
|
||||||
if settings == nil {
|
if settings == nil {
|
||||||
|
|||||||
@@ -191,3 +191,45 @@ func DefaultRectifierSettings() *RectifierSettings {
|
|||||||
ThinkingBudgetEnabled: true,
|
ThinkingBudgetEnabled: true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Beta Policy 策略常量
|
||||||
|
const (
|
||||||
|
BetaPolicyActionPass = "pass" // 透传,不做任何处理
|
||||||
|
BetaPolicyActionFilter = "filter" // 过滤,从 beta header 中移除该 token
|
||||||
|
BetaPolicyActionBlock = "block" // 拦截,直接返回错误
|
||||||
|
|
||||||
|
BetaPolicyScopeAll = "all" // 所有账号类型
|
||||||
|
BetaPolicyScopeOAuth = "oauth" // 仅 OAuth 账号
|
||||||
|
BetaPolicyScopeAPIKey = "apikey" // 仅 API Key 账号
|
||||||
|
)
|
||||||
|
|
||||||
|
// BetaPolicyRule 单条 Beta 策略规则
|
||||||
|
type BetaPolicyRule struct {
|
||||||
|
BetaToken string `json:"beta_token"` // beta token 值
|
||||||
|
Action string `json:"action"` // "pass" | "filter" | "block"
|
||||||
|
Scope string `json:"scope"` // "all" | "oauth" | "apikey"
|
||||||
|
ErrorMessage string `json:"error_message,omitempty"` // 自定义错误消息 (action=block 时生效)
|
||||||
|
}
|
||||||
|
|
||||||
|
// BetaPolicySettings Beta 策略配置
|
||||||
|
type BetaPolicySettings struct {
|
||||||
|
Rules []BetaPolicyRule `json:"rules"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultBetaPolicySettings 返回默认的 Beta 策略配置
|
||||||
|
func DefaultBetaPolicySettings() *BetaPolicySettings {
|
||||||
|
return &BetaPolicySettings{
|
||||||
|
Rules: []BetaPolicyRule{
|
||||||
|
{
|
||||||
|
BetaToken: "fast-mode-2026-02-01",
|
||||||
|
Action: BetaPolicyActionFilter,
|
||||||
|
Scope: BetaPolicyScopeAll,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
BetaToken: "context-1m-2025-08-07",
|
||||||
|
Action: BetaPolicyActionFilter,
|
||||||
|
Scope: BetaPolicyScopeAll,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -105,6 +105,9 @@ func ProvideDeferredService(accountRepo AccountRepository, timingWheel *TimingWh
|
|||||||
// ProvideConcurrencyService creates ConcurrencyService and starts slot cleanup worker.
|
// ProvideConcurrencyService creates ConcurrencyService and starts slot cleanup worker.
|
||||||
func ProvideConcurrencyService(cache ConcurrencyCache, accountRepo AccountRepository, cfg *config.Config) *ConcurrencyService {
|
func ProvideConcurrencyService(cache ConcurrencyCache, accountRepo AccountRepository, cfg *config.Config) *ConcurrencyService {
|
||||||
svc := NewConcurrencyService(cache)
|
svc := NewConcurrencyService(cache)
|
||||||
|
if err := svc.CleanupStaleProcessSlots(context.Background()); err != nil {
|
||||||
|
logger.LegacyPrintf("service.concurrency", "Warning: startup cleanup stale process slots failed: %v", err)
|
||||||
|
}
|
||||||
if cfg != nil {
|
if cfg != nil {
|
||||||
svc.StartSlotCleanupWorker(accountRepo, cfg.Gateway.Scheduling.SlotCleanupInterval)
|
svc.StartSlotCleanupWorker(accountRepo, cfg.Gateway.Scheduling.SlotCleanupInterval)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -76,6 +76,9 @@ func (c StubConcurrencyCache) GetAccountConcurrencyBatch(_ context.Context, acco
|
|||||||
func (c StubConcurrencyCache) CleanupExpiredAccountSlots(_ context.Context, _ int64) error {
|
func (c StubConcurrencyCache) CleanupExpiredAccountSlots(_ context.Context, _ int64) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
func (c StubConcurrencyCache) CleanupStaleProcessSlots(_ context.Context, _ string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// ============================================================
|
// ============================================================
|
||||||
// StubGatewayCache — service.GatewayCache 的空实现
|
// StubGatewayCache — service.GatewayCache 的空实现
|
||||||
|
|||||||
@@ -8,7 +8,7 @@
|
|||||||
# - Creates necessary data directories
|
# - Creates necessary data directories
|
||||||
#
|
#
|
||||||
# After running this script, you can start services with:
|
# After running this script, you can start services with:
|
||||||
# docker-compose -f docker-compose.local.yml up -d
|
# docker-compose up -d
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
set -e
|
set -e
|
||||||
@@ -65,7 +65,7 @@ main() {
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
# Check if deployment already exists
|
# Check if deployment already exists
|
||||||
if [ -f "docker-compose.local.yml" ] && [ -f ".env" ]; then
|
if [ -f "docker-compose.yml" ] && [ -f ".env" ]; then
|
||||||
print_warning "Deployment files already exist in current directory."
|
print_warning "Deployment files already exist in current directory."
|
||||||
read -p "Overwrite existing files? (y/N): " -r
|
read -p "Overwrite existing files? (y/N): " -r
|
||||||
echo
|
echo
|
||||||
@@ -75,17 +75,17 @@ main() {
|
|||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Download docker-compose.local.yml
|
# Download docker-compose.local.yml and save as docker-compose.yml
|
||||||
print_info "Downloading docker-compose.local.yml..."
|
print_info "Downloading docker-compose.yml..."
|
||||||
if command_exists curl; then
|
if command_exists curl; then
|
||||||
curl -sSL "${GITHUB_RAW_URL}/docker-compose.local.yml" -o docker-compose.local.yml
|
curl -sSL "${GITHUB_RAW_URL}/docker-compose.local.yml" -o docker-compose.yml
|
||||||
elif command_exists wget; then
|
elif command_exists wget; then
|
||||||
wget -q "${GITHUB_RAW_URL}/docker-compose.local.yml" -O docker-compose.local.yml
|
wget -q "${GITHUB_RAW_URL}/docker-compose.local.yml" -O docker-compose.yml
|
||||||
else
|
else
|
||||||
print_error "Neither curl nor wget is installed. Please install one of them."
|
print_error "Neither curl nor wget is installed. Please install one of them."
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
print_success "Downloaded docker-compose.local.yml"
|
print_success "Downloaded docker-compose.yml"
|
||||||
|
|
||||||
# Download .env.example
|
# Download .env.example
|
||||||
print_info "Downloading .env.example..."
|
print_info "Downloading .env.example..."
|
||||||
@@ -144,7 +144,7 @@ main() {
|
|||||||
print_warning "Please keep them secure and do not share publicly!"
|
print_warning "Please keep them secure and do not share publicly!"
|
||||||
echo ""
|
echo ""
|
||||||
echo "Directory structure:"
|
echo "Directory structure:"
|
||||||
echo " docker-compose.local.yml - Docker Compose configuration"
|
echo " docker-compose.yml - Docker Compose configuration"
|
||||||
echo " .env - Environment variables (generated secrets)"
|
echo " .env - Environment variables (generated secrets)"
|
||||||
echo " .env.example - Example template (for reference)"
|
echo " .env.example - Example template (for reference)"
|
||||||
echo " data/ - Application data (will be created on first run)"
|
echo " data/ - Application data (will be created on first run)"
|
||||||
@@ -154,10 +154,10 @@ main() {
|
|||||||
echo "Next steps:"
|
echo "Next steps:"
|
||||||
echo " 1. (Optional) Edit .env to customize configuration"
|
echo " 1. (Optional) Edit .env to customize configuration"
|
||||||
echo " 2. Start services:"
|
echo " 2. Start services:"
|
||||||
echo " docker-compose -f docker-compose.local.yml up -d"
|
echo " docker-compose up -d"
|
||||||
echo ""
|
echo ""
|
||||||
echo " 3. View logs:"
|
echo " 3. View logs:"
|
||||||
echo " docker-compose -f docker-compose.local.yml logs -f sub2api"
|
echo " docker-compose logs -f sub2api"
|
||||||
echo ""
|
echo ""
|
||||||
echo " 4. Access Web UI:"
|
echo " 4. Access Web UI:"
|
||||||
echo " http://localhost:8080"
|
echo " http://localhost:8080"
|
||||||
|
|||||||
@@ -99,16 +99,17 @@ curl -X POST "${BASE}/api/v1/admin/users/123/balance" \
|
|||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
### 4) 购买页 URL Query 透传(iframe / 新窗口一致)
|
### 4) 购买页 / 自定义页面 URL Query 透传(iframe / 新窗口一致)
|
||||||
当 Sub2API 打开 `purchase_subscription_url` 时,会统一追加:
|
当 Sub2API 打开 `purchase_subscription_url` 或用户侧自定义页面 iframe URL 时,会统一追加:
|
||||||
- `user_id`
|
- `user_id`
|
||||||
- `token`
|
- `token`
|
||||||
- `theme`(`light` / `dark`)
|
- `theme`(`light` / `dark`)
|
||||||
|
- `lang`(例如 `zh` / `en`,用于向嵌入页传递当前界面语言)
|
||||||
- `ui_mode`(固定 `embedded`)
|
- `ui_mode`(固定 `embedded`)
|
||||||
|
|
||||||
示例:
|
示例:
|
||||||
```text
|
```text
|
||||||
https://pay.example.com/pay?user_id=123&token=<jwt>&theme=light&ui_mode=embedded
|
https://pay.example.com/pay?user_id=123&token=<jwt>&theme=light&lang=zh&ui_mode=embedded
|
||||||
```
|
```
|
||||||
|
|
||||||
### 5) 失败处理建议
|
### 5) 失败处理建议
|
||||||
@@ -218,16 +219,17 @@ curl -X POST "${BASE}/api/v1/admin/users/123/balance" \
|
|||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
### 4) Purchase URL query forwarding (iframe and new tab)
|
### 4) Purchase / Custom Page URL query forwarding (iframe and new tab)
|
||||||
When Sub2API opens `purchase_subscription_url`, it appends:
|
When Sub2API opens `purchase_subscription_url` or a user-facing custom page iframe URL, it appends:
|
||||||
- `user_id`
|
- `user_id`
|
||||||
- `token`
|
- `token`
|
||||||
- `theme` (`light` / `dark`)
|
- `theme` (`light` / `dark`)
|
||||||
|
- `lang` (for example `zh` / `en`, used to pass the current UI language to the embedded page)
|
||||||
- `ui_mode` (fixed: `embedded`)
|
- `ui_mode` (fixed: `embedded`)
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
```text
|
```text
|
||||||
https://pay.example.com/pay?user_id=123&token=<jwt>&theme=light&ui_mode=embedded
|
https://pay.example.com/pay?user_id=123&token=<jwt>&theme=light&lang=zh&ui_mode=embedded
|
||||||
```
|
```
|
||||||
|
|
||||||
### 5) Failure handling recommendations
|
### 5) Failure handling recommendations
|
||||||
|
|||||||
@@ -581,6 +581,43 @@ export async function validateSoraSessionToken(
|
|||||||
return data
|
return data
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Batch operation result type
|
||||||
|
*/
|
||||||
|
export interface BatchOperationResult {
|
||||||
|
total: number
|
||||||
|
success: number
|
||||||
|
failed: number
|
||||||
|
errors?: Array<{ account_id: number; error: string }>
|
||||||
|
warnings?: Array<{ account_id: number; warning: string }>
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Batch clear account errors
|
||||||
|
* @param accountIds - Array of account IDs
|
||||||
|
* @returns Batch operation result
|
||||||
|
*/
|
||||||
|
export async function batchClearError(accountIds: number[]): Promise<BatchOperationResult> {
|
||||||
|
const { data } = await apiClient.post<BatchOperationResult>('/admin/accounts/batch-clear-error', {
|
||||||
|
account_ids: accountIds
|
||||||
|
})
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Batch refresh account credentials
|
||||||
|
* @param accountIds - Array of account IDs
|
||||||
|
* @returns Batch operation result
|
||||||
|
*/
|
||||||
|
export async function batchRefresh(accountIds: number[]): Promise<BatchOperationResult> {
|
||||||
|
const { data } = await apiClient.post<BatchOperationResult>('/admin/accounts/batch-refresh', {
|
||||||
|
account_ids: accountIds,
|
||||||
|
}, {
|
||||||
|
timeout: 120000 // 120s timeout for large batch refreshes
|
||||||
|
})
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
export const accountsAPI = {
|
export const accountsAPI = {
|
||||||
list,
|
list,
|
||||||
listWithEtag,
|
listWithEtag,
|
||||||
@@ -615,7 +652,9 @@ export const accountsAPI = {
|
|||||||
syncFromCrs,
|
syncFromCrs,
|
||||||
exportData,
|
exportData,
|
||||||
importData,
|
importData,
|
||||||
getAntigravityDefaultModelMapping
|
getAntigravityDefaultModelMapping,
|
||||||
|
batchClearError,
|
||||||
|
batchRefresh
|
||||||
}
|
}
|
||||||
|
|
||||||
export default accountsAPI
|
export default accountsAPI
|
||||||
|
|||||||
@@ -308,6 +308,49 @@ export async function updateRectifierSettings(
|
|||||||
return data
|
return data
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ==================== Beta Policy Settings ====================
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Beta policy rule interface
|
||||||
|
*/
|
||||||
|
export interface BetaPolicyRule {
|
||||||
|
beta_token: string
|
||||||
|
action: 'pass' | 'filter' | 'block'
|
||||||
|
scope: 'all' | 'oauth' | 'apikey'
|
||||||
|
error_message?: string
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Beta policy settings interface
|
||||||
|
*/
|
||||||
|
export interface BetaPolicySettings {
|
||||||
|
rules: BetaPolicyRule[]
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get beta policy settings
|
||||||
|
* @returns Beta policy settings
|
||||||
|
*/
|
||||||
|
export async function getBetaPolicySettings(): Promise<BetaPolicySettings> {
|
||||||
|
const { data } = await apiClient.get<BetaPolicySettings>('/admin/settings/beta-policy')
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Update beta policy settings
|
||||||
|
* @param settings - Beta policy settings to update
|
||||||
|
* @returns Updated settings
|
||||||
|
*/
|
||||||
|
export async function updateBetaPolicySettings(
|
||||||
|
settings: BetaPolicySettings
|
||||||
|
): Promise<BetaPolicySettings> {
|
||||||
|
const { data } = await apiClient.put<BetaPolicySettings>(
|
||||||
|
'/admin/settings/beta-policy',
|
||||||
|
settings
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
// ==================== Sora S3 Settings ====================
|
// ==================== Sora S3 Settings ====================
|
||||||
|
|
||||||
export interface SoraS3Settings {
|
export interface SoraS3Settings {
|
||||||
@@ -456,6 +499,8 @@ export const settingsAPI = {
|
|||||||
updateStreamTimeoutSettings,
|
updateStreamTimeoutSettings,
|
||||||
getRectifierSettings,
|
getRectifierSettings,
|
||||||
updateRectifierSettings,
|
updateRectifierSettings,
|
||||||
|
getBetaPolicySettings,
|
||||||
|
updateBetaPolicySettings,
|
||||||
getSoraS3Settings,
|
getSoraS3Settings,
|
||||||
updateSoraS3Settings,
|
updateSoraS3Settings,
|
||||||
testSoraS3Connection,
|
testSoraS3Connection,
|
||||||
|
|||||||
@@ -335,6 +335,28 @@ export async function resetPassword(request: ResetPasswordRequest): Promise<Rese
|
|||||||
return data
|
return data
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Complete LinuxDo OAuth registration by supplying an invitation code
|
||||||
|
* @param pendingOAuthToken - Short-lived JWT from the OAuth callback
|
||||||
|
* @param invitationCode - Invitation code entered by the user
|
||||||
|
* @returns Token pair on success
|
||||||
|
*/
|
||||||
|
export async function completeLinuxDoOAuthRegistration(
|
||||||
|
pendingOAuthToken: string,
|
||||||
|
invitationCode: string
|
||||||
|
): Promise<{ access_token: string; refresh_token: string; expires_in: number; token_type: string }> {
|
||||||
|
const { data } = await apiClient.post<{
|
||||||
|
access_token: string
|
||||||
|
refresh_token: string
|
||||||
|
expires_in: number
|
||||||
|
token_type: string
|
||||||
|
}>('/auth/oauth/linuxdo/complete-registration', {
|
||||||
|
pending_oauth_token: pendingOAuthToken,
|
||||||
|
invitation_code: invitationCode
|
||||||
|
})
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
export const authAPI = {
|
export const authAPI = {
|
||||||
login,
|
login,
|
||||||
login2FA,
|
login2FA,
|
||||||
@@ -357,7 +379,8 @@ export const authAPI = {
|
|||||||
forgotPassword,
|
forgotPassword,
|
||||||
resetPassword,
|
resetPassword,
|
||||||
refreshToken,
|
refreshToken,
|
||||||
revokeAllSessions
|
revokeAllSessions,
|
||||||
|
completeLinuxDoOAuthRegistration
|
||||||
}
|
}
|
||||||
|
|
||||||
export default authAPI
|
export default authAPI
|
||||||
|
|||||||
@@ -73,11 +73,9 @@
|
|||||||
</div>
|
</div>
|
||||||
|
|
||||||
<!-- API Key 账号配额限制 -->
|
<!-- API Key 账号配额限制 -->
|
||||||
<div v-if="showDailyQuota || showWeeklyQuota || showTotalQuota" class="flex items-center gap-1">
|
<QuotaBadge v-if="showDailyQuota" :used="account.quota_daily_used ?? 0" :limit="account.quota_daily_limit!" label="D" />
|
||||||
<QuotaBadge v-if="showDailyQuota" :used="account.quota_daily_used ?? 0" :limit="account.quota_daily_limit!" label="D" />
|
<QuotaBadge v-if="showWeeklyQuota" :used="account.quota_weekly_used ?? 0" :limit="account.quota_weekly_limit!" label="W" />
|
||||||
<QuotaBadge v-if="showWeeklyQuota" :used="account.quota_weekly_used ?? 0" :limit="account.quota_weekly_limit!" label="W" />
|
<QuotaBadge v-if="showTotalQuota" :used="account.quota_used ?? 0" :limit="account.quota_limit!" />
|
||||||
<QuotaBadge v-if="showTotalQuota" :used="account.quota_used ?? 0" :limit="account.quota_limit!" />
|
|
||||||
</div>
|
|
||||||
</div>
|
</div>
|
||||||
</template>
|
</template>
|
||||||
|
|
||||||
|
|||||||
@@ -20,6 +20,8 @@
|
|||||||
</div>
|
</div>
|
||||||
<div class="flex gap-2">
|
<div class="flex gap-2">
|
||||||
<button @click="$emit('delete')" class="btn btn-danger btn-sm">{{ t('admin.accounts.bulkActions.delete') }}</button>
|
<button @click="$emit('delete')" class="btn btn-danger btn-sm">{{ t('admin.accounts.bulkActions.delete') }}</button>
|
||||||
|
<button @click="$emit('reset-status')" class="btn btn-secondary btn-sm">{{ t('admin.accounts.bulkActions.resetStatus') }}</button>
|
||||||
|
<button @click="$emit('refresh-token')" class="btn btn-secondary btn-sm">{{ t('admin.accounts.bulkActions.refreshToken') }}</button>
|
||||||
<button @click="$emit('toggle-schedulable', true)" class="btn btn-success btn-sm">{{ t('admin.accounts.bulkActions.enableScheduling') }}</button>
|
<button @click="$emit('toggle-schedulable', true)" class="btn btn-success btn-sm">{{ t('admin.accounts.bulkActions.enableScheduling') }}</button>
|
||||||
<button @click="$emit('toggle-schedulable', false)" class="btn btn-warning btn-sm">{{ t('admin.accounts.bulkActions.disableScheduling') }}</button>
|
<button @click="$emit('toggle-schedulable', false)" class="btn btn-warning btn-sm">{{ t('admin.accounts.bulkActions.disableScheduling') }}</button>
|
||||||
<button @click="$emit('edit')" class="btn btn-primary btn-sm">{{ t('admin.accounts.bulkActions.edit') }}</button>
|
<button @click="$emit('edit')" class="btn btn-primary btn-sm">{{ t('admin.accounts.bulkActions.edit') }}</button>
|
||||||
@@ -29,5 +31,5 @@
|
|||||||
|
|
||||||
<script setup lang="ts">
|
<script setup lang="ts">
|
||||||
import { useI18n } from 'vue-i18n'
|
import { useI18n } from 'vue-i18n'
|
||||||
defineProps(['selectedIds']); defineEmits(['delete', 'edit', 'clear', 'select-page', 'toggle-schedulable']); const { t } = useI18n()
|
defineProps(['selectedIds']); defineEmits(['delete', 'edit', 'clear', 'select-page', 'toggle-schedulable', 'reset-status', 'refresh-token']); const { t } = useI18n()
|
||||||
</script>
|
</script>
|
||||||
@@ -162,8 +162,7 @@ const load = async () => {
|
|||||||
const loadGroups = async () => {
|
const loadGroups = async () => {
|
||||||
try {
|
try {
|
||||||
const groups = await adminAPI.groups.getAll()
|
const groups = await adminAPI.groups.getAll()
|
||||||
// 过滤掉订阅类型分组(需通过订阅管理流程绑定)
|
allGroups.value = groups
|
||||||
allGroups.value = groups.filter((g) => g.subscription_type !== 'subscription')
|
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('Failed to load groups:', error)
|
console.error('Failed to load groups:', error)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,18 +1,40 @@
|
|||||||
<script setup lang="ts">
|
<script setup lang="ts">
|
||||||
import { ref } from 'vue'
|
import { ref, useTemplateRef, nextTick } from 'vue'
|
||||||
|
|
||||||
defineProps<{
|
defineProps<{
|
||||||
content?: string
|
content?: string
|
||||||
}>()
|
}>()
|
||||||
|
|
||||||
const show = ref(false)
|
const show = ref(false)
|
||||||
|
const triggerRef = useTemplateRef<HTMLElement>('trigger')
|
||||||
|
const tooltipStyle = ref({ top: '0px', left: '0px' })
|
||||||
|
|
||||||
|
function onEnter() {
|
||||||
|
show.value = true
|
||||||
|
nextTick(updatePosition)
|
||||||
|
}
|
||||||
|
|
||||||
|
function onLeave() {
|
||||||
|
show.value = false
|
||||||
|
}
|
||||||
|
|
||||||
|
function updatePosition() {
|
||||||
|
const el = triggerRef.value
|
||||||
|
if (!el) return
|
||||||
|
const rect = el.getBoundingClientRect()
|
||||||
|
tooltipStyle.value = {
|
||||||
|
top: `${rect.top + window.scrollY}px`,
|
||||||
|
left: `${rect.left + rect.width / 2 + window.scrollX}px`,
|
||||||
|
}
|
||||||
|
}
|
||||||
</script>
|
</script>
|
||||||
|
|
||||||
<template>
|
<template>
|
||||||
<div
|
<div
|
||||||
|
ref="trigger"
|
||||||
class="group relative ml-1 inline-flex items-center align-middle"
|
class="group relative ml-1 inline-flex items-center align-middle"
|
||||||
@mouseenter="show = true"
|
@mouseenter="onEnter"
|
||||||
@mouseleave="show = false"
|
@mouseleave="onLeave"
|
||||||
>
|
>
|
||||||
<!-- Trigger Icon -->
|
<!-- Trigger Icon -->
|
||||||
<slot name="trigger">
|
<slot name="trigger">
|
||||||
@@ -31,14 +53,16 @@ const show = ref(false)
|
|||||||
</svg>
|
</svg>
|
||||||
</slot>
|
</slot>
|
||||||
|
|
||||||
<!-- Popover Content -->
|
<!-- Teleport to body to escape modal overflow clipping -->
|
||||||
<div
|
<Teleport to="body">
|
||||||
v-show="show"
|
<div
|
||||||
class="absolute bottom-full left-1/2 z-50 mb-2 w-64 -translate-x-1/2 rounded-lg bg-gray-900 p-3 text-xs leading-relaxed text-white shadow-xl ring-1 ring-white/10 opacity-0 transition-opacity duration-200 group-hover:opacity-100 dark:bg-gray-800"
|
v-show="show"
|
||||||
>
|
class="fixed z-[99999] w-64 -translate-x-1/2 -translate-y-full rounded-lg bg-gray-900 p-3 text-xs leading-relaxed text-white shadow-xl ring-1 ring-white/10 dark:bg-gray-800"
|
||||||
<slot>{{ content }}</slot>
|
:style="{ top: `calc(${tooltipStyle.top} - 8px)`, left: tooltipStyle.left }"
|
||||||
<div class="absolute -bottom-1 left-1/2 h-2 w-2 -translate-x-1/2 rotate-45 bg-gray-900 dark:bg-gray-800"></div>
|
>
|
||||||
</div>
|
<slot>{{ content }}</slot>
|
||||||
|
<div class="absolute -bottom-1 left-1/2 h-2 w-2 -translate-x-1/2 rotate-45 bg-gray-900 dark:bg-gray-800"></div>
|
||||||
|
</div>
|
||||||
|
</Teleport>
|
||||||
</div>
|
</div>
|
||||||
</template>
|
</template>
|
||||||
|
|
||||||
|
|||||||
@@ -28,6 +28,10 @@
|
|||||||
<Icon v-else name="key" size="xs" />
|
<Icon v-else name="key" size="xs" />
|
||||||
<span>{{ typeLabel }}</span>
|
<span>{{ typeLabel }}</span>
|
||||||
</span>
|
</span>
|
||||||
|
<!-- Plan type part (optional) -->
|
||||||
|
<span v-if="planLabel" :class="['inline-flex items-center gap-1 px-1.5 py-1 border-l border-white/20', typeClass]">
|
||||||
|
<span>{{ planLabel }}</span>
|
||||||
|
</span>
|
||||||
</div>
|
</div>
|
||||||
</template>
|
</template>
|
||||||
|
|
||||||
@@ -40,6 +44,7 @@ import Icon from '@/components/icons/Icon.vue'
|
|||||||
interface Props {
|
interface Props {
|
||||||
platform: AccountPlatform
|
platform: AccountPlatform
|
||||||
type: AccountType
|
type: AccountType
|
||||||
|
planType?: string
|
||||||
}
|
}
|
||||||
|
|
||||||
const props = defineProps<Props>()
|
const props = defineProps<Props>()
|
||||||
@@ -65,6 +70,24 @@ const typeLabel = computed(() => {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
const planLabel = computed(() => {
|
||||||
|
if (!props.planType) return ''
|
||||||
|
const lower = props.planType.toLowerCase()
|
||||||
|
switch (lower) {
|
||||||
|
case 'plus':
|
||||||
|
return 'Plus'
|
||||||
|
case 'team':
|
||||||
|
return 'Team'
|
||||||
|
case 'chatgptpro':
|
||||||
|
case 'pro':
|
||||||
|
return 'Pro'
|
||||||
|
case 'free':
|
||||||
|
return 'Free'
|
||||||
|
default:
|
||||||
|
return props.planType
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
const platformClass = computed(() => {
|
const platformClass = computed(() => {
|
||||||
if (props.platform === 'anthropic') {
|
if (props.platform === 'anthropic') {
|
||||||
return 'bg-orange-100 text-orange-700 dark:bg-orange-900/30 dark:text-orange-400'
|
return 'bg-orange-100 text-orange-700 dark:bg-orange-900/30 dark:text-orange-400'
|
||||||
|
|||||||
@@ -434,7 +434,12 @@ export default {
|
|||||||
callbackProcessing: 'Completing login, please wait...',
|
callbackProcessing: 'Completing login, please wait...',
|
||||||
callbackHint: 'If you are not redirected automatically, go back to the login page and try again.',
|
callbackHint: 'If you are not redirected automatically, go back to the login page and try again.',
|
||||||
callbackMissingToken: 'Missing login token, please try again.',
|
callbackMissingToken: 'Missing login token, please try again.',
|
||||||
backToLogin: 'Back to Login'
|
backToLogin: 'Back to Login',
|
||||||
|
invitationRequired: 'This Linux.do account is not yet registered. The site requires an invitation code — please enter one to complete registration.',
|
||||||
|
invalidPendingToken: 'The registration token has expired. Please sign in with Linux.do again.',
|
||||||
|
completeRegistration: 'Complete Registration',
|
||||||
|
completing: 'Completing registration…',
|
||||||
|
completeRegistrationFailed: 'Registration failed. Please check your invitation code and try again.'
|
||||||
},
|
},
|
||||||
oauth: {
|
oauth: {
|
||||||
code: 'Code',
|
code: 'Code',
|
||||||
@@ -1836,7 +1841,12 @@ export default {
|
|||||||
edit: 'Bulk Edit',
|
edit: 'Bulk Edit',
|
||||||
delete: 'Bulk Delete',
|
delete: 'Bulk Delete',
|
||||||
enableScheduling: 'Enable Scheduling',
|
enableScheduling: 'Enable Scheduling',
|
||||||
disableScheduling: 'Disable Scheduling'
|
disableScheduling: 'Disable Scheduling',
|
||||||
|
resetStatus: 'Reset Status',
|
||||||
|
refreshToken: 'Refresh Token',
|
||||||
|
resetStatusSuccess: 'Successfully reset {count} account(s) status',
|
||||||
|
refreshTokenSuccess: 'Successfully refreshed {count} account(s) token',
|
||||||
|
partialSuccess: 'Partially completed: {success} succeeded, {failed} failed'
|
||||||
},
|
},
|
||||||
bulkEdit: {
|
bulkEdit: {
|
||||||
title: 'Bulk Edit Accounts',
|
title: 'Bulk Edit Accounts',
|
||||||
@@ -4033,6 +4043,23 @@ export default {
|
|||||||
saved: 'Rectifier settings saved',
|
saved: 'Rectifier settings saved',
|
||||||
saveFailed: 'Failed to save rectifier settings'
|
saveFailed: 'Failed to save rectifier settings'
|
||||||
},
|
},
|
||||||
|
betaPolicy: {
|
||||||
|
title: 'Beta Policy',
|
||||||
|
description: 'How to handle Beta features when configuring the forwarding of Anthropic API requests. Applicable only to the /v1/messages endpoint.',
|
||||||
|
action: 'Action',
|
||||||
|
actionPass: 'Pass (transparent)',
|
||||||
|
actionFilter: 'Filter (remove)',
|
||||||
|
actionBlock: 'Block (reject)',
|
||||||
|
scope: 'Scope',
|
||||||
|
scopeAll: 'All accounts',
|
||||||
|
scopeOAuth: 'OAuth only',
|
||||||
|
scopeAPIKey: 'API Key only',
|
||||||
|
errorMessage: 'Error message',
|
||||||
|
errorMessagePlaceholder: 'Custom error message when blocked',
|
||||||
|
errorMessageHint: 'Leave empty for default message',
|
||||||
|
saved: 'Beta policy settings saved',
|
||||||
|
saveFailed: 'Failed to save beta policy settings'
|
||||||
|
},
|
||||||
saveSettings: 'Save Settings',
|
saveSettings: 'Save Settings',
|
||||||
saving: 'Saving...',
|
saving: 'Saving...',
|
||||||
settingsSaved: 'Settings saved successfully',
|
settingsSaved: 'Settings saved successfully',
|
||||||
|
|||||||
@@ -433,7 +433,12 @@ export default {
|
|||||||
callbackProcessing: '正在验证登录信息,请稍候...',
|
callbackProcessing: '正在验证登录信息,请稍候...',
|
||||||
callbackHint: '如果页面未自动跳转,请返回登录页重试。',
|
callbackHint: '如果页面未自动跳转,请返回登录页重试。',
|
||||||
callbackMissingToken: '登录信息缺失,请返回重试。',
|
callbackMissingToken: '登录信息缺失,请返回重试。',
|
||||||
backToLogin: '返回登录'
|
backToLogin: '返回登录',
|
||||||
|
invitationRequired: '该 Linux.do 账号尚未注册,站点已开启邀请码注册,请输入邀请码以完成注册。',
|
||||||
|
invalidPendingToken: '注册凭证已失效,请重新使用 Linux.do 登录。',
|
||||||
|
completeRegistration: '完成注册',
|
||||||
|
completing: '正在完成注册...',
|
||||||
|
completeRegistrationFailed: '注册失败,请检查邀请码后重试。'
|
||||||
},
|
},
|
||||||
oauth: {
|
oauth: {
|
||||||
code: '授权码',
|
code: '授权码',
|
||||||
@@ -1983,7 +1988,12 @@ export default {
|
|||||||
edit: '批量编辑账号',
|
edit: '批量编辑账号',
|
||||||
delete: '批量删除',
|
delete: '批量删除',
|
||||||
enableScheduling: '批量启用调度',
|
enableScheduling: '批量启用调度',
|
||||||
disableScheduling: '批量停止调度'
|
disableScheduling: '批量停止调度',
|
||||||
|
resetStatus: '批量重置状态',
|
||||||
|
refreshToken: '批量刷新令牌',
|
||||||
|
resetStatusSuccess: '已成功重置 {count} 个账号状态',
|
||||||
|
refreshTokenSuccess: '已成功刷新 {count} 个账号令牌',
|
||||||
|
partialSuccess: '操作部分完成:{success} 成功,{failed} 失败'
|
||||||
},
|
},
|
||||||
bulkEdit: {
|
bulkEdit: {
|
||||||
title: '批量编辑账号',
|
title: '批量编辑账号',
|
||||||
@@ -4206,6 +4216,23 @@ export default {
|
|||||||
saved: '整流器设置保存成功',
|
saved: '整流器设置保存成功',
|
||||||
saveFailed: '保存整流器设置失败'
|
saveFailed: '保存整流器设置失败'
|
||||||
},
|
},
|
||||||
|
betaPolicy: {
|
||||||
|
title: 'Beta 策略',
|
||||||
|
description: '配置转发 Anthropic API 请求时如何处理 Beta 特性。仅适用于 /v1/messages 接口。',
|
||||||
|
action: '处理方式',
|
||||||
|
actionPass: '透传(不处理)',
|
||||||
|
actionFilter: '过滤(移除)',
|
||||||
|
actionBlock: '拦截(拒绝请求)',
|
||||||
|
scope: '生效范围',
|
||||||
|
scopeAll: '全部账号',
|
||||||
|
scopeOAuth: '仅 OAuth 账号',
|
||||||
|
scopeAPIKey: '仅 API Key 账号',
|
||||||
|
errorMessage: '错误消息',
|
||||||
|
errorMessagePlaceholder: '拦截时返回的自定义错误消息',
|
||||||
|
errorMessageHint: '留空则使用默认错误消息',
|
||||||
|
saved: 'Beta 策略设置保存成功',
|
||||||
|
saveFailed: '保存 Beta 策略设置失败'
|
||||||
|
},
|
||||||
saveSettings: '保存设置',
|
saveSettings: '保存设置',
|
||||||
saving: '保存中...',
|
saving: '保存中...',
|
||||||
settingsSaved: '设置保存成功',
|
settingsSaved: '设置保存成功',
|
||||||
|
|||||||
67
frontend/src/utils/__tests__/embedded-url.spec.ts
Normal file
67
frontend/src/utils/__tests__/embedded-url.spec.ts
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||||
|
import { buildEmbeddedUrl, detectTheme } from '../embedded-url'
|
||||||
|
|
||||||
|
describe('embedded-url', () => {
|
||||||
|
const originalLocation = window.location
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
Object.defineProperty(window, 'location', {
|
||||||
|
value: {
|
||||||
|
origin: 'https://app.example.com',
|
||||||
|
href: 'https://app.example.com/user/purchase',
|
||||||
|
},
|
||||||
|
writable: true,
|
||||||
|
configurable: true,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
afterEach(() => {
|
||||||
|
Object.defineProperty(window, 'location', {
|
||||||
|
value: originalLocation,
|
||||||
|
writable: true,
|
||||||
|
configurable: true,
|
||||||
|
})
|
||||||
|
document.documentElement.classList.remove('dark')
|
||||||
|
vi.restoreAllMocks()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('adds embedded query parameters including locale and source context', () => {
|
||||||
|
const result = buildEmbeddedUrl(
|
||||||
|
'https://pay.example.com/checkout?plan=pro',
|
||||||
|
42,
|
||||||
|
'token-123',
|
||||||
|
'dark',
|
||||||
|
'zh-CN',
|
||||||
|
)
|
||||||
|
|
||||||
|
const url = new URL(result)
|
||||||
|
expect(url.searchParams.get('plan')).toBe('pro')
|
||||||
|
expect(url.searchParams.get('user_id')).toBe('42')
|
||||||
|
expect(url.searchParams.get('token')).toBe('token-123')
|
||||||
|
expect(url.searchParams.get('theme')).toBe('dark')
|
||||||
|
expect(url.searchParams.get('lang')).toBe('zh-CN')
|
||||||
|
expect(url.searchParams.get('ui_mode')).toBe('embedded')
|
||||||
|
expect(url.searchParams.get('src_host')).toBe('https://app.example.com')
|
||||||
|
expect(url.searchParams.get('src_url')).toBe('https://app.example.com/user/purchase')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('omits optional params when they are empty', () => {
|
||||||
|
const result = buildEmbeddedUrl('https://pay.example.com/checkout', undefined, '', 'light')
|
||||||
|
|
||||||
|
const url = new URL(result)
|
||||||
|
expect(url.searchParams.get('theme')).toBe('light')
|
||||||
|
expect(url.searchParams.get('ui_mode')).toBe('embedded')
|
||||||
|
expect(url.searchParams.has('user_id')).toBe(false)
|
||||||
|
expect(url.searchParams.has('token')).toBe(false)
|
||||||
|
expect(url.searchParams.has('lang')).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('returns original string for invalid url input', () => {
|
||||||
|
expect(buildEmbeddedUrl('not a url', 1, 'token')).toBe('not a url')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('detects dark mode from document root class', () => {
|
||||||
|
document.documentElement.classList.add('dark')
|
||||||
|
expect(detectTheme()).toBe('dark')
|
||||||
|
})
|
||||||
|
})
|
||||||
@@ -1,12 +1,13 @@
|
|||||||
/**
|
/**
|
||||||
* Shared URL builder for iframe-embedded pages.
|
* Shared URL builder for iframe-embedded pages.
|
||||||
* Used by PurchaseSubscriptionView and CustomPageView to build consistent URLs
|
* Used by PurchaseSubscriptionView and CustomPageView to build consistent URLs
|
||||||
* with user_id, token, theme, ui_mode, src_host, and src parameters.
|
* with user_id, token, theme, lang, ui_mode, src_host, and src parameters.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
const EMBEDDED_USER_ID_QUERY_KEY = 'user_id'
|
const EMBEDDED_USER_ID_QUERY_KEY = 'user_id'
|
||||||
const EMBEDDED_AUTH_TOKEN_QUERY_KEY = 'token'
|
const EMBEDDED_AUTH_TOKEN_QUERY_KEY = 'token'
|
||||||
const EMBEDDED_THEME_QUERY_KEY = 'theme'
|
const EMBEDDED_THEME_QUERY_KEY = 'theme'
|
||||||
|
const EMBEDDED_LANG_QUERY_KEY = 'lang'
|
||||||
const EMBEDDED_UI_MODE_QUERY_KEY = 'ui_mode'
|
const EMBEDDED_UI_MODE_QUERY_KEY = 'ui_mode'
|
||||||
const EMBEDDED_UI_MODE_VALUE = 'embedded'
|
const EMBEDDED_UI_MODE_VALUE = 'embedded'
|
||||||
const EMBEDDED_SRC_HOST_QUERY_KEY = 'src_host'
|
const EMBEDDED_SRC_HOST_QUERY_KEY = 'src_host'
|
||||||
@@ -17,6 +18,7 @@ export function buildEmbeddedUrl(
|
|||||||
userId?: number,
|
userId?: number,
|
||||||
authToken?: string | null,
|
authToken?: string | null,
|
||||||
theme: 'light' | 'dark' = 'light',
|
theme: 'light' | 'dark' = 'light',
|
||||||
|
lang?: string,
|
||||||
): string {
|
): string {
|
||||||
if (!baseUrl) return baseUrl
|
if (!baseUrl) return baseUrl
|
||||||
try {
|
try {
|
||||||
@@ -28,6 +30,9 @@ export function buildEmbeddedUrl(
|
|||||||
url.searchParams.set(EMBEDDED_AUTH_TOKEN_QUERY_KEY, authToken)
|
url.searchParams.set(EMBEDDED_AUTH_TOKEN_QUERY_KEY, authToken)
|
||||||
}
|
}
|
||||||
url.searchParams.set(EMBEDDED_THEME_QUERY_KEY, theme)
|
url.searchParams.set(EMBEDDED_THEME_QUERY_KEY, theme)
|
||||||
|
if (lang) {
|
||||||
|
url.searchParams.set(EMBEDDED_LANG_QUERY_KEY, lang)
|
||||||
|
}
|
||||||
url.searchParams.set(EMBEDDED_UI_MODE_QUERY_KEY, EMBEDDED_UI_MODE_VALUE)
|
url.searchParams.set(EMBEDDED_UI_MODE_QUERY_KEY, EMBEDDED_UI_MODE_VALUE)
|
||||||
// Source tracking: let the embedded page know where it's being loaded from
|
// Source tracking: let the embedded page know where it's being loaded from
|
||||||
if (typeof window !== 'undefined') {
|
if (typeof window !== 'undefined') {
|
||||||
|
|||||||
@@ -131,7 +131,7 @@
|
|||||||
</div>
|
</div>
|
||||||
</template>
|
</template>
|
||||||
<template #table>
|
<template #table>
|
||||||
<AccountBulkActionsBar :selected-ids="selIds" @delete="handleBulkDelete" @edit="showBulkEdit = true" @clear="clearSelection" @select-page="selectPage" @toggle-schedulable="handleBulkToggleSchedulable" />
|
<AccountBulkActionsBar :selected-ids="selIds" @delete="handleBulkDelete" @reset-status="handleBulkResetStatus" @refresh-token="handleBulkRefreshToken" @edit="showBulkEdit = true" @clear="clearSelection" @select-page="selectPage" @toggle-schedulable="handleBulkToggleSchedulable" />
|
||||||
<div ref="accountTableRef" class="flex min-h-0 flex-1 flex-col overflow-hidden">
|
<div ref="accountTableRef" class="flex min-h-0 flex-1 flex-col overflow-hidden">
|
||||||
<DataTable
|
<DataTable
|
||||||
:columns="cols"
|
:columns="cols"
|
||||||
@@ -171,7 +171,7 @@
|
|||||||
<span v-else class="text-sm text-gray-400 dark:text-dark-500">-</span>
|
<span v-else class="text-sm text-gray-400 dark:text-dark-500">-</span>
|
||||||
</template>
|
</template>
|
||||||
<template #cell-platform_type="{ row }">
|
<template #cell-platform_type="{ row }">
|
||||||
<PlatformTypeBadge :platform="row.platform" :type="row.type" />
|
<PlatformTypeBadge :platform="row.platform" :type="row.type" :plan-type="row.credentials?.plan_type" />
|
||||||
</template>
|
</template>
|
||||||
<template #cell-capacity="{ row }">
|
<template #cell-capacity="{ row }">
|
||||||
<AccountCapacityCell :account="row" />
|
<AccountCapacityCell :account="row" />
|
||||||
@@ -889,6 +889,38 @@ const toggleSelectAllVisible = (event: Event) => {
|
|||||||
toggleVisible(target.checked)
|
toggleVisible(target.checked)
|
||||||
}
|
}
|
||||||
const handleBulkDelete = async () => { if(!confirm(t('common.confirm'))) return; try { await Promise.all(selIds.value.map(id => adminAPI.accounts.delete(id))); clearSelection(); reload() } catch (error) { console.error('Failed to bulk delete accounts:', error) } }
|
const handleBulkDelete = async () => { if(!confirm(t('common.confirm'))) return; try { await Promise.all(selIds.value.map(id => adminAPI.accounts.delete(id))); clearSelection(); reload() } catch (error) { console.error('Failed to bulk delete accounts:', error) } }
|
||||||
|
const handleBulkResetStatus = async () => {
|
||||||
|
if (!confirm(t('common.confirm'))) return
|
||||||
|
try {
|
||||||
|
const result = await adminAPI.accounts.batchClearError(selIds.value)
|
||||||
|
if (result.failed > 0) {
|
||||||
|
appStore.showError(t('admin.accounts.bulkActions.partialSuccess', { success: result.success, failed: result.failed }))
|
||||||
|
} else {
|
||||||
|
appStore.showSuccess(t('admin.accounts.bulkActions.resetStatusSuccess', { count: result.success }))
|
||||||
|
clearSelection()
|
||||||
|
}
|
||||||
|
reload()
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Failed to bulk reset status:', error)
|
||||||
|
appStore.showError(String(error))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const handleBulkRefreshToken = async () => {
|
||||||
|
if (!confirm(t('common.confirm'))) return
|
||||||
|
try {
|
||||||
|
const result = await adminAPI.accounts.batchRefresh(selIds.value)
|
||||||
|
if (result.failed > 0) {
|
||||||
|
appStore.showError(t('admin.accounts.bulkActions.partialSuccess', { success: result.success, failed: result.failed }))
|
||||||
|
} else {
|
||||||
|
appStore.showSuccess(t('admin.accounts.bulkActions.refreshTokenSuccess', { count: result.success }))
|
||||||
|
clearSelection()
|
||||||
|
}
|
||||||
|
reload()
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Failed to bulk refresh token:', error)
|
||||||
|
appStore.showError(String(error))
|
||||||
|
}
|
||||||
|
}
|
||||||
const updateSchedulableInList = (accountIds: number[], schedulable: boolean) => {
|
const updateSchedulableInList = (accountIds: number[], schedulable: boolean) => {
|
||||||
if (accountIds.length === 0) return
|
if (accountIds.length === 0) return
|
||||||
const idSet = new Set(accountIds)
|
const idSet = new Set(accountIds)
|
||||||
|
|||||||
@@ -405,6 +405,117 @@
|
|||||||
</template>
|
</template>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
<!-- Beta Policy Settings -->
|
||||||
|
<div class="card">
|
||||||
|
<div class="border-b border-gray-100 px-6 py-4 dark:border-dark-700">
|
||||||
|
<h2 class="text-lg font-semibold text-gray-900 dark:text-white">
|
||||||
|
{{ t('admin.settings.betaPolicy.title') }}
|
||||||
|
</h2>
|
||||||
|
<p class="mt-1 text-sm text-gray-500 dark:text-gray-400">
|
||||||
|
{{ t('admin.settings.betaPolicy.description') }}
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
<div class="space-y-5 p-6">
|
||||||
|
<!-- Loading State -->
|
||||||
|
<div v-if="betaPolicyLoading" class="flex items-center gap-2 text-gray-500">
|
||||||
|
<div class="h-4 w-4 animate-spin rounded-full border-b-2 border-primary-600"></div>
|
||||||
|
{{ t('common.loading') }}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<template v-else>
|
||||||
|
<!-- Rule Cards -->
|
||||||
|
<div
|
||||||
|
v-for="rule in betaPolicyForm.rules"
|
||||||
|
:key="rule.beta_token"
|
||||||
|
class="rounded-lg border border-gray-200 p-4 dark:border-dark-600"
|
||||||
|
>
|
||||||
|
<div class="mb-3 flex items-center gap-2">
|
||||||
|
<span class="text-sm font-medium text-gray-900 dark:text-white">
|
||||||
|
{{ getBetaDisplayName(rule.beta_token) }}
|
||||||
|
</span>
|
||||||
|
<span class="rounded bg-gray-100 px-2 py-0.5 text-xs text-gray-500 dark:bg-dark-700 dark:text-gray-400">
|
||||||
|
{{ rule.beta_token }}
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="grid grid-cols-2 gap-4">
|
||||||
|
<!-- Action -->
|
||||||
|
<div>
|
||||||
|
<label class="mb-1 block text-xs font-medium text-gray-600 dark:text-gray-400">
|
||||||
|
{{ t('admin.settings.betaPolicy.action') }}
|
||||||
|
</label>
|
||||||
|
<Select
|
||||||
|
:modelValue="rule.action"
|
||||||
|
@update:modelValue="rule.action = $event as any"
|
||||||
|
:options="betaPolicyActionOptions"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- Scope -->
|
||||||
|
<div>
|
||||||
|
<label class="mb-1 block text-xs font-medium text-gray-600 dark:text-gray-400">
|
||||||
|
{{ t('admin.settings.betaPolicy.scope') }}
|
||||||
|
</label>
|
||||||
|
<Select
|
||||||
|
:modelValue="rule.scope"
|
||||||
|
@update:modelValue="rule.scope = $event as any"
|
||||||
|
:options="betaPolicyScopeOptions"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- Error Message (only when action=block) -->
|
||||||
|
<div v-if="rule.action === 'block'" class="mt-3">
|
||||||
|
<label class="mb-1 block text-xs font-medium text-gray-600 dark:text-gray-400">
|
||||||
|
{{ t('admin.settings.betaPolicy.errorMessage') }}
|
||||||
|
</label>
|
||||||
|
<input
|
||||||
|
v-model="rule.error_message"
|
||||||
|
type="text"
|
||||||
|
class="input"
|
||||||
|
:placeholder="t('admin.settings.betaPolicy.errorMessagePlaceholder')"
|
||||||
|
/>
|
||||||
|
<p class="mt-1 text-xs text-gray-400 dark:text-gray-500">
|
||||||
|
{{ t('admin.settings.betaPolicy.errorMessageHint') }}
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- Save Button -->
|
||||||
|
<div class="flex justify-end border-t border-gray-100 pt-4 dark:border-dark-700">
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
@click="saveBetaPolicySettings"
|
||||||
|
:disabled="betaPolicySaving"
|
||||||
|
class="btn btn-primary btn-sm"
|
||||||
|
>
|
||||||
|
<svg
|
||||||
|
v-if="betaPolicySaving"
|
||||||
|
class="mr-1 h-4 w-4 animate-spin"
|
||||||
|
fill="none"
|
||||||
|
viewBox="0 0 24 24"
|
||||||
|
>
|
||||||
|
<circle
|
||||||
|
class="opacity-25"
|
||||||
|
cx="12"
|
||||||
|
cy="12"
|
||||||
|
r="10"
|
||||||
|
stroke="currentColor"
|
||||||
|
stroke-width="4"
|
||||||
|
></circle>
|
||||||
|
<path
|
||||||
|
class="opacity-75"
|
||||||
|
fill="currentColor"
|
||||||
|
d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"
|
||||||
|
></path>
|
||||||
|
</svg>
|
||||||
|
{{ betaPolicySaving ? t('common.saving') : t('common.save') }}
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</template>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
</div><!-- /Tab: Gateway -->
|
</div><!-- /Tab: Gateway -->
|
||||||
|
|
||||||
<!-- Tab: Security — Registration, Turnstile, LinuxDo -->
|
<!-- Tab: Security — Registration, Turnstile, LinuxDo -->
|
||||||
@@ -1627,6 +1738,18 @@ const rectifierForm = reactive({
|
|||||||
thinking_budget_enabled: true
|
thinking_budget_enabled: true
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// Beta Policy 状态
|
||||||
|
const betaPolicyLoading = ref(true)
|
||||||
|
const betaPolicySaving = ref(false)
|
||||||
|
const betaPolicyForm = reactive({
|
||||||
|
rules: [] as Array<{
|
||||||
|
beta_token: string
|
||||||
|
action: 'pass' | 'filter' | 'block'
|
||||||
|
scope: 'all' | 'oauth' | 'apikey'
|
||||||
|
error_message?: string
|
||||||
|
}>
|
||||||
|
})
|
||||||
|
|
||||||
interface DefaultSubscriptionGroupOption {
|
interface DefaultSubscriptionGroupOption {
|
||||||
value: number
|
value: number
|
||||||
label: string
|
label: string
|
||||||
@@ -2165,12 +2288,64 @@ async function saveRectifierSettings() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const betaPolicyActionOptions = computed(() => [
|
||||||
|
{ value: 'pass', label: t('admin.settings.betaPolicy.actionPass') },
|
||||||
|
{ value: 'filter', label: t('admin.settings.betaPolicy.actionFilter') },
|
||||||
|
{ value: 'block', label: t('admin.settings.betaPolicy.actionBlock') }
|
||||||
|
])
|
||||||
|
|
||||||
|
const betaPolicyScopeOptions = computed(() => [
|
||||||
|
{ value: 'all', label: t('admin.settings.betaPolicy.scopeAll') },
|
||||||
|
{ value: 'oauth', label: t('admin.settings.betaPolicy.scopeOAuth') },
|
||||||
|
{ value: 'apikey', label: t('admin.settings.betaPolicy.scopeAPIKey') }
|
||||||
|
])
|
||||||
|
|
||||||
|
// Beta Policy 方法
|
||||||
|
const betaDisplayNames: Record<string, string> = {
|
||||||
|
'fast-mode-2026-02-01': 'Fast Mode',
|
||||||
|
'context-1m-2025-08-07': 'Context 1M'
|
||||||
|
}
|
||||||
|
|
||||||
|
function getBetaDisplayName(token: string): string {
|
||||||
|
return betaDisplayNames[token] || token
|
||||||
|
}
|
||||||
|
|
||||||
|
async function loadBetaPolicySettings() {
|
||||||
|
betaPolicyLoading.value = true
|
||||||
|
try {
|
||||||
|
const settings = await adminAPI.settings.getBetaPolicySettings()
|
||||||
|
betaPolicyForm.rules = settings.rules
|
||||||
|
} catch (error: any) {
|
||||||
|
console.error('Failed to load beta policy settings:', error)
|
||||||
|
} finally {
|
||||||
|
betaPolicyLoading.value = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async function saveBetaPolicySettings() {
|
||||||
|
betaPolicySaving.value = true
|
||||||
|
try {
|
||||||
|
const updated = await adminAPI.settings.updateBetaPolicySettings({
|
||||||
|
rules: betaPolicyForm.rules
|
||||||
|
})
|
||||||
|
betaPolicyForm.rules = updated.rules
|
||||||
|
appStore.showSuccess(t('admin.settings.betaPolicy.saved'))
|
||||||
|
} catch (error: any) {
|
||||||
|
appStore.showError(
|
||||||
|
t('admin.settings.betaPolicy.saveFailed') + ': ' + (error.message || t('common.unknownError'))
|
||||||
|
)
|
||||||
|
} finally {
|
||||||
|
betaPolicySaving.value = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
onMounted(() => {
|
onMounted(() => {
|
||||||
loadSettings()
|
loadSettings()
|
||||||
loadSubscriptionGroups()
|
loadSubscriptionGroups()
|
||||||
loadAdminApiKey()
|
loadAdminApiKey()
|
||||||
loadStreamTimeoutSettings()
|
loadStreamTimeoutSettings()
|
||||||
loadRectifierSettings()
|
loadRectifierSettings()
|
||||||
|
loadBetaPolicySettings()
|
||||||
})
|
})
|
||||||
</script>
|
</script>
|
||||||
|
|
||||||
|
|||||||
@@ -10,6 +10,36 @@
|
|||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<transition name="fade">
|
||||||
|
<div v-if="needsInvitation" class="space-y-4">
|
||||||
|
<p class="text-sm text-gray-700 dark:text-gray-300">
|
||||||
|
{{ t('auth.linuxdo.invitationRequired') }}
|
||||||
|
</p>
|
||||||
|
<div>
|
||||||
|
<input
|
||||||
|
v-model="invitationCode"
|
||||||
|
type="text"
|
||||||
|
class="input w-full"
|
||||||
|
:placeholder="t('auth.invitationCodePlaceholder')"
|
||||||
|
:disabled="isSubmitting"
|
||||||
|
@keyup.enter="handleSubmitInvitation"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
<transition name="fade">
|
||||||
|
<p v-if="invitationError" class="text-sm text-red-600 dark:text-red-400">
|
||||||
|
{{ invitationError }}
|
||||||
|
</p>
|
||||||
|
</transition>
|
||||||
|
<button
|
||||||
|
class="btn btn-primary w-full"
|
||||||
|
:disabled="isSubmitting || !invitationCode.trim()"
|
||||||
|
@click="handleSubmitInvitation"
|
||||||
|
>
|
||||||
|
{{ isSubmitting ? t('auth.linuxdo.completing') : t('auth.linuxdo.completeRegistration') }}
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</transition>
|
||||||
|
|
||||||
<transition name="fade">
|
<transition name="fade">
|
||||||
<div
|
<div
|
||||||
v-if="errorMessage"
|
v-if="errorMessage"
|
||||||
@@ -41,6 +71,7 @@ import { useI18n } from 'vue-i18n'
|
|||||||
import { AuthLayout } from '@/components/layout'
|
import { AuthLayout } from '@/components/layout'
|
||||||
import Icon from '@/components/icons/Icon.vue'
|
import Icon from '@/components/icons/Icon.vue'
|
||||||
import { useAuthStore, useAppStore } from '@/stores'
|
import { useAuthStore, useAppStore } from '@/stores'
|
||||||
|
import { completeLinuxDoOAuthRegistration } from '@/api/auth'
|
||||||
|
|
||||||
const route = useRoute()
|
const route = useRoute()
|
||||||
const router = useRouter()
|
const router = useRouter()
|
||||||
@@ -52,6 +83,14 @@ const appStore = useAppStore()
|
|||||||
const isProcessing = ref(true)
|
const isProcessing = ref(true)
|
||||||
const errorMessage = ref('')
|
const errorMessage = ref('')
|
||||||
|
|
||||||
|
// Invitation code flow state
|
||||||
|
const needsInvitation = ref(false)
|
||||||
|
const pendingOAuthToken = ref('')
|
||||||
|
const invitationCode = ref('')
|
||||||
|
const isSubmitting = ref(false)
|
||||||
|
const invitationError = ref('')
|
||||||
|
const redirectTo = ref('/dashboard')
|
||||||
|
|
||||||
function parseFragmentParams(): URLSearchParams {
|
function parseFragmentParams(): URLSearchParams {
|
||||||
const raw = typeof window !== 'undefined' ? window.location.hash : ''
|
const raw = typeof window !== 'undefined' ? window.location.hash : ''
|
||||||
const hash = raw.startsWith('#') ? raw.slice(1) : raw
|
const hash = raw.startsWith('#') ? raw.slice(1) : raw
|
||||||
@@ -67,6 +106,34 @@ function sanitizeRedirectPath(path: string | null | undefined): string {
|
|||||||
return path
|
return path
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async function handleSubmitInvitation() {
|
||||||
|
invitationError.value = ''
|
||||||
|
if (!invitationCode.value.trim()) return
|
||||||
|
|
||||||
|
isSubmitting.value = true
|
||||||
|
try {
|
||||||
|
const tokenData = await completeLinuxDoOAuthRegistration(
|
||||||
|
pendingOAuthToken.value,
|
||||||
|
invitationCode.value.trim()
|
||||||
|
)
|
||||||
|
if (tokenData.refresh_token) {
|
||||||
|
localStorage.setItem('refresh_token', tokenData.refresh_token)
|
||||||
|
}
|
||||||
|
if (tokenData.expires_in) {
|
||||||
|
localStorage.setItem('token_expires_at', String(Date.now() + tokenData.expires_in * 1000))
|
||||||
|
}
|
||||||
|
await authStore.setToken(tokenData.access_token)
|
||||||
|
appStore.showSuccess(t('auth.loginSuccess'))
|
||||||
|
await router.replace(redirectTo.value)
|
||||||
|
} catch (e: unknown) {
|
||||||
|
const err = e as { message?: string; response?: { data?: { message?: string } } }
|
||||||
|
invitationError.value =
|
||||||
|
err.response?.data?.message || err.message || t('auth.linuxdo.completeRegistrationFailed')
|
||||||
|
} finally {
|
||||||
|
isSubmitting.value = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
onMounted(async () => {
|
onMounted(async () => {
|
||||||
const params = parseFragmentParams()
|
const params = parseFragmentParams()
|
||||||
|
|
||||||
@@ -80,6 +147,19 @@ onMounted(async () => {
|
|||||||
const errorDesc = params.get('error_description') || params.get('error_message') || ''
|
const errorDesc = params.get('error_description') || params.get('error_message') || ''
|
||||||
|
|
||||||
if (error) {
|
if (error) {
|
||||||
|
if (error === 'invitation_required') {
|
||||||
|
pendingOAuthToken.value = params.get('pending_oauth_token') || ''
|
||||||
|
redirectTo.value = sanitizeRedirectPath(params.get('redirect'))
|
||||||
|
if (!pendingOAuthToken.value) {
|
||||||
|
errorMessage.value = t('auth.linuxdo.invalidPendingToken')
|
||||||
|
appStore.showError(errorMessage.value)
|
||||||
|
isProcessing.value = false
|
||||||
|
return
|
||||||
|
}
|
||||||
|
needsInvitation.value = true
|
||||||
|
isProcessing.value = false
|
||||||
|
return
|
||||||
|
}
|
||||||
errorMessage.value = errorDesc || error
|
errorMessage.value = errorDesc || error
|
||||||
appStore.showError(errorMessage.value)
|
appStore.showError(errorMessage.value)
|
||||||
isProcessing.value = false
|
isProcessing.value = false
|
||||||
|
|||||||
@@ -75,7 +75,7 @@ import AppLayout from '@/components/layout/AppLayout.vue'
|
|||||||
import Icon from '@/components/icons/Icon.vue'
|
import Icon from '@/components/icons/Icon.vue'
|
||||||
import { buildEmbeddedUrl, detectTheme } from '@/utils/embedded-url'
|
import { buildEmbeddedUrl, detectTheme } from '@/utils/embedded-url'
|
||||||
|
|
||||||
const { t } = useI18n()
|
const { t, locale } = useI18n()
|
||||||
const route = useRoute()
|
const route = useRoute()
|
||||||
const appStore = useAppStore()
|
const appStore = useAppStore()
|
||||||
const authStore = useAuthStore()
|
const authStore = useAuthStore()
|
||||||
@@ -107,6 +107,7 @@ const embeddedUrl = computed(() => {
|
|||||||
authStore.user?.id,
|
authStore.user?.id,
|
||||||
authStore.token,
|
authStore.token,
|
||||||
pageTheme.value,
|
pageTheme.value,
|
||||||
|
locale.value,
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|||||||
@@ -76,7 +76,7 @@ import AppLayout from '@/components/layout/AppLayout.vue'
|
|||||||
import Icon from '@/components/icons/Icon.vue'
|
import Icon from '@/components/icons/Icon.vue'
|
||||||
import { buildEmbeddedUrl, detectTheme } from '@/utils/embedded-url'
|
import { buildEmbeddedUrl, detectTheme } from '@/utils/embedded-url'
|
||||||
|
|
||||||
const { t } = useI18n()
|
const { t, locale } = useI18n()
|
||||||
const appStore = useAppStore()
|
const appStore = useAppStore()
|
||||||
const authStore = useAuthStore()
|
const authStore = useAuthStore()
|
||||||
|
|
||||||
@@ -90,7 +90,7 @@ const purchaseEnabled = computed(() => {
|
|||||||
|
|
||||||
const purchaseUrl = computed(() => {
|
const purchaseUrl = computed(() => {
|
||||||
const baseUrl = (appStore.cachedPublicSettings?.purchase_subscription_url || '').trim()
|
const baseUrl = (appStore.cachedPublicSettings?.purchase_subscription_url || '').trim()
|
||||||
return buildEmbeddedUrl(baseUrl, authStore.user?.id, authStore.token, purchaseTheme.value)
|
return buildEmbeddedUrl(baseUrl, authStore.user?.id, authStore.token, purchaseTheme.value, locale.value)
|
||||||
})
|
})
|
||||||
|
|
||||||
const isValidUrl = computed(() => {
|
const isValidUrl = computed(() => {
|
||||||
|
|||||||
Reference in New Issue
Block a user