Merge tag 'v0.1.90' into merge/upstream-v0.1.90

注册邮箱域名白名单策略上线,后台大数据场景性能大幅优化。

- 注册邮箱域名白名单:支持管理员配置允许注册的邮箱域名策略
- Keys 页面表单筛选:用户 /keys 页面支持按条件筛选 API Key
- Settings 页面分 Tab 拆分:管理后台设置页面按功能模块分 Tab 展示

- 后台大数据场景加载性能优化:仪表盘/用户/账号/Ops 页面大数据集加载显著提速
- Usage 大表分页优化:默认避免全量 COUNT(*),大幅降低分页查询耗时
- 消除重复的 normalizeAccountIDList,补充新增组件的单元测试
- 清理无用文件和过时文档,精简项目结构
- EmailVerifyView 硬编码英文字符串替换为 i18n 调用

- 修复 Anthropic 平台无限流重置时间的 429 误标记账号限流问题
- 修复自定义菜单页面管理员视角菜单不生效问题
- 修复 Ops 错误详情弹窗未展示真实上游 payload 的问题
- 修复充值/订阅菜单 icon 显示问题

# Conflicts:
#	.gitignore
#	backend/cmd/server/VERSION
#	backend/ent/group.go
#	backend/ent/runtime/runtime.go
#	backend/ent/schema/group.go
#	backend/go.sum
#	backend/internal/handler/admin/account_handler.go
#	backend/internal/handler/admin/dashboard_handler.go
#	backend/internal/pkg/usagestats/usage_log_types.go
#	backend/internal/repository/group_repo.go
#	backend/internal/repository/usage_log_repo.go
#	backend/internal/server/middleware/security_headers.go
#	backend/internal/server/router.go
#	backend/internal/service/account_usage_service.go
#	backend/internal/service/admin_service_bulk_update_test.go
#	backend/internal/service/dashboard_service.go
#	backend/internal/service/gateway_service.go
#	frontend/src/api/admin/dashboard.ts
#	frontend/src/components/account/BulkEditAccountModal.vue
#	frontend/src/components/charts/GroupDistributionChart.vue
#	frontend/src/components/layout/AppSidebar.vue
#	frontend/src/i18n/locales/en.ts
#	frontend/src/i18n/locales/zh.ts
#	frontend/src/views/admin/GroupsView.vue
#	frontend/src/views/admin/SettingsView.vue
#	frontend/src/views/admin/UsageView.vue
#	frontend/src/views/user/PurchaseSubscriptionView.vue
This commit is contained in:
erio
2026-03-04 19:58:38 +08:00
461 changed files with 63392 additions and 6617 deletions

View File

@@ -86,6 +86,15 @@ func TestAPIContracts(t *testing.T) {
"last_used_at": null,
"quota": 0,
"quota_used": 0,
"rate_limit_5h": 0,
"rate_limit_1d": 0,
"rate_limit_7d": 0,
"usage_5h": 0,
"usage_1d": 0,
"usage_7d": 0,
"window_5h_start": null,
"window_1d_start": null,
"window_7d_start": null,
"expires_at": null,
"created_at": "2025-01-02T03:04:05Z",
"updated_at": "2025-01-02T03:04:05Z"
@@ -126,6 +135,15 @@ func TestAPIContracts(t *testing.T) {
"last_used_at": null,
"quota": 0,
"quota_used": 0,
"rate_limit_5h": 0,
"rate_limit_1d": 0,
"rate_limit_7d": 0,
"usage_5h": 0,
"usage_1d": 0,
"usage_7d": 0,
"window_5h_start": null,
"window_1d_start": null,
"window_7d_start": null,
"expires_at": null,
"created_at": "2025-01-02T03:04:05Z",
"updated_at": "2025-01-02T03:04:05Z"
@@ -186,11 +204,12 @@ func TestAPIContracts(t *testing.T) {
"image_price_1k": null,
"image_price_2k": null,
"image_price_4k": null,
"sora_image_price_360": null,
"sora_image_price_540": null,
"sora_video_price_per_request": null,
"sora_video_price_per_request_hd": null,
"claude_code_only": false,
"sora_image_price_360": null,
"sora_image_price_540": null,
"sora_storage_quota_bytes": 0,
"sora_video_price_per_request": null,
"sora_video_price_per_request_hd": null,
"claude_code_only": false,
"fallback_group_id": null,
"fallback_group_id_on_invalid_request": null,
"created_at": "2025-01-02T03:04:05Z",
@@ -384,10 +403,12 @@ func TestAPIContracts(t *testing.T) {
"user_id": 1,
"api_key_id": 100,
"account_id": 200,
"request_id": "req_123",
"model": "claude-3",
"group_id": null,
"subscription_id": null,
"request_id": "req_123",
"model": "claude-3",
"request_type": "stream",
"openai_ws_mode": false,
"group_id": null,
"subscription_id": null,
"input_tokens": 10,
"output_tokens": 20,
"cache_creation_tokens": 1,
@@ -425,9 +446,10 @@ func TestAPIContracts(t *testing.T) {
setup: func(t *testing.T, deps *contractDeps) {
t.Helper()
deps.settingRepo.SetAll(map[string]string{
service.SettingKeyRegistrationEnabled: "true",
service.SettingKeyEmailVerifyEnabled: "false",
service.SettingKeyPromoCodeEnabled: "true",
service.SettingKeyRegistrationEnabled: "true",
service.SettingKeyEmailVerifyEnabled: "false",
service.SettingKeyRegistrationEmailSuffixWhitelist: "[]",
service.SettingKeyPromoCodeEnabled: "true",
service.SettingKeySMTPHost: "smtp.example.com",
service.SettingKeySMTPPort: "587",
@@ -466,6 +488,7 @@ func TestAPIContracts(t *testing.T) {
"data": {
"registration_enabled": true,
"email_verify_enabled": false,
"registration_email_suffix_whitelist": [],
"promo_code_enabled": true,
"password_reset_enabled": false,
"totp_enabled": false,
@@ -496,18 +519,23 @@ func TestAPIContracts(t *testing.T) {
"doc_url": "https://docs.example.com",
"default_concurrency": 5,
"default_balance": 1.25,
"default_subscriptions": [],
"enable_model_fallback": false,
"fallback_model_anthropic": "claude-3-5-sonnet-20241022",
"fallback_model_antigravity": "gemini-2.5-pro",
"fallback_model_gemini": "gemini-2.5-pro",
"fallback_model_openai": "gpt-4o",
"enable_identity_patch": true,
"identity_patch_prompt": "",
"invitation_code_enabled": false,
"home_content": "",
"fallback_model_openai": "gpt-4o",
"enable_identity_patch": true,
"identity_patch_prompt": "",
"sora_client_enabled": false,
"invitation_code_enabled": false,
"home_content": "",
"hide_ccs_import_button": false,
"purchase_subscription_enabled": false,
"purchase_subscription_url": ""
"purchase_subscription_url": "",
"min_claude_code_version": "",
"allow_ungrouped_key_scheduling": false,
"custom_menu_items": []
}
}`,
},
@@ -615,12 +643,12 @@ func newContractDeps(t *testing.T) *contractDeps {
settingRepo := newStubSettingRepo()
settingService := service.NewSettingService(settingRepo, cfg)
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil)
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil)
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil)
adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil, nil)
adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
jwtAuth := func(c *gin.Context) {
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{
@@ -775,6 +803,10 @@ func (r *stubUserRepo) RemoveGroupFromAllowedGroups(ctx context.Context, groupID
return 0, errors.New("not implemented")
}
func (r *stubUserRepo) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error {
return errors.New("not implemented")
}
func (r *stubUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
return errors.New("not implemented")
}
@@ -1016,6 +1048,14 @@ func (s *stubAccountRepo) ListSchedulableByGroupIDAndPlatforms(ctx context.Conte
return nil, errors.New("not implemented")
}
func (s *stubAccountRepo) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
return nil, errors.New("not implemented")
}
func (s *stubAccountRepo) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]service.Account, error) {
return nil, errors.New("not implemented")
}
func (s *stubAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
return errors.New("not implemented")
}
@@ -1373,7 +1413,7 @@ func (r *stubApiKeyRepo) Delete(ctx context.Context, id int64) error {
return nil
}
func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams, _ service.APIKeyListFilters) ([]service.APIKey, *pagination.PaginationResult, error) {
ids := make([]int64, 0, len(r.byID))
for id := range r.byID {
if r.byID[id].UserID == userID {
@@ -1487,6 +1527,16 @@ func (r *stubApiKeyRepo) UpdateLastUsed(ctx context.Context, id int64, usedAt ti
return nil
}
func (r *stubApiKeyRepo) IncrementRateLimitUsage(ctx context.Context, id int64, cost float64) error {
return nil
}
func (r *stubApiKeyRepo) ResetRateLimitWindows(ctx context.Context, id int64) error {
return nil
}
func (r *stubApiKeyRepo) GetRateLimitData(ctx context.Context, id int64) (*service.APIKeyRateLimitData, error) {
return nil, nil
}
type stubUsageLogRepo struct {
userLogs map[int64][]service.UsageLog
}
@@ -1555,11 +1605,15 @@ func (r *stubUsageLogRepo) GetDashboardStats(ctx context.Context) (*usagestats.D
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) {
func (r *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) {
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) {
func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) {
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) {
return nil, errors.New("not implemented")
}

View File

@@ -19,7 +19,7 @@ func TestAdminAuthJWTValidatesTokenVersion(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{JWT: config.JWTConfig{Secret: "test-secret", ExpireHour: 1}}
authService := service.NewAuthService(nil, nil, nil, cfg, nil, nil, nil, nil, nil)
authService := service.NewAuthService(nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
admin := &service.User{
ID: 1,
@@ -181,6 +181,10 @@ func (s *stubUserRepo) RemoveGroupFromAllowedGroups(ctx context.Context, groupID
panic("unexpected RemoveGroupFromAllowedGroups call")
}
func (s *stubUserRepo) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error {
panic("unexpected AddGroupToAllowedGroups call")
}
func (s *stubUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
panic("unexpected UpdateTotpSecret call")
}

View File

@@ -19,8 +19,16 @@ func NewAPIKeyAuthMiddleware(apiKeyService *service.APIKeyService, subscriptionS
}
// apiKeyAuthWithSubscription API Key认证中间件支持订阅验证
//
// 中间件职责分为两层:
// - 鉴权Authentication验证 Key 有效性、用户状态、IP 限制 —— 始终执行
// - 计费执行Billing Enforcement过期/配额/订阅/余额检查 —— skipBilling 时整块跳过
//
// /v1/usage 端点只需鉴权,不需要计费执行(允许过期/配额耗尽的 Key 查询自身用量)。
func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc {
return func(c *gin.Context) {
// ── 1. 提取 API Key ──────────────────────────────────────────
queryKey := strings.TrimSpace(c.Query("key"))
queryApiKey := strings.TrimSpace(c.Query("api_key"))
if queryKey != "" || queryApiKey != "" {
@@ -56,7 +64,8 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
return
}
// 从数据库验证API key
// ── 2. 验证 Key 存在 ─────────────────────────────────────────
apiKey, err := apiKeyService.GetByKey(c.Request.Context(), apiKeyString)
if err != nil {
if errors.Is(err, service.ErrAPIKeyNotFound) {
@@ -67,29 +76,13 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
return
}
// 检查API key是否激活
if !apiKey.IsActive() {
// Provide more specific error message based on status
switch apiKey.Status {
case service.StatusAPIKeyQuotaExhausted:
AbortWithError(c, 429, "API_KEY_QUOTA_EXHAUSTED", "API key 额度已用完")
case service.StatusAPIKeyExpired:
AbortWithError(c, 403, "API_KEY_EXPIRED", "API key 已过期")
default:
AbortWithError(c, 401, "API_KEY_DISABLED", "API key is disabled")
}
return
}
// ── 3. 基础鉴权(始终执行) ─────────────────────────────────
// 检查API Key是否过期即使状态是active也要检查时间
if apiKey.IsExpired() {
AbortWithError(c, 403, "API_KEY_EXPIRED", "API key 已过期")
return
}
// 检查API Key配额是否耗尽
if apiKey.IsQuotaExhausted() {
AbortWithError(c, 429, "API_KEY_QUOTA_EXHAUSTED", "API key 额度已用完")
// disabled / 未知状态 → 无条件拦截expired 和 quota_exhausted 留给计费阶段
if !apiKey.IsActive() &&
apiKey.Status != service.StatusAPIKeyExpired &&
apiKey.Status != service.StatusAPIKeyQuotaExhausted {
AbortWithError(c, 401, "API_KEY_DISABLED", "API key is disabled")
return
}
@@ -97,7 +90,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
// 注意:错误信息故意模糊,避免暴露具体的 IP 限制机制
if len(apiKey.IPWhitelist) > 0 || len(apiKey.IPBlacklist) > 0 {
clientIP := ip.GetTrustedClientIP(c)
allowed, _ := ip.CheckIPRestriction(clientIP, apiKey.IPWhitelist, apiKey.IPBlacklist)
allowed, _ := ip.CheckIPRestrictionWithCompiledRules(clientIP, apiKey.CompiledIPWhitelist, apiKey.CompiledIPBlacklist)
if !allowed {
AbortWithError(c, 403, "ACCESS_DENIED", "Access denied")
return
@@ -116,8 +109,9 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
return
}
// ── 4. SimpleMode → early return ─────────────────────────────
if cfg.RunMode == config.RunModeSimple {
// 简易模式:跳过余额和订阅检查,但仍需设置必要的上下文
c.Set(string(ContextKeyAPIKey), apiKey)
c.Set(string(ContextKeyUser), AuthSubject{
UserID: apiKey.User.ID,
@@ -130,54 +124,89 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
return
}
// 判断计费方式:订阅模式 vs 余额模式
// ── 5. 加载订阅(订阅模式时始终加载) ───────────────────────
// skipBilling: /v1/usage 只需鉴权,跳过所有计费执行
skipBilling := c.Request.URL.Path == "/v1/usage"
var subscription *service.UserSubscription
isSubscriptionType := apiKey.Group != nil && apiKey.Group.IsSubscriptionType()
if isSubscriptionType && subscriptionService != nil {
// 订阅模式获取订阅L1 缓存 + singleflight
subscription, err := subscriptionService.GetActiveSubscription(
sub, subErr := subscriptionService.GetActiveSubscription(
c.Request.Context(),
apiKey.User.ID,
apiKey.Group.ID,
)
if err != nil {
AbortWithError(c, 403, "SUBSCRIPTION_NOT_FOUND", "No active subscription found for this group")
return
}
// 合并验证 + 限额检查(纯内存操作)
needsMaintenance, err := subscriptionService.ValidateAndCheckLimits(subscription, apiKey.Group)
if err != nil {
code := "SUBSCRIPTION_INVALID"
status := 403
if errors.Is(err, service.ErrDailyLimitExceeded) ||
errors.Is(err, service.ErrWeeklyLimitExceeded) ||
errors.Is(err, service.ErrMonthlyLimitExceeded) {
code = "USAGE_LIMIT_EXCEEDED"
status = 429
if subErr != nil {
if !skipBilling {
AbortWithError(c, 403, "SUBSCRIPTION_NOT_FOUND", "No active subscription found for this group")
return
}
AbortWithError(c, status, code, err.Error())
return
}
// 将订阅信息存入上下文
c.Set(string(ContextKeySubscription), subscription)
// 窗口维护异步化(不阻塞请求)
// 传递独立拷贝,避免与 handler 读取 context 中的 subscription 产生 data race
if needsMaintenance {
maintenanceCopy := *subscription
subscriptionService.DoWindowMaintenance(&maintenanceCopy)
}
} else {
// 余额模式:检查用户余额
if apiKey.User.Balance <= 0 {
AbortWithError(c, 403, "INSUFFICIENT_BALANCE", "Insufficient account balance")
return
// skipBilling: 订阅不存在也放行handler 会返回可用的数据
} else {
subscription = sub
}
}
// 将API key和用户信息存入上下文
// ── 6. 计费执行skipBilling 时整块跳过) ────────────────────
if !skipBilling {
// Key 状态检查
switch apiKey.Status {
case service.StatusAPIKeyQuotaExhausted:
AbortWithError(c, 429, "API_KEY_QUOTA_EXHAUSTED", "API key 额度已用完")
return
case service.StatusAPIKeyExpired:
AbortWithError(c, 403, "API_KEY_EXPIRED", "API key 已过期")
return
}
// 运行时过期/配额检查(即使状态是 active也要检查时间和用量
if apiKey.IsExpired() {
AbortWithError(c, 403, "API_KEY_EXPIRED", "API key 已过期")
return
}
if apiKey.IsQuotaExhausted() {
AbortWithError(c, 429, "API_KEY_QUOTA_EXHAUSTED", "API key 额度已用完")
return
}
// 订阅模式:验证订阅限额
if subscription != nil {
needsMaintenance, validateErr := subscriptionService.ValidateAndCheckLimits(subscription, apiKey.Group)
if validateErr != nil {
code := "SUBSCRIPTION_INVALID"
status := 403
if errors.Is(validateErr, service.ErrDailyLimitExceeded) ||
errors.Is(validateErr, service.ErrWeeklyLimitExceeded) ||
errors.Is(validateErr, service.ErrMonthlyLimitExceeded) {
code = "USAGE_LIMIT_EXCEEDED"
status = 429
}
AbortWithError(c, status, code, validateErr.Error())
return
}
// 窗口维护异步化(不阻塞请求)
if needsMaintenance {
maintenanceCopy := *subscription
subscriptionService.DoWindowMaintenance(&maintenanceCopy)
}
} else {
// 非订阅模式 或 订阅模式但 subscriptionService 未注入:回退到余额检查
if apiKey.User.Balance <= 0 {
AbortWithError(c, 403, "INSUFFICIENT_BALANCE", "Insufficient account balance")
return
}
}
}
// ── 7. 设置上下文 → Next ─────────────────────────────────────
if subscription != nil {
c.Set(string(ContextKeySubscription), subscription)
}
c.Set(string(ContextKeyAPIKey), apiKey)
c.Set(string(ContextKeyUser), AuthSubject{
UserID: apiKey.User.ID,

View File

@@ -80,17 +80,25 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs
abortWithGoogleError(c, 403, "No active subscription found for this group")
return
}
if err := subscriptionService.ValidateSubscription(c.Request.Context(), subscription); err != nil {
abortWithGoogleError(c, 403, err.Error())
return
}
_ = subscriptionService.CheckAndActivateWindow(c.Request.Context(), subscription)
_ = subscriptionService.CheckAndResetWindows(c.Request.Context(), subscription)
if err := subscriptionService.CheckUsageLimits(c.Request.Context(), subscription, apiKey.Group, 0); err != nil {
abortWithGoogleError(c, 429, err.Error())
needsMaintenance, err := subscriptionService.ValidateAndCheckLimits(subscription, apiKey.Group)
if err != nil {
status := 403
if errors.Is(err, service.ErrDailyLimitExceeded) ||
errors.Is(err, service.ErrWeeklyLimitExceeded) ||
errors.Is(err, service.ErrMonthlyLimitExceeded) {
status = 429
}
abortWithGoogleError(c, status, err.Error())
return
}
c.Set(string(ContextKeySubscription), subscription)
if needsMaintenance {
maintenanceCopy := *subscription
subscriptionService.DoWindowMaintenance(&maintenanceCopy)
}
} else {
if apiKey.User.Balance <= 0 {
abortWithGoogleError(c, 403, "Insufficient account balance")

View File

@@ -23,6 +23,15 @@ type fakeAPIKeyRepo struct {
updateLastUsed func(ctx context.Context, id int64, usedAt time.Time) error
}
type fakeGoogleSubscriptionRepo struct {
getActive func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error)
updateStatus func(ctx context.Context, subscriptionID int64, status string) error
activateWindow func(ctx context.Context, id int64, start time.Time) error
resetDaily func(ctx context.Context, id int64, start time.Time) error
resetWeekly func(ctx context.Context, id int64, start time.Time) error
resetMonthly func(ctx context.Context, id int64, start time.Time) error
}
func (f fakeAPIKeyRepo) Create(ctx context.Context, key *service.APIKey) error {
return errors.New("not implemented")
}
@@ -47,7 +56,7 @@ func (f fakeAPIKeyRepo) Update(ctx context.Context, key *service.APIKey) error {
func (f fakeAPIKeyRepo) Delete(ctx context.Context, id int64) error {
return errors.New("not implemented")
}
func (f fakeAPIKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
func (f fakeAPIKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams, _ service.APIKeyListFilters) ([]service.APIKey, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (f fakeAPIKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
@@ -86,6 +95,94 @@ func (f fakeAPIKeyRepo) UpdateLastUsed(ctx context.Context, id int64, usedAt tim
}
return nil
}
func (f fakeAPIKeyRepo) IncrementRateLimitUsage(ctx context.Context, id int64, cost float64) error {
return nil
}
func (f fakeAPIKeyRepo) ResetRateLimitWindows(ctx context.Context, id int64) error {
return nil
}
func (f fakeAPIKeyRepo) GetRateLimitData(ctx context.Context, id int64) (*service.APIKeyRateLimitData, error) {
return &service.APIKeyRateLimitData{}, nil
}
func (f fakeGoogleSubscriptionRepo) Create(ctx context.Context, sub *service.UserSubscription) error {
return errors.New("not implemented")
}
func (f fakeGoogleSubscriptionRepo) GetByID(ctx context.Context, id int64) (*service.UserSubscription, error) {
return nil, errors.New("not implemented")
}
func (f fakeGoogleSubscriptionRepo) GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
return nil, errors.New("not implemented")
}
func (f fakeGoogleSubscriptionRepo) GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
if f.getActive != nil {
return f.getActive(ctx, userID, groupID)
}
return nil, errors.New("not implemented")
}
func (f fakeGoogleSubscriptionRepo) Update(ctx context.Context, sub *service.UserSubscription) error {
return errors.New("not implemented")
}
func (f fakeGoogleSubscriptionRepo) Delete(ctx context.Context, id int64) error {
return errors.New("not implemented")
}
func (f fakeGoogleSubscriptionRepo) ListByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
return nil, errors.New("not implemented")
}
func (f fakeGoogleSubscriptionRepo) ListActiveByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
return nil, errors.New("not implemented")
}
func (f fakeGoogleSubscriptionRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (f fakeGoogleSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (f fakeGoogleSubscriptionRepo) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {
return false, errors.New("not implemented")
}
func (f fakeGoogleSubscriptionRepo) ExtendExpiry(ctx context.Context, subscriptionID int64, newExpiresAt time.Time) error {
return errors.New("not implemented")
}
func (f fakeGoogleSubscriptionRepo) UpdateStatus(ctx context.Context, subscriptionID int64, status string) error {
if f.updateStatus != nil {
return f.updateStatus(ctx, subscriptionID, status)
}
return errors.New("not implemented")
}
func (f fakeGoogleSubscriptionRepo) UpdateNotes(ctx context.Context, subscriptionID int64, notes string) error {
return errors.New("not implemented")
}
func (f fakeGoogleSubscriptionRepo) ActivateWindows(ctx context.Context, id int64, start time.Time) error {
if f.activateWindow != nil {
return f.activateWindow(ctx, id, start)
}
return errors.New("not implemented")
}
func (f fakeGoogleSubscriptionRepo) ResetDailyUsage(ctx context.Context, id int64, start time.Time) error {
if f.resetDaily != nil {
return f.resetDaily(ctx, id, start)
}
return errors.New("not implemented")
}
func (f fakeGoogleSubscriptionRepo) ResetWeeklyUsage(ctx context.Context, id int64, start time.Time) error {
if f.resetWeekly != nil {
return f.resetWeekly(ctx, id, start)
}
return errors.New("not implemented")
}
func (f fakeGoogleSubscriptionRepo) ResetMonthlyUsage(ctx context.Context, id int64, start time.Time) error {
if f.resetMonthly != nil {
return f.resetMonthly(ctx, id, start)
}
return errors.New("not implemented")
}
func (f fakeGoogleSubscriptionRepo) IncrementUsage(ctx context.Context, id int64, costUSD float64) error {
return errors.New("not implemented")
}
func (f fakeGoogleSubscriptionRepo) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) {
return 0, errors.New("not implemented")
}
type googleErrorResponse struct {
Error struct {
@@ -505,3 +602,85 @@ func TestApiKeyAuthWithSubscriptionGoogle_TouchesLastUsedInStandardMode(t *testi
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, 1, touchCalls)
}
func TestApiKeyAuthWithSubscriptionGoogle_SubscriptionLimitExceededReturns429(t *testing.T) {
gin.SetMode(gin.TestMode)
limit := 1.0
group := &service.Group{
ID: 77,
Name: "gemini-sub",
Status: service.StatusActive,
Platform: service.PlatformGemini,
Hydrated: true,
SubscriptionType: service.SubscriptionTypeSubscription,
DailyLimitUSD: &limit,
}
user := &service.User{
ID: 999,
Role: service.RoleUser,
Status: service.StatusActive,
Balance: 10,
Concurrency: 3,
}
apiKey := &service.APIKey{
ID: 501,
UserID: user.ID,
Key: "google-sub-limit",
Status: service.StatusActive,
User: user,
Group: group,
}
apiKey.GroupID = &group.ID
apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
if key != apiKey.Key {
return nil, service.ErrAPIKeyNotFound
}
clone := *apiKey
return &clone, nil
},
})
now := time.Now()
sub := &service.UserSubscription{
ID: 601,
UserID: user.ID,
GroupID: group.ID,
Status: service.SubscriptionStatusActive,
ExpiresAt: now.Add(24 * time.Hour),
DailyWindowStart: &now,
DailyUsageUSD: 10,
}
subscriptionService := service.NewSubscriptionService(nil, fakeGoogleSubscriptionRepo{
getActive: func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
if userID != user.ID || groupID != group.ID {
return nil, service.ErrSubscriptionNotFound
}
clone := *sub
return &clone, nil
},
updateStatus: func(ctx context.Context, subscriptionID int64, status string) error { return nil },
activateWindow: func(ctx context.Context, id int64, start time.Time) error { return nil },
resetDaily: func(ctx context.Context, id int64, start time.Time) error { return nil },
resetWeekly: func(ctx context.Context, id int64, start time.Time) error { return nil },
resetMonthly: func(ctx context.Context, id int64, start time.Time) error { return nil },
}, nil, nil, &config.Config{RunMode: config.RunModeStandard})
r := gin.New()
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, &config.Config{RunMode: config.RunModeStandard}))
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
req.Header.Set("x-goog-api-key", apiKey.Key)
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
require.Equal(t, http.StatusTooManyRequests, rec.Code)
var resp googleErrorResponse
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, http.StatusTooManyRequests, resp.Error.Code)
require.Equal(t, "RESOURCE_EXHAUSTED", resp.Error.Status)
require.Contains(t, resp.Error.Message, "daily usage limit exceeded")
}

View File

@@ -537,7 +537,7 @@ func (r *stubApiKeyRepo) Delete(ctx context.Context, id int64) error {
return errors.New("not implemented")
}
func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams, _ service.APIKeyListFilters) ([]service.APIKey, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
@@ -588,6 +588,16 @@ func (r *stubApiKeyRepo) UpdateLastUsed(ctx context.Context, id int64, usedAt ti
return nil
}
func (r *stubApiKeyRepo) IncrementRateLimitUsage(ctx context.Context, id int64, cost float64) error {
return nil
}
func (r *stubApiKeyRepo) ResetRateLimitWindows(ctx context.Context, id int64) error {
return nil
}
func (r *stubApiKeyRepo) GetRateLimitData(ctx context.Context, id int64) (*service.APIKeyRateLimitData, error) {
return nil, nil
}
type stubUserSubscriptionRepo struct {
getActive func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error)
updateStatus func(ctx context.Context, subscriptionID int64, status string) error

View File

@@ -40,7 +40,7 @@ func newJWTTestEnv(users map[int64]*service.User) (*gin.Engine, *service.AuthSer
cfg.JWT.AccessTokenExpireMinutes = 60
userRepo := &stubJWTUserRepo{users: users}
authSvc := service.NewAuthService(userRepo, nil, nil, cfg, nil, nil, nil, nil, nil)
authSvc := service.NewAuthService(userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
userSvc := service.NewUserService(userRepo, nil, nil)
mw := NewJWTAuthMiddleware(authSvc, userSvc)

View File

@@ -2,8 +2,11 @@ package middleware
import (
"context"
"net/http"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
@@ -71,3 +74,48 @@ func AbortWithError(c *gin.Context, statusCode int, code, message string) {
c.JSON(statusCode, NewErrorResponse(code, message))
c.Abort()
}
// ──────────────────────────────────────────────────────────
// RequireGroupAssignment — 未分组 Key 拦截中间件
// ──────────────────────────────────────────────────────────
// GatewayErrorWriter 定义网关错误响应格式(不同协议使用不同格式)
type GatewayErrorWriter func(c *gin.Context, status int, message string)
// AnthropicErrorWriter 按 Anthropic API 规范输出错误
func AnthropicErrorWriter(c *gin.Context, status int, message string) {
c.JSON(status, gin.H{
"type": "error",
"error": gin.H{"type": "permission_error", "message": message},
})
}
// GoogleErrorWriter 按 Google API 规范输出错误
func GoogleErrorWriter(c *gin.Context, status int, message string) {
c.JSON(status, gin.H{
"error": gin.H{
"code": status,
"message": message,
"status": googleapi.HTTPStatusToGoogleStatus(status),
},
})
}
// RequireGroupAssignment 检查 API Key 是否已分配到分组,
// 如果未分组且系统设置不允许未分组 Key 调度则返回 403。
func RequireGroupAssignment(settingService *service.SettingService, writeError GatewayErrorWriter) gin.HandlerFunc {
return func(c *gin.Context) {
apiKey, ok := GetAPIKeyFromContext(c)
if !ok || apiKey.GroupID != nil {
c.Next()
return
}
// 未分组 Key — 检查系统设置
if settingService.IsUngroupedKeySchedulingAllowed(c.Request.Context()) {
c.Next()
return
}
writeError(c, http.StatusForbidden, "API Key is not assigned to any group and cannot be used. Please contact the administrator to assign it to a group.")
c.Abort()
}
}

View File

@@ -41,9 +41,9 @@ func GetNonceFromContext(c *gin.Context) string {
}
// SecurityHeaders sets baseline security headers for all responses.
// getFrameSrc is an optional function that returns an extra origin to inject into frame-src;
// getFrameSrcOrigins is an optional function that returns extra origins to inject into frame-src;
// pass nil to disable dynamic frame-src injection.
func SecurityHeaders(cfg config.CSPConfig, getFrameSrc func() string) gin.HandlerFunc {
func SecurityHeaders(cfg config.CSPConfig, getFrameSrcOrigins func() []string) gin.HandlerFunc {
policy := strings.TrimSpace(cfg.Policy)
if policy == "" {
policy = config.DefaultCSPPolicy
@@ -54,15 +54,21 @@ func SecurityHeaders(cfg config.CSPConfig, getFrameSrc func() string) gin.Handle
return func(c *gin.Context) {
finalPolicy := policy
if getFrameSrc != nil {
if origin := getFrameSrc(); origin != "" {
finalPolicy = addToDirective(finalPolicy, "frame-src", origin)
if getFrameSrcOrigins != nil {
for _, origin := range getFrameSrcOrigins() {
if origin != "" {
finalPolicy = addToDirective(finalPolicy, "frame-src", origin)
}
}
}
c.Header("X-Content-Type-Options", "nosniff")
c.Header("X-Frame-Options", "DENY")
c.Header("Referrer-Policy", "strict-origin-when-cross-origin")
if isAPIRoutePath(c) {
c.Next()
return
}
if cfg.Enabled {
// Generate nonce for this request
@@ -80,6 +86,18 @@ func SecurityHeaders(cfg config.CSPConfig, getFrameSrc func() string) gin.Handle
}
}
func isAPIRoutePath(c *gin.Context) bool {
if c == nil || c.Request == nil || c.Request.URL == nil {
return false
}
path := c.Request.URL.Path
return strings.HasPrefix(path, "/v1/") ||
strings.HasPrefix(path, "/v1beta/") ||
strings.HasPrefix(path, "/antigravity/") ||
strings.HasPrefix(path, "/sora/") ||
strings.HasPrefix(path, "/responses")
}
// enhanceCSPPolicy ensures the CSP policy includes nonce support and Cloudflare Insights domain.
// This allows the application to work correctly even if the config file has an older CSP policy.
func enhanceCSPPolicy(policy string) string {

View File

@@ -131,6 +131,26 @@ func TestSecurityHeaders(t *testing.T) {
assert.Contains(t, csp, CloudflareInsightsDomain)
})
t.Run("api_route_skips_csp_nonce_generation", func(t *testing.T) {
cfg := config.CSPConfig{
Enabled: true,
Policy: "default-src 'self'; script-src 'self' __CSP_NONCE__",
}
middleware := SecurityHeaders(cfg, nil)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
middleware(c)
assert.Equal(t, "nosniff", w.Header().Get("X-Content-Type-Options"))
assert.Equal(t, "DENY", w.Header().Get("X-Frame-Options"))
assert.Equal(t, "strict-origin-when-cross-origin", w.Header().Get("Referrer-Policy"))
assert.Empty(t, w.Header().Get("Content-Security-Policy"))
assert.Empty(t, GetNonceFromContext(c))
})
t.Run("csp_enabled_with_nonce_placeholder", func(t *testing.T) {
cfg := config.CSPConfig{
Enabled: true,

View File

@@ -3,8 +3,6 @@ package server
import (
"context"
"log"
"net/url"
"strings"
"sync/atomic"
"time"
@@ -19,24 +17,7 @@ import (
"github.com/redis/go-redis/v9"
)
// extractOrigin returns the scheme+host origin from rawURL, or "" on error.
// Only http and https schemes are accepted; other values (e.g. "//host/path") return "".
func extractOrigin(rawURL string) string {
rawURL = strings.TrimSpace(rawURL)
if rawURL == "" {
return ""
}
u, err := url.Parse(rawURL)
if err != nil || u.Host == "" {
return ""
}
if u.Scheme != "http" && u.Scheme != "https" {
return ""
}
return u.Scheme + "://" + u.Host
}
const paymentOriginFetchTimeout = 5 * time.Second
const frameSrcRefreshTimeout = 5 * time.Second
// SetupRouter 配置路由器中间件和路由
func SetupRouter(
@@ -52,38 +33,32 @@ func SetupRouter(
cfg *config.Config,
redisClient *redis.Client,
) *gin.Engine {
// 缓存 purchase_subscription_url 的 origin用于动态注入 CSP frame-src
var cachedPaymentOrigin atomic.Pointer[string]
empty := ""
cachedPaymentOrigin.Store(&empty)
// 缓存 iframe 页面的 origin 列表,用于动态注入 CSP frame-src
var cachedFrameOrigins atomic.Pointer[[]string]
emptyOrigins := []string{}
cachedFrameOrigins.Store(&emptyOrigins)
refreshPaymentOrigin := func() {
ctx, cancel := context.WithTimeout(context.Background(), paymentOriginFetchTimeout)
refreshFrameOrigins := func() {
ctx, cancel := context.WithTimeout(context.Background(), frameSrcRefreshTimeout)
defer cancel()
settings, err := settingService.GetPublicSettings(ctx)
origins, err := settingService.GetFrameSrcOrigins(ctx)
if err != nil {
// 获取失败时保留已有缓存,避免 frame-src 被意外清空
return
}
if settings.PurchaseSubscriptionEnabled {
origin := extractOrigin(settings.PurchaseSubscriptionURL)
cachedPaymentOrigin.Store(&origin)
} else {
e := ""
cachedPaymentOrigin.Store(&e)
}
cachedFrameOrigins.Store(&origins)
}
refreshPaymentOrigin() // 启动时初始化
refreshFrameOrigins() // 启动时初始化
// 应用中间件
r.Use(middleware2.RequestLogger())
r.Use(middleware2.Logger())
r.Use(middleware2.CORS(cfg.CORS))
r.Use(middleware2.SecurityHeaders(cfg.Security.CSP, func() string {
if p := cachedPaymentOrigin.Load(); p != nil {
r.Use(middleware2.SecurityHeaders(cfg.Security.CSP, func() []string {
if p := cachedFrameOrigins.Load(); p != nil {
return *p
}
return ""
return nil
}))
// Serve embedded frontend with settings injection if available
@@ -92,21 +67,21 @@ func SetupRouter(
if err != nil {
log.Printf("Warning: Failed to create frontend server with settings injection: %v, using legacy mode", err)
r.Use(web.ServeEmbeddedFrontend())
settingService.SetOnUpdateCallback(refreshPaymentOrigin)
settingService.SetOnUpdateCallback(refreshFrameOrigins)
} else {
// Register combined callback: invalidate HTML cache + refresh payment origin
// Register combined callback: invalidate HTML cache + refresh frame origins
settingService.SetOnUpdateCallback(func() {
frontendServer.InvalidateCache()
refreshPaymentOrigin()
refreshFrameOrigins()
})
r.Use(frontendServer.Middleware())
}
} else {
settingService.SetOnUpdateCallback(refreshPaymentOrigin)
settingService.SetOnUpdateCallback(refreshFrameOrigins)
}
// 注册路由
registerRoutes(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, cfg, redisClient)
registerRoutes(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg, redisClient)
return r
}
@@ -121,6 +96,7 @@ func registerRoutes(
apiKeyService *service.APIKeyService,
subscriptionService *service.SubscriptionService,
opsService *service.OpsService,
settingService *service.SettingService,
cfg *config.Config,
redisClient *redis.Client,
) {
@@ -133,6 +109,7 @@ func registerRoutes(
// 注册各模块路由
routes.RegisterAuthRoutes(v1, h, jwtAuth, redisClient)
routes.RegisterUserRoutes(v1, h, jwtAuth)
routes.RegisterSoraClientRoutes(v1, h, jwtAuth)
routes.RegisterAdminRoutes(v1, h, adminAuth)
routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, cfg)
routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg)
}

View File

@@ -55,6 +55,9 @@ func RegisterAdminRoutes(
// 系统设置
registerSettingsRoutes(admin, h)
// 数据管理
registerDataManagementRoutes(admin, h)
// 运维监控Ops
registerOpsRoutes(admin, h)
@@ -72,6 +75,16 @@ func RegisterAdminRoutes(
// 错误透传规则管理
registerErrorPassthroughRoutes(admin, h)
// API Key 管理
registerAdminAPIKeyRoutes(admin, h)
}
}
func registerAdminAPIKeyRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
apiKeys := admin.Group("/api-keys")
{
apiKeys.PUT("/:id", h.Admin.APIKey.UpdateGroup)
}
}
@@ -155,6 +168,7 @@ func registerOpsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
ops.GET("/system-logs/health", h.Admin.Ops.GetSystemLogIngestionHealth)
// Dashboard (vNext - raw path for MVP)
ops.GET("/dashboard/snapshot-v2", h.Admin.Ops.GetDashboardSnapshotV2)
ops.GET("/dashboard/overview", h.Admin.Ops.GetDashboardOverview)
ops.GET("/dashboard/throughput-trend", h.Admin.Ops.GetDashboardThroughputTrend)
ops.GET("/dashboard/latency-histogram", h.Admin.Ops.GetDashboardLatencyHistogram)
@@ -167,6 +181,7 @@ func registerOpsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
func registerDashboardRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
dashboard := admin.Group("/dashboard")
{
dashboard.GET("/snapshot-v2", h.Admin.Dashboard.GetSnapshotV2)
dashboard.GET("/stats", h.Admin.Dashboard.GetStats)
dashboard.GET("/realtime", h.Admin.Dashboard.GetRealtimeMetrics)
dashboard.GET("/trend", h.Admin.Dashboard.GetUsageTrend)
@@ -232,6 +247,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
accounts.POST("/:id/clear-error", h.Admin.Account.ClearError)
accounts.GET("/:id/usage", h.Admin.Account.GetUsage)
accounts.GET("/:id/today-stats", h.Admin.Account.GetTodayStats)
accounts.POST("/today-stats/batch", h.Admin.Account.GetBatchTodayStats)
accounts.POST("/:id/clear-rate-limit", h.Admin.Account.ClearRateLimit)
accounts.GET("/:id/temp-unschedulable", h.Admin.Account.GetTempUnschedulable)
accounts.DELETE("/:id/temp-unschedulable", h.Admin.Account.ClearTempUnschedulable)
@@ -372,6 +388,38 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
// 流超时处理配置
adminSettings.GET("/stream-timeout", h.Admin.Setting.GetStreamTimeoutSettings)
adminSettings.PUT("/stream-timeout", h.Admin.Setting.UpdateStreamTimeoutSettings)
// Sora S3 存储配置
adminSettings.GET("/sora-s3", h.Admin.Setting.GetSoraS3Settings)
adminSettings.PUT("/sora-s3", h.Admin.Setting.UpdateSoraS3Settings)
adminSettings.POST("/sora-s3/test", h.Admin.Setting.TestSoraS3Connection)
adminSettings.GET("/sora-s3/profiles", h.Admin.Setting.ListSoraS3Profiles)
adminSettings.POST("/sora-s3/profiles", h.Admin.Setting.CreateSoraS3Profile)
adminSettings.PUT("/sora-s3/profiles/:profile_id", h.Admin.Setting.UpdateSoraS3Profile)
adminSettings.DELETE("/sora-s3/profiles/:profile_id", h.Admin.Setting.DeleteSoraS3Profile)
adminSettings.POST("/sora-s3/profiles/:profile_id/activate", h.Admin.Setting.SetActiveSoraS3Profile)
}
}
func registerDataManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
dataManagement := admin.Group("/data-management")
{
dataManagement.GET("/agent/health", h.Admin.DataManagement.GetAgentHealth)
dataManagement.GET("/config", h.Admin.DataManagement.GetConfig)
dataManagement.PUT("/config", h.Admin.DataManagement.UpdateConfig)
dataManagement.GET("/sources/:source_type/profiles", h.Admin.DataManagement.ListSourceProfiles)
dataManagement.POST("/sources/:source_type/profiles", h.Admin.DataManagement.CreateSourceProfile)
dataManagement.PUT("/sources/:source_type/profiles/:profile_id", h.Admin.DataManagement.UpdateSourceProfile)
dataManagement.DELETE("/sources/:source_type/profiles/:profile_id", h.Admin.DataManagement.DeleteSourceProfile)
dataManagement.POST("/sources/:source_type/profiles/:profile_id/activate", h.Admin.DataManagement.SetActiveSourceProfile)
dataManagement.POST("/s3/test", h.Admin.DataManagement.TestS3)
dataManagement.GET("/s3/profiles", h.Admin.DataManagement.ListS3Profiles)
dataManagement.POST("/s3/profiles", h.Admin.DataManagement.CreateS3Profile)
dataManagement.PUT("/s3/profiles/:profile_id", h.Admin.DataManagement.UpdateS3Profile)
dataManagement.DELETE("/s3/profiles/:profile_id", h.Admin.DataManagement.DeleteS3Profile)
dataManagement.POST("/s3/profiles/:profile_id/activate", h.Admin.DataManagement.SetActiveS3Profile)
dataManagement.POST("/backups", h.Admin.DataManagement.CreateBackupJob)
dataManagement.GET("/backups", h.Admin.DataManagement.ListBackupJobs)
dataManagement.GET("/backups/:job_id", h.Admin.DataManagement.GetBackupJob)
}
}

View File

@@ -19,6 +19,7 @@ func RegisterGatewayRoutes(
apiKeyService *service.APIKeyService,
subscriptionService *service.SubscriptionService,
opsService *service.OpsService,
settingService *service.SettingService,
cfg *config.Config,
) {
bodyLimit := middleware.RequestBodyLimit(cfg.Gateway.MaxBodySize)
@@ -30,12 +31,17 @@ func RegisterGatewayRoutes(
clientRequestID := middleware.ClientRequestID()
opsErrorLogger := handler.OpsErrorLoggerMiddleware(opsService)
// 未分组 Key 拦截中间件(按协议格式区分错误响应)
requireGroupAnthropic := middleware.RequireGroupAssignment(settingService, middleware.AnthropicErrorWriter)
requireGroupGoogle := middleware.RequireGroupAssignment(settingService, middleware.GoogleErrorWriter)
// API网关Claude API兼容
gateway := r.Group("/v1")
gateway.Use(bodyLimit)
gateway.Use(clientRequestID)
gateway.Use(opsErrorLogger)
gateway.Use(gin.HandlerFunc(apiKeyAuth))
gateway.Use(requireGroupAnthropic)
{
gateway.POST("/messages", h.Gateway.Messages)
gateway.POST("/messages/count_tokens", h.Gateway.CountTokens)
@@ -43,6 +49,7 @@ func RegisterGatewayRoutes(
gateway.GET("/usage", h.Gateway.Usage)
// OpenAI Responses API
gateway.POST("/responses", h.OpenAIGateway.Responses)
gateway.GET("/responses", h.OpenAIGateway.ResponsesWebSocket)
// 明确阻止旧协议入口OpenAI 仅支持 Responses API避免客户端误解为会自动路由到其它平台。
gateway.POST("/chat/completions", func(c *gin.Context) {
c.JSON(http.StatusBadRequest, gin.H{
@@ -60,6 +67,7 @@ func RegisterGatewayRoutes(
gemini.Use(clientRequestID)
gemini.Use(opsErrorLogger)
gemini.Use(middleware.APIKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
gemini.Use(requireGroupGoogle)
{
gemini.GET("/models", h.Gateway.GeminiV1BetaListModels)
gemini.GET("/models/:model", h.Gateway.GeminiV1BetaGetModel)
@@ -68,10 +76,11 @@ func RegisterGatewayRoutes(
}
// OpenAI Responses API不带v1前缀的别名
r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), h.OpenAIGateway.Responses)
r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses)
r.GET("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ResponsesWebSocket)
// Antigravity 模型列表
r.GET("/antigravity/models", gin.HandlerFunc(apiKeyAuth), h.Gateway.AntigravityModels)
r.GET("/antigravity/models", gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.Gateway.AntigravityModels)
// Antigravity 专用路由(仅使用 antigravity 账户,不混合调度)
antigravityV1 := r.Group("/antigravity/v1")
@@ -80,6 +89,7 @@ func RegisterGatewayRoutes(
antigravityV1.Use(opsErrorLogger)
antigravityV1.Use(middleware.ForcePlatform(service.PlatformAntigravity))
antigravityV1.Use(gin.HandlerFunc(apiKeyAuth))
antigravityV1.Use(requireGroupAnthropic)
{
antigravityV1.POST("/messages", h.Gateway.Messages)
antigravityV1.POST("/messages/count_tokens", h.Gateway.CountTokens)
@@ -93,6 +103,7 @@ func RegisterGatewayRoutes(
antigravityV1Beta.Use(opsErrorLogger)
antigravityV1Beta.Use(middleware.ForcePlatform(service.PlatformAntigravity))
antigravityV1Beta.Use(middleware.APIKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
antigravityV1Beta.Use(requireGroupGoogle)
{
antigravityV1Beta.GET("/models", h.Gateway.GeminiV1BetaListModels)
antigravityV1Beta.GET("/models/:model", h.Gateway.GeminiV1BetaGetModel)
@@ -106,6 +117,7 @@ func RegisterGatewayRoutes(
soraV1.Use(opsErrorLogger)
soraV1.Use(middleware.ForcePlatform(service.PlatformSora))
soraV1.Use(gin.HandlerFunc(apiKeyAuth))
soraV1.Use(requireGroupAnthropic)
{
soraV1.POST("/chat/completions", h.SoraGateway.ChatCompletions)
soraV1.GET("/models", h.Gateway.Models)

View File

@@ -0,0 +1,33 @@
package routes
import (
"github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/gin-gonic/gin"
)
// RegisterSoraClientRoutes 注册 Sora 客户端 API 路由(需要用户认证)。
func RegisterSoraClientRoutes(
v1 *gin.RouterGroup,
h *handler.Handlers,
jwtAuth middleware.JWTAuthMiddleware,
) {
if h.SoraClient == nil {
return
}
authenticated := v1.Group("/sora")
authenticated.Use(gin.HandlerFunc(jwtAuth))
{
authenticated.POST("/generate", h.SoraClient.Generate)
authenticated.GET("/generations", h.SoraClient.ListGenerations)
authenticated.GET("/generations/:id", h.SoraClient.GetGeneration)
authenticated.DELETE("/generations/:id", h.SoraClient.DeleteGeneration)
authenticated.POST("/generations/:id/cancel", h.SoraClient.CancelGeneration)
authenticated.POST("/generations/:id/save", h.SoraClient.SaveToStorage)
authenticated.GET("/quota", h.SoraClient.GetQuota)
authenticated.GET("/models", h.SoraClient.GetModels)
authenticated.GET("/storage-status", h.SoraClient.GetStorageStatus)
}
}