Compare commits

..

24 Commits

Author SHA1 Message Date
Wesley Liddick
3cc407bc0e Merge pull request #900 from ischanx/feat/admin-bind-subscription-group
feat: 允许管理员为持有有效订阅的用户绑定订阅类型分组
2026-03-10 11:20:47 +08:00
shaw
00a0a12138 feat: Anthropic平台可配置 anthropic-beta 策略 2026-03-10 11:20:10 +08:00
ischanx
b08767a4f9 fix: avoid admin subscription binding regressions 2026-03-10 10:53:54 +08:00
Wesley Liddick
ac6bde7a98 Merge pull request #872 from StarryKira/fix/oauth-linuxdo-invitation-required
fix: Linux.do OAuth 注册支持邀请码两步流程 (fix #836)
2026-03-10 09:10:35 +08:00
Wesley Liddick
d2d41d68dd Merge pull request #894 from touwaeriol/pr/startup-concurrency-cleanup
feat: cleanup stale concurrency slots on startup
2026-03-10 09:08:33 +08:00
Wesley Liddick
944b7f7617 Merge pull request #904 from james-6-23/fix-pool-mode-retry
fix: OpenAI临时性400错误支持池模式同账号重试 & HelpTooltip层级修复
2026-03-10 09:08:12 +08:00
Wesley Liddick
53825eb073 Merge pull request #903 from touwaeriol/fix/openai-responses-sse-max-line-size-v2
fix: use shared max_line_size config for OpenAI Responses SSE scanner
2026-03-10 09:06:48 +08:00
Wesley Liddick
1a7f49513f Merge pull request #896 from touwaeriol/pr/fix-quota-badge-vertical-layout
fix(frontend): stack quota badges vertically in capacity column
2026-03-10 09:04:53 +08:00
Wesley Liddick
885a2ce7ef Merge pull request #893 from touwaeriol/pr/iframe-lang-passthrough
feat(frontend): pass locale to iframe embedded pages via lang parameter
2026-03-10 09:04:37 +08:00
Wesley Liddick
14ba80a0af Merge pull request #897 from DaydreamCoding/feat/batch-reset-and-openai-jwt
feat: OpenAI 账号信息增强 & 批量操作支持
2026-03-10 09:03:59 +08:00
kyx236
5fa22fdf82 fix: OpenAI临时性400错误支持池模式同账号重试 & HelpTooltip层级修复
1. 识别OpenAI "An error occurred while processing your request" 临时性400错误
   并触发failover,同时在池模式下标记RetryableOnSameAccount,允许同账号重试
2. ForwardAsAnthropic路径同步支持临时性400错误的识别和同账号重试
3. HelpTooltip组件使用Teleport渲染到body,修复在dialog内被裁切的问题
2026-03-10 03:00:58 +08:00
erio
bcaae2eb91 fix: use shared max_line_size config for OpenAI Responses SSE scanner
Two SSE scanners in openai_gateway_messages.go were hardcoded to 1MB
while all other scanners use defaultMaxLineSize (500MB) with config
override. This caused Responses API streams to fail on large payloads.
2026-03-10 02:50:04 +08:00
ischanx
767a41e263 feat: 允许管理员为持有有效订阅的用户绑定订阅类型分组
之前管理员无法通过 API 密钥管理将用户绑定到订阅类型分组(直接返回错误)。
现在改为检查用户是否持有该分组的有效订阅,有则允许绑定,无则拒绝。

- admin_service: 新增 userSubRepo 依赖,替换硬拒绝为订阅校验
- admin_service: 区分 ErrSubscriptionNotFound 和内部错误,避免 DB 故障被误报
- wire_gen/api_contract_test: 同步新增参数
- UserApiKeysModal: 管理员分组下拉不再过滤订阅类型分组

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-10 00:51:43 +08:00
QTom
252d6c5301 feat: 支持批量重置状态和批量刷新令牌
- 提取 refreshSingleAccount 私有方法复用单账号刷新逻辑
- 新增 BatchClearError handler (POST /admin/accounts/batch-clear-error)
- 新增 BatchRefresh handler (POST /admin/accounts/batch-refresh)
- 前端 AccountBulkActionsBar 添加批量重置状态/刷新令牌按钮
- AccountsView 添加 handler 支持 partial success 反馈
- i18n 中英文补充批量操作相关翻译
2026-03-09 21:54:27 +08:00
QTom
7a4e65ad4b feat: 导入账号时 best-effort 从 id_token 提取用户信息
提取 DecodeIDToken(跳过过期校验)供导入场景使用,
ParseIDToken 复用它并保留原有过期检查行为。
导入 OpenAI/Sora OAuth 账号时自动补充缺失的 email、
plan_type、chatgpt_account_id 等字段,不覆盖已有值。
2026-03-09 21:53:46 +08:00
QTom
a582aa89a9 feat: 从 OpenAI JWT 提取 chatgpt_plan_type 并在前端展示
OAuth 授权和 token 刷新时从 id_token 的 OpenAI auth claim 中
提取 chatgpt_plan_type(plus/team/pro/free),存入 credentials,
账号管理页面 PlatformTypeBadge 显示订阅类型。
2026-03-09 21:53:46 +08:00
erio
acefa1da12 fix(frontend): stack quota badges vertically in capacity column
QuotaBadge components (daily/weekly/total) were wrapped in a
horizontal flex container, making them visually inconsistent with
other capacity badges (concurrency, window cost, sessions, RPM)
which stack vertically. Remove the wrapper so all badges align
consistently.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-09 20:18:12 +08:00
erio
a88698f3fc feat: cleanup stale concurrency slots on startup
When the service restarts, concurrency slots from the old process
remain in Redis, causing phantom occupancy. On startup, scan all
concurrency sorted sets and remove members with non-current process
prefix, then clear orphaned wait queue counters.

Uses Go-side SCAN to discover keys (compatible with Redis client
prefix hooks in tests), then passes them to a Lua script for
atomic member-level cleanup.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-09 19:55:18 +08:00
erio
ebc6755b33 feat(frontend): pass locale to iframe embedded pages via lang parameter
Embedded pages (purchase subscription, custom pages) now receive the
current user locale through a `lang` URL parameter, allowing iframe
content to match the user's language preference.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-09 19:38:23 +08:00
shaw
c8eff34388 chore: update readme 2026-03-09 16:25:56 +08:00
shaw
f19b03825b chore: update docs 2026-03-09 16:11:44 +08:00
Elysia
b43ee62947 fix CI/CD Error 2026-03-09 13:13:39 +08:00
Elysia
106b20cdbf fix claudecode review bug 2026-03-09 01:18:49 +08:00
Elysia
c069b3b1e8 fix issue #836 linux.do注册无需邀请码 2026-03-09 00:35:34 +08:00
61 changed files with 2156 additions and 243 deletions

View File

@@ -150,14 +150,14 @@ mkdir -p sub2api-deploy && cd sub2api-deploy
curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/docker-deploy.sh | bash curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/docker-deploy.sh | bash
# Start services # Start services
docker-compose -f docker-compose.local.yml up -d docker-compose up -d
# View logs # View logs
docker-compose -f docker-compose.local.yml logs -f sub2api docker-compose logs -f sub2api
``` ```
**What the script does:** **What the script does:**
- Downloads `docker-compose.local.yml` and `.env.example` - Downloads `docker-compose.local.yml` (saved as `docker-compose.yml`) and `.env.example`
- Generates secure credentials (JWT_SECRET, TOTP_ENCRYPTION_KEY, POSTGRES_PASSWORD) - Generates secure credentials (JWT_SECRET, TOTP_ENCRYPTION_KEY, POSTGRES_PASSWORD)
- Creates `.env` file with auto-generated secrets - Creates `.env` file with auto-generated secrets
- Creates data directories (uses local directories for easy backup/migration) - Creates data directories (uses local directories for easy backup/migration)
@@ -522,6 +522,28 @@ sub2api/
└── install.sh # One-click installation script └── install.sh # One-click installation script
``` ```
## Disclaimer
> **Please read carefully before using this project:**
>
> :rotating_light: **Terms of Service Risk**: Using this project may violate Anthropic's Terms of Service. Please read Anthropic's user agreement carefully before use. All risks arising from the use of this project are borne solely by the user.
>
> :book: **Disclaimer**: This project is for technical learning and research purposes only. The author assumes no responsibility for account suspension, service interruption, or any other losses caused by the use of this project.
---
## Star History
<a href="https://star-history.com/#Wei-Shaw/sub2api&Date">
<picture>
<source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=Wei-Shaw/sub2api&type=Date&theme=dark" />
<source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=Wei-Shaw/sub2api&type=Date" />
<img alt="Star History Chart" src="https://api.star-history.com/svg?repos=Wei-Shaw/sub2api&type=Date" />
</picture>
</a>
---
## License ## License
MIT License MIT License

View File

@@ -154,14 +154,14 @@ mkdir -p sub2api-deploy && cd sub2api-deploy
curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/docker-deploy.sh | bash curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/docker-deploy.sh | bash
# 启动服务 # 启动服务
docker-compose -f docker-compose.local.yml up -d docker-compose up -d
# 查看日志 # 查看日志
docker-compose -f docker-compose.local.yml logs -f sub2api docker-compose logs -f sub2api
``` ```
**脚本功能:** **脚本功能:**
- 下载 `docker-compose.local.yml` `.env.example` - 下载 `docker-compose.local.yml`(本地保存为 `docker-compose.yml``.env.example`
- 自动生成安全凭证JWT_SECRET、TOTP_ENCRYPTION_KEY、POSTGRES_PASSWORD - 自动生成安全凭证JWT_SECRET、TOTP_ENCRYPTION_KEY、POSTGRES_PASSWORD
- 创建 `.env` 文件并填充自动生成的密钥 - 创建 `.env` 文件并填充自动生成的密钥
- 创建数据目录(使用本地目录,便于备份和迁移) - 创建数据目录(使用本地目录,便于备份和迁移)
@@ -588,6 +588,28 @@ sub2api/
└── install.sh # 一键安装脚本 └── install.sh # 一键安装脚本
``` ```
## 免责声明
> **使用本项目前请仔细阅读:**
>
> :rotating_light: **服务条款风险**: 使用本项目可能违反 Anthropic 的服务条款。请在使用前仔细阅读 Anthropic 的用户协议,使用本项目的一切风险由用户自行承担。
>
> :book: **免责声明**: 本项目仅供技术学习和研究使用,作者不对因使用本项目导致的账户封禁、服务中断或其他损失承担任何责任。
---
## Star History
<a href="https://star-history.com/#Wei-Shaw/sub2api&Date">
<picture>
<source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=Wei-Shaw/sub2api&type=Date&theme=dark" />
<source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=Wei-Shaw/sub2api&type=Date" />
<img alt="Star History Chart" src="https://api.star-history.com/svg?repos=Wei-Shaw/sub2api&type=Date" />
</picture>
</a>
---
## 许可证 ## 许可证
MIT License MIT License

View File

@@ -33,7 +33,7 @@ func main() {
}() }()
userRepo := repository.NewUserRepository(client, sqlDB) userRepo := repository.NewUserRepository(client, sqlDB)
authService := service.NewAuthService(userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil) authService := service.NewAuthService(client, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()

View File

@@ -67,7 +67,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService) apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService)
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator) promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator)
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig) subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig)
authService := service.NewAuthService(userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService) authService := service.NewAuthService(client, userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService)
userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator, billingCache) userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator, billingCache)
redeemCache := repository.NewRedeemCache(redisClient) redeemCache := repository.NewRedeemCache(redisClient)
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator) redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator)
@@ -104,7 +104,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
proxyRepository := repository.NewProxyRepository(client, db) proxyRepository := repository.NewProxyRepository(client, db)
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig) proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient) proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService) adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService, userSubscriptionRepository)
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig) concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig) concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
adminUserHandler := admin.NewUserHandler(adminService, concurrencyService) adminUserHandler := admin.NewUserHandler(adminService, concurrencyService)

View File

@@ -8,6 +8,9 @@ import (
"strings" "strings"
"time" "time"
"log/slog"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -292,6 +295,8 @@ func (h *AccountHandler) importData(ctx context.Context, req DataImportRequest)
} }
} }
enrichCredentialsFromIDToken(&item)
accountInput := &service.CreateAccountInput{ accountInput := &service.CreateAccountInput{
Name: item.Name, Name: item.Name,
Notes: item.Notes, Notes: item.Notes,
@@ -535,6 +540,57 @@ func defaultProxyName(name string) string {
return name return name
} }
// enrichCredentialsFromIDToken performs best-effort extraction of user info fields
// (email, plan_type, chatgpt_account_id, etc.) from id_token in credentials.
// Only applies to OpenAI/Sora OAuth accounts. Skips expired token errors silently.
// Existing credential values are never overwritten — only missing fields are filled.
func enrichCredentialsFromIDToken(item *DataAccount) {
if item.Credentials == nil {
return
}
// Only enrich OpenAI/Sora OAuth accounts
platform := strings.ToLower(strings.TrimSpace(item.Platform))
if platform != service.PlatformOpenAI && platform != service.PlatformSora {
return
}
if strings.ToLower(strings.TrimSpace(item.Type)) != service.AccountTypeOAuth {
return
}
idToken, _ := item.Credentials["id_token"].(string)
if strings.TrimSpace(idToken) == "" {
return
}
// DecodeIDToken skips expiry validation — safe for imported data
claims, err := openai.DecodeIDToken(idToken)
if err != nil {
slog.Debug("import_enrich_id_token_decode_failed", "account", item.Name, "error", err)
return
}
userInfo := claims.GetUserInfo()
if userInfo == nil {
return
}
// Fill missing fields only (never overwrite existing values)
setIfMissing := func(key, value string) {
if value == "" {
return
}
if existing, _ := item.Credentials[key].(string); existing == "" {
item.Credentials[key] = value
}
}
setIfMissing("email", userInfo.Email)
setIfMissing("plan_type", userInfo.PlanType)
setIfMissing("chatgpt_account_id", userInfo.ChatGPTAccountID)
setIfMissing("chatgpt_user_id", userInfo.ChatGPTUserID)
setIfMissing("organization_id", userInfo.OrganizationID)
}
func normalizeProxyStatus(status string) string { func normalizeProxyStatus(status string) string {
normalized := strings.TrimSpace(strings.ToLower(status)) normalized := strings.TrimSpace(strings.ToLower(status))
switch normalized { switch normalized {

View File

@@ -8,6 +8,7 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"log"
"net/http" "net/http"
"strconv" "strconv"
"strings" "strings"
@@ -18,6 +19,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/claude"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
@@ -751,52 +753,31 @@ func (h *AccountHandler) PreviewFromCRS(c *gin.Context) {
response.Success(c, result) response.Success(c, result)
} }
// Refresh handles refreshing account credentials // refreshSingleAccount refreshes credentials for a single OAuth account.
// POST /api/v1/admin/accounts/:id/refresh // Returns (updatedAccount, warning, error) where warning is used for Antigravity ProjectIDMissing scenario.
func (h *AccountHandler) Refresh(c *gin.Context) { func (h *AccountHandler) refreshSingleAccount(ctx context.Context, account *service.Account) (*service.Account, string, error) {
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid account ID")
return
}
// Get account
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
if err != nil {
response.NotFound(c, "Account not found")
return
}
// Only refresh OAuth-based accounts (oauth and setup-token)
if !account.IsOAuth() { if !account.IsOAuth() {
response.BadRequest(c, "Cannot refresh non-OAuth account credentials") return nil, "", infraerrors.BadRequest("NOT_OAUTH", "cannot refresh non-OAuth account")
return
} }
var newCredentials map[string]any var newCredentials map[string]any
if account.IsOpenAI() { if account.IsOpenAI() {
// Use OpenAI OAuth service to refresh token tokenInfo, err := h.openaiOAuthService.RefreshAccountToken(ctx, account)
tokenInfo, err := h.openaiOAuthService.RefreshAccountToken(c.Request.Context(), account)
if err != nil { if err != nil {
response.ErrorFrom(c, err) return nil, "", err
return
} }
// Build new credentials from token info
newCredentials = h.openaiOAuthService.BuildAccountCredentials(tokenInfo) newCredentials = h.openaiOAuthService.BuildAccountCredentials(tokenInfo)
// Preserve non-token settings from existing credentials
for k, v := range account.Credentials { for k, v := range account.Credentials {
if _, exists := newCredentials[k]; !exists { if _, exists := newCredentials[k]; !exists {
newCredentials[k] = v newCredentials[k] = v
} }
} }
} else if account.Platform == service.PlatformGemini { } else if account.Platform == service.PlatformGemini {
tokenInfo, err := h.geminiOAuthService.RefreshAccountToken(c.Request.Context(), account) tokenInfo, err := h.geminiOAuthService.RefreshAccountToken(ctx, account)
if err != nil { if err != nil {
response.InternalError(c, "Failed to refresh credentials: "+err.Error()) return nil, "", fmt.Errorf("failed to refresh credentials: %w", err)
return
} }
newCredentials = h.geminiOAuthService.BuildAccountCredentials(tokenInfo) newCredentials = h.geminiOAuthService.BuildAccountCredentials(tokenInfo)
@@ -806,10 +787,9 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
} }
} }
} else if account.Platform == service.PlatformAntigravity { } else if account.Platform == service.PlatformAntigravity {
tokenInfo, err := h.antigravityOAuthService.RefreshAccountToken(c.Request.Context(), account) tokenInfo, err := h.antigravityOAuthService.RefreshAccountToken(ctx, account)
if err != nil { if err != nil {
response.ErrorFrom(c, err) return nil, "", err
return
} }
newCredentials = h.antigravityOAuthService.BuildAccountCredentials(tokenInfo) newCredentials = h.antigravityOAuthService.BuildAccountCredentials(tokenInfo)
@@ -828,37 +808,27 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
} }
// 如果 project_id 获取失败,更新凭证但不标记为 error // 如果 project_id 获取失败,更新凭证但不标记为 error
// LoadCodeAssist 失败可能是临时网络问题,给它机会在下次自动刷新时重试
if tokenInfo.ProjectIDMissing { if tokenInfo.ProjectIDMissing {
// 先更新凭证token 本身刷新成功了) updatedAccount, updateErr := h.adminService.UpdateAccount(ctx, account.ID, &service.UpdateAccountInput{
_, updateErr := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
Credentials: newCredentials, Credentials: newCredentials,
}) })
if updateErr != nil { if updateErr != nil {
response.InternalError(c, "Failed to update credentials: "+updateErr.Error()) return nil, "", fmt.Errorf("failed to update credentials: %w", updateErr)
return
} }
// 不标记为 error只返回警告信息 return updatedAccount, "missing_project_id_temporary", nil
response.Success(c, gin.H{
"message": "Token refreshed successfully, but project_id could not be retrieved (will retry automatically)",
"warning": "missing_project_id_temporary",
})
return
} }
// 成功获取到 project_id如果之前是 missing_project_id 错误则清除 // 成功获取到 project_id如果之前是 missing_project_id 错误则清除
if account.Status == service.StatusError && strings.Contains(account.ErrorMessage, "missing_project_id:") { if account.Status == service.StatusError && strings.Contains(account.ErrorMessage, "missing_project_id:") {
if _, clearErr := h.adminService.ClearAccountError(c.Request.Context(), accountID); clearErr != nil { if _, clearErr := h.adminService.ClearAccountError(ctx, account.ID); clearErr != nil {
response.InternalError(c, "Failed to clear account error: "+clearErr.Error()) return nil, "", fmt.Errorf("failed to clear account error: %w", clearErr)
return
} }
} }
} else { } else {
// Use Anthropic/Claude OAuth service to refresh token // Use Anthropic/Claude OAuth service to refresh token
tokenInfo, err := h.oauthService.RefreshAccountToken(c.Request.Context(), account) tokenInfo, err := h.oauthService.RefreshAccountToken(ctx, account)
if err != nil { if err != nil {
response.ErrorFrom(c, err) return nil, "", err
return
} }
// Copy existing credentials to preserve non-token settings (e.g., intercept_warmup_requests) // Copy existing credentials to preserve non-token settings (e.g., intercept_warmup_requests)
@@ -880,20 +850,51 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
} }
} }
updatedAccount, err := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{ updatedAccount, err := h.adminService.UpdateAccount(ctx, account.ID, &service.UpdateAccountInput{
Credentials: newCredentials, Credentials: newCredentials,
}) })
if err != nil {
return nil, "", err
}
// 刷新成功后,清除 token 缓存,确保下次请求使用新 token
if h.tokenCacheInvalidator != nil {
if invalidateErr := h.tokenCacheInvalidator.InvalidateToken(ctx, updatedAccount); invalidateErr != nil {
log.Printf("[WARN] Failed to invalidate token cache for account %d: %v", updatedAccount.ID, invalidateErr)
}
}
return updatedAccount, "", nil
}
// Refresh handles refreshing account credentials
// POST /api/v1/admin/accounts/:id/refresh
func (h *AccountHandler) Refresh(c *gin.Context) {
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid account ID")
return
}
// Get account
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
if err != nil {
response.NotFound(c, "Account not found")
return
}
updatedAccount, warning, err := h.refreshSingleAccount(c.Request.Context(), account)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
// 刷新成功后,清除 token 缓存,确保下次请求使用新 token if warning == "missing_project_id_temporary" {
if h.tokenCacheInvalidator != nil { response.Success(c, gin.H{
if invalidateErr := h.tokenCacheInvalidator.InvalidateToken(c.Request.Context(), updatedAccount); invalidateErr != nil { "message": "Token refreshed successfully, but project_id could not be retrieved (will retry automatically)",
// 缓存失效失败只记录日志,不影响主流程 "warning": "missing_project_id_temporary",
_ = c.Error(invalidateErr) })
} return
} }
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), updatedAccount)) response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), updatedAccount))
@@ -949,14 +950,175 @@ func (h *AccountHandler) ClearError(c *gin.Context) {
// 这解决了管理员重置账号状态后,旧的失效 token 仍在缓存中导致立即再次 401 的问题 // 这解决了管理员重置账号状态后,旧的失效 token 仍在缓存中导致立即再次 401 的问题
if h.tokenCacheInvalidator != nil && account.IsOAuth() { if h.tokenCacheInvalidator != nil && account.IsOAuth() {
if invalidateErr := h.tokenCacheInvalidator.InvalidateToken(c.Request.Context(), account); invalidateErr != nil { if invalidateErr := h.tokenCacheInvalidator.InvalidateToken(c.Request.Context(), account); invalidateErr != nil {
// 缓存失效失败只记录日志,不影响主流程 log.Printf("[WARN] Failed to invalidate token cache for account %d: %v", accountID, invalidateErr)
_ = c.Error(invalidateErr)
} }
} }
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account)) response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
} }
// BatchClearError handles batch clearing account errors
// POST /api/v1/admin/accounts/batch-clear-error
func (h *AccountHandler) BatchClearError(c *gin.Context) {
var req struct {
AccountIDs []int64 `json:"account_ids"`
}
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if len(req.AccountIDs) == 0 {
response.BadRequest(c, "account_ids is required")
return
}
ctx := c.Request.Context()
const maxConcurrency = 10
g, gctx := errgroup.WithContext(ctx)
g.SetLimit(maxConcurrency)
var mu sync.Mutex
var successCount, failedCount int
var errors []gin.H
// 注意:所有 goroutine 必须 return nil避免 errgroup cancel 其他并发任务
for _, id := range req.AccountIDs {
accountID := id // 闭包捕获
g.Go(func() error {
account, err := h.adminService.ClearAccountError(gctx, accountID)
if err != nil {
mu.Lock()
failedCount++
errors = append(errors, gin.H{
"account_id": accountID,
"error": err.Error(),
})
mu.Unlock()
return nil
}
// 清除错误后,同时清除 token 缓存
if h.tokenCacheInvalidator != nil && account.IsOAuth() {
if invalidateErr := h.tokenCacheInvalidator.InvalidateToken(gctx, account); invalidateErr != nil {
log.Printf("[WARN] Failed to invalidate token cache for account %d: %v", accountID, invalidateErr)
}
}
mu.Lock()
successCount++
mu.Unlock()
return nil
})
}
if err := g.Wait(); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{
"total": len(req.AccountIDs),
"success": successCount,
"failed": failedCount,
"errors": errors,
})
}
// BatchRefresh handles batch refreshing account credentials
// POST /api/v1/admin/accounts/batch-refresh
func (h *AccountHandler) BatchRefresh(c *gin.Context) {
var req struct {
AccountIDs []int64 `json:"account_ids"`
}
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if len(req.AccountIDs) == 0 {
response.BadRequest(c, "account_ids is required")
return
}
ctx := c.Request.Context()
accounts, err := h.adminService.GetAccountsByIDs(ctx, req.AccountIDs)
if err != nil {
response.ErrorFrom(c, err)
return
}
// 建立已获取账号的 ID 集合,检测缺失的 ID
foundIDs := make(map[int64]bool, len(accounts))
for _, acc := range accounts {
if acc != nil {
foundIDs[acc.ID] = true
}
}
const maxConcurrency = 10
g, gctx := errgroup.WithContext(ctx)
g.SetLimit(maxConcurrency)
var mu sync.Mutex
var successCount, failedCount int
var errors []gin.H
var warnings []gin.H
// 将不存在的账号 ID 标记为失败
for _, id := range req.AccountIDs {
if !foundIDs[id] {
failedCount++
errors = append(errors, gin.H{
"account_id": id,
"error": "account not found",
})
}
}
// 注意:所有 goroutine 必须 return nil避免 errgroup cancel 其他并发任务
for _, account := range accounts {
acc := account // 闭包捕获
if acc == nil {
continue
}
g.Go(func() error {
_, warning, err := h.refreshSingleAccount(gctx, acc)
mu.Lock()
if err != nil {
failedCount++
errors = append(errors, gin.H{
"account_id": acc.ID,
"error": err.Error(),
})
} else {
successCount++
if warning != "" {
warnings = append(warnings, gin.H{
"account_id": acc.ID,
"warning": warning,
})
}
}
mu.Unlock()
return nil
})
}
if err := g.Wait(); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{
"total": len(req.AccountIDs),
"success": successCount,
"failed": failedCount,
"errors": errors,
"warnings": warnings,
})
}
// BatchCreate handles batch creating accounts // BatchCreate handles batch creating accounts
// POST /api/v1/admin/accounts/batch // POST /api/v1/admin/accounts/batch
func (h *AccountHandler) BatchCreate(c *gin.Context) { func (h *AccountHandler) BatchCreate(c *gin.Context) {

View File

@@ -1405,6 +1405,61 @@ func (h *SettingHandler) UpdateRectifierSettings(c *gin.Context) {
}) })
} }
// GetBetaPolicySettings 获取 Beta 策略配置
// GET /api/v1/admin/settings/beta-policy
func (h *SettingHandler) GetBetaPolicySettings(c *gin.Context) {
settings, err := h.settingService.GetBetaPolicySettings(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
rules := make([]dto.BetaPolicyRule, len(settings.Rules))
for i, r := range settings.Rules {
rules[i] = dto.BetaPolicyRule(r)
}
response.Success(c, dto.BetaPolicySettings{Rules: rules})
}
// UpdateBetaPolicySettingsRequest 更新 Beta 策略配置请求
type UpdateBetaPolicySettingsRequest struct {
Rules []dto.BetaPolicyRule `json:"rules"`
}
// UpdateBetaPolicySettings 更新 Beta 策略配置
// PUT /api/v1/admin/settings/beta-policy
func (h *SettingHandler) UpdateBetaPolicySettings(c *gin.Context) {
var req UpdateBetaPolicySettingsRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
rules := make([]service.BetaPolicyRule, len(req.Rules))
for i, r := range req.Rules {
rules[i] = service.BetaPolicyRule(r)
}
settings := &service.BetaPolicySettings{Rules: rules}
if err := h.settingService.SetBetaPolicySettings(c.Request.Context(), settings); err != nil {
response.BadRequest(c, err.Error())
return
}
// Re-fetch to return updated settings
updated, err := h.settingService.GetBetaPolicySettings(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
outRules := make([]dto.BetaPolicyRule, len(updated.Rules))
for i, r := range updated.Rules {
outRules[i] = dto.BetaPolicyRule(r)
}
response.Success(c, dto.BetaPolicySettings{Rules: outRules})
}
// UpdateStreamTimeoutSettingsRequest 更新流超时配置请求 // UpdateStreamTimeoutSettingsRequest 更新流超时配置请求
type UpdateStreamTimeoutSettingsRequest struct { type UpdateStreamTimeoutSettingsRequest struct {
Enabled bool `json:"enabled"` Enabled bool `json:"enabled"`

View File

@@ -211,8 +211,22 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
email = linuxDoSyntheticEmail(subject) email = linuxDoSyntheticEmail(subject)
} }
tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username) // 传入空邀请码;如果需要邀请码,服务层返回 ErrOAuthInvitationRequired
tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "")
if err != nil { if err != nil {
if errors.Is(err, service.ErrOAuthInvitationRequired) {
pendingToken, tokenErr := h.authService.CreatePendingOAuthToken(email, username)
if tokenErr != nil {
redirectOAuthError(c, frontendCallback, "login_failed", "service_error", "")
return
}
fragment := url.Values{}
fragment.Set("error", "invitation_required")
fragment.Set("pending_oauth_token", pendingToken)
fragment.Set("redirect", redirectTo)
redirectWithFragment(c, frontendCallback, fragment)
return
}
// 避免把内部细节泄露给客户端;给前端保留结构化原因与提示信息即可。 // 避免把内部细节泄露给客户端;给前端保留结构化原因与提示信息即可。
redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err)) redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err))
return return
@@ -227,6 +241,41 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
redirectWithFragment(c, frontendCallback, fragment) redirectWithFragment(c, frontendCallback, fragment)
} }
type completeLinuxDoOAuthRequest struct {
PendingOAuthToken string `json:"pending_oauth_token" binding:"required"`
InvitationCode string `json:"invitation_code" binding:"required"`
}
// CompleteLinuxDoOAuthRegistration completes a pending OAuth registration by validating
// the invitation code and creating the user account.
// POST /api/v1/auth/oauth/linuxdo/complete-registration
func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) {
var req completeLinuxDoOAuthRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "INVALID_REQUEST", "message": err.Error()})
return
}
email, username, err := h.authService.VerifyPendingOAuthToken(req.PendingOAuthToken)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "INVALID_TOKEN", "message": "invalid or expired registration token"})
return
}
tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
if err != nil {
response.ErrorFrom(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
"access_token": tokenPair.AccessToken,
"refresh_token": tokenPair.RefreshToken,
"expires_in": tokenPair.ExpiresIn,
"token_type": "Bearer",
})
}
func (h *AuthHandler) getLinuxDoOAuthConfig(ctx context.Context) (config.LinuxDoConnectConfig, error) { func (h *AuthHandler) getLinuxDoOAuthConfig(ctx context.Context) (config.LinuxDoConnectConfig, error) {
if h != nil && h.settingSvc != nil { if h != nil && h.settingSvc != nil {
return h.settingSvc.GetLinuxDoConnectOAuthConfig(ctx) return h.settingSvc.GetLinuxDoConnectOAuthConfig(ctx)

View File

@@ -168,6 +168,19 @@ type RectifierSettings struct {
ThinkingBudgetEnabled bool `json:"thinking_budget_enabled"` ThinkingBudgetEnabled bool `json:"thinking_budget_enabled"`
} }
// BetaPolicyRule Beta 策略规则 DTO
type BetaPolicyRule struct {
BetaToken string `json:"beta_token"`
Action string `json:"action"`
Scope string `json:"scope"`
ErrorMessage string `json:"error_message,omitempty"`
}
// BetaPolicySettings Beta 策略配置 DTO
type BetaPolicySettings struct {
Rules []BetaPolicyRule `json:"rules"`
}
// ParseCustomMenuItems parses a JSON string into a slice of CustomMenuItem. // ParseCustomMenuItems parses a JSON string into a slice of CustomMenuItem.
// Returns empty slice on empty/invalid input. // Returns empty slice on empty/invalid input.
func ParseCustomMenuItems(raw string) []CustomMenuItem { func ParseCustomMenuItems(raw string) []CustomMenuItem {

View File

@@ -652,6 +652,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
accountReleaseFunc() accountReleaseFunc()
} }
if err != nil { if err != nil {
// Beta policy block: return 400 immediately, no failover
var betaBlockedErr *service.BetaBlockedError
if errors.As(err, &betaBlockedErr) {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", betaBlockedErr.Message)
return
}
var promptTooLongErr *service.PromptTooLongError var promptTooLongErr *service.PromptTooLongError
if errors.As(err, &promptTooLongErr) { if errors.As(err, &promptTooLongErr) {
reqLog.Warn("gateway.prompt_too_long_from_antigravity", reqLog.Warn("gateway.prompt_too_long_from_antigravity",

View File

@@ -127,6 +127,7 @@ func (f *fakeConcurrencyCache) GetAccountConcurrencyBatch(_ context.Context, acc
return result, nil return result, nil
} }
func (f *fakeConcurrencyCache) CleanupExpiredAccountSlots(context.Context, int64) error { return nil } func (f *fakeConcurrencyCache) CleanupExpiredAccountSlots(context.Context, int64) error { return nil }
func (f *fakeConcurrencyCache) CleanupStaleProcessSlots(context.Context, string) error { return nil }
func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*service.Account) (*GatewayHandler, func()) { func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*service.Account) (*GatewayHandler, func()) {
t.Helper() t.Helper()

View File

@@ -89,6 +89,10 @@ func (m *concurrencyCacheMock) CleanupExpiredAccountSlots(ctx context.Context, a
return nil return nil
} }
func (m *concurrencyCacheMock) CleanupStaleProcessSlots(ctx context.Context, activeRequestPrefix string) error {
return nil
}
func TestConcurrencyHelper_TryAcquireUserSlot(t *testing.T) { func TestConcurrencyHelper_TryAcquireUserSlot(t *testing.T) {
cache := &concurrencyCacheMock{ cache := &concurrencyCacheMock{
acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) { acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {

View File

@@ -120,6 +120,10 @@ func (s *helperConcurrencyCacheStub) CleanupExpiredAccountSlots(ctx context.Cont
return nil return nil
} }
func (s *helperConcurrencyCacheStub) CleanupStaleProcessSlots(ctx context.Context, activeRequestPrefix string) error {
return nil
}
func newHelperTestContext(method, path string) (*gin.Context, *httptest.ResponseRecorder) { func newHelperTestContext(method, path string) (*gin.Context, *httptest.ResponseRecorder) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder() rec := httptest.NewRecorder()

View File

@@ -16,7 +16,7 @@ const (
// DroppedBetas 是转发时需要从 anthropic-beta header 中移除的 beta token 列表。 // DroppedBetas 是转发时需要从 anthropic-beta header 中移除的 beta token 列表。
// 这些 token 是客户端特有的,不应透传给上游 API。 // 这些 token 是客户端特有的,不应透传给上游 API。
var DroppedBetas = []string{BetaFastMode} var DroppedBetas = []string{}
// DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header // DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header
const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming

View File

@@ -268,6 +268,7 @@ type IDTokenClaims struct {
type OpenAIAuthClaims struct { type OpenAIAuthClaims struct {
ChatGPTAccountID string `json:"chatgpt_account_id"` ChatGPTAccountID string `json:"chatgpt_account_id"`
ChatGPTUserID string `json:"chatgpt_user_id"` ChatGPTUserID string `json:"chatgpt_user_id"`
ChatGPTPlanType string `json:"chatgpt_plan_type"`
UserID string `json:"user_id"` UserID string `json:"user_id"`
Organizations []OrganizationClaim `json:"organizations"` Organizations []OrganizationClaim `json:"organizations"`
} }
@@ -325,12 +326,9 @@ func (r *RefreshTokenRequest) ToFormData() string {
return params.Encode() return params.Encode()
} }
// ParseIDToken parses the ID Token JWT and extracts claims. // DecodeIDToken decodes the ID Token JWT payload without validating expiration.
// 注意:当前仅解码 payload 并校验 exp未验证 JWT 签名。 // Use this for best-effort extraction (e.g., during data import) where the token may be expired.
// 生产环境如需用 ID Token 做授权决策,应通过 OpenAI 的 JWKS 端点验证签名: func DecodeIDToken(idToken string) (*IDTokenClaims, error) {
//
// https://auth.openai.com/.well-known/jwks.json
func ParseIDToken(idToken string) (*IDTokenClaims, error) {
parts := strings.Split(idToken, ".") parts := strings.Split(idToken, ".")
if len(parts) != 3 { if len(parts) != 3 {
return nil, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts)) return nil, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts))
@@ -360,6 +358,20 @@ func ParseIDToken(idToken string) (*IDTokenClaims, error) {
return nil, fmt.Errorf("failed to parse JWT claims: %w", err) return nil, fmt.Errorf("failed to parse JWT claims: %w", err)
} }
return &claims, nil
}
// ParseIDToken parses the ID Token JWT and extracts claims.
// 注意:当前仅解码 payload 并校验 exp未验证 JWT 签名。
// 生产环境如需用 ID Token 做授权决策,应通过 OpenAI 的 JWKS 端点验证签名:
//
// https://auth.openai.com/.well-known/jwks.json
func ParseIDToken(idToken string) (*IDTokenClaims, error) {
claims, err := DecodeIDToken(idToken)
if err != nil {
return nil, err
}
// 校验 ID Token 是否已过期(允许 2 分钟时钟偏差,防止因服务器时钟略有差异误判刚颁发的令牌) // 校验 ID Token 是否已过期(允许 2 分钟时钟偏差,防止因服务器时钟略有差异误判刚颁发的令牌)
const clockSkewTolerance = 120 // 秒 const clockSkewTolerance = 120 // 秒
now := time.Now().Unix() now := time.Now().Unix()
@@ -367,7 +379,7 @@ func ParseIDToken(idToken string) (*IDTokenClaims, error) {
return nil, fmt.Errorf("id_token has expired (exp: %d, now: %d, skew_tolerance: %ds)", claims.Exp, now, clockSkewTolerance) return nil, fmt.Errorf("id_token has expired (exp: %d, now: %d, skew_tolerance: %ds)", claims.Exp, now, clockSkewTolerance)
} }
return &claims, nil return claims, nil
} }
// UserInfo represents user information extracted from ID Token claims. // UserInfo represents user information extracted from ID Token claims.
@@ -375,6 +387,7 @@ type UserInfo struct {
Email string Email string
ChatGPTAccountID string ChatGPTAccountID string
ChatGPTUserID string ChatGPTUserID string
PlanType string
UserID string UserID string
OrganizationID string OrganizationID string
Organizations []OrganizationClaim Organizations []OrganizationClaim
@@ -389,6 +402,7 @@ func (c *IDTokenClaims) GetUserInfo() *UserInfo {
if c.OpenAIAuth != nil { if c.OpenAIAuth != nil {
info.ChatGPTAccountID = c.OpenAIAuth.ChatGPTAccountID info.ChatGPTAccountID = c.OpenAIAuth.ChatGPTAccountID
info.ChatGPTUserID = c.OpenAIAuth.ChatGPTUserID info.ChatGPTUserID = c.OpenAIAuth.ChatGPTUserID
info.PlanType = c.OpenAIAuth.ChatGPTPlanType
info.UserID = c.OpenAIAuth.UserID info.UserID = c.OpenAIAuth.UserID
info.Organizations = c.OpenAIAuth.Organizations info.Organizations = c.OpenAIAuth.Organizations

View File

@@ -147,17 +147,47 @@ var (
return 1 return 1
`) `)
// cleanupExpiredSlotsScript - remove expired slots // cleanupExpiredSlotsScript 清理单个账号/用户有序集合中过期槽位
// KEYS[1] = concurrency:account:{accountID} // KEYS[1] = 有序集合键
// ARGV[1] = TTL (seconds) // ARGV[1] = TTL(秒)
cleanupExpiredSlotsScript = redis.NewScript(` cleanupExpiredSlotsScript = redis.NewScript(`
local key = KEYS[1] local key = KEYS[1]
local ttl = tonumber(ARGV[1]) local ttl = tonumber(ARGV[1])
local timeResult = redis.call('TIME') local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1]) local now = tonumber(timeResult[1])
local expireBefore = now - ttl local expireBefore = now - ttl
return redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore) redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
`) if redis.call('ZCARD', key) == 0 then
redis.call('DEL', key)
else
redis.call('EXPIRE', key, ttl)
end
return 1
`)
// startupCleanupScript 清理非当前进程前缀的槽位成员。
// KEYS 是有序集合键列表ARGV[1] 是当前进程前缀ARGV[2] 是槽位 TTL。
// 遍历每个 KEYS[i],移除前缀不匹配的成员,清空后删 key否则刷新 EXPIRE。
startupCleanupScript = redis.NewScript(`
local activePrefix = ARGV[1]
local slotTTL = tonumber(ARGV[2])
local removed = 0
for i = 1, #KEYS do
local key = KEYS[i]
local members = redis.call('ZRANGE', key, 0, -1)
for _, member in ipairs(members) do
if string.sub(member, 1, string.len(activePrefix)) ~= activePrefix then
removed = removed + redis.call('ZREM', key, member)
end
end
if redis.call('ZCARD', key) == 0 then
redis.call('DEL', key)
else
redis.call('EXPIRE', key, slotTTL)
end
end
return removed
`)
) )
type concurrencyCache struct { type concurrencyCache struct {
@@ -463,3 +493,72 @@ func (c *concurrencyCache) CleanupExpiredAccountSlots(ctx context.Context, accou
_, err := cleanupExpiredSlotsScript.Run(ctx, c.rdb, []string{key}, c.slotTTLSeconds).Result() _, err := cleanupExpiredSlotsScript.Run(ctx, c.rdb, []string{key}, c.slotTTLSeconds).Result()
return err return err
} }
func (c *concurrencyCache) CleanupStaleProcessSlots(ctx context.Context, activeRequestPrefix string) error {
if activeRequestPrefix == "" {
return nil
}
// 1. 清理有序集合中非当前进程前缀的成员
slotPatterns := []string{accountSlotKeyPrefix + "*", userSlotKeyPrefix + "*"}
for _, pattern := range slotPatterns {
if err := c.cleanupSlotsByPattern(ctx, pattern, activeRequestPrefix); err != nil {
return err
}
}
// 2. 删除所有等待队列计数器(重启后计数器失效)
waitPatterns := []string{accountWaitKeyPrefix + "*", waitQueueKeyPrefix + "*"}
for _, pattern := range waitPatterns {
if err := c.deleteKeysByPattern(ctx, pattern); err != nil {
return err
}
}
return nil
}
// cleanupSlotsByPattern 扫描匹配 pattern 的有序集合键,批量调用 Lua 脚本清理非当前进程成员。
func (c *concurrencyCache) cleanupSlotsByPattern(ctx context.Context, pattern, activePrefix string) error {
const scanCount = 200
var cursor uint64
for {
keys, nextCursor, err := c.rdb.Scan(ctx, cursor, pattern, scanCount).Result()
if err != nil {
return fmt.Errorf("scan %s: %w", pattern, err)
}
if len(keys) > 0 {
_, err := startupCleanupScript.Run(ctx, c.rdb, keys, activePrefix, c.slotTTLSeconds).Result()
if err != nil {
return fmt.Errorf("cleanup slots %s: %w", pattern, err)
}
}
cursor = nextCursor
if cursor == 0 {
break
}
}
return nil
}
// deleteKeysByPattern 扫描匹配 pattern 的键并删除。
func (c *concurrencyCache) deleteKeysByPattern(ctx context.Context, pattern string) error {
const scanCount = 200
var cursor uint64
for {
keys, nextCursor, err := c.rdb.Scan(ctx, cursor, pattern, scanCount).Result()
if err != nil {
return fmt.Errorf("scan %s: %w", pattern, err)
}
if len(keys) > 0 {
if err := c.rdb.Del(ctx, keys...).Err(); err != nil {
return fmt.Errorf("del %s: %w", pattern, err)
}
}
cursor = nextCursor
if cursor == 0 {
break
}
}
return nil
}

View File

@@ -25,6 +25,10 @@ type ConcurrencyCacheSuite struct {
cache service.ConcurrencyCache cache service.ConcurrencyCache
} }
func TestConcurrencyCacheSuite(t *testing.T) {
suite.Run(t, new(ConcurrencyCacheSuite))
}
func (s *ConcurrencyCacheSuite) SetupTest() { func (s *ConcurrencyCacheSuite) SetupTest() {
s.IntegrationRedisSuite.SetupTest() s.IntegrationRedisSuite.SetupTest()
s.cache = NewConcurrencyCache(s.rdb, testSlotTTLMinutes, int(testSlotTTL.Seconds())) s.cache = NewConcurrencyCache(s.rdb, testSlotTTLMinutes, int(testSlotTTL.Seconds()))
@@ -247,17 +251,41 @@ func (s *ConcurrencyCacheSuite) TestAccountWaitQueue_IncrementAndDecrement() {
require.Equal(s.T(), 1, val, "expected account wait count 1") require.Equal(s.T(), 1, val, "expected account wait count 1")
} }
func (s *ConcurrencyCacheSuite) TestAccountWaitQueue_DecrementNoNegative() { func (s *ConcurrencyCacheSuite) TestCleanupStaleProcessSlots() {
accountID := int64(301) accountID := int64(901)
waitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID) userID := int64(902)
accountKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
userKey := fmt.Sprintf("%s%d", userSlotKeyPrefix, userID)
userWaitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
accountWaitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
require.NoError(s.T(), s.cache.DecrementAccountWaitCount(s.ctx, accountID), "DecrementAccountWaitCount on non-existent key") now := time.Now().Unix()
require.NoError(s.T(), s.rdb.ZAdd(s.ctx, accountKey,
redis.Z{Score: float64(now), Member: "oldproc-1"},
redis.Z{Score: float64(now), Member: "keep-1"},
).Err())
require.NoError(s.T(), s.rdb.ZAdd(s.ctx, userKey,
redis.Z{Score: float64(now), Member: "oldproc-2"},
redis.Z{Score: float64(now), Member: "keep-2"},
).Err())
require.NoError(s.T(), s.rdb.Set(s.ctx, userWaitKey, 3, time.Minute).Err())
require.NoError(s.T(), s.rdb.Set(s.ctx, accountWaitKey, 2, time.Minute).Err())
val, err := s.rdb.Get(s.ctx, waitKey).Int() require.NoError(s.T(), s.cache.CleanupStaleProcessSlots(s.ctx, "keep-"))
if !errors.Is(err, redis.Nil) {
require.NoError(s.T(), err, "Get waitKey") accountMembers, err := s.rdb.ZRange(s.ctx, accountKey, 0, -1).Result()
} require.NoError(s.T(), err)
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative account wait count after decrement on empty") require.Equal(s.T(), []string{"keep-1"}, accountMembers)
userMembers, err := s.rdb.ZRange(s.ctx, userKey, 0, -1).Result()
require.NoError(s.T(), err)
require.Equal(s.T(), []string{"keep-2"}, userMembers)
_, err = s.rdb.Get(s.ctx, userWaitKey).Result()
require.True(s.T(), errors.Is(err, redis.Nil))
_, err = s.rdb.Get(s.ctx, accountWaitKey).Result()
require.True(s.T(), errors.Is(err, redis.Nil))
} }
func (s *ConcurrencyCacheSuite) TestGetAccountConcurrency_Missing() { func (s *ConcurrencyCacheSuite) TestGetAccountConcurrency_Missing() {
@@ -407,6 +435,53 @@ func (s *ConcurrencyCacheSuite) TestCleanupExpiredAccountSlots_NoExpired() {
require.Equal(s.T(), 2, cur) require.Equal(s.T(), 2, cur)
} }
func TestConcurrencyCacheSuite(t *testing.T) { func (s *ConcurrencyCacheSuite) TestCleanupStaleProcessSlots_RemovesOldPrefixesAndWaitCounters() {
suite.Run(t, new(ConcurrencyCacheSuite)) accountID := int64(901)
userID := int64(902)
accountSlotKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
userSlotKey := fmt.Sprintf("%s%d", userSlotKeyPrefix, userID)
userWaitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
accountWaitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
now := float64(time.Now().Unix())
require.NoError(s.T(), s.rdb.ZAdd(s.ctx, accountSlotKey,
redis.Z{Score: now, Member: "oldproc-1"},
redis.Z{Score: now, Member: "activeproc-1"},
).Err())
require.NoError(s.T(), s.rdb.Expire(s.ctx, accountSlotKey, testSlotTTL).Err())
require.NoError(s.T(), s.rdb.ZAdd(s.ctx, userSlotKey,
redis.Z{Score: now, Member: "oldproc-2"},
redis.Z{Score: now, Member: "activeproc-2"},
).Err())
require.NoError(s.T(), s.rdb.Expire(s.ctx, userSlotKey, testSlotTTL).Err())
require.NoError(s.T(), s.rdb.Set(s.ctx, userWaitKey, 3, testSlotTTL).Err())
require.NoError(s.T(), s.rdb.Set(s.ctx, accountWaitKey, 2, testSlotTTL).Err())
require.NoError(s.T(), s.cache.CleanupStaleProcessSlots(s.ctx, "activeproc-"))
accountMembers, err := s.rdb.ZRange(s.ctx, accountSlotKey, 0, -1).Result()
require.NoError(s.T(), err)
require.Equal(s.T(), []string{"activeproc-1"}, accountMembers)
userMembers, err := s.rdb.ZRange(s.ctx, userSlotKey, 0, -1).Result()
require.NoError(s.T(), err)
require.Equal(s.T(), []string{"activeproc-2"}, userMembers)
_, err = s.rdb.Get(s.ctx, userWaitKey).Result()
require.ErrorIs(s.T(), err, redis.Nil)
_, err = s.rdb.Get(s.ctx, accountWaitKey).Result()
require.ErrorIs(s.T(), err, redis.Nil)
}
func (s *ConcurrencyCacheSuite) TestCleanupStaleProcessSlots_DeletesEmptySlotKeys() {
accountID := int64(903)
accountSlotKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
require.NoError(s.T(), s.rdb.ZAdd(s.ctx, accountSlotKey, redis.Z{Score: float64(time.Now().Unix()), Member: "oldproc-1"}).Err())
require.NoError(s.T(), s.rdb.Expire(s.ctx, accountSlotKey, testSlotTTL).Err())
require.NoError(s.T(), s.cache.CleanupStaleProcessSlots(s.ctx, "activeproc-"))
exists, err := s.rdb.Exists(s.ctx, accountSlotKey).Result()
require.NoError(s.T(), err)
require.EqualValues(s.T(), 0, exists)
} }

View File

@@ -645,7 +645,7 @@ func newContractDeps(t *testing.T) *contractDeps {
settingRepo := newStubSettingRepo() settingRepo := newStubSettingRepo()
settingService := service.NewSettingService(settingRepo, cfg) settingService := service.NewSettingService(settingRepo, cfg)
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil) adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil)
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil) authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService) usageHandler := handler.NewUsageHandler(usageService, apiKeyService)

View File

@@ -19,7 +19,7 @@ func TestAdminAuthJWTValidatesTokenVersion(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
cfg := &config.Config{JWT: config.JWTConfig{Secret: "test-secret", ExpireHour: 1}} cfg := &config.Config{JWT: config.JWTConfig{Secret: "test-secret", ExpireHour: 1}}
authService := service.NewAuthService(nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil) authService := service.NewAuthService(nil, nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
admin := &service.User{ admin := &service.User{
ID: 1, ID: 1,

View File

@@ -40,7 +40,7 @@ func newJWTTestEnv(users map[int64]*service.User) (*gin.Engine, *service.AuthSer
cfg.JWT.AccessTokenExpireMinutes = 60 cfg.JWT.AccessTokenExpireMinutes = 60
userRepo := &stubJWTUserRepo{users: users} userRepo := &stubJWTUserRepo{users: users}
authSvc := service.NewAuthService(userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil) authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
userSvc := service.NewUserService(userRepo, nil, nil) userSvc := service.NewUserService(userRepo, nil, nil)
mw := NewJWTAuthMiddleware(authSvc, userSvc) mw := NewJWTAuthMiddleware(authSvc, userSvc)

View File

@@ -264,6 +264,8 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
accounts.POST("/batch-update-credentials", h.Admin.Account.BatchUpdateCredentials) accounts.POST("/batch-update-credentials", h.Admin.Account.BatchUpdateCredentials)
accounts.POST("/batch-refresh-tier", h.Admin.Account.BatchRefreshTier) accounts.POST("/batch-refresh-tier", h.Admin.Account.BatchRefreshTier)
accounts.POST("/bulk-update", h.Admin.Account.BulkUpdate) accounts.POST("/bulk-update", h.Admin.Account.BulkUpdate)
accounts.POST("/batch-clear-error", h.Admin.Account.BatchClearError)
accounts.POST("/batch-refresh", h.Admin.Account.BatchRefresh)
// Antigravity 默认模型映射 // Antigravity 默认模型映射
accounts.GET("/antigravity/default-model-mapping", h.Admin.Account.GetAntigravityDefaultModelMapping) accounts.GET("/antigravity/default-model-mapping", h.Admin.Account.GetAntigravityDefaultModelMapping)
@@ -396,6 +398,9 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
// 请求整流器配置 // 请求整流器配置
adminSettings.GET("/rectifier", h.Admin.Setting.GetRectifierSettings) adminSettings.GET("/rectifier", h.Admin.Setting.GetRectifierSettings)
adminSettings.PUT("/rectifier", h.Admin.Setting.UpdateRectifierSettings) adminSettings.PUT("/rectifier", h.Admin.Setting.UpdateRectifierSettings)
// Beta 策略配置
adminSettings.GET("/beta-policy", h.Admin.Setting.GetBetaPolicySettings)
adminSettings.PUT("/beta-policy", h.Admin.Setting.UpdateBetaPolicySettings)
// Sora S3 存储配置 // Sora S3 存储配置
adminSettings.GET("/sora-s3", h.Admin.Setting.GetSoraS3Settings) adminSettings.GET("/sora-s3", h.Admin.Setting.GetSoraS3Settings)
adminSettings.PUT("/sora-s3", h.Admin.Setting.UpdateSoraS3Settings) adminSettings.PUT("/sora-s3", h.Admin.Setting.UpdateSoraS3Settings)

View File

@@ -61,6 +61,12 @@ func RegisterAuthRoutes(
}), h.Auth.ResetPassword) }), h.Auth.ResetPassword)
auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart) auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart)
auth.GET("/oauth/linuxdo/callback", h.Auth.LinuxDoOAuthCallback) auth.GET("/oauth/linuxdo/callback", h.Auth.LinuxDoOAuthCallback)
auth.POST("/oauth/linuxdo/complete-registration",
rateLimiter.LimitWithOptions("oauth-linuxdo-complete", 10, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}),
h.Auth.CompleteLinuxDoOAuthRegistration,
)
} }
// 公开设置(无需认证) // 公开设置(无需认证)

View File

@@ -432,6 +432,7 @@ type adminServiceImpl struct {
entClient *dbent.Client // 用于开启数据库事务 entClient *dbent.Client // 用于开启数据库事务
settingService *SettingService settingService *SettingService
defaultSubAssigner DefaultSubscriptionAssigner defaultSubAssigner DefaultSubscriptionAssigner
userSubRepo UserSubscriptionRepository
} }
type userGroupRateBatchReader interface { type userGroupRateBatchReader interface {
@@ -459,6 +460,7 @@ func NewAdminService(
entClient *dbent.Client, entClient *dbent.Client,
settingService *SettingService, settingService *SettingService,
defaultSubAssigner DefaultSubscriptionAssigner, defaultSubAssigner DefaultSubscriptionAssigner,
userSubRepo UserSubscriptionRepository,
) AdminService { ) AdminService {
return &adminServiceImpl{ return &adminServiceImpl{
userRepo: userRepo, userRepo: userRepo,
@@ -476,6 +478,7 @@ func NewAdminService(
entClient: entClient, entClient: entClient,
settingService: settingService, settingService: settingService,
defaultSubAssigner: defaultSubAssigner, defaultSubAssigner: defaultSubAssigner,
userSubRepo: userSubRepo,
} }
} }
@@ -1277,9 +1280,17 @@ func (s *adminServiceImpl) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID i
if group.Status != StatusActive { if group.Status != StatusActive {
return nil, infraerrors.BadRequest("GROUP_NOT_ACTIVE", "target group is not active") return nil, infraerrors.BadRequest("GROUP_NOT_ACTIVE", "target group is not active")
} }
// 订阅类型分组:不允许通过此 API 直接绑定,需通过订阅管理流程 // 订阅类型分组:用户须持有该分组的有效订阅才可绑定
if group.IsSubscriptionType() { if group.IsSubscriptionType() {
return nil, infraerrors.BadRequest("SUBSCRIPTION_GROUP_NOT_ALLOWED", "subscription groups must be managed through the subscription workflow") if s.userSubRepo == nil {
return nil, infraerrors.InternalServer("SUBSCRIPTION_REPOSITORY_UNAVAILABLE", "subscription repository is not configured")
}
if _, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, apiKey.UserID, *groupID); err != nil {
if errors.Is(err, ErrSubscriptionNotFound) {
return nil, infraerrors.BadRequest("SUBSCRIPTION_REQUIRED", "user does not have an active subscription for this group")
}
return nil, err
}
} }
gid := *groupID gid := *groupID
@@ -1287,7 +1298,7 @@ func (s *adminServiceImpl) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID i
apiKey.Group = group apiKey.Group = group
// 专属标准分组:使用事务保证「添加分组权限」与「更新 API Key」的原子性 // 专属标准分组:使用事务保证「添加分组权限」与「更新 API Key」的原子性
if group.IsExclusive { if group.IsExclusive && !group.IsSubscriptionType() {
opCtx := ctx opCtx := ctx
var tx *dbent.Tx var tx *dbent.Tx
if s.entClient == nil { if s.entClient == nil {

View File

@@ -32,28 +32,44 @@ func (s *userRepoStubForGroupUpdate) AddGroupToAllowedGroups(_ context.Context,
return s.addGroupErr return s.addGroupErr
} }
func (s *userRepoStubForGroupUpdate) Create(context.Context, *User) error { panic("unexpected") } func (s *userRepoStubForGroupUpdate) Create(context.Context, *User) error { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) GetByID(context.Context, int64) (*User, error) { panic("unexpected") } func (s *userRepoStubForGroupUpdate) GetByID(context.Context, int64) (*User, error) {
func (s *userRepoStubForGroupUpdate) GetByEmail(context.Context, string) (*User, error) { panic("unexpected") } panic("unexpected")
func (s *userRepoStubForGroupUpdate) GetFirstAdmin(context.Context) (*User, error) { panic("unexpected") } }
func (s *userRepoStubForGroupUpdate) Update(context.Context, *User) error { panic("unexpected") } func (s *userRepoStubForGroupUpdate) GetByEmail(context.Context, string) (*User, error) {
func (s *userRepoStubForGroupUpdate) Delete(context.Context, int64) error { panic("unexpected") } panic("unexpected")
}
func (s *userRepoStubForGroupUpdate) GetFirstAdmin(context.Context) (*User, error) {
panic("unexpected")
}
func (s *userRepoStubForGroupUpdate) Update(context.Context, *User) error { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) Delete(context.Context, int64) error { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) { func (s *userRepoStubForGroupUpdate) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
panic("unexpected") panic("unexpected")
} }
func (s *userRepoStubForGroupUpdate) ListWithFilters(context.Context, pagination.PaginationParams, UserListFilters) ([]User, *pagination.PaginationResult, error) { func (s *userRepoStubForGroupUpdate) ListWithFilters(context.Context, pagination.PaginationParams, UserListFilters) ([]User, *pagination.PaginationResult, error) {
panic("unexpected") panic("unexpected")
} }
func (s *userRepoStubForGroupUpdate) UpdateBalance(context.Context, int64, float64) error { panic("unexpected") } func (s *userRepoStubForGroupUpdate) UpdateBalance(context.Context, int64, float64) error {
func (s *userRepoStubForGroupUpdate) DeductBalance(context.Context, int64, float64) error { panic("unexpected") } panic("unexpected")
func (s *userRepoStubForGroupUpdate) UpdateConcurrency(context.Context, int64, int) error { panic("unexpected") } }
func (s *userRepoStubForGroupUpdate) ExistsByEmail(context.Context, string) (bool, error) { panic("unexpected") } func (s *userRepoStubForGroupUpdate) DeductBalance(context.Context, int64, float64) error {
panic("unexpected")
}
func (s *userRepoStubForGroupUpdate) UpdateConcurrency(context.Context, int64, int) error {
panic("unexpected")
}
func (s *userRepoStubForGroupUpdate) ExistsByEmail(context.Context, string) (bool, error) {
panic("unexpected")
}
func (s *userRepoStubForGroupUpdate) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) { func (s *userRepoStubForGroupUpdate) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
panic("unexpected") panic("unexpected")
} }
func (s *userRepoStubForGroupUpdate) UpdateTotpSecret(context.Context, int64, *string) error { panic("unexpected") } func (s *userRepoStubForGroupUpdate) UpdateTotpSecret(context.Context, int64, *string) error {
func (s *userRepoStubForGroupUpdate) EnableTotp(context.Context, int64) error { panic("unexpected") } panic("unexpected")
func (s *userRepoStubForGroupUpdate) DisableTotp(context.Context, int64) error { panic("unexpected") } }
func (s *userRepoStubForGroupUpdate) EnableTotp(context.Context, int64) error { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) DisableTotp(context.Context, int64) error { panic("unexpected") }
// apiKeyRepoStubForGroupUpdate implements APIKeyRepository for AdminUpdateAPIKeyGroupID tests. // apiKeyRepoStubForGroupUpdate implements APIKeyRepository for AdminUpdateAPIKeyGroupID tests.
type apiKeyRepoStubForGroupUpdate struct { type apiKeyRepoStubForGroupUpdate struct {
@@ -194,6 +210,29 @@ func (s *groupRepoStubForGroupUpdate) UpdateSortOrders(context.Context, []GroupS
panic("unexpected") panic("unexpected")
} }
type userSubRepoStubForGroupUpdate struct {
userSubRepoNoop
getActiveSub *UserSubscription
getActiveErr error
called bool
calledUserID int64
calledGroupID int64
}
func (s *userSubRepoStubForGroupUpdate) GetActiveByUserIDAndGroupID(_ context.Context, userID, groupID int64) (*UserSubscription, error) {
s.called = true
s.calledUserID = userID
s.calledGroupID = groupID
if s.getActiveErr != nil {
return nil, s.getActiveErr
}
if s.getActiveSub == nil {
return nil, ErrSubscriptionNotFound
}
clone := *s.getActiveSub
return &clone, nil
}
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
// Tests // Tests
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
@@ -386,14 +425,49 @@ func TestAdminService_AdminUpdateAPIKeyGroupID_NonExclusiveGroup_NoAllowedGroupU
func TestAdminService_AdminUpdateAPIKeyGroupID_SubscriptionGroup_Blocked(t *testing.T) { func TestAdminService_AdminUpdateAPIKeyGroupID_SubscriptionGroup_Blocked(t *testing.T) {
existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil} existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil}
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing} apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Sub", Status: StatusActive, IsExclusive: true, SubscriptionType: SubscriptionTypeSubscription}} groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Sub", Status: StatusActive, IsExclusive: false, SubscriptionType: SubscriptionTypeSubscription}}
userRepo := &userRepoStubForGroupUpdate{}
userSubRepo := &userSubRepoStubForGroupUpdate{getActiveErr: ErrSubscriptionNotFound}
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo, userSubRepo: userSubRepo}
// 无有效订阅时应拒绝绑定
_, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10))
require.Error(t, err)
require.Equal(t, "SUBSCRIPTION_REQUIRED", infraerrors.Reason(err))
require.True(t, userSubRepo.called)
require.Equal(t, int64(42), userSubRepo.calledUserID)
require.Equal(t, int64(10), userSubRepo.calledGroupID)
require.False(t, userRepo.addGroupCalled)
}
func TestAdminService_AdminUpdateAPIKeyGroupID_SubscriptionGroup_RequiresRepo(t *testing.T) {
existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil}
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Sub", Status: StatusActive, IsExclusive: false, SubscriptionType: SubscriptionTypeSubscription}}
userRepo := &userRepoStubForGroupUpdate{} userRepo := &userRepoStubForGroupUpdate{}
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo} svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo}
// 订阅类型分组应被阻止绑定
_, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10)) _, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10))
require.Error(t, err) require.Error(t, err)
require.Equal(t, "SUBSCRIPTION_GROUP_NOT_ALLOWED", infraerrors.Reason(err)) require.Equal(t, "SUBSCRIPTION_REPOSITORY_UNAVAILABLE", infraerrors.Reason(err))
require.False(t, userRepo.addGroupCalled)
}
func TestAdminService_AdminUpdateAPIKeyGroupID_SubscriptionGroup_AllowsActiveSubscription(t *testing.T) {
existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil}
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Sub", Status: StatusActive, IsExclusive: true, SubscriptionType: SubscriptionTypeSubscription}}
userRepo := &userRepoStubForGroupUpdate{}
userSubRepo := &userSubRepoStubForGroupUpdate{
getActiveSub: &UserSubscription{ID: 99, UserID: 42, GroupID: 10},
}
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo, userSubRepo: userSubRepo}
got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10))
require.NoError(t, err)
require.True(t, userSubRepo.called)
require.NotNil(t, got.APIKey.GroupID)
require.Equal(t, int64(10), *got.APIKey.GroupID)
require.False(t, userRepo.addGroupCalled) require.False(t, userRepo.addGroupCalled)
} }

View File

@@ -12,6 +12,7 @@ import (
"strings" "strings"
"time" "time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/logger"
@@ -21,24 +22,25 @@ import (
) )
var ( var (
ErrInvalidCredentials = infraerrors.Unauthorized("INVALID_CREDENTIALS", "invalid email or password") ErrInvalidCredentials = infraerrors.Unauthorized("INVALID_CREDENTIALS", "invalid email or password")
ErrUserNotActive = infraerrors.Forbidden("USER_NOT_ACTIVE", "user is not active") ErrUserNotActive = infraerrors.Forbidden("USER_NOT_ACTIVE", "user is not active")
ErrEmailExists = infraerrors.Conflict("EMAIL_EXISTS", "email already exists") ErrEmailExists = infraerrors.Conflict("EMAIL_EXISTS", "email already exists")
ErrEmailReserved = infraerrors.BadRequest("EMAIL_RESERVED", "email is reserved") ErrEmailReserved = infraerrors.BadRequest("EMAIL_RESERVED", "email is reserved")
ErrInvalidToken = infraerrors.Unauthorized("INVALID_TOKEN", "invalid token") ErrInvalidToken = infraerrors.Unauthorized("INVALID_TOKEN", "invalid token")
ErrTokenExpired = infraerrors.Unauthorized("TOKEN_EXPIRED", "token has expired") ErrTokenExpired = infraerrors.Unauthorized("TOKEN_EXPIRED", "token has expired")
ErrAccessTokenExpired = infraerrors.Unauthorized("ACCESS_TOKEN_EXPIRED", "access token has expired") ErrAccessTokenExpired = infraerrors.Unauthorized("ACCESS_TOKEN_EXPIRED", "access token has expired")
ErrTokenTooLarge = infraerrors.BadRequest("TOKEN_TOO_LARGE", "token too large") ErrTokenTooLarge = infraerrors.BadRequest("TOKEN_TOO_LARGE", "token too large")
ErrTokenRevoked = infraerrors.Unauthorized("TOKEN_REVOKED", "token has been revoked") ErrTokenRevoked = infraerrors.Unauthorized("TOKEN_REVOKED", "token has been revoked")
ErrRefreshTokenInvalid = infraerrors.Unauthorized("REFRESH_TOKEN_INVALID", "invalid refresh token") ErrRefreshTokenInvalid = infraerrors.Unauthorized("REFRESH_TOKEN_INVALID", "invalid refresh token")
ErrRefreshTokenExpired = infraerrors.Unauthorized("REFRESH_TOKEN_EXPIRED", "refresh token has expired") ErrRefreshTokenExpired = infraerrors.Unauthorized("REFRESH_TOKEN_EXPIRED", "refresh token has expired")
ErrRefreshTokenReused = infraerrors.Unauthorized("REFRESH_TOKEN_REUSED", "refresh token has been reused") ErrRefreshTokenReused = infraerrors.Unauthorized("REFRESH_TOKEN_REUSED", "refresh token has been reused")
ErrEmailVerifyRequired = infraerrors.BadRequest("EMAIL_VERIFY_REQUIRED", "email verification is required") ErrEmailVerifyRequired = infraerrors.BadRequest("EMAIL_VERIFY_REQUIRED", "email verification is required")
ErrEmailSuffixNotAllowed = infraerrors.BadRequest("EMAIL_SUFFIX_NOT_ALLOWED", "email suffix is not allowed") ErrEmailSuffixNotAllowed = infraerrors.BadRequest("EMAIL_SUFFIX_NOT_ALLOWED", "email suffix is not allowed")
ErrRegDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled") ErrRegDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
ErrServiceUnavailable = infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable") ErrServiceUnavailable = infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable")
ErrInvitationCodeRequired = infraerrors.BadRequest("INVITATION_CODE_REQUIRED", "invitation code is required") ErrInvitationCodeRequired = infraerrors.BadRequest("INVITATION_CODE_REQUIRED", "invitation code is required")
ErrInvitationCodeInvalid = infraerrors.BadRequest("INVITATION_CODE_INVALID", "invalid or used invitation code") ErrInvitationCodeInvalid = infraerrors.BadRequest("INVITATION_CODE_INVALID", "invalid or used invitation code")
ErrOAuthInvitationRequired = infraerrors.Forbidden("OAUTH_INVITATION_REQUIRED", "invitation code required to complete oauth registration")
) )
// maxTokenLength 限制 token 大小,避免超长 header 触发解析时的异常内存分配。 // maxTokenLength 限制 token 大小,避免超长 header 触发解析时的异常内存分配。
@@ -58,6 +60,7 @@ type JWTClaims struct {
// AuthService 认证服务 // AuthService 认证服务
type AuthService struct { type AuthService struct {
entClient *dbent.Client
userRepo UserRepository userRepo UserRepository
redeemRepo RedeemCodeRepository redeemRepo RedeemCodeRepository
refreshTokenCache RefreshTokenCache refreshTokenCache RefreshTokenCache
@@ -76,6 +79,7 @@ type DefaultSubscriptionAssigner interface {
// NewAuthService 创建认证服务实例 // NewAuthService 创建认证服务实例
func NewAuthService( func NewAuthService(
entClient *dbent.Client,
userRepo UserRepository, userRepo UserRepository,
redeemRepo RedeemCodeRepository, redeemRepo RedeemCodeRepository,
refreshTokenCache RefreshTokenCache, refreshTokenCache RefreshTokenCache,
@@ -88,6 +92,7 @@ func NewAuthService(
defaultSubAssigner DefaultSubscriptionAssigner, defaultSubAssigner DefaultSubscriptionAssigner,
) *AuthService { ) *AuthService {
return &AuthService{ return &AuthService{
entClient: entClient,
userRepo: userRepo, userRepo: userRepo,
redeemRepo: redeemRepo, redeemRepo: redeemRepo,
refreshTokenCache: refreshTokenCache, refreshTokenCache: refreshTokenCache,
@@ -523,9 +528,10 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
return token, user, nil return token, user, nil
} }
// LoginOrRegisterOAuthWithTokenPair 用于第三方 OAuth/SSO 登录,返回完整的 TokenPair // LoginOrRegisterOAuthWithTokenPair 用于第三方 OAuth/SSO 登录,返回完整的 TokenPair
// 与 LoginOrRegisterOAuth 功能相同,但返回 TokenPair 而非单个 token // 与 LoginOrRegisterOAuth 功能相同,但返回 TokenPair 而非单个 token
func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, email, username string) (*TokenPair, *User, error) { // invitationCode 仅在邀请码注册模式下新用户注册时使用;已有账号登录时忽略。
func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, email, username, invitationCode string) (*TokenPair, *User, error) {
// 检查 refreshTokenCache 是否可用 // 检查 refreshTokenCache 是否可用
if s.refreshTokenCache == nil { if s.refreshTokenCache == nil {
return nil, nil, errors.New("refresh token cache not configured") return nil, nil, errors.New("refresh token cache not configured")
@@ -552,6 +558,22 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
return nil, nil, ErrRegDisabled return nil, nil, ErrRegDisabled
} }
// 检查是否需要邀请码
var invitationRedeemCode *RedeemCode
if s.settingService != nil && s.settingService.IsInvitationCodeEnabled(ctx) {
if invitationCode == "" {
return nil, nil, ErrOAuthInvitationRequired
}
redeemCode, err := s.redeemRepo.GetByCode(ctx, invitationCode)
if err != nil {
return nil, nil, ErrInvitationCodeInvalid
}
if redeemCode.Type != RedeemTypeInvitation || redeemCode.Status != StatusUnused {
return nil, nil, ErrInvitationCodeInvalid
}
invitationRedeemCode = redeemCode
}
randomPassword, err := randomHexString(32) randomPassword, err := randomHexString(32)
if err != nil { if err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to generate random password for oauth signup: %v", err) logger.LegacyPrintf("service.auth", "[Auth] Failed to generate random password for oauth signup: %v", err)
@@ -579,20 +601,58 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
Status: StatusActive, Status: StatusActive,
} }
if err := s.userRepo.Create(ctx, newUser); err != nil { if s.entClient != nil && invitationRedeemCode != nil {
if errors.Is(err, ErrEmailExists) { tx, err := s.entClient.Tx(ctx)
user, err = s.userRepo.GetByEmail(ctx, email) if err != nil {
if err != nil { logger.LegacyPrintf("service.auth", "[Auth] Failed to begin transaction for oauth registration: %v", err)
logger.LegacyPrintf("service.auth", "[Auth] Database error getting user after conflict: %v", err) return nil, nil, ErrServiceUnavailable
}
defer func() { _ = tx.Rollback() }()
txCtx := dbent.NewTxContext(ctx, tx)
if err := s.userRepo.Create(txCtx, newUser); err != nil {
if errors.Is(err, ErrEmailExists) {
user, err = s.userRepo.GetByEmail(ctx, email)
if err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Database error getting user after conflict: %v", err)
return nil, nil, ErrServiceUnavailable
}
} else {
logger.LegacyPrintf("service.auth", "[Auth] Database error creating oauth user: %v", err)
return nil, nil, ErrServiceUnavailable return nil, nil, ErrServiceUnavailable
} }
} else { } else {
logger.LegacyPrintf("service.auth", "[Auth] Database error creating oauth user: %v", err) if err := s.redeemRepo.Use(txCtx, invitationRedeemCode.ID, newUser.ID); err != nil {
return nil, nil, ErrServiceUnavailable return nil, nil, ErrInvitationCodeInvalid
}
if err := tx.Commit(); err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to commit oauth registration transaction: %v", err)
return nil, nil, ErrServiceUnavailable
}
user = newUser
s.assignDefaultSubscriptions(ctx, user.ID)
} }
} else { } else {
user = newUser if err := s.userRepo.Create(ctx, newUser); err != nil {
s.assignDefaultSubscriptions(ctx, user.ID) if errors.Is(err, ErrEmailExists) {
user, err = s.userRepo.GetByEmail(ctx, email)
if err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Database error getting user after conflict: %v", err)
return nil, nil, ErrServiceUnavailable
}
} else {
logger.LegacyPrintf("service.auth", "[Auth] Database error creating oauth user: %v", err)
return nil, nil, ErrServiceUnavailable
}
} else {
user = newUser
s.assignDefaultSubscriptions(ctx, user.ID)
if invitationRedeemCode != nil {
if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil {
return nil, nil, ErrInvitationCodeInvalid
}
}
}
} }
} else { } else {
logger.LegacyPrintf("service.auth", "[Auth] Database error during oauth login: %v", err) logger.LegacyPrintf("service.auth", "[Auth] Database error during oauth login: %v", err)
@@ -618,6 +678,63 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
return tokenPair, user, nil return tokenPair, user, nil
} }
// pendingOAuthTokenTTL is the validity period for pending OAuth tokens.
const pendingOAuthTokenTTL = 10 * time.Minute
// pendingOAuthPurpose is the purpose claim value for pending OAuth registration tokens.
const pendingOAuthPurpose = "pending_oauth_registration"
type pendingOAuthClaims struct {
Email string `json:"email"`
Username string `json:"username"`
Purpose string `json:"purpose"`
jwt.RegisteredClaims
}
// CreatePendingOAuthToken generates a short-lived JWT that carries the OAuth identity
// while waiting for the user to supply an invitation code.
func (s *AuthService) CreatePendingOAuthToken(email, username string) (string, error) {
now := time.Now()
claims := &pendingOAuthClaims{
Email: email,
Username: username,
Purpose: pendingOAuthPurpose,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(now.Add(pendingOAuthTokenTTL)),
IssuedAt: jwt.NewNumericDate(now),
NotBefore: jwt.NewNumericDate(now),
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString([]byte(s.cfg.JWT.Secret))
}
// VerifyPendingOAuthToken validates a pending OAuth token and returns the embedded identity.
// Returns ErrInvalidToken when the token is invalid or expired.
func (s *AuthService) VerifyPendingOAuthToken(tokenStr string) (email, username string, err error) {
if len(tokenStr) > maxTokenLength {
return "", "", ErrInvalidToken
}
parser := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name}))
token, parseErr := parser.ParseWithClaims(tokenStr, &pendingOAuthClaims{}, func(t *jwt.Token) (any, error) {
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
}
return []byte(s.cfg.JWT.Secret), nil
})
if parseErr != nil {
return "", "", ErrInvalidToken
}
claims, ok := token.Claims.(*pendingOAuthClaims)
if !ok || !token.Valid {
return "", "", ErrInvalidToken
}
if claims.Purpose != pendingOAuthPurpose {
return "", "", ErrInvalidToken
}
return claims.Email, claims.Username, nil
}
func (s *AuthService) assignDefaultSubscriptions(ctx context.Context, userID int64) { func (s *AuthService) assignDefaultSubscriptions(ctx context.Context, userID int64) {
if s.settingService == nil || s.defaultSubAssigner == nil || userID <= 0 { if s.settingService == nil || s.defaultSubAssigner == nil || userID <= 0 {
return return

View File

@@ -0,0 +1,146 @@
//go:build unit
package service
import (
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/require"
)
func newAuthServiceForPendingOAuthTest() *AuthService {
cfg := &config.Config{
JWT: config.JWTConfig{
Secret: "test-secret-pending-oauth",
ExpireHour: 1,
},
}
return NewAuthService(nil, nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
}
// TestVerifyPendingOAuthToken_ValidToken 验证正常签发的 pending token 可以被成功解析。
func TestVerifyPendingOAuthToken_ValidToken(t *testing.T) {
svc := newAuthServiceForPendingOAuthTest()
token, err := svc.CreatePendingOAuthToken("user@example.com", "alice")
require.NoError(t, err)
require.NotEmpty(t, token)
email, username, err := svc.VerifyPendingOAuthToken(token)
require.NoError(t, err)
require.Equal(t, "user@example.com", email)
require.Equal(t, "alice", username)
}
// TestVerifyPendingOAuthToken_RegularJWTRejected 用普通 access token 尝试验证,应返回 ErrInvalidToken。
func TestVerifyPendingOAuthToken_RegularJWTRejected(t *testing.T) {
svc := newAuthServiceForPendingOAuthTest()
// 签发一个普通 access tokenJWTClaims无 Purpose 字段)
accessToken, err := svc.GenerateToken(&User{
ID: 1,
Email: "user@example.com",
Role: RoleUser,
})
require.NoError(t, err)
_, _, err = svc.VerifyPendingOAuthToken(accessToken)
require.ErrorIs(t, err, ErrInvalidToken)
}
// TestVerifyPendingOAuthToken_WrongPurpose 手动构造 purpose 字段不匹配的 JWT应返回 ErrInvalidToken。
func TestVerifyPendingOAuthToken_WrongPurpose(t *testing.T) {
svc := newAuthServiceForPendingOAuthTest()
now := time.Now()
claims := &pendingOAuthClaims{
Email: "user@example.com",
Username: "alice",
Purpose: "some_other_purpose",
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(now.Add(10 * time.Minute)),
IssuedAt: jwt.NewNumericDate(now),
NotBefore: jwt.NewNumericDate(now),
},
}
tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret))
require.NoError(t, err)
_, _, err = svc.VerifyPendingOAuthToken(tokenStr)
require.ErrorIs(t, err, ErrInvalidToken)
}
// TestVerifyPendingOAuthToken_MissingPurpose 手动构造无 purpose 字段的 JWT模拟旧 token应返回 ErrInvalidToken。
func TestVerifyPendingOAuthToken_MissingPurpose(t *testing.T) {
svc := newAuthServiceForPendingOAuthTest()
now := time.Now()
claims := &pendingOAuthClaims{
Email: "user@example.com",
Username: "alice",
Purpose: "", // 旧 token 无此字段,反序列化后为零值
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(now.Add(10 * time.Minute)),
IssuedAt: jwt.NewNumericDate(now),
NotBefore: jwt.NewNumericDate(now),
},
}
tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret))
require.NoError(t, err)
_, _, err = svc.VerifyPendingOAuthToken(tokenStr)
require.ErrorIs(t, err, ErrInvalidToken)
}
// TestVerifyPendingOAuthToken_ExpiredToken 过期 token 应返回 ErrInvalidToken。
func TestVerifyPendingOAuthToken_ExpiredToken(t *testing.T) {
svc := newAuthServiceForPendingOAuthTest()
past := time.Now().Add(-1 * time.Hour)
claims := &pendingOAuthClaims{
Email: "user@example.com",
Username: "alice",
Purpose: pendingOAuthPurpose,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(past),
IssuedAt: jwt.NewNumericDate(past.Add(-10 * time.Minute)),
NotBefore: jwt.NewNumericDate(past.Add(-10 * time.Minute)),
},
}
tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret))
require.NoError(t, err)
_, _, err = svc.VerifyPendingOAuthToken(tokenStr)
require.ErrorIs(t, err, ErrInvalidToken)
}
// TestVerifyPendingOAuthToken_WrongSecret 不同密钥签发的 token 应返回 ErrInvalidToken。
func TestVerifyPendingOAuthToken_WrongSecret(t *testing.T) {
other := NewAuthService(nil, nil, nil, nil, &config.Config{
JWT: config.JWTConfig{Secret: "other-secret"},
}, nil, nil, nil, nil, nil, nil)
token, err := other.CreatePendingOAuthToken("user@example.com", "alice")
require.NoError(t, err)
svc := newAuthServiceForPendingOAuthTest()
_, _, err = svc.VerifyPendingOAuthToken(token)
require.ErrorIs(t, err, ErrInvalidToken)
}
// TestVerifyPendingOAuthToken_TooLong 超长 token 应返回 ErrInvalidToken。
func TestVerifyPendingOAuthToken_TooLong(t *testing.T) {
svc := newAuthServiceForPendingOAuthTest()
giant := make([]byte, maxTokenLength+1)
for i := range giant {
giant[i] = 'a'
}
_, _, err := svc.VerifyPendingOAuthToken(string(giant))
require.ErrorIs(t, err, ErrInvalidToken)
}

View File

@@ -130,6 +130,7 @@ func newAuthService(repo *userRepoStub, settings map[string]string, emailCache E
} }
return NewAuthService( return NewAuthService(
nil, // entClient
repo, repo,
nil, // redeemRepo nil, // redeemRepo
nil, // refreshTokenCache nil, // refreshTokenCache

View File

@@ -43,6 +43,7 @@ func newAuthServiceForRegisterTurnstileTest(settings map[string]string, verifier
turnstileService := NewTurnstileService(settingService, verifier) turnstileService := NewTurnstileService(settingService, verifier)
return NewAuthService( return NewAuthService(
nil, // entClient
&userRepoStub{}, &userRepoStub{},
nil, // redeemRepo nil, // redeemRepo
nil, // refreshTokenCache nil, // refreshTokenCache

View File

@@ -43,6 +43,9 @@ type ConcurrencyCache interface {
// 清理过期槽位(后台任务) // 清理过期槽位(后台任务)
CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error
// 启动时清理旧进程遗留槽位与等待计数
CleanupStaleProcessSlots(ctx context.Context, activeRequestPrefix string) error
} }
var ( var (
@@ -59,13 +62,22 @@ func initRequestIDPrefix() string {
return "r" + strconv.FormatUint(fallback, 36) return "r" + strconv.FormatUint(fallback, 36)
} }
// generateRequestID generates a unique request ID for concurrency slot tracking. func RequestIDPrefix() string {
// Format: {process_random_prefix}-{base36_counter} return requestIDPrefix
}
func generateRequestID() string { func generateRequestID() string {
seq := requestIDCounter.Add(1) seq := requestIDCounter.Add(1)
return requestIDPrefix + "-" + strconv.FormatUint(seq, 36) return requestIDPrefix + "-" + strconv.FormatUint(seq, 36)
} }
func (s *ConcurrencyService) CleanupStaleProcessSlots(ctx context.Context) error {
if s == nil || s.cache == nil {
return nil
}
return s.cache.CleanupStaleProcessSlots(ctx, RequestIDPrefix())
}
const ( const (
// Default extra wait slots beyond concurrency limit // Default extra wait slots beyond concurrency limit
defaultExtraWaitSlots = 20 defaultExtraWaitSlots = 20

View File

@@ -91,6 +91,32 @@ func (c *stubConcurrencyCacheForTest) CleanupExpiredAccountSlots(_ context.Conte
return c.cleanupErr return c.cleanupErr
} }
func (c *stubConcurrencyCacheForTest) CleanupStaleProcessSlots(_ context.Context, _ string) error {
return c.cleanupErr
}
type trackingConcurrencyCache struct {
stubConcurrencyCacheForTest
cleanupPrefix string
}
func (c *trackingConcurrencyCache) CleanupStaleProcessSlots(_ context.Context, prefix string) error {
c.cleanupPrefix = prefix
return c.cleanupErr
}
func TestCleanupStaleProcessSlots_NilCache(t *testing.T) {
svc := &ConcurrencyService{cache: nil}
require.NoError(t, svc.CleanupStaleProcessSlots(context.Background()))
}
func TestCleanupStaleProcessSlots_DelegatesPrefix(t *testing.T) {
cache := &trackingConcurrencyCache{}
svc := NewConcurrencyService(cache)
require.NoError(t, svc.CleanupStaleProcessSlots(context.Background()))
require.Equal(t, RequestIDPrefix(), cache.cleanupPrefix)
}
func TestAcquireAccountSlot_Success(t *testing.T) { func TestAcquireAccountSlot_Success(t *testing.T) {
cache := &stubConcurrencyCacheForTest{acquireResult: true} cache := &stubConcurrencyCacheForTest{acquireResult: true}
svc := NewConcurrencyService(cache) svc := NewConcurrencyService(cache)

View File

@@ -182,6 +182,13 @@ const (
// SettingKeyRectifierSettings stores JSON config for rectifier settings (thinking signature + budget). // SettingKeyRectifierSettings stores JSON config for rectifier settings (thinking signature + budget).
SettingKeyRectifierSettings = "rectifier_settings" SettingKeyRectifierSettings = "rectifier_settings"
// =========================
// Beta Policy Settings
// =========================
// SettingKeyBetaPolicySettings stores JSON config for beta policy rules.
SettingKeyBetaPolicySettings = "beta_policy_settings"
// ========================= // =========================
// Sora S3 存储配置 // Sora S3 存储配置
// ========================= // =========================

View File

@@ -86,10 +86,10 @@ func TestStripBetaTokens(t *testing.T) {
want: "oauth-2025-04-20,interleaved-thinking-2025-05-14", want: "oauth-2025-04-20,interleaved-thinking-2025-05-14",
}, },
{ {
name: "DroppedBetas removes fast-mode only", name: "DroppedBetas is empty (filtering moved to configurable beta policy)",
header: "oauth-2025-04-20,context-1m-2025-08-07,fast-mode-2026-02-01,interleaved-thinking-2025-05-14", header: "oauth-2025-04-20,context-1m-2025-08-07,fast-mode-2026-02-01,interleaved-thinking-2025-05-14",
tokens: claude.DroppedBetas, tokens: claude.DroppedBetas,
want: "oauth-2025-04-20,context-1m-2025-08-07,interleaved-thinking-2025-05-14", want: "oauth-2025-04-20,context-1m-2025-08-07,fast-mode-2026-02-01,interleaved-thinking-2025-05-14",
}, },
} }
@@ -114,25 +114,23 @@ func TestMergeAnthropicBetaDropping_Context1M(t *testing.T) {
func TestMergeAnthropicBetaDropping_DroppedBetas(t *testing.T) { func TestMergeAnthropicBetaDropping_DroppedBetas(t *testing.T) {
required := []string{"oauth-2025-04-20", "interleaved-thinking-2025-05-14"} required := []string{"oauth-2025-04-20", "interleaved-thinking-2025-05-14"}
incoming := "context-1m-2025-08-07,fast-mode-2026-02-01,foo-beta,oauth-2025-04-20" incoming := "context-1m-2025-08-07,fast-mode-2026-02-01,foo-beta,oauth-2025-04-20"
// DroppedBetas is now empty — filtering moved to configurable beta policy.
// Without a policy filter set, nothing gets dropped from the static set.
drop := droppedBetaSet() drop := droppedBetaSet()
got := mergeAnthropicBetaDropping(required, incoming, drop) got := mergeAnthropicBetaDropping(required, incoming, drop)
require.Equal(t, "oauth-2025-04-20,interleaved-thinking-2025-05-14,context-1m-2025-08-07,foo-beta", got) require.Equal(t, "oauth-2025-04-20,interleaved-thinking-2025-05-14,context-1m-2025-08-07,fast-mode-2026-02-01,foo-beta", got)
require.Contains(t, got, "context-1m-2025-08-07") require.Contains(t, got, "context-1m-2025-08-07")
require.NotContains(t, got, "fast-mode-2026-02-01") require.Contains(t, got, "fast-mode-2026-02-01")
} }
func TestDroppedBetaSet(t *testing.T) { func TestDroppedBetaSet(t *testing.T) {
// Base set contains DroppedBetas // Base set contains DroppedBetas (now empty — filtering moved to configurable beta policy)
base := droppedBetaSet() base := droppedBetaSet()
require.NotContains(t, base, claude.BetaContext1M)
require.Contains(t, base, claude.BetaFastMode)
require.Len(t, base, len(claude.DroppedBetas)) require.Len(t, base, len(claude.DroppedBetas))
// With extra tokens // With extra tokens
extended := droppedBetaSet(claude.BetaClaudeCode) extended := droppedBetaSet(claude.BetaClaudeCode)
require.NotContains(t, extended, claude.BetaContext1M)
require.Contains(t, extended, claude.BetaFastMode)
require.Contains(t, extended, claude.BetaClaudeCode) require.Contains(t, extended, claude.BetaClaudeCode)
require.Len(t, extended, len(claude.DroppedBetas)+1) require.Len(t, extended, len(claude.DroppedBetas)+1)
} }

View File

@@ -1986,6 +1986,10 @@ func (m *mockConcurrencyCache) CleanupExpiredAccountSlots(ctx context.Context, a
return nil return nil
} }
func (m *mockConcurrencyCache) CleanupStaleProcessSlots(ctx context.Context, activeRequestPrefix string) error {
return nil
}
func (m *mockConcurrencyCache) GetUsersLoadBatch(ctx context.Context, users []UserWithConcurrency) (map[int64]*UserLoadInfo, error) { func (m *mockConcurrencyCache) GetUsersLoadBatch(ctx context.Context, users []UserWithConcurrency) (map[int64]*UserLoadInfo, error) {
result := make(map[int64]*UserLoadInfo, len(users)) result := make(map[int64]*UserLoadInfo, len(users))
for _, user := range users { for _, user := range users {

View File

@@ -3948,6 +3948,20 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
return s.forwardAnthropicAPIKeyPassthrough(ctx, c, account, passthroughBody, passthroughModel, parsed.Stream, startTime) return s.forwardAnthropicAPIKeyPassthrough(ctx, c, account, passthroughBody, passthroughModel, parsed.Stream, startTime)
} }
// Beta policy: evaluate once; block check + cache filter set for buildUpstreamRequest.
// Always overwrite the cache to prevent stale values from a previous retry with a different account.
if account.Platform == PlatformAnthropic && c != nil {
policy := s.evaluateBetaPolicy(ctx, c.GetHeader("anthropic-beta"), account)
if policy.blockErr != nil {
return nil, policy.blockErr
}
filterSet := policy.filterSet
if filterSet == nil {
filterSet = map[string]struct{}{}
}
c.Set(betaPolicyFilterSetKey, filterSet)
}
body := parsed.Body body := parsed.Body
reqModel := parsed.Model reqModel := parsed.Model
reqStream := parsed.Stream reqStream := parsed.Stream
@@ -5133,6 +5147,11 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
applyClaudeOAuthHeaderDefaults(req, reqStream) applyClaudeOAuthHeaderDefaults(req, reqStream)
} }
// Build effective drop set: merge static defaults with dynamic beta policy filter rules
policyFilterSet := s.getBetaPolicyFilterSet(ctx, c, account)
effectiveDropSet := mergeDropSets(policyFilterSet)
effectiveDropWithClaudeCodeSet := mergeDropSets(policyFilterSet, claude.BetaClaudeCode)
// 处理 anthropic-beta headerOAuth 账号需要包含 oauth beta // 处理 anthropic-beta headerOAuth 账号需要包含 oauth beta
if tokenType == "oauth" { if tokenType == "oauth" {
if mimicClaudeCode { if mimicClaudeCode {
@@ -5146,17 +5165,22 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
// messages requests typically use only oauth + interleaved-thinking. // messages requests typically use only oauth + interleaved-thinking.
// Also drop claude-code beta if a downstream client added it. // Also drop claude-code beta if a downstream client added it.
requiredBetas := []string{claude.BetaOAuth, claude.BetaInterleavedThinking} requiredBetas := []string{claude.BetaOAuth, claude.BetaInterleavedThinking}
req.Header.Set("anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, droppedBetasWithClaudeCodeSet)) req.Header.Set("anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, effectiveDropWithClaudeCodeSet))
} else { } else {
// Claude Code 客户端:尽量透传原始 header仅补齐 oauth beta // Claude Code 客户端:尽量透传原始 header仅补齐 oauth beta
clientBetaHeader := req.Header.Get("anthropic-beta") clientBetaHeader := req.Header.Get("anthropic-beta")
req.Header.Set("anthropic-beta", stripBetaTokensWithSet(s.getBetaHeader(modelID, clientBetaHeader), defaultDroppedBetasSet)) req.Header.Set("anthropic-beta", stripBetaTokensWithSet(s.getBetaHeader(modelID, clientBetaHeader), effectiveDropSet))
} }
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" { } else {
// API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭) // API-key accounts: apply beta policy filter to strip controlled tokens
if requestNeedsBetaFeatures(body) { if existingBeta := req.Header.Get("anthropic-beta"); existingBeta != "" {
if beta := defaultAPIKeyBetaHeader(body); beta != "" { req.Header.Set("anthropic-beta", stripBetaTokensWithSet(existingBeta, effectiveDropSet))
req.Header.Set("anthropic-beta", beta) } else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey {
// API-key仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭)
if requestNeedsBetaFeatures(body) {
if beta := defaultAPIKeyBetaHeader(body); beta != "" {
req.Header.Set("anthropic-beta", beta)
}
} }
} }
} }
@@ -5334,6 +5358,104 @@ func stripBetaTokensWithSet(header string, drop map[string]struct{}) string {
return strings.Join(out, ",") return strings.Join(out, ",")
} }
// BetaBlockedError indicates a request was blocked by a beta policy rule.
type BetaBlockedError struct {
Message string
}
func (e *BetaBlockedError) Error() string { return e.Message }
// betaPolicyResult holds the evaluated result of beta policy rules for a single request.
type betaPolicyResult struct {
blockErr *BetaBlockedError // non-nil if a block rule matched
filterSet map[string]struct{} // tokens to filter (may be nil)
}
// evaluateBetaPolicy loads settings once and evaluates all rules against the given request.
func (s *GatewayService) evaluateBetaPolicy(ctx context.Context, betaHeader string, account *Account) betaPolicyResult {
if s.settingService == nil {
return betaPolicyResult{}
}
settings, err := s.settingService.GetBetaPolicySettings(ctx)
if err != nil || settings == nil {
return betaPolicyResult{}
}
isOAuth := account.IsOAuth()
var result betaPolicyResult
for _, rule := range settings.Rules {
if !betaPolicyScopeMatches(rule.Scope, isOAuth) {
continue
}
switch rule.Action {
case BetaPolicyActionBlock:
if result.blockErr == nil && betaHeader != "" && containsBetaToken(betaHeader, rule.BetaToken) {
msg := rule.ErrorMessage
if msg == "" {
msg = "beta feature " + rule.BetaToken + " is not allowed"
}
result.blockErr = &BetaBlockedError{Message: msg}
}
case BetaPolicyActionFilter:
if result.filterSet == nil {
result.filterSet = make(map[string]struct{})
}
result.filterSet[rule.BetaToken] = struct{}{}
}
}
return result
}
// mergeDropSets merges the static defaultDroppedBetasSet with dynamic policy filter tokens.
// Returns defaultDroppedBetasSet directly when policySet is empty (zero allocation).
func mergeDropSets(policySet map[string]struct{}, extra ...string) map[string]struct{} {
if len(policySet) == 0 && len(extra) == 0 {
return defaultDroppedBetasSet
}
m := make(map[string]struct{}, len(defaultDroppedBetasSet)+len(policySet)+len(extra))
for t := range defaultDroppedBetasSet {
m[t] = struct{}{}
}
for t := range policySet {
m[t] = struct{}{}
}
for _, t := range extra {
m[t] = struct{}{}
}
return m
}
// betaPolicyFilterSetKey is the gin.Context key for caching the policy filter set within a request.
const betaPolicyFilterSetKey = "betaPolicyFilterSet"
// getBetaPolicyFilterSet returns the beta policy filter set, using the gin context cache if available.
// In the /v1/messages path, Forward() evaluates the policy first and caches the result;
// buildUpstreamRequest reuses it (zero extra DB calls). In the count_tokens path, this
// evaluates on demand (one DB call).
func (s *GatewayService) getBetaPolicyFilterSet(ctx context.Context, c *gin.Context, account *Account) map[string]struct{} {
if c != nil {
if v, ok := c.Get(betaPolicyFilterSetKey); ok {
if fs, ok := v.(map[string]struct{}); ok {
return fs
}
}
}
return s.evaluateBetaPolicy(ctx, "", account).filterSet
}
// betaPolicyScopeMatches checks whether a rule's scope matches the current account type.
func betaPolicyScopeMatches(scope string, isOAuth bool) bool {
switch scope {
case BetaPolicyScopeAll:
return true
case BetaPolicyScopeOAuth:
return isOAuth
case BetaPolicyScopeAPIKey:
return !isOAuth
default:
return true // unknown scope → match all (fail-open)
}
}
// droppedBetaSet returns claude.DroppedBetas as a set, with optional extra tokens. // droppedBetaSet returns claude.DroppedBetas as a set, with optional extra tokens.
func droppedBetaSet(extra ...string) map[string]struct{} { func droppedBetaSet(extra ...string) map[string]struct{} {
m := make(map[string]struct{}, len(defaultDroppedBetasSet)+len(extra)) m := make(map[string]struct{}, len(defaultDroppedBetasSet)+len(extra))
@@ -5370,10 +5492,7 @@ func buildBetaTokenSet(tokens []string) map[string]struct{} {
return m return m
} }
var ( var defaultDroppedBetasSet = buildBetaTokenSet(claude.DroppedBetas)
defaultDroppedBetasSet = buildBetaTokenSet(claude.DroppedBetas)
droppedBetasWithClaudeCodeSet = droppedBetaSet(claude.BetaClaudeCode)
)
// applyClaudeCodeMimicHeaders forces "Claude Code-like" request headers. // applyClaudeCodeMimicHeaders forces "Claude Code-like" request headers.
// This mirrors opencode-anthropic-auth behavior: do not trust downstream // This mirrors opencode-anthropic-auth behavior: do not trust downstream
@@ -7311,6 +7430,9 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
applyClaudeOAuthHeaderDefaults(req, false) applyClaudeOAuthHeaderDefaults(req, false)
} }
// Build effective drop set for count_tokens: merge static defaults with dynamic beta policy filter rules
ctEffectiveDropSet := mergeDropSets(s.getBetaPolicyFilterSet(ctx, c, account))
// OAuth 账号:处理 anthropic-beta header // OAuth 账号:处理 anthropic-beta header
if tokenType == "oauth" { if tokenType == "oauth" {
if mimicClaudeCode { if mimicClaudeCode {
@@ -7318,8 +7440,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
incomingBeta := req.Header.Get("anthropic-beta") incomingBeta := req.Header.Get("anthropic-beta")
requiredBetas := []string{claude.BetaClaudeCode, claude.BetaOAuth, claude.BetaInterleavedThinking, claude.BetaTokenCounting} requiredBetas := []string{claude.BetaClaudeCode, claude.BetaOAuth, claude.BetaInterleavedThinking, claude.BetaTokenCounting}
drop := droppedBetaSet() req.Header.Set("anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, ctEffectiveDropSet))
req.Header.Set("anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, drop))
} else { } else {
clientBetaHeader := req.Header.Get("anthropic-beta") clientBetaHeader := req.Header.Get("anthropic-beta")
if clientBetaHeader == "" { if clientBetaHeader == "" {
@@ -7329,14 +7450,19 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
if !strings.Contains(beta, claude.BetaTokenCounting) { if !strings.Contains(beta, claude.BetaTokenCounting) {
beta = beta + "," + claude.BetaTokenCounting beta = beta + "," + claude.BetaTokenCounting
} }
req.Header.Set("anthropic-beta", stripBetaTokensWithSet(beta, defaultDroppedBetasSet)) req.Header.Set("anthropic-beta", stripBetaTokensWithSet(beta, ctEffectiveDropSet))
} }
} }
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" { } else {
// API-key:与 messages 同步的按需 beta 注入(默认关闭) // API-key accounts: apply beta policy filter to strip controlled tokens
if requestNeedsBetaFeatures(body) { if existingBeta := req.Header.Get("anthropic-beta"); existingBeta != "" {
if beta := defaultAPIKeyBetaHeader(body); beta != "" { req.Header.Set("anthropic-beta", stripBetaTokensWithSet(existingBeta, ctEffectiveDropSet))
req.Header.Set("anthropic-beta", beta) } else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey {
// API-key与 messages 同步的按需 beta 注入(默认关闭)
if requestNeedsBetaFeatures(body) {
if beta := defaultAPIKeyBetaHeader(body); beta != "" {
req.Header.Set("anthropic-beta", beta)
}
} }
} }
} }

View File

@@ -2,6 +2,7 @@ package service
import ( import (
"bufio" "bufio"
"bytes"
"context" "context"
"encoding/json" "encoding/json"
"errors" "errors"
@@ -140,12 +141,13 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
// 8. Handle error response with failover // 8. Handle error response with failover
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
if s.shouldFailoverUpstreamError(resp.StatusCode) { respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) _ = resp.Body.Close()
_ = resp.Body.Close() resp.Body = io.NopCloser(bytes.NewReader(respBody))
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) {
upstreamDetail := "" upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
@@ -167,7 +169,11 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
if s.rateLimitService != nil { if s.rateLimitService != nil {
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
} }
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,
RetryableOnSameAccount: account.IsPoolMode() && isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody),
}
} }
// Non-failover error: return Anthropic-formatted error to client // Non-failover error: return Anthropic-formatted error to client
return s.handleAnthropicErrorResponse(resp, c, account) return s.handleAnthropicErrorResponse(resp, c, account)
@@ -279,7 +285,11 @@ func (s *OpenAIGatewayService) handleAnthropicBufferedStreamingResponse(
requestID := resp.Header.Get("x-request-id") requestID := resp.Header.Get("x-request-id")
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024) maxLineSize := defaultMaxLineSize
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
var finalResponse *apicompat.ResponsesResponse var finalResponse *apicompat.ResponsesResponse
var usage OpenAIUsage var usage OpenAIUsage
@@ -378,7 +388,11 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
firstChunk := true firstChunk := true
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024) maxLineSize := defaultMaxLineSize
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
// resultWithUsage builds the final result snapshot. // resultWithUsage builds the final result snapshot.
resultWithUsage := func() *OpenAIForwardResult { resultWithUsage := func() *OpenAIForwardResult {

View File

@@ -911,6 +911,36 @@ func isOpenAIInstructionsRequiredError(upstreamStatusCode int, upstreamMsg strin
return false return false
} }
func isOpenAITransientProcessingError(upstreamStatusCode int, upstreamMsg string, upstreamBody []byte) bool {
if upstreamStatusCode != http.StatusBadRequest {
return false
}
match := func(text string) bool {
lower := strings.ToLower(strings.TrimSpace(text))
if lower == "" {
return false
}
if strings.Contains(lower, "an error occurred while processing your request") {
return true
}
return strings.Contains(lower, "you can retry your request") &&
strings.Contains(lower, "help.openai.com") &&
strings.Contains(lower, "request id")
}
if match(upstreamMsg) {
return true
}
if len(upstreamBody) == 0 {
return false
}
if match(gjson.GetBytes(upstreamBody, "error.message").String()) {
return true
}
return match(string(upstreamBody))
}
// ExtractSessionID extracts the raw session ID from headers or body without hashing. // ExtractSessionID extracts the raw session ID from headers or body without hashing.
// Used by ForwardAsAnthropic to pass as prompt_cache_key for upstream cache. // Used by ForwardAsAnthropic to pass as prompt_cache_key for upstream cache.
func (s *OpenAIGatewayService) ExtractSessionID(c *gin.Context, body []byte) string { func (s *OpenAIGatewayService) ExtractSessionID(c *gin.Context, body []byte) string {
@@ -1518,6 +1548,13 @@ func (s *OpenAIGatewayService) shouldFailoverUpstreamError(statusCode int) bool
} }
} }
func (s *OpenAIGatewayService) shouldFailoverOpenAIUpstreamResponse(statusCode int, upstreamMsg string, upstreamBody []byte) bool {
if s.shouldFailoverUpstreamError(statusCode) {
return true
}
return isOpenAITransientProcessingError(statusCode, upstreamMsg, upstreamBody)
}
func (s *OpenAIGatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account) { func (s *OpenAIGatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account) {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
@@ -2016,13 +2053,13 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
// Handle error response // Handle error response
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
if s.shouldFailoverUpstreamError(resp.StatusCode) { respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) _ = resp.Body.Close()
_ = resp.Body.Close() resp.Body = io.NopCloser(bytes.NewReader(respBody))
resp.Body = io.NopCloser(bytes.NewReader(respBody))
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) {
upstreamDetail := "" upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
@@ -2046,7 +2083,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
return nil, &UpstreamFailoverError{ return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode, StatusCode: resp.StatusCode,
ResponseBody: respBody, ResponseBody: respBody,
RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode), RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)),
} }
} }
return s.handleErrorResponse(ctx, resp, c, account, body) return s.handleErrorResponse(ctx, resp, c, account, body)

View File

@@ -211,6 +211,26 @@ func TestLogOpenAIInstructionsRequiredDebug_NonTargetErrorSkipped(t *testing.T)
require.False(t, logSink.ContainsMessage("OpenAI 上游返回 Instructions are required已记录请求详情用于排查")) require.False(t, logSink.ContainsMessage("OpenAI 上游返回 Instructions are required已记录请求详情用于排查"))
} }
func TestIsOpenAITransientProcessingError(t *testing.T) {
require.True(t, isOpenAITransientProcessingError(
http.StatusBadRequest,
"An error occurred while processing your request.",
nil,
))
require.True(t, isOpenAITransientProcessingError(
http.StatusBadRequest,
"",
[]byte(`{"error":{"message":"An error occurred while processing your request. You can retry your request, or contact us through our help center at help.openai.com if the error persists. Please include the request ID req_123 in your message."}}`),
))
require.False(t, isOpenAITransientProcessingError(
http.StatusBadRequest,
"Missing required parameter: 'instructions'",
[]byte(`{"error":{"message":"Missing required parameter: 'instructions'"}}`),
))
}
func TestOpenAIGatewayService_Forward_LogsInstructionsRequiredDetails(t *testing.T) { func TestOpenAIGatewayService_Forward_LogsInstructionsRequiredDetails(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
logSink, restore := captureStructuredLog(t) logSink, restore := captureStructuredLog(t)
@@ -264,3 +284,51 @@ func TestOpenAIGatewayService_Forward_LogsInstructionsRequiredDetails(t *testing
require.True(t, logSink.ContainsField("request_body_size")) require.True(t, logSink.ContainsField("request_body_size"))
require.False(t, logSink.ContainsField("request_body_preview")) require.False(t, logSink.ContainsField("request_body_preview"))
} }
func TestOpenAIGatewayService_Forward_TransientProcessingErrorTriggersFailover(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil))
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0")
c.Request.Header.Set("Content-Type", "application/json")
upstream := &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusBadRequest,
Header: http.Header{
"Content-Type": []string{"application/json"},
"x-request-id": []string{"rid-processing-400"},
},
Body: io.NopCloser(strings.NewReader(`{"error":{"message":"An error occurred while processing your request. You can retry your request, or contact us through our help center at help.openai.com if the error persists. Please include the request ID req_123 in your message.","type":"invalid_request_error"}}`)),
},
}
svc := &OpenAIGatewayService{
cfg: &config.Config{
Gateway: config.GatewayConfig{ForceCodexCLI: false},
},
httpUpstream: upstream,
}
account := &Account{
ID: 1001,
Name: "codex max套餐",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Concurrency: 1,
Credentials: map[string]any{"api_key": "sk-test"},
Status: StatusActive,
Schedulable: true,
RateMultiplier: f64p(1),
}
body := []byte(`{"model":"gpt-5.1-codex","stream":false,"input":[{"type":"text","text":"hello"}]}`)
_, err := svc.Forward(context.Background(), c, account, body)
require.Error(t, err)
var failoverErr *UpstreamFailoverError
require.ErrorAs(t, err, &failoverErr)
require.Equal(t, http.StatusBadRequest, failoverErr.StatusCode)
require.Contains(t, string(failoverErr.ResponseBody), "An error occurred while processing your request")
require.False(t, c.Writer.Written(), "service 层应返回 failover 错误给上层换号,而不是直接向客户端写响应")
}

View File

@@ -130,6 +130,7 @@ type OpenAITokenInfo struct {
ChatGPTAccountID string `json:"chatgpt_account_id,omitempty"` ChatGPTAccountID string `json:"chatgpt_account_id,omitempty"`
ChatGPTUserID string `json:"chatgpt_user_id,omitempty"` ChatGPTUserID string `json:"chatgpt_user_id,omitempty"`
OrganizationID string `json:"organization_id,omitempty"` OrganizationID string `json:"organization_id,omitempty"`
PlanType string `json:"plan_type,omitempty"`
} }
// ExchangeCode exchanges authorization code for tokens // ExchangeCode exchanges authorization code for tokens
@@ -202,6 +203,7 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch
tokenInfo.ChatGPTAccountID = userInfo.ChatGPTAccountID tokenInfo.ChatGPTAccountID = userInfo.ChatGPTAccountID
tokenInfo.ChatGPTUserID = userInfo.ChatGPTUserID tokenInfo.ChatGPTUserID = userInfo.ChatGPTUserID
tokenInfo.OrganizationID = userInfo.OrganizationID tokenInfo.OrganizationID = userInfo.OrganizationID
tokenInfo.PlanType = userInfo.PlanType
} }
return tokenInfo, nil return tokenInfo, nil
@@ -246,6 +248,7 @@ func (s *OpenAIOAuthService) RefreshTokenWithClientID(ctx context.Context, refre
tokenInfo.ChatGPTAccountID = userInfo.ChatGPTAccountID tokenInfo.ChatGPTAccountID = userInfo.ChatGPTAccountID
tokenInfo.ChatGPTUserID = userInfo.ChatGPTUserID tokenInfo.ChatGPTUserID = userInfo.ChatGPTUserID
tokenInfo.OrganizationID = userInfo.OrganizationID tokenInfo.OrganizationID = userInfo.OrganizationID
tokenInfo.PlanType = userInfo.PlanType
} }
return tokenInfo, nil return tokenInfo, nil
@@ -510,6 +513,9 @@ func (s *OpenAIOAuthService) BuildAccountCredentials(tokenInfo *OpenAITokenInfo)
if tokenInfo.OrganizationID != "" { if tokenInfo.OrganizationID != "" {
creds["organization_id"] = tokenInfo.OrganizationID creds["organization_id"] = tokenInfo.OrganizationID
} }
if tokenInfo.PlanType != "" {
creds["plan_type"] = tokenInfo.PlanType
}
if strings.TrimSpace(tokenInfo.ClientID) != "" { if strings.TrimSpace(tokenInfo.ClientID) != "" {
creds["client_id"] = strings.TrimSpace(tokenInfo.ClientID) creds["client_id"] = strings.TrimSpace(tokenInfo.ClientID)
} }

View File

@@ -1247,6 +1247,60 @@ func (s *SettingService) IsBudgetRectifierEnabled(ctx context.Context) bool {
return settings.Enabled && settings.ThinkingBudgetEnabled return settings.Enabled && settings.ThinkingBudgetEnabled
} }
// GetBetaPolicySettings 获取 Beta 策略配置
func (s *SettingService) GetBetaPolicySettings(ctx context.Context) (*BetaPolicySettings, error) {
value, err := s.settingRepo.GetValue(ctx, SettingKeyBetaPolicySettings)
if err != nil {
if errors.Is(err, ErrSettingNotFound) {
return DefaultBetaPolicySettings(), nil
}
return nil, fmt.Errorf("get beta policy settings: %w", err)
}
if value == "" {
return DefaultBetaPolicySettings(), nil
}
var settings BetaPolicySettings
if err := json.Unmarshal([]byte(value), &settings); err != nil {
return DefaultBetaPolicySettings(), nil
}
return &settings, nil
}
// SetBetaPolicySettings 设置 Beta 策略配置
func (s *SettingService) SetBetaPolicySettings(ctx context.Context, settings *BetaPolicySettings) error {
if settings == nil {
return fmt.Errorf("settings cannot be nil")
}
validActions := map[string]bool{
BetaPolicyActionPass: true, BetaPolicyActionFilter: true, BetaPolicyActionBlock: true,
}
validScopes := map[string]bool{
BetaPolicyScopeAll: true, BetaPolicyScopeOAuth: true, BetaPolicyScopeAPIKey: true,
}
for i, rule := range settings.Rules {
if rule.BetaToken == "" {
return fmt.Errorf("rule[%d]: beta_token cannot be empty", i)
}
if !validActions[rule.Action] {
return fmt.Errorf("rule[%d]: invalid action %q", i, rule.Action)
}
if !validScopes[rule.Scope] {
return fmt.Errorf("rule[%d]: invalid scope %q", i, rule.Scope)
}
}
data, err := json.Marshal(settings)
if err != nil {
return fmt.Errorf("marshal beta policy settings: %w", err)
}
return s.settingRepo.Set(ctx, SettingKeyBetaPolicySettings, string(data))
}
// SetStreamTimeoutSettings 设置流超时处理配置 // SetStreamTimeoutSettings 设置流超时处理配置
func (s *SettingService) SetStreamTimeoutSettings(ctx context.Context, settings *StreamTimeoutSettings) error { func (s *SettingService) SetStreamTimeoutSettings(ctx context.Context, settings *StreamTimeoutSettings) error {
if settings == nil { if settings == nil {

View File

@@ -191,3 +191,45 @@ func DefaultRectifierSettings() *RectifierSettings {
ThinkingBudgetEnabled: true, ThinkingBudgetEnabled: true,
} }
} }
// Beta Policy 策略常量
const (
BetaPolicyActionPass = "pass" // 透传,不做任何处理
BetaPolicyActionFilter = "filter" // 过滤,从 beta header 中移除该 token
BetaPolicyActionBlock = "block" // 拦截,直接返回错误
BetaPolicyScopeAll = "all" // 所有账号类型
BetaPolicyScopeOAuth = "oauth" // 仅 OAuth 账号
BetaPolicyScopeAPIKey = "apikey" // 仅 API Key 账号
)
// BetaPolicyRule 单条 Beta 策略规则
type BetaPolicyRule struct {
BetaToken string `json:"beta_token"` // beta token 值
Action string `json:"action"` // "pass" | "filter" | "block"
Scope string `json:"scope"` // "all" | "oauth" | "apikey"
ErrorMessage string `json:"error_message,omitempty"` // 自定义错误消息 (action=block 时生效)
}
// BetaPolicySettings Beta 策略配置
type BetaPolicySettings struct {
Rules []BetaPolicyRule `json:"rules"`
}
// DefaultBetaPolicySettings 返回默认的 Beta 策略配置
func DefaultBetaPolicySettings() *BetaPolicySettings {
return &BetaPolicySettings{
Rules: []BetaPolicyRule{
{
BetaToken: "fast-mode-2026-02-01",
Action: BetaPolicyActionFilter,
Scope: BetaPolicyScopeAll,
},
{
BetaToken: "context-1m-2025-08-07",
Action: BetaPolicyActionFilter,
Scope: BetaPolicyScopeAll,
},
},
}
}

View File

@@ -105,6 +105,9 @@ func ProvideDeferredService(accountRepo AccountRepository, timingWheel *TimingWh
// ProvideConcurrencyService creates ConcurrencyService and starts slot cleanup worker. // ProvideConcurrencyService creates ConcurrencyService and starts slot cleanup worker.
func ProvideConcurrencyService(cache ConcurrencyCache, accountRepo AccountRepository, cfg *config.Config) *ConcurrencyService { func ProvideConcurrencyService(cache ConcurrencyCache, accountRepo AccountRepository, cfg *config.Config) *ConcurrencyService {
svc := NewConcurrencyService(cache) svc := NewConcurrencyService(cache)
if err := svc.CleanupStaleProcessSlots(context.Background()); err != nil {
logger.LegacyPrintf("service.concurrency", "Warning: startup cleanup stale process slots failed: %v", err)
}
if cfg != nil { if cfg != nil {
svc.StartSlotCleanupWorker(accountRepo, cfg.Gateway.Scheduling.SlotCleanupInterval) svc.StartSlotCleanupWorker(accountRepo, cfg.Gateway.Scheduling.SlotCleanupInterval)
} }

View File

@@ -76,6 +76,9 @@ func (c StubConcurrencyCache) GetAccountConcurrencyBatch(_ context.Context, acco
func (c StubConcurrencyCache) CleanupExpiredAccountSlots(_ context.Context, _ int64) error { func (c StubConcurrencyCache) CleanupExpiredAccountSlots(_ context.Context, _ int64) error {
return nil return nil
} }
func (c StubConcurrencyCache) CleanupStaleProcessSlots(_ context.Context, _ string) error {
return nil
}
// ============================================================ // ============================================================
// StubGatewayCache — service.GatewayCache 的空实现 // StubGatewayCache — service.GatewayCache 的空实现

View File

@@ -8,7 +8,7 @@
# - Creates necessary data directories # - Creates necessary data directories
# #
# After running this script, you can start services with: # After running this script, you can start services with:
# docker-compose -f docker-compose.local.yml up -d # docker-compose up -d
# ============================================================================= # =============================================================================
set -e set -e
@@ -65,7 +65,7 @@ main() {
fi fi
# Check if deployment already exists # Check if deployment already exists
if [ -f "docker-compose.local.yml" ] && [ -f ".env" ]; then if [ -f "docker-compose.yml" ] && [ -f ".env" ]; then
print_warning "Deployment files already exist in current directory." print_warning "Deployment files already exist in current directory."
read -p "Overwrite existing files? (y/N): " -r read -p "Overwrite existing files? (y/N): " -r
echo echo
@@ -75,17 +75,17 @@ main() {
fi fi
fi fi
# Download docker-compose.local.yml # Download docker-compose.local.yml and save as docker-compose.yml
print_info "Downloading docker-compose.local.yml..." print_info "Downloading docker-compose.yml..."
if command_exists curl; then if command_exists curl; then
curl -sSL "${GITHUB_RAW_URL}/docker-compose.local.yml" -o docker-compose.local.yml curl -sSL "${GITHUB_RAW_URL}/docker-compose.local.yml" -o docker-compose.yml
elif command_exists wget; then elif command_exists wget; then
wget -q "${GITHUB_RAW_URL}/docker-compose.local.yml" -O docker-compose.local.yml wget -q "${GITHUB_RAW_URL}/docker-compose.local.yml" -O docker-compose.yml
else else
print_error "Neither curl nor wget is installed. Please install one of them." print_error "Neither curl nor wget is installed. Please install one of them."
exit 1 exit 1
fi fi
print_success "Downloaded docker-compose.local.yml" print_success "Downloaded docker-compose.yml"
# Download .env.example # Download .env.example
print_info "Downloading .env.example..." print_info "Downloading .env.example..."
@@ -144,7 +144,7 @@ main() {
print_warning "Please keep them secure and do not share publicly!" print_warning "Please keep them secure and do not share publicly!"
echo "" echo ""
echo "Directory structure:" echo "Directory structure:"
echo " docker-compose.local.yml - Docker Compose configuration" echo " docker-compose.yml - Docker Compose configuration"
echo " .env - Environment variables (generated secrets)" echo " .env - Environment variables (generated secrets)"
echo " .env.example - Example template (for reference)" echo " .env.example - Example template (for reference)"
echo " data/ - Application data (will be created on first run)" echo " data/ - Application data (will be created on first run)"
@@ -154,10 +154,10 @@ main() {
echo "Next steps:" echo "Next steps:"
echo " 1. (Optional) Edit .env to customize configuration" echo " 1. (Optional) Edit .env to customize configuration"
echo " 2. Start services:" echo " 2. Start services:"
echo " docker-compose -f docker-compose.local.yml up -d" echo " docker-compose up -d"
echo "" echo ""
echo " 3. View logs:" echo " 3. View logs:"
echo " docker-compose -f docker-compose.local.yml logs -f sub2api" echo " docker-compose logs -f sub2api"
echo "" echo ""
echo " 4. Access Web UI:" echo " 4. Access Web UI:"
echo " http://localhost:8080" echo " http://localhost:8080"

View File

@@ -99,16 +99,17 @@ curl -X POST "${BASE}/api/v1/admin/users/123/balance" \
}' }'
``` ```
### 4) 购买页 URL Query 透传iframe / 新窗口一致) ### 4) 购买页 / 自定义页面 URL Query 透传iframe / 新窗口一致)
当 Sub2API 打开 `purchase_subscription_url` 时,会统一追加: 当 Sub2API 打开 `purchase_subscription_url` 或用户侧自定义页面 iframe URL 时,会统一追加:
- `user_id` - `user_id`
- `token` - `token`
- `theme``light` / `dark` - `theme``light` / `dark`
- `lang`(例如 `zh` / `en`,用于向嵌入页传递当前界面语言)
- `ui_mode`(固定 `embedded` - `ui_mode`(固定 `embedded`
示例: 示例:
```text ```text
https://pay.example.com/pay?user_id=123&token=<jwt>&theme=light&ui_mode=embedded https://pay.example.com/pay?user_id=123&token=<jwt>&theme=light&lang=zh&ui_mode=embedded
``` ```
### 5) 失败处理建议 ### 5) 失败处理建议
@@ -218,16 +219,17 @@ curl -X POST "${BASE}/api/v1/admin/users/123/balance" \
}' }'
``` ```
### 4) Purchase URL query forwarding (iframe and new tab) ### 4) Purchase / Custom Page URL query forwarding (iframe and new tab)
When Sub2API opens `purchase_subscription_url`, it appends: When Sub2API opens `purchase_subscription_url` or a user-facing custom page iframe URL, it appends:
- `user_id` - `user_id`
- `token` - `token`
- `theme` (`light` / `dark`) - `theme` (`light` / `dark`)
- `lang` (for example `zh` / `en`, used to pass the current UI language to the embedded page)
- `ui_mode` (fixed: `embedded`) - `ui_mode` (fixed: `embedded`)
Example: Example:
```text ```text
https://pay.example.com/pay?user_id=123&token=<jwt>&theme=light&ui_mode=embedded https://pay.example.com/pay?user_id=123&token=<jwt>&theme=light&lang=zh&ui_mode=embedded
``` ```
### 5) Failure handling recommendations ### 5) Failure handling recommendations

View File

@@ -581,6 +581,43 @@ export async function validateSoraSessionToken(
return data return data
} }
/**
* Batch operation result type
*/
export interface BatchOperationResult {
total: number
success: number
failed: number
errors?: Array<{ account_id: number; error: string }>
warnings?: Array<{ account_id: number; warning: string }>
}
/**
* Batch clear account errors
* @param accountIds - Array of account IDs
* @returns Batch operation result
*/
export async function batchClearError(accountIds: number[]): Promise<BatchOperationResult> {
const { data } = await apiClient.post<BatchOperationResult>('/admin/accounts/batch-clear-error', {
account_ids: accountIds
})
return data
}
/**
* Batch refresh account credentials
* @param accountIds - Array of account IDs
* @returns Batch operation result
*/
export async function batchRefresh(accountIds: number[]): Promise<BatchOperationResult> {
const { data } = await apiClient.post<BatchOperationResult>('/admin/accounts/batch-refresh', {
account_ids: accountIds,
}, {
timeout: 120000 // 120s timeout for large batch refreshes
})
return data
}
export const accountsAPI = { export const accountsAPI = {
list, list,
listWithEtag, listWithEtag,
@@ -615,7 +652,9 @@ export const accountsAPI = {
syncFromCrs, syncFromCrs,
exportData, exportData,
importData, importData,
getAntigravityDefaultModelMapping getAntigravityDefaultModelMapping,
batchClearError,
batchRefresh
} }
export default accountsAPI export default accountsAPI

View File

@@ -308,6 +308,49 @@ export async function updateRectifierSettings(
return data return data
} }
// ==================== Beta Policy Settings ====================
/**
* Beta policy rule interface
*/
export interface BetaPolicyRule {
beta_token: string
action: 'pass' | 'filter' | 'block'
scope: 'all' | 'oauth' | 'apikey'
error_message?: string
}
/**
* Beta policy settings interface
*/
export interface BetaPolicySettings {
rules: BetaPolicyRule[]
}
/**
* Get beta policy settings
* @returns Beta policy settings
*/
export async function getBetaPolicySettings(): Promise<BetaPolicySettings> {
const { data } = await apiClient.get<BetaPolicySettings>('/admin/settings/beta-policy')
return data
}
/**
* Update beta policy settings
* @param settings - Beta policy settings to update
* @returns Updated settings
*/
export async function updateBetaPolicySettings(
settings: BetaPolicySettings
): Promise<BetaPolicySettings> {
const { data } = await apiClient.put<BetaPolicySettings>(
'/admin/settings/beta-policy',
settings
)
return data
}
// ==================== Sora S3 Settings ==================== // ==================== Sora S3 Settings ====================
export interface SoraS3Settings { export interface SoraS3Settings {
@@ -456,6 +499,8 @@ export const settingsAPI = {
updateStreamTimeoutSettings, updateStreamTimeoutSettings,
getRectifierSettings, getRectifierSettings,
updateRectifierSettings, updateRectifierSettings,
getBetaPolicySettings,
updateBetaPolicySettings,
getSoraS3Settings, getSoraS3Settings,
updateSoraS3Settings, updateSoraS3Settings,
testSoraS3Connection, testSoraS3Connection,

View File

@@ -335,6 +335,28 @@ export async function resetPassword(request: ResetPasswordRequest): Promise<Rese
return data return data
} }
/**
* Complete LinuxDo OAuth registration by supplying an invitation code
* @param pendingOAuthToken - Short-lived JWT from the OAuth callback
* @param invitationCode - Invitation code entered by the user
* @returns Token pair on success
*/
export async function completeLinuxDoOAuthRegistration(
pendingOAuthToken: string,
invitationCode: string
): Promise<{ access_token: string; refresh_token: string; expires_in: number; token_type: string }> {
const { data } = await apiClient.post<{
access_token: string
refresh_token: string
expires_in: number
token_type: string
}>('/auth/oauth/linuxdo/complete-registration', {
pending_oauth_token: pendingOAuthToken,
invitation_code: invitationCode
})
return data
}
export const authAPI = { export const authAPI = {
login, login,
login2FA, login2FA,
@@ -357,7 +379,8 @@ export const authAPI = {
forgotPassword, forgotPassword,
resetPassword, resetPassword,
refreshToken, refreshToken,
revokeAllSessions revokeAllSessions,
completeLinuxDoOAuthRegistration
} }
export default authAPI export default authAPI

View File

@@ -73,11 +73,9 @@
</div> </div>
<!-- API Key 账号配额限制 --> <!-- API Key 账号配额限制 -->
<div v-if="showDailyQuota || showWeeklyQuota || showTotalQuota" class="flex items-center gap-1"> <QuotaBadge v-if="showDailyQuota" :used="account.quota_daily_used ?? 0" :limit="account.quota_daily_limit!" label="D" />
<QuotaBadge v-if="showDailyQuota" :used="account.quota_daily_used ?? 0" :limit="account.quota_daily_limit!" label="D" /> <QuotaBadge v-if="showWeeklyQuota" :used="account.quota_weekly_used ?? 0" :limit="account.quota_weekly_limit!" label="W" />
<QuotaBadge v-if="showWeeklyQuota" :used="account.quota_weekly_used ?? 0" :limit="account.quota_weekly_limit!" label="W" /> <QuotaBadge v-if="showTotalQuota" :used="account.quota_used ?? 0" :limit="account.quota_limit!" />
<QuotaBadge v-if="showTotalQuota" :used="account.quota_used ?? 0" :limit="account.quota_limit!" />
</div>
</div> </div>
</template> </template>

View File

@@ -20,6 +20,8 @@
</div> </div>
<div class="flex gap-2"> <div class="flex gap-2">
<button @click="$emit('delete')" class="btn btn-danger btn-sm">{{ t('admin.accounts.bulkActions.delete') }}</button> <button @click="$emit('delete')" class="btn btn-danger btn-sm">{{ t('admin.accounts.bulkActions.delete') }}</button>
<button @click="$emit('reset-status')" class="btn btn-secondary btn-sm">{{ t('admin.accounts.bulkActions.resetStatus') }}</button>
<button @click="$emit('refresh-token')" class="btn btn-secondary btn-sm">{{ t('admin.accounts.bulkActions.refreshToken') }}</button>
<button @click="$emit('toggle-schedulable', true)" class="btn btn-success btn-sm">{{ t('admin.accounts.bulkActions.enableScheduling') }}</button> <button @click="$emit('toggle-schedulable', true)" class="btn btn-success btn-sm">{{ t('admin.accounts.bulkActions.enableScheduling') }}</button>
<button @click="$emit('toggle-schedulable', false)" class="btn btn-warning btn-sm">{{ t('admin.accounts.bulkActions.disableScheduling') }}</button> <button @click="$emit('toggle-schedulable', false)" class="btn btn-warning btn-sm">{{ t('admin.accounts.bulkActions.disableScheduling') }}</button>
<button @click="$emit('edit')" class="btn btn-primary btn-sm">{{ t('admin.accounts.bulkActions.edit') }}</button> <button @click="$emit('edit')" class="btn btn-primary btn-sm">{{ t('admin.accounts.bulkActions.edit') }}</button>
@@ -29,5 +31,5 @@
<script setup lang="ts"> <script setup lang="ts">
import { useI18n } from 'vue-i18n' import { useI18n } from 'vue-i18n'
defineProps(['selectedIds']); defineEmits(['delete', 'edit', 'clear', 'select-page', 'toggle-schedulable']); const { t } = useI18n() defineProps(['selectedIds']); defineEmits(['delete', 'edit', 'clear', 'select-page', 'toggle-schedulable', 'reset-status', 'refresh-token']); const { t } = useI18n()
</script> </script>

View File

@@ -162,8 +162,7 @@ const load = async () => {
const loadGroups = async () => { const loadGroups = async () => {
try { try {
const groups = await adminAPI.groups.getAll() const groups = await adminAPI.groups.getAll()
// 过滤掉订阅类型分组(需通过订阅管理流程绑定) allGroups.value = groups
allGroups.value = groups.filter((g) => g.subscription_type !== 'subscription')
} catch (error) { } catch (error) {
console.error('Failed to load groups:', error) console.error('Failed to load groups:', error)
} }

View File

@@ -1,18 +1,40 @@
<script setup lang="ts"> <script setup lang="ts">
import { ref } from 'vue' import { ref, useTemplateRef, nextTick } from 'vue'
defineProps<{ defineProps<{
content?: string content?: string
}>() }>()
const show = ref(false) const show = ref(false)
const triggerRef = useTemplateRef<HTMLElement>('trigger')
const tooltipStyle = ref({ top: '0px', left: '0px' })
function onEnter() {
show.value = true
nextTick(updatePosition)
}
function onLeave() {
show.value = false
}
function updatePosition() {
const el = triggerRef.value
if (!el) return
const rect = el.getBoundingClientRect()
tooltipStyle.value = {
top: `${rect.top + window.scrollY}px`,
left: `${rect.left + rect.width / 2 + window.scrollX}px`,
}
}
</script> </script>
<template> <template>
<div <div
ref="trigger"
class="group relative ml-1 inline-flex items-center align-middle" class="group relative ml-1 inline-flex items-center align-middle"
@mouseenter="show = true" @mouseenter="onEnter"
@mouseleave="show = false" @mouseleave="onLeave"
> >
<!-- Trigger Icon --> <!-- Trigger Icon -->
<slot name="trigger"> <slot name="trigger">
@@ -31,14 +53,16 @@ const show = ref(false)
</svg> </svg>
</slot> </slot>
<!-- Popover Content --> <!-- Teleport to body to escape modal overflow clipping -->
<div <Teleport to="body">
v-show="show" <div
class="absolute bottom-full left-1/2 z-50 mb-2 w-64 -translate-x-1/2 rounded-lg bg-gray-900 p-3 text-xs leading-relaxed text-white shadow-xl ring-1 ring-white/10 opacity-0 transition-opacity duration-200 group-hover:opacity-100 dark:bg-gray-800" v-show="show"
> class="fixed z-[99999] w-64 -translate-x-1/2 -translate-y-full rounded-lg bg-gray-900 p-3 text-xs leading-relaxed text-white shadow-xl ring-1 ring-white/10 dark:bg-gray-800"
<slot>{{ content }}</slot> :style="{ top: `calc(${tooltipStyle.top} - 8px)`, left: tooltipStyle.left }"
<div class="absolute -bottom-1 left-1/2 h-2 w-2 -translate-x-1/2 rotate-45 bg-gray-900 dark:bg-gray-800"></div> >
</div> <slot>{{ content }}</slot>
<div class="absolute -bottom-1 left-1/2 h-2 w-2 -translate-x-1/2 rotate-45 bg-gray-900 dark:bg-gray-800"></div>
</div>
</Teleport>
</div> </div>
</template> </template>

View File

@@ -28,6 +28,10 @@
<Icon v-else name="key" size="xs" /> <Icon v-else name="key" size="xs" />
<span>{{ typeLabel }}</span> <span>{{ typeLabel }}</span>
</span> </span>
<!-- Plan type part (optional) -->
<span v-if="planLabel" :class="['inline-flex items-center gap-1 px-1.5 py-1 border-l border-white/20', typeClass]">
<span>{{ planLabel }}</span>
</span>
</div> </div>
</template> </template>
@@ -40,6 +44,7 @@ import Icon from '@/components/icons/Icon.vue'
interface Props { interface Props {
platform: AccountPlatform platform: AccountPlatform
type: AccountType type: AccountType
planType?: string
} }
const props = defineProps<Props>() const props = defineProps<Props>()
@@ -65,6 +70,24 @@ const typeLabel = computed(() => {
} }
}) })
const planLabel = computed(() => {
if (!props.planType) return ''
const lower = props.planType.toLowerCase()
switch (lower) {
case 'plus':
return 'Plus'
case 'team':
return 'Team'
case 'chatgptpro':
case 'pro':
return 'Pro'
case 'free':
return 'Free'
default:
return props.planType
}
})
const platformClass = computed(() => { const platformClass = computed(() => {
if (props.platform === 'anthropic') { if (props.platform === 'anthropic') {
return 'bg-orange-100 text-orange-700 dark:bg-orange-900/30 dark:text-orange-400' return 'bg-orange-100 text-orange-700 dark:bg-orange-900/30 dark:text-orange-400'

View File

@@ -434,7 +434,12 @@ export default {
callbackProcessing: 'Completing login, please wait...', callbackProcessing: 'Completing login, please wait...',
callbackHint: 'If you are not redirected automatically, go back to the login page and try again.', callbackHint: 'If you are not redirected automatically, go back to the login page and try again.',
callbackMissingToken: 'Missing login token, please try again.', callbackMissingToken: 'Missing login token, please try again.',
backToLogin: 'Back to Login' backToLogin: 'Back to Login',
invitationRequired: 'This Linux.do account is not yet registered. The site requires an invitation code — please enter one to complete registration.',
invalidPendingToken: 'The registration token has expired. Please sign in with Linux.do again.',
completeRegistration: 'Complete Registration',
completing: 'Completing registration…',
completeRegistrationFailed: 'Registration failed. Please check your invitation code and try again.'
}, },
oauth: { oauth: {
code: 'Code', code: 'Code',
@@ -1836,7 +1841,12 @@ export default {
edit: 'Bulk Edit', edit: 'Bulk Edit',
delete: 'Bulk Delete', delete: 'Bulk Delete',
enableScheduling: 'Enable Scheduling', enableScheduling: 'Enable Scheduling',
disableScheduling: 'Disable Scheduling' disableScheduling: 'Disable Scheduling',
resetStatus: 'Reset Status',
refreshToken: 'Refresh Token',
resetStatusSuccess: 'Successfully reset {count} account(s) status',
refreshTokenSuccess: 'Successfully refreshed {count} account(s) token',
partialSuccess: 'Partially completed: {success} succeeded, {failed} failed'
}, },
bulkEdit: { bulkEdit: {
title: 'Bulk Edit Accounts', title: 'Bulk Edit Accounts',
@@ -4033,6 +4043,23 @@ export default {
saved: 'Rectifier settings saved', saved: 'Rectifier settings saved',
saveFailed: 'Failed to save rectifier settings' saveFailed: 'Failed to save rectifier settings'
}, },
betaPolicy: {
title: 'Beta Policy',
description: 'How to handle Beta features when configuring the forwarding of Anthropic API requests. Applicable only to the /v1/messages endpoint.',
action: 'Action',
actionPass: 'Pass (transparent)',
actionFilter: 'Filter (remove)',
actionBlock: 'Block (reject)',
scope: 'Scope',
scopeAll: 'All accounts',
scopeOAuth: 'OAuth only',
scopeAPIKey: 'API Key only',
errorMessage: 'Error message',
errorMessagePlaceholder: 'Custom error message when blocked',
errorMessageHint: 'Leave empty for default message',
saved: 'Beta policy settings saved',
saveFailed: 'Failed to save beta policy settings'
},
saveSettings: 'Save Settings', saveSettings: 'Save Settings',
saving: 'Saving...', saving: 'Saving...',
settingsSaved: 'Settings saved successfully', settingsSaved: 'Settings saved successfully',

View File

@@ -433,7 +433,12 @@ export default {
callbackProcessing: '正在验证登录信息,请稍候...', callbackProcessing: '正在验证登录信息,请稍候...',
callbackHint: '如果页面未自动跳转,请返回登录页重试。', callbackHint: '如果页面未自动跳转,请返回登录页重试。',
callbackMissingToken: '登录信息缺失,请返回重试。', callbackMissingToken: '登录信息缺失,请返回重试。',
backToLogin: '返回登录' backToLogin: '返回登录',
invitationRequired: '该 Linux.do 账号尚未注册,站点已开启邀请码注册,请输入邀请码以完成注册。',
invalidPendingToken: '注册凭证已失效,请重新使用 Linux.do 登录。',
completeRegistration: '完成注册',
completing: '正在完成注册...',
completeRegistrationFailed: '注册失败,请检查邀请码后重试。'
}, },
oauth: { oauth: {
code: '授权码', code: '授权码',
@@ -1983,7 +1988,12 @@ export default {
edit: '批量编辑账号', edit: '批量编辑账号',
delete: '批量删除', delete: '批量删除',
enableScheduling: '批量启用调度', enableScheduling: '批量启用调度',
disableScheduling: '批量停止调度' disableScheduling: '批量停止调度',
resetStatus: '批量重置状态',
refreshToken: '批量刷新令牌',
resetStatusSuccess: '已成功重置 {count} 个账号状态',
refreshTokenSuccess: '已成功刷新 {count} 个账号令牌',
partialSuccess: '操作部分完成:{success} 成功,{failed} 失败'
}, },
bulkEdit: { bulkEdit: {
title: '批量编辑账号', title: '批量编辑账号',
@@ -4206,6 +4216,23 @@ export default {
saved: '整流器设置保存成功', saved: '整流器设置保存成功',
saveFailed: '保存整流器设置失败' saveFailed: '保存整流器设置失败'
}, },
betaPolicy: {
title: 'Beta 策略',
description: '配置转发 Anthropic API 请求时如何处理 Beta 特性。仅适用于 /v1/messages 接口。',
action: '处理方式',
actionPass: '透传(不处理)',
actionFilter: '过滤(移除)',
actionBlock: '拦截(拒绝请求)',
scope: '生效范围',
scopeAll: '全部账号',
scopeOAuth: '仅 OAuth 账号',
scopeAPIKey: '仅 API Key 账号',
errorMessage: '错误消息',
errorMessagePlaceholder: '拦截时返回的自定义错误消息',
errorMessageHint: '留空则使用默认错误消息',
saved: 'Beta 策略设置保存成功',
saveFailed: '保存 Beta 策略设置失败'
},
saveSettings: '保存设置', saveSettings: '保存设置',
saving: '保存中...', saving: '保存中...',
settingsSaved: '设置保存成功', settingsSaved: '设置保存成功',

View File

@@ -0,0 +1,67 @@
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
import { buildEmbeddedUrl, detectTheme } from '../embedded-url'
describe('embedded-url', () => {
const originalLocation = window.location
beforeEach(() => {
Object.defineProperty(window, 'location', {
value: {
origin: 'https://app.example.com',
href: 'https://app.example.com/user/purchase',
},
writable: true,
configurable: true,
})
})
afterEach(() => {
Object.defineProperty(window, 'location', {
value: originalLocation,
writable: true,
configurable: true,
})
document.documentElement.classList.remove('dark')
vi.restoreAllMocks()
})
it('adds embedded query parameters including locale and source context', () => {
const result = buildEmbeddedUrl(
'https://pay.example.com/checkout?plan=pro',
42,
'token-123',
'dark',
'zh-CN',
)
const url = new URL(result)
expect(url.searchParams.get('plan')).toBe('pro')
expect(url.searchParams.get('user_id')).toBe('42')
expect(url.searchParams.get('token')).toBe('token-123')
expect(url.searchParams.get('theme')).toBe('dark')
expect(url.searchParams.get('lang')).toBe('zh-CN')
expect(url.searchParams.get('ui_mode')).toBe('embedded')
expect(url.searchParams.get('src_host')).toBe('https://app.example.com')
expect(url.searchParams.get('src_url')).toBe('https://app.example.com/user/purchase')
})
it('omits optional params when they are empty', () => {
const result = buildEmbeddedUrl('https://pay.example.com/checkout', undefined, '', 'light')
const url = new URL(result)
expect(url.searchParams.get('theme')).toBe('light')
expect(url.searchParams.get('ui_mode')).toBe('embedded')
expect(url.searchParams.has('user_id')).toBe(false)
expect(url.searchParams.has('token')).toBe(false)
expect(url.searchParams.has('lang')).toBe(false)
})
it('returns original string for invalid url input', () => {
expect(buildEmbeddedUrl('not a url', 1, 'token')).toBe('not a url')
})
it('detects dark mode from document root class', () => {
document.documentElement.classList.add('dark')
expect(detectTheme()).toBe('dark')
})
})

View File

@@ -1,12 +1,13 @@
/** /**
* Shared URL builder for iframe-embedded pages. * Shared URL builder for iframe-embedded pages.
* Used by PurchaseSubscriptionView and CustomPageView to build consistent URLs * Used by PurchaseSubscriptionView and CustomPageView to build consistent URLs
* with user_id, token, theme, ui_mode, src_host, and src parameters. * with user_id, token, theme, lang, ui_mode, src_host, and src parameters.
*/ */
const EMBEDDED_USER_ID_QUERY_KEY = 'user_id' const EMBEDDED_USER_ID_QUERY_KEY = 'user_id'
const EMBEDDED_AUTH_TOKEN_QUERY_KEY = 'token' const EMBEDDED_AUTH_TOKEN_QUERY_KEY = 'token'
const EMBEDDED_THEME_QUERY_KEY = 'theme' const EMBEDDED_THEME_QUERY_KEY = 'theme'
const EMBEDDED_LANG_QUERY_KEY = 'lang'
const EMBEDDED_UI_MODE_QUERY_KEY = 'ui_mode' const EMBEDDED_UI_MODE_QUERY_KEY = 'ui_mode'
const EMBEDDED_UI_MODE_VALUE = 'embedded' const EMBEDDED_UI_MODE_VALUE = 'embedded'
const EMBEDDED_SRC_HOST_QUERY_KEY = 'src_host' const EMBEDDED_SRC_HOST_QUERY_KEY = 'src_host'
@@ -17,6 +18,7 @@ export function buildEmbeddedUrl(
userId?: number, userId?: number,
authToken?: string | null, authToken?: string | null,
theme: 'light' | 'dark' = 'light', theme: 'light' | 'dark' = 'light',
lang?: string,
): string { ): string {
if (!baseUrl) return baseUrl if (!baseUrl) return baseUrl
try { try {
@@ -28,6 +30,9 @@ export function buildEmbeddedUrl(
url.searchParams.set(EMBEDDED_AUTH_TOKEN_QUERY_KEY, authToken) url.searchParams.set(EMBEDDED_AUTH_TOKEN_QUERY_KEY, authToken)
} }
url.searchParams.set(EMBEDDED_THEME_QUERY_KEY, theme) url.searchParams.set(EMBEDDED_THEME_QUERY_KEY, theme)
if (lang) {
url.searchParams.set(EMBEDDED_LANG_QUERY_KEY, lang)
}
url.searchParams.set(EMBEDDED_UI_MODE_QUERY_KEY, EMBEDDED_UI_MODE_VALUE) url.searchParams.set(EMBEDDED_UI_MODE_QUERY_KEY, EMBEDDED_UI_MODE_VALUE)
// Source tracking: let the embedded page know where it's being loaded from // Source tracking: let the embedded page know where it's being loaded from
if (typeof window !== 'undefined') { if (typeof window !== 'undefined') {

View File

@@ -131,7 +131,7 @@
</div> </div>
</template> </template>
<template #table> <template #table>
<AccountBulkActionsBar :selected-ids="selIds" @delete="handleBulkDelete" @edit="showBulkEdit = true" @clear="clearSelection" @select-page="selectPage" @toggle-schedulable="handleBulkToggleSchedulable" /> <AccountBulkActionsBar :selected-ids="selIds" @delete="handleBulkDelete" @reset-status="handleBulkResetStatus" @refresh-token="handleBulkRefreshToken" @edit="showBulkEdit = true" @clear="clearSelection" @select-page="selectPage" @toggle-schedulable="handleBulkToggleSchedulable" />
<div ref="accountTableRef" class="flex min-h-0 flex-1 flex-col overflow-hidden"> <div ref="accountTableRef" class="flex min-h-0 flex-1 flex-col overflow-hidden">
<DataTable <DataTable
:columns="cols" :columns="cols"
@@ -171,7 +171,7 @@
<span v-else class="text-sm text-gray-400 dark:text-dark-500">-</span> <span v-else class="text-sm text-gray-400 dark:text-dark-500">-</span>
</template> </template>
<template #cell-platform_type="{ row }"> <template #cell-platform_type="{ row }">
<PlatformTypeBadge :platform="row.platform" :type="row.type" /> <PlatformTypeBadge :platform="row.platform" :type="row.type" :plan-type="row.credentials?.plan_type" />
</template> </template>
<template #cell-capacity="{ row }"> <template #cell-capacity="{ row }">
<AccountCapacityCell :account="row" /> <AccountCapacityCell :account="row" />
@@ -889,6 +889,38 @@ const toggleSelectAllVisible = (event: Event) => {
toggleVisible(target.checked) toggleVisible(target.checked)
} }
const handleBulkDelete = async () => { if(!confirm(t('common.confirm'))) return; try { await Promise.all(selIds.value.map(id => adminAPI.accounts.delete(id))); clearSelection(); reload() } catch (error) { console.error('Failed to bulk delete accounts:', error) } } const handleBulkDelete = async () => { if(!confirm(t('common.confirm'))) return; try { await Promise.all(selIds.value.map(id => adminAPI.accounts.delete(id))); clearSelection(); reload() } catch (error) { console.error('Failed to bulk delete accounts:', error) } }
const handleBulkResetStatus = async () => {
if (!confirm(t('common.confirm'))) return
try {
const result = await adminAPI.accounts.batchClearError(selIds.value)
if (result.failed > 0) {
appStore.showError(t('admin.accounts.bulkActions.partialSuccess', { success: result.success, failed: result.failed }))
} else {
appStore.showSuccess(t('admin.accounts.bulkActions.resetStatusSuccess', { count: result.success }))
clearSelection()
}
reload()
} catch (error) {
console.error('Failed to bulk reset status:', error)
appStore.showError(String(error))
}
}
const handleBulkRefreshToken = async () => {
if (!confirm(t('common.confirm'))) return
try {
const result = await adminAPI.accounts.batchRefresh(selIds.value)
if (result.failed > 0) {
appStore.showError(t('admin.accounts.bulkActions.partialSuccess', { success: result.success, failed: result.failed }))
} else {
appStore.showSuccess(t('admin.accounts.bulkActions.refreshTokenSuccess', { count: result.success }))
clearSelection()
}
reload()
} catch (error) {
console.error('Failed to bulk refresh token:', error)
appStore.showError(String(error))
}
}
const updateSchedulableInList = (accountIds: number[], schedulable: boolean) => { const updateSchedulableInList = (accountIds: number[], schedulable: boolean) => {
if (accountIds.length === 0) return if (accountIds.length === 0) return
const idSet = new Set(accountIds) const idSet = new Set(accountIds)

View File

@@ -405,6 +405,117 @@
</template> </template>
</div> </div>
</div> </div>
<!-- Beta Policy Settings -->
<div class="card">
<div class="border-b border-gray-100 px-6 py-4 dark:border-dark-700">
<h2 class="text-lg font-semibold text-gray-900 dark:text-white">
{{ t('admin.settings.betaPolicy.title') }}
</h2>
<p class="mt-1 text-sm text-gray-500 dark:text-gray-400">
{{ t('admin.settings.betaPolicy.description') }}
</p>
</div>
<div class="space-y-5 p-6">
<!-- Loading State -->
<div v-if="betaPolicyLoading" class="flex items-center gap-2 text-gray-500">
<div class="h-4 w-4 animate-spin rounded-full border-b-2 border-primary-600"></div>
{{ t('common.loading') }}
</div>
<template v-else>
<!-- Rule Cards -->
<div
v-for="rule in betaPolicyForm.rules"
:key="rule.beta_token"
class="rounded-lg border border-gray-200 p-4 dark:border-dark-600"
>
<div class="mb-3 flex items-center gap-2">
<span class="text-sm font-medium text-gray-900 dark:text-white">
{{ getBetaDisplayName(rule.beta_token) }}
</span>
<span class="rounded bg-gray-100 px-2 py-0.5 text-xs text-gray-500 dark:bg-dark-700 dark:text-gray-400">
{{ rule.beta_token }}
</span>
</div>
<div class="grid grid-cols-2 gap-4">
<!-- Action -->
<div>
<label class="mb-1 block text-xs font-medium text-gray-600 dark:text-gray-400">
{{ t('admin.settings.betaPolicy.action') }}
</label>
<Select
:modelValue="rule.action"
@update:modelValue="rule.action = $event as any"
:options="betaPolicyActionOptions"
/>
</div>
<!-- Scope -->
<div>
<label class="mb-1 block text-xs font-medium text-gray-600 dark:text-gray-400">
{{ t('admin.settings.betaPolicy.scope') }}
</label>
<Select
:modelValue="rule.scope"
@update:modelValue="rule.scope = $event as any"
:options="betaPolicyScopeOptions"
/>
</div>
</div>
<!-- Error Message (only when action=block) -->
<div v-if="rule.action === 'block'" class="mt-3">
<label class="mb-1 block text-xs font-medium text-gray-600 dark:text-gray-400">
{{ t('admin.settings.betaPolicy.errorMessage') }}
</label>
<input
v-model="rule.error_message"
type="text"
class="input"
:placeholder="t('admin.settings.betaPolicy.errorMessagePlaceholder')"
/>
<p class="mt-1 text-xs text-gray-400 dark:text-gray-500">
{{ t('admin.settings.betaPolicy.errorMessageHint') }}
</p>
</div>
</div>
<!-- Save Button -->
<div class="flex justify-end border-t border-gray-100 pt-4 dark:border-dark-700">
<button
type="button"
@click="saveBetaPolicySettings"
:disabled="betaPolicySaving"
class="btn btn-primary btn-sm"
>
<svg
v-if="betaPolicySaving"
class="mr-1 h-4 w-4 animate-spin"
fill="none"
viewBox="0 0 24 24"
>
<circle
class="opacity-25"
cx="12"
cy="12"
r="10"
stroke="currentColor"
stroke-width="4"
></circle>
<path
class="opacity-75"
fill="currentColor"
d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"
></path>
</svg>
{{ betaPolicySaving ? t('common.saving') : t('common.save') }}
</button>
</div>
</template>
</div>
</div>
</div><!-- /Tab: Gateway --> </div><!-- /Tab: Gateway -->
<!-- Tab: Security Registration, Turnstile, LinuxDo --> <!-- Tab: Security Registration, Turnstile, LinuxDo -->
@@ -1627,6 +1738,18 @@ const rectifierForm = reactive({
thinking_budget_enabled: true thinking_budget_enabled: true
}) })
// Beta Policy 状态
const betaPolicyLoading = ref(true)
const betaPolicySaving = ref(false)
const betaPolicyForm = reactive({
rules: [] as Array<{
beta_token: string
action: 'pass' | 'filter' | 'block'
scope: 'all' | 'oauth' | 'apikey'
error_message?: string
}>
})
interface DefaultSubscriptionGroupOption { interface DefaultSubscriptionGroupOption {
value: number value: number
label: string label: string
@@ -2165,12 +2288,64 @@ async function saveRectifierSettings() {
} }
} }
const betaPolicyActionOptions = computed(() => [
{ value: 'pass', label: t('admin.settings.betaPolicy.actionPass') },
{ value: 'filter', label: t('admin.settings.betaPolicy.actionFilter') },
{ value: 'block', label: t('admin.settings.betaPolicy.actionBlock') }
])
const betaPolicyScopeOptions = computed(() => [
{ value: 'all', label: t('admin.settings.betaPolicy.scopeAll') },
{ value: 'oauth', label: t('admin.settings.betaPolicy.scopeOAuth') },
{ value: 'apikey', label: t('admin.settings.betaPolicy.scopeAPIKey') }
])
// Beta Policy 方法
const betaDisplayNames: Record<string, string> = {
'fast-mode-2026-02-01': 'Fast Mode',
'context-1m-2025-08-07': 'Context 1M'
}
function getBetaDisplayName(token: string): string {
return betaDisplayNames[token] || token
}
async function loadBetaPolicySettings() {
betaPolicyLoading.value = true
try {
const settings = await adminAPI.settings.getBetaPolicySettings()
betaPolicyForm.rules = settings.rules
} catch (error: any) {
console.error('Failed to load beta policy settings:', error)
} finally {
betaPolicyLoading.value = false
}
}
async function saveBetaPolicySettings() {
betaPolicySaving.value = true
try {
const updated = await adminAPI.settings.updateBetaPolicySettings({
rules: betaPolicyForm.rules
})
betaPolicyForm.rules = updated.rules
appStore.showSuccess(t('admin.settings.betaPolicy.saved'))
} catch (error: any) {
appStore.showError(
t('admin.settings.betaPolicy.saveFailed') + ': ' + (error.message || t('common.unknownError'))
)
} finally {
betaPolicySaving.value = false
}
}
onMounted(() => { onMounted(() => {
loadSettings() loadSettings()
loadSubscriptionGroups() loadSubscriptionGroups()
loadAdminApiKey() loadAdminApiKey()
loadStreamTimeoutSettings() loadStreamTimeoutSettings()
loadRectifierSettings() loadRectifierSettings()
loadBetaPolicySettings()
}) })
</script> </script>

View File

@@ -10,6 +10,36 @@
</p> </p>
</div> </div>
<transition name="fade">
<div v-if="needsInvitation" class="space-y-4">
<p class="text-sm text-gray-700 dark:text-gray-300">
{{ t('auth.linuxdo.invitationRequired') }}
</p>
<div>
<input
v-model="invitationCode"
type="text"
class="input w-full"
:placeholder="t('auth.invitationCodePlaceholder')"
:disabled="isSubmitting"
@keyup.enter="handleSubmitInvitation"
/>
</div>
<transition name="fade">
<p v-if="invitationError" class="text-sm text-red-600 dark:text-red-400">
{{ invitationError }}
</p>
</transition>
<button
class="btn btn-primary w-full"
:disabled="isSubmitting || !invitationCode.trim()"
@click="handleSubmitInvitation"
>
{{ isSubmitting ? t('auth.linuxdo.completing') : t('auth.linuxdo.completeRegistration') }}
</button>
</div>
</transition>
<transition name="fade"> <transition name="fade">
<div <div
v-if="errorMessage" v-if="errorMessage"
@@ -41,6 +71,7 @@ import { useI18n } from 'vue-i18n'
import { AuthLayout } from '@/components/layout' import { AuthLayout } from '@/components/layout'
import Icon from '@/components/icons/Icon.vue' import Icon from '@/components/icons/Icon.vue'
import { useAuthStore, useAppStore } from '@/stores' import { useAuthStore, useAppStore } from '@/stores'
import { completeLinuxDoOAuthRegistration } from '@/api/auth'
const route = useRoute() const route = useRoute()
const router = useRouter() const router = useRouter()
@@ -52,6 +83,14 @@ const appStore = useAppStore()
const isProcessing = ref(true) const isProcessing = ref(true)
const errorMessage = ref('') const errorMessage = ref('')
// Invitation code flow state
const needsInvitation = ref(false)
const pendingOAuthToken = ref('')
const invitationCode = ref('')
const isSubmitting = ref(false)
const invitationError = ref('')
const redirectTo = ref('/dashboard')
function parseFragmentParams(): URLSearchParams { function parseFragmentParams(): URLSearchParams {
const raw = typeof window !== 'undefined' ? window.location.hash : '' const raw = typeof window !== 'undefined' ? window.location.hash : ''
const hash = raw.startsWith('#') ? raw.slice(1) : raw const hash = raw.startsWith('#') ? raw.slice(1) : raw
@@ -67,6 +106,34 @@ function sanitizeRedirectPath(path: string | null | undefined): string {
return path return path
} }
async function handleSubmitInvitation() {
invitationError.value = ''
if (!invitationCode.value.trim()) return
isSubmitting.value = true
try {
const tokenData = await completeLinuxDoOAuthRegistration(
pendingOAuthToken.value,
invitationCode.value.trim()
)
if (tokenData.refresh_token) {
localStorage.setItem('refresh_token', tokenData.refresh_token)
}
if (tokenData.expires_in) {
localStorage.setItem('token_expires_at', String(Date.now() + tokenData.expires_in * 1000))
}
await authStore.setToken(tokenData.access_token)
appStore.showSuccess(t('auth.loginSuccess'))
await router.replace(redirectTo.value)
} catch (e: unknown) {
const err = e as { message?: string; response?: { data?: { message?: string } } }
invitationError.value =
err.response?.data?.message || err.message || t('auth.linuxdo.completeRegistrationFailed')
} finally {
isSubmitting.value = false
}
}
onMounted(async () => { onMounted(async () => {
const params = parseFragmentParams() const params = parseFragmentParams()
@@ -80,6 +147,19 @@ onMounted(async () => {
const errorDesc = params.get('error_description') || params.get('error_message') || '' const errorDesc = params.get('error_description') || params.get('error_message') || ''
if (error) { if (error) {
if (error === 'invitation_required') {
pendingOAuthToken.value = params.get('pending_oauth_token') || ''
redirectTo.value = sanitizeRedirectPath(params.get('redirect'))
if (!pendingOAuthToken.value) {
errorMessage.value = t('auth.linuxdo.invalidPendingToken')
appStore.showError(errorMessage.value)
isProcessing.value = false
return
}
needsInvitation.value = true
isProcessing.value = false
return
}
errorMessage.value = errorDesc || error errorMessage.value = errorDesc || error
appStore.showError(errorMessage.value) appStore.showError(errorMessage.value)
isProcessing.value = false isProcessing.value = false

View File

@@ -75,7 +75,7 @@ import AppLayout from '@/components/layout/AppLayout.vue'
import Icon from '@/components/icons/Icon.vue' import Icon from '@/components/icons/Icon.vue'
import { buildEmbeddedUrl, detectTheme } from '@/utils/embedded-url' import { buildEmbeddedUrl, detectTheme } from '@/utils/embedded-url'
const { t } = useI18n() const { t, locale } = useI18n()
const route = useRoute() const route = useRoute()
const appStore = useAppStore() const appStore = useAppStore()
const authStore = useAuthStore() const authStore = useAuthStore()
@@ -107,6 +107,7 @@ const embeddedUrl = computed(() => {
authStore.user?.id, authStore.user?.id,
authStore.token, authStore.token,
pageTheme.value, pageTheme.value,
locale.value,
) )
}) })

View File

@@ -76,7 +76,7 @@ import AppLayout from '@/components/layout/AppLayout.vue'
import Icon from '@/components/icons/Icon.vue' import Icon from '@/components/icons/Icon.vue'
import { buildEmbeddedUrl, detectTheme } from '@/utils/embedded-url' import { buildEmbeddedUrl, detectTheme } from '@/utils/embedded-url'
const { t } = useI18n() const { t, locale } = useI18n()
const appStore = useAppStore() const appStore = useAppStore()
const authStore = useAuthStore() const authStore = useAuthStore()
@@ -90,7 +90,7 @@ const purchaseEnabled = computed(() => {
const purchaseUrl = computed(() => { const purchaseUrl = computed(() => {
const baseUrl = (appStore.cachedPublicSettings?.purchase_subscription_url || '').trim() const baseUrl = (appStore.cachedPublicSettings?.purchase_subscription_url || '').trim()
return buildEmbeddedUrl(baseUrl, authStore.user?.id, authStore.token, purchaseTheme.value) return buildEmbeddedUrl(baseUrl, authStore.user?.id, authStore.token, purchaseTheme.value, locale.value)
}) })
const isValidUrl = computed(() => { const isValidUrl = computed(() => {