mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-17 13:24:45 +08:00
Merge tag 'v0.1.90' into merge/upstream-v0.1.90
注册邮箱域名白名单策略上线,后台大数据场景性能大幅优化。 - 注册邮箱域名白名单:支持管理员配置允许注册的邮箱域名策略 - Keys 页面表单筛选:用户 /keys 页面支持按条件筛选 API Key - Settings 页面分 Tab 拆分:管理后台设置页面按功能模块分 Tab 展示 - 后台大数据场景加载性能优化:仪表盘/用户/账号/Ops 页面大数据集加载显著提速 - Usage 大表分页优化:默认避免全量 COUNT(*),大幅降低分页查询耗时 - 消除重复的 normalizeAccountIDList,补充新增组件的单元测试 - 清理无用文件和过时文档,精简项目结构 - EmailVerifyView 硬编码英文字符串替换为 i18n 调用 - 修复 Anthropic 平台无限流重置时间的 429 误标记账号限流问题 - 修复自定义菜单页面管理员视角菜单不生效问题 - 修复 Ops 错误详情弹窗未展示真实上游 payload 的问题 - 修复充值/订阅菜单 icon 显示问题 # Conflicts: # .gitignore # backend/cmd/server/VERSION # backend/ent/group.go # backend/ent/runtime/runtime.go # backend/ent/schema/group.go # backend/go.sum # backend/internal/handler/admin/account_handler.go # backend/internal/handler/admin/dashboard_handler.go # backend/internal/pkg/usagestats/usage_log_types.go # backend/internal/repository/group_repo.go # backend/internal/repository/usage_log_repo.go # backend/internal/server/middleware/security_headers.go # backend/internal/server/router.go # backend/internal/service/account_usage_service.go # backend/internal/service/admin_service_bulk_update_test.go # backend/internal/service/dashboard_service.go # backend/internal/service/gateway_service.go # frontend/src/api/admin/dashboard.ts # frontend/src/components/account/BulkEditAccountModal.vue # frontend/src/components/charts/GroupDistributionChart.vue # frontend/src/components/layout/AppSidebar.vue # frontend/src/i18n/locales/en.ts # frontend/src/i18n/locales/zh.ts # frontend/src/views/admin/GroupsView.vue # frontend/src/views/admin/SettingsView.vue # frontend/src/views/admin/UsageView.vue # frontend/src/views/user/PurchaseSubscriptionView.vue
This commit is contained in:
@@ -86,6 +86,15 @@ func TestAPIContracts(t *testing.T) {
|
||||
"last_used_at": null,
|
||||
"quota": 0,
|
||||
"quota_used": 0,
|
||||
"rate_limit_5h": 0,
|
||||
"rate_limit_1d": 0,
|
||||
"rate_limit_7d": 0,
|
||||
"usage_5h": 0,
|
||||
"usage_1d": 0,
|
||||
"usage_7d": 0,
|
||||
"window_5h_start": null,
|
||||
"window_1d_start": null,
|
||||
"window_7d_start": null,
|
||||
"expires_at": null,
|
||||
"created_at": "2025-01-02T03:04:05Z",
|
||||
"updated_at": "2025-01-02T03:04:05Z"
|
||||
@@ -126,6 +135,15 @@ func TestAPIContracts(t *testing.T) {
|
||||
"last_used_at": null,
|
||||
"quota": 0,
|
||||
"quota_used": 0,
|
||||
"rate_limit_5h": 0,
|
||||
"rate_limit_1d": 0,
|
||||
"rate_limit_7d": 0,
|
||||
"usage_5h": 0,
|
||||
"usage_1d": 0,
|
||||
"usage_7d": 0,
|
||||
"window_5h_start": null,
|
||||
"window_1d_start": null,
|
||||
"window_7d_start": null,
|
||||
"expires_at": null,
|
||||
"created_at": "2025-01-02T03:04:05Z",
|
||||
"updated_at": "2025-01-02T03:04:05Z"
|
||||
@@ -186,11 +204,12 @@ func TestAPIContracts(t *testing.T) {
|
||||
"image_price_1k": null,
|
||||
"image_price_2k": null,
|
||||
"image_price_4k": null,
|
||||
"sora_image_price_360": null,
|
||||
"sora_image_price_540": null,
|
||||
"sora_video_price_per_request": null,
|
||||
"sora_video_price_per_request_hd": null,
|
||||
"claude_code_only": false,
|
||||
"sora_image_price_360": null,
|
||||
"sora_image_price_540": null,
|
||||
"sora_storage_quota_bytes": 0,
|
||||
"sora_video_price_per_request": null,
|
||||
"sora_video_price_per_request_hd": null,
|
||||
"claude_code_only": false,
|
||||
"fallback_group_id": null,
|
||||
"fallback_group_id_on_invalid_request": null,
|
||||
"created_at": "2025-01-02T03:04:05Z",
|
||||
@@ -384,10 +403,12 @@ func TestAPIContracts(t *testing.T) {
|
||||
"user_id": 1,
|
||||
"api_key_id": 100,
|
||||
"account_id": 200,
|
||||
"request_id": "req_123",
|
||||
"model": "claude-3",
|
||||
"group_id": null,
|
||||
"subscription_id": null,
|
||||
"request_id": "req_123",
|
||||
"model": "claude-3",
|
||||
"request_type": "stream",
|
||||
"openai_ws_mode": false,
|
||||
"group_id": null,
|
||||
"subscription_id": null,
|
||||
"input_tokens": 10,
|
||||
"output_tokens": 20,
|
||||
"cache_creation_tokens": 1,
|
||||
@@ -425,9 +446,10 @@ func TestAPIContracts(t *testing.T) {
|
||||
setup: func(t *testing.T, deps *contractDeps) {
|
||||
t.Helper()
|
||||
deps.settingRepo.SetAll(map[string]string{
|
||||
service.SettingKeyRegistrationEnabled: "true",
|
||||
service.SettingKeyEmailVerifyEnabled: "false",
|
||||
service.SettingKeyPromoCodeEnabled: "true",
|
||||
service.SettingKeyRegistrationEnabled: "true",
|
||||
service.SettingKeyEmailVerifyEnabled: "false",
|
||||
service.SettingKeyRegistrationEmailSuffixWhitelist: "[]",
|
||||
service.SettingKeyPromoCodeEnabled: "true",
|
||||
|
||||
service.SettingKeySMTPHost: "smtp.example.com",
|
||||
service.SettingKeySMTPPort: "587",
|
||||
@@ -466,6 +488,7 @@ func TestAPIContracts(t *testing.T) {
|
||||
"data": {
|
||||
"registration_enabled": true,
|
||||
"email_verify_enabled": false,
|
||||
"registration_email_suffix_whitelist": [],
|
||||
"promo_code_enabled": true,
|
||||
"password_reset_enabled": false,
|
||||
"totp_enabled": false,
|
||||
@@ -496,18 +519,23 @@ func TestAPIContracts(t *testing.T) {
|
||||
"doc_url": "https://docs.example.com",
|
||||
"default_concurrency": 5,
|
||||
"default_balance": 1.25,
|
||||
"default_subscriptions": [],
|
||||
"enable_model_fallback": false,
|
||||
"fallback_model_anthropic": "claude-3-5-sonnet-20241022",
|
||||
"fallback_model_antigravity": "gemini-2.5-pro",
|
||||
"fallback_model_gemini": "gemini-2.5-pro",
|
||||
"fallback_model_openai": "gpt-4o",
|
||||
"enable_identity_patch": true,
|
||||
"identity_patch_prompt": "",
|
||||
"invitation_code_enabled": false,
|
||||
"home_content": "",
|
||||
"fallback_model_openai": "gpt-4o",
|
||||
"enable_identity_patch": true,
|
||||
"identity_patch_prompt": "",
|
||||
"sora_client_enabled": false,
|
||||
"invitation_code_enabled": false,
|
||||
"home_content": "",
|
||||
"hide_ccs_import_button": false,
|
||||
"purchase_subscription_enabled": false,
|
||||
"purchase_subscription_url": ""
|
||||
"purchase_subscription_url": "",
|
||||
"min_claude_code_version": "",
|
||||
"allow_ungrouped_key_scheduling": false,
|
||||
"custom_menu_items": []
|
||||
}
|
||||
}`,
|
||||
},
|
||||
@@ -615,12 +643,12 @@ func newContractDeps(t *testing.T) *contractDeps {
|
||||
settingRepo := newStubSettingRepo()
|
||||
settingService := service.NewSettingService(settingRepo, cfg)
|
||||
|
||||
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil)
|
||||
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil)
|
||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
||||
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil)
|
||||
adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil, nil)
|
||||
adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
|
||||
jwtAuth := func(c *gin.Context) {
|
||||
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{
|
||||
@@ -775,6 +803,10 @@ func (r *stubUserRepo) RemoveGroupFromAllowedGroups(ctx context.Context, groupID
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserRepo) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
@@ -1016,6 +1048,14 @@ func (s *stubAccountRepo) ListSchedulableByGroupIDAndPlatforms(ctx context.Conte
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (s *stubAccountRepo) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (s *stubAccountRepo) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]service.Account, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (s *stubAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
@@ -1373,7 +1413,7 @@ func (r *stubApiKeyRepo) Delete(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
|
||||
func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams, _ service.APIKeyListFilters) ([]service.APIKey, *pagination.PaginationResult, error) {
|
||||
ids := make([]int64, 0, len(r.byID))
|
||||
for id := range r.byID {
|
||||
if r.byID[id].UserID == userID {
|
||||
@@ -1487,6 +1527,16 @@ func (r *stubApiKeyRepo) UpdateLastUsed(ctx context.Context, id int64, usedAt ti
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) IncrementRateLimitUsage(ctx context.Context, id int64, cost float64) error {
|
||||
return nil
|
||||
}
|
||||
func (r *stubApiKeyRepo) ResetRateLimitWindows(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
func (r *stubApiKeyRepo) GetRateLimitData(ctx context.Context, id int64) (*service.APIKeyRateLimitData, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
type stubUsageLogRepo struct {
|
||||
userLogs map[int64][]service.UsageLog
|
||||
}
|
||||
@@ -1555,11 +1605,15 @@ func (r *stubUsageLogRepo) GetDashboardStats(ctx context.Context) (*usagestats.D
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) {
|
||||
func (r *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) {
|
||||
func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUsageLogRepo) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ func TestAdminAuthJWTValidatesTokenVersion(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
cfg := &config.Config{JWT: config.JWTConfig{Secret: "test-secret", ExpireHour: 1}}
|
||||
authService := service.NewAuthService(nil, nil, nil, cfg, nil, nil, nil, nil, nil)
|
||||
authService := service.NewAuthService(nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
|
||||
|
||||
admin := &service.User{
|
||||
ID: 1,
|
||||
@@ -181,6 +181,10 @@ func (s *stubUserRepo) RemoveGroupFromAllowedGroups(ctx context.Context, groupID
|
||||
panic("unexpected RemoveGroupFromAllowedGroups call")
|
||||
}
|
||||
|
||||
func (s *stubUserRepo) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error {
|
||||
panic("unexpected AddGroupToAllowedGroups call")
|
||||
}
|
||||
|
||||
func (s *stubUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
|
||||
panic("unexpected UpdateTotpSecret call")
|
||||
}
|
||||
|
||||
@@ -19,8 +19,16 @@ func NewAPIKeyAuthMiddleware(apiKeyService *service.APIKeyService, subscriptionS
|
||||
}
|
||||
|
||||
// apiKeyAuthWithSubscription API Key认证中间件(支持订阅验证)
|
||||
//
|
||||
// 中间件职责分为两层:
|
||||
// - 鉴权(Authentication):验证 Key 有效性、用户状态、IP 限制 —— 始终执行
|
||||
// - 计费执行(Billing Enforcement):过期/配额/订阅/余额检查 —— skipBilling 时整块跳过
|
||||
//
|
||||
// /v1/usage 端点只需鉴权,不需要计费执行(允许过期/配额耗尽的 Key 查询自身用量)。
|
||||
func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// ── 1. 提取 API Key ──────────────────────────────────────────
|
||||
|
||||
queryKey := strings.TrimSpace(c.Query("key"))
|
||||
queryApiKey := strings.TrimSpace(c.Query("api_key"))
|
||||
if queryKey != "" || queryApiKey != "" {
|
||||
@@ -56,7 +64,8 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
|
||||
return
|
||||
}
|
||||
|
||||
// 从数据库验证API key
|
||||
// ── 2. 验证 Key 存在 ─────────────────────────────────────────
|
||||
|
||||
apiKey, err := apiKeyService.GetByKey(c.Request.Context(), apiKeyString)
|
||||
if err != nil {
|
||||
if errors.Is(err, service.ErrAPIKeyNotFound) {
|
||||
@@ -67,29 +76,13 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
|
||||
return
|
||||
}
|
||||
|
||||
// 检查API key是否激活
|
||||
if !apiKey.IsActive() {
|
||||
// Provide more specific error message based on status
|
||||
switch apiKey.Status {
|
||||
case service.StatusAPIKeyQuotaExhausted:
|
||||
AbortWithError(c, 429, "API_KEY_QUOTA_EXHAUSTED", "API key 额度已用完")
|
||||
case service.StatusAPIKeyExpired:
|
||||
AbortWithError(c, 403, "API_KEY_EXPIRED", "API key 已过期")
|
||||
default:
|
||||
AbortWithError(c, 401, "API_KEY_DISABLED", "API key is disabled")
|
||||
}
|
||||
return
|
||||
}
|
||||
// ── 3. 基础鉴权(始终执行) ─────────────────────────────────
|
||||
|
||||
// 检查API Key是否过期(即使状态是active,也要检查时间)
|
||||
if apiKey.IsExpired() {
|
||||
AbortWithError(c, 403, "API_KEY_EXPIRED", "API key 已过期")
|
||||
return
|
||||
}
|
||||
|
||||
// 检查API Key配额是否耗尽
|
||||
if apiKey.IsQuotaExhausted() {
|
||||
AbortWithError(c, 429, "API_KEY_QUOTA_EXHAUSTED", "API key 额度已用完")
|
||||
// disabled / 未知状态 → 无条件拦截(expired 和 quota_exhausted 留给计费阶段)
|
||||
if !apiKey.IsActive() &&
|
||||
apiKey.Status != service.StatusAPIKeyExpired &&
|
||||
apiKey.Status != service.StatusAPIKeyQuotaExhausted {
|
||||
AbortWithError(c, 401, "API_KEY_DISABLED", "API key is disabled")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -97,7 +90,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
|
||||
// 注意:错误信息故意模糊,避免暴露具体的 IP 限制机制
|
||||
if len(apiKey.IPWhitelist) > 0 || len(apiKey.IPBlacklist) > 0 {
|
||||
clientIP := ip.GetTrustedClientIP(c)
|
||||
allowed, _ := ip.CheckIPRestriction(clientIP, apiKey.IPWhitelist, apiKey.IPBlacklist)
|
||||
allowed, _ := ip.CheckIPRestrictionWithCompiledRules(clientIP, apiKey.CompiledIPWhitelist, apiKey.CompiledIPBlacklist)
|
||||
if !allowed {
|
||||
AbortWithError(c, 403, "ACCESS_DENIED", "Access denied")
|
||||
return
|
||||
@@ -116,8 +109,9 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
|
||||
return
|
||||
}
|
||||
|
||||
// ── 4. SimpleMode → early return ─────────────────────────────
|
||||
|
||||
if cfg.RunMode == config.RunModeSimple {
|
||||
// 简易模式:跳过余额和订阅检查,但仍需设置必要的上下文
|
||||
c.Set(string(ContextKeyAPIKey), apiKey)
|
||||
c.Set(string(ContextKeyUser), AuthSubject{
|
||||
UserID: apiKey.User.ID,
|
||||
@@ -130,54 +124,89 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
|
||||
return
|
||||
}
|
||||
|
||||
// 判断计费方式:订阅模式 vs 余额模式
|
||||
// ── 5. 加载订阅(订阅模式时始终加载) ───────────────────────
|
||||
|
||||
// skipBilling: /v1/usage 只需鉴权,跳过所有计费执行
|
||||
skipBilling := c.Request.URL.Path == "/v1/usage"
|
||||
|
||||
var subscription *service.UserSubscription
|
||||
isSubscriptionType := apiKey.Group != nil && apiKey.Group.IsSubscriptionType()
|
||||
|
||||
if isSubscriptionType && subscriptionService != nil {
|
||||
// 订阅模式:获取订阅(L1 缓存 + singleflight)
|
||||
subscription, err := subscriptionService.GetActiveSubscription(
|
||||
sub, subErr := subscriptionService.GetActiveSubscription(
|
||||
c.Request.Context(),
|
||||
apiKey.User.ID,
|
||||
apiKey.Group.ID,
|
||||
)
|
||||
if err != nil {
|
||||
AbortWithError(c, 403, "SUBSCRIPTION_NOT_FOUND", "No active subscription found for this group")
|
||||
return
|
||||
}
|
||||
|
||||
// 合并验证 + 限额检查(纯内存操作)
|
||||
needsMaintenance, err := subscriptionService.ValidateAndCheckLimits(subscription, apiKey.Group)
|
||||
if err != nil {
|
||||
code := "SUBSCRIPTION_INVALID"
|
||||
status := 403
|
||||
if errors.Is(err, service.ErrDailyLimitExceeded) ||
|
||||
errors.Is(err, service.ErrWeeklyLimitExceeded) ||
|
||||
errors.Is(err, service.ErrMonthlyLimitExceeded) {
|
||||
code = "USAGE_LIMIT_EXCEEDED"
|
||||
status = 429
|
||||
if subErr != nil {
|
||||
if !skipBilling {
|
||||
AbortWithError(c, 403, "SUBSCRIPTION_NOT_FOUND", "No active subscription found for this group")
|
||||
return
|
||||
}
|
||||
AbortWithError(c, status, code, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 将订阅信息存入上下文
|
||||
c.Set(string(ContextKeySubscription), subscription)
|
||||
|
||||
// 窗口维护异步化(不阻塞请求)
|
||||
// 传递独立拷贝,避免与 handler 读取 context 中的 subscription 产生 data race
|
||||
if needsMaintenance {
|
||||
maintenanceCopy := *subscription
|
||||
subscriptionService.DoWindowMaintenance(&maintenanceCopy)
|
||||
}
|
||||
} else {
|
||||
// 余额模式:检查用户余额
|
||||
if apiKey.User.Balance <= 0 {
|
||||
AbortWithError(c, 403, "INSUFFICIENT_BALANCE", "Insufficient account balance")
|
||||
return
|
||||
// skipBilling: 订阅不存在也放行,handler 会返回可用的数据
|
||||
} else {
|
||||
subscription = sub
|
||||
}
|
||||
}
|
||||
|
||||
// 将API key和用户信息存入上下文
|
||||
// ── 6. 计费执行(skipBilling 时整块跳过) ────────────────────
|
||||
|
||||
if !skipBilling {
|
||||
// Key 状态检查
|
||||
switch apiKey.Status {
|
||||
case service.StatusAPIKeyQuotaExhausted:
|
||||
AbortWithError(c, 429, "API_KEY_QUOTA_EXHAUSTED", "API key 额度已用完")
|
||||
return
|
||||
case service.StatusAPIKeyExpired:
|
||||
AbortWithError(c, 403, "API_KEY_EXPIRED", "API key 已过期")
|
||||
return
|
||||
}
|
||||
|
||||
// 运行时过期/配额检查(即使状态是 active,也要检查时间和用量)
|
||||
if apiKey.IsExpired() {
|
||||
AbortWithError(c, 403, "API_KEY_EXPIRED", "API key 已过期")
|
||||
return
|
||||
}
|
||||
if apiKey.IsQuotaExhausted() {
|
||||
AbortWithError(c, 429, "API_KEY_QUOTA_EXHAUSTED", "API key 额度已用完")
|
||||
return
|
||||
}
|
||||
|
||||
// 订阅模式:验证订阅限额
|
||||
if subscription != nil {
|
||||
needsMaintenance, validateErr := subscriptionService.ValidateAndCheckLimits(subscription, apiKey.Group)
|
||||
if validateErr != nil {
|
||||
code := "SUBSCRIPTION_INVALID"
|
||||
status := 403
|
||||
if errors.Is(validateErr, service.ErrDailyLimitExceeded) ||
|
||||
errors.Is(validateErr, service.ErrWeeklyLimitExceeded) ||
|
||||
errors.Is(validateErr, service.ErrMonthlyLimitExceeded) {
|
||||
code = "USAGE_LIMIT_EXCEEDED"
|
||||
status = 429
|
||||
}
|
||||
AbortWithError(c, status, code, validateErr.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 窗口维护异步化(不阻塞请求)
|
||||
if needsMaintenance {
|
||||
maintenanceCopy := *subscription
|
||||
subscriptionService.DoWindowMaintenance(&maintenanceCopy)
|
||||
}
|
||||
} else {
|
||||
// 非订阅模式 或 订阅模式但 subscriptionService 未注入:回退到余额检查
|
||||
if apiKey.User.Balance <= 0 {
|
||||
AbortWithError(c, 403, "INSUFFICIENT_BALANCE", "Insufficient account balance")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── 7. 设置上下文 → Next ─────────────────────────────────────
|
||||
|
||||
if subscription != nil {
|
||||
c.Set(string(ContextKeySubscription), subscription)
|
||||
}
|
||||
c.Set(string(ContextKeyAPIKey), apiKey)
|
||||
c.Set(string(ContextKeyUser), AuthSubject{
|
||||
UserID: apiKey.User.ID,
|
||||
|
||||
@@ -80,17 +80,25 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs
|
||||
abortWithGoogleError(c, 403, "No active subscription found for this group")
|
||||
return
|
||||
}
|
||||
if err := subscriptionService.ValidateSubscription(c.Request.Context(), subscription); err != nil {
|
||||
abortWithGoogleError(c, 403, err.Error())
|
||||
return
|
||||
}
|
||||
_ = subscriptionService.CheckAndActivateWindow(c.Request.Context(), subscription)
|
||||
_ = subscriptionService.CheckAndResetWindows(c.Request.Context(), subscription)
|
||||
if err := subscriptionService.CheckUsageLimits(c.Request.Context(), subscription, apiKey.Group, 0); err != nil {
|
||||
abortWithGoogleError(c, 429, err.Error())
|
||||
|
||||
needsMaintenance, err := subscriptionService.ValidateAndCheckLimits(subscription, apiKey.Group)
|
||||
if err != nil {
|
||||
status := 403
|
||||
if errors.Is(err, service.ErrDailyLimitExceeded) ||
|
||||
errors.Is(err, service.ErrWeeklyLimitExceeded) ||
|
||||
errors.Is(err, service.ErrMonthlyLimitExceeded) {
|
||||
status = 429
|
||||
}
|
||||
abortWithGoogleError(c, status, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
c.Set(string(ContextKeySubscription), subscription)
|
||||
|
||||
if needsMaintenance {
|
||||
maintenanceCopy := *subscription
|
||||
subscriptionService.DoWindowMaintenance(&maintenanceCopy)
|
||||
}
|
||||
} else {
|
||||
if apiKey.User.Balance <= 0 {
|
||||
abortWithGoogleError(c, 403, "Insufficient account balance")
|
||||
|
||||
@@ -23,6 +23,15 @@ type fakeAPIKeyRepo struct {
|
||||
updateLastUsed func(ctx context.Context, id int64, usedAt time.Time) error
|
||||
}
|
||||
|
||||
type fakeGoogleSubscriptionRepo struct {
|
||||
getActive func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error)
|
||||
updateStatus func(ctx context.Context, subscriptionID int64, status string) error
|
||||
activateWindow func(ctx context.Context, id int64, start time.Time) error
|
||||
resetDaily func(ctx context.Context, id int64, start time.Time) error
|
||||
resetWeekly func(ctx context.Context, id int64, start time.Time) error
|
||||
resetMonthly func(ctx context.Context, id int64, start time.Time) error
|
||||
}
|
||||
|
||||
func (f fakeAPIKeyRepo) Create(ctx context.Context, key *service.APIKey) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
@@ -47,7 +56,7 @@ func (f fakeAPIKeyRepo) Update(ctx context.Context, key *service.APIKey) error {
|
||||
func (f fakeAPIKeyRepo) Delete(ctx context.Context, id int64) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
func (f fakeAPIKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
|
||||
func (f fakeAPIKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams, _ service.APIKeyListFilters) ([]service.APIKey, *pagination.PaginationResult, error) {
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeAPIKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
|
||||
@@ -86,6 +95,94 @@ func (f fakeAPIKeyRepo) UpdateLastUsed(ctx context.Context, id int64, usedAt tim
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (f fakeAPIKeyRepo) IncrementRateLimitUsage(ctx context.Context, id int64, cost float64) error {
|
||||
return nil
|
||||
}
|
||||
func (f fakeAPIKeyRepo) ResetRateLimitWindows(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
func (f fakeAPIKeyRepo) GetRateLimitData(ctx context.Context, id int64) (*service.APIKeyRateLimitData, error) {
|
||||
return &service.APIKeyRateLimitData{}, nil
|
||||
}
|
||||
|
||||
func (f fakeGoogleSubscriptionRepo) Create(ctx context.Context, sub *service.UserSubscription) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
func (f fakeGoogleSubscriptionRepo) GetByID(ctx context.Context, id int64) (*service.UserSubscription, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeGoogleSubscriptionRepo) GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeGoogleSubscriptionRepo) GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
|
||||
if f.getActive != nil {
|
||||
return f.getActive(ctx, userID, groupID)
|
||||
}
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeGoogleSubscriptionRepo) Update(ctx context.Context, sub *service.UserSubscription) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
func (f fakeGoogleSubscriptionRepo) Delete(ctx context.Context, id int64) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
func (f fakeGoogleSubscriptionRepo) ListByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeGoogleSubscriptionRepo) ListActiveByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeGoogleSubscriptionRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeGoogleSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeGoogleSubscriptionRepo) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {
|
||||
return false, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeGoogleSubscriptionRepo) ExtendExpiry(ctx context.Context, subscriptionID int64, newExpiresAt time.Time) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
func (f fakeGoogleSubscriptionRepo) UpdateStatus(ctx context.Context, subscriptionID int64, status string) error {
|
||||
if f.updateStatus != nil {
|
||||
return f.updateStatus(ctx, subscriptionID, status)
|
||||
}
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
func (f fakeGoogleSubscriptionRepo) UpdateNotes(ctx context.Context, subscriptionID int64, notes string) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
func (f fakeGoogleSubscriptionRepo) ActivateWindows(ctx context.Context, id int64, start time.Time) error {
|
||||
if f.activateWindow != nil {
|
||||
return f.activateWindow(ctx, id, start)
|
||||
}
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
func (f fakeGoogleSubscriptionRepo) ResetDailyUsage(ctx context.Context, id int64, start time.Time) error {
|
||||
if f.resetDaily != nil {
|
||||
return f.resetDaily(ctx, id, start)
|
||||
}
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
func (f fakeGoogleSubscriptionRepo) ResetWeeklyUsage(ctx context.Context, id int64, start time.Time) error {
|
||||
if f.resetWeekly != nil {
|
||||
return f.resetWeekly(ctx, id, start)
|
||||
}
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
func (f fakeGoogleSubscriptionRepo) ResetMonthlyUsage(ctx context.Context, id int64, start time.Time) error {
|
||||
if f.resetMonthly != nil {
|
||||
return f.resetMonthly(ctx, id, start)
|
||||
}
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
func (f fakeGoogleSubscriptionRepo) IncrementUsage(ctx context.Context, id int64, costUSD float64) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
func (f fakeGoogleSubscriptionRepo) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) {
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
|
||||
type googleErrorResponse struct {
|
||||
Error struct {
|
||||
@@ -505,3 +602,85 @@ func TestApiKeyAuthWithSubscriptionGoogle_TouchesLastUsedInStandardMode(t *testi
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
require.Equal(t, 1, touchCalls)
|
||||
}
|
||||
|
||||
func TestApiKeyAuthWithSubscriptionGoogle_SubscriptionLimitExceededReturns429(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
limit := 1.0
|
||||
group := &service.Group{
|
||||
ID: 77,
|
||||
Name: "gemini-sub",
|
||||
Status: service.StatusActive,
|
||||
Platform: service.PlatformGemini,
|
||||
Hydrated: true,
|
||||
SubscriptionType: service.SubscriptionTypeSubscription,
|
||||
DailyLimitUSD: &limit,
|
||||
}
|
||||
user := &service.User{
|
||||
ID: 999,
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Balance: 10,
|
||||
Concurrency: 3,
|
||||
}
|
||||
apiKey := &service.APIKey{
|
||||
ID: 501,
|
||||
UserID: user.ID,
|
||||
Key: "google-sub-limit",
|
||||
Status: service.StatusActive,
|
||||
User: user,
|
||||
Group: group,
|
||||
}
|
||||
apiKey.GroupID = &group.ID
|
||||
|
||||
apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
if key != apiKey.Key {
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
clone := *apiKey
|
||||
return &clone, nil
|
||||
},
|
||||
})
|
||||
|
||||
now := time.Now()
|
||||
sub := &service.UserSubscription{
|
||||
ID: 601,
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: now.Add(24 * time.Hour),
|
||||
DailyWindowStart: &now,
|
||||
DailyUsageUSD: 10,
|
||||
}
|
||||
subscriptionService := service.NewSubscriptionService(nil, fakeGoogleSubscriptionRepo{
|
||||
getActive: func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
|
||||
if userID != user.ID || groupID != group.ID {
|
||||
return nil, service.ErrSubscriptionNotFound
|
||||
}
|
||||
clone := *sub
|
||||
return &clone, nil
|
||||
},
|
||||
updateStatus: func(ctx context.Context, subscriptionID int64, status string) error { return nil },
|
||||
activateWindow: func(ctx context.Context, id int64, start time.Time) error { return nil },
|
||||
resetDaily: func(ctx context.Context, id int64, start time.Time) error { return nil },
|
||||
resetWeekly: func(ctx context.Context, id int64, start time.Time) error { return nil },
|
||||
resetMonthly: func(ctx context.Context, id int64, start time.Time) error { return nil },
|
||||
}, nil, nil, &config.Config{RunMode: config.RunModeStandard})
|
||||
|
||||
r := gin.New()
|
||||
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, &config.Config{RunMode: config.RunModeStandard}))
|
||||
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
|
||||
req.Header.Set("x-goog-api-key", apiKey.Key)
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusTooManyRequests, rec.Code)
|
||||
var resp googleErrorResponse
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Equal(t, http.StatusTooManyRequests, resp.Error.Code)
|
||||
require.Equal(t, "RESOURCE_EXHAUSTED", resp.Error.Status)
|
||||
require.Contains(t, resp.Error.Message, "daily usage limit exceeded")
|
||||
}
|
||||
|
||||
@@ -537,7 +537,7 @@ func (r *stubApiKeyRepo) Delete(ctx context.Context, id int64) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
|
||||
func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams, _ service.APIKeyListFilters) ([]service.APIKey, *pagination.PaginationResult, error) {
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
@@ -588,6 +588,16 @@ func (r *stubApiKeyRepo) UpdateLastUsed(ctx context.Context, id int64, usedAt ti
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) IncrementRateLimitUsage(ctx context.Context, id int64, cost float64) error {
|
||||
return nil
|
||||
}
|
||||
func (r *stubApiKeyRepo) ResetRateLimitWindows(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
func (r *stubApiKeyRepo) GetRateLimitData(ctx context.Context, id int64) (*service.APIKeyRateLimitData, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
type stubUserSubscriptionRepo struct {
|
||||
getActive func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error)
|
||||
updateStatus func(ctx context.Context, subscriptionID int64, status string) error
|
||||
|
||||
@@ -40,7 +40,7 @@ func newJWTTestEnv(users map[int64]*service.User) (*gin.Engine, *service.AuthSer
|
||||
cfg.JWT.AccessTokenExpireMinutes = 60
|
||||
|
||||
userRepo := &stubJWTUserRepo{users: users}
|
||||
authSvc := service.NewAuthService(userRepo, nil, nil, cfg, nil, nil, nil, nil, nil)
|
||||
authSvc := service.NewAuthService(userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
|
||||
userSvc := service.NewUserService(userRepo, nil, nil)
|
||||
mw := NewJWTAuthMiddleware(authSvc, userSvc)
|
||||
|
||||
|
||||
@@ -2,8 +2,11 @@ package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
@@ -71,3 +74,48 @@ func AbortWithError(c *gin.Context, statusCode int, code, message string) {
|
||||
c.JSON(statusCode, NewErrorResponse(code, message))
|
||||
c.Abort()
|
||||
}
|
||||
|
||||
// ──────────────────────────────────────────────────────────
|
||||
// RequireGroupAssignment — 未分组 Key 拦截中间件
|
||||
// ──────────────────────────────────────────────────────────
|
||||
|
||||
// GatewayErrorWriter 定义网关错误响应格式(不同协议使用不同格式)
|
||||
type GatewayErrorWriter func(c *gin.Context, status int, message string)
|
||||
|
||||
// AnthropicErrorWriter 按 Anthropic API 规范输出错误
|
||||
func AnthropicErrorWriter(c *gin.Context, status int, message string) {
|
||||
c.JSON(status, gin.H{
|
||||
"type": "error",
|
||||
"error": gin.H{"type": "permission_error", "message": message},
|
||||
})
|
||||
}
|
||||
|
||||
// GoogleErrorWriter 按 Google API 规范输出错误
|
||||
func GoogleErrorWriter(c *gin.Context, status int, message string) {
|
||||
c.JSON(status, gin.H{
|
||||
"error": gin.H{
|
||||
"code": status,
|
||||
"message": message,
|
||||
"status": googleapi.HTTPStatusToGoogleStatus(status),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// RequireGroupAssignment 检查 API Key 是否已分配到分组,
|
||||
// 如果未分组且系统设置不允许未分组 Key 调度则返回 403。
|
||||
func RequireGroupAssignment(settingService *service.SettingService, writeError GatewayErrorWriter) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
apiKey, ok := GetAPIKeyFromContext(c)
|
||||
if !ok || apiKey.GroupID != nil {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
// 未分组 Key — 检查系统设置
|
||||
if settingService.IsUngroupedKeySchedulingAllowed(c.Request.Context()) {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
writeError(c, http.StatusForbidden, "API Key is not assigned to any group and cannot be used. Please contact the administrator to assign it to a group.")
|
||||
c.Abort()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -41,9 +41,9 @@ func GetNonceFromContext(c *gin.Context) string {
|
||||
}
|
||||
|
||||
// SecurityHeaders sets baseline security headers for all responses.
|
||||
// getFrameSrc is an optional function that returns an extra origin to inject into frame-src;
|
||||
// getFrameSrcOrigins is an optional function that returns extra origins to inject into frame-src;
|
||||
// pass nil to disable dynamic frame-src injection.
|
||||
func SecurityHeaders(cfg config.CSPConfig, getFrameSrc func() string) gin.HandlerFunc {
|
||||
func SecurityHeaders(cfg config.CSPConfig, getFrameSrcOrigins func() []string) gin.HandlerFunc {
|
||||
policy := strings.TrimSpace(cfg.Policy)
|
||||
if policy == "" {
|
||||
policy = config.DefaultCSPPolicy
|
||||
@@ -54,15 +54,21 @@ func SecurityHeaders(cfg config.CSPConfig, getFrameSrc func() string) gin.Handle
|
||||
|
||||
return func(c *gin.Context) {
|
||||
finalPolicy := policy
|
||||
if getFrameSrc != nil {
|
||||
if origin := getFrameSrc(); origin != "" {
|
||||
finalPolicy = addToDirective(finalPolicy, "frame-src", origin)
|
||||
if getFrameSrcOrigins != nil {
|
||||
for _, origin := range getFrameSrcOrigins() {
|
||||
if origin != "" {
|
||||
finalPolicy = addToDirective(finalPolicy, "frame-src", origin)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.Header("X-Content-Type-Options", "nosniff")
|
||||
c.Header("X-Frame-Options", "DENY")
|
||||
c.Header("Referrer-Policy", "strict-origin-when-cross-origin")
|
||||
if isAPIRoutePath(c) {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
if cfg.Enabled {
|
||||
// Generate nonce for this request
|
||||
@@ -80,6 +86,18 @@ func SecurityHeaders(cfg config.CSPConfig, getFrameSrc func() string) gin.Handle
|
||||
}
|
||||
}
|
||||
|
||||
func isAPIRoutePath(c *gin.Context) bool {
|
||||
if c == nil || c.Request == nil || c.Request.URL == nil {
|
||||
return false
|
||||
}
|
||||
path := c.Request.URL.Path
|
||||
return strings.HasPrefix(path, "/v1/") ||
|
||||
strings.HasPrefix(path, "/v1beta/") ||
|
||||
strings.HasPrefix(path, "/antigravity/") ||
|
||||
strings.HasPrefix(path, "/sora/") ||
|
||||
strings.HasPrefix(path, "/responses")
|
||||
}
|
||||
|
||||
// enhanceCSPPolicy ensures the CSP policy includes nonce support and Cloudflare Insights domain.
|
||||
// This allows the application to work correctly even if the config file has an older CSP policy.
|
||||
func enhanceCSPPolicy(policy string) string {
|
||||
|
||||
@@ -131,6 +131,26 @@ func TestSecurityHeaders(t *testing.T) {
|
||||
assert.Contains(t, csp, CloudflareInsightsDomain)
|
||||
})
|
||||
|
||||
t.Run("api_route_skips_csp_nonce_generation", func(t *testing.T) {
|
||||
cfg := config.CSPConfig{
|
||||
Enabled: true,
|
||||
Policy: "default-src 'self'; script-src 'self' __CSP_NONCE__",
|
||||
}
|
||||
middleware := SecurityHeaders(cfg, nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
|
||||
middleware(c)
|
||||
|
||||
assert.Equal(t, "nosniff", w.Header().Get("X-Content-Type-Options"))
|
||||
assert.Equal(t, "DENY", w.Header().Get("X-Frame-Options"))
|
||||
assert.Equal(t, "strict-origin-when-cross-origin", w.Header().Get("Referrer-Policy"))
|
||||
assert.Empty(t, w.Header().Get("Content-Security-Policy"))
|
||||
assert.Empty(t, GetNonceFromContext(c))
|
||||
})
|
||||
|
||||
t.Run("csp_enabled_with_nonce_placeholder", func(t *testing.T) {
|
||||
cfg := config.CSPConfig{
|
||||
Enabled: true,
|
||||
|
||||
@@ -3,8 +3,6 @@ package server
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
@@ -19,24 +17,7 @@ import (
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// extractOrigin returns the scheme+host origin from rawURL, or "" on error.
|
||||
// Only http and https schemes are accepted; other values (e.g. "//host/path") return "".
|
||||
func extractOrigin(rawURL string) string {
|
||||
rawURL = strings.TrimSpace(rawURL)
|
||||
if rawURL == "" {
|
||||
return ""
|
||||
}
|
||||
u, err := url.Parse(rawURL)
|
||||
if err != nil || u.Host == "" {
|
||||
return ""
|
||||
}
|
||||
if u.Scheme != "http" && u.Scheme != "https" {
|
||||
return ""
|
||||
}
|
||||
return u.Scheme + "://" + u.Host
|
||||
}
|
||||
|
||||
const paymentOriginFetchTimeout = 5 * time.Second
|
||||
const frameSrcRefreshTimeout = 5 * time.Second
|
||||
|
||||
// SetupRouter 配置路由器中间件和路由
|
||||
func SetupRouter(
|
||||
@@ -52,38 +33,32 @@ func SetupRouter(
|
||||
cfg *config.Config,
|
||||
redisClient *redis.Client,
|
||||
) *gin.Engine {
|
||||
// 缓存 purchase_subscription_url 的 origin,用于动态注入 CSP frame-src
|
||||
var cachedPaymentOrigin atomic.Pointer[string]
|
||||
empty := ""
|
||||
cachedPaymentOrigin.Store(&empty)
|
||||
// 缓存 iframe 页面的 origin 列表,用于动态注入 CSP frame-src
|
||||
var cachedFrameOrigins atomic.Pointer[[]string]
|
||||
emptyOrigins := []string{}
|
||||
cachedFrameOrigins.Store(&emptyOrigins)
|
||||
|
||||
refreshPaymentOrigin := func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), paymentOriginFetchTimeout)
|
||||
refreshFrameOrigins := func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), frameSrcRefreshTimeout)
|
||||
defer cancel()
|
||||
settings, err := settingService.GetPublicSettings(ctx)
|
||||
origins, err := settingService.GetFrameSrcOrigins(ctx)
|
||||
if err != nil {
|
||||
// 获取失败时保留已有缓存,避免 frame-src 被意外清空
|
||||
return
|
||||
}
|
||||
if settings.PurchaseSubscriptionEnabled {
|
||||
origin := extractOrigin(settings.PurchaseSubscriptionURL)
|
||||
cachedPaymentOrigin.Store(&origin)
|
||||
} else {
|
||||
e := ""
|
||||
cachedPaymentOrigin.Store(&e)
|
||||
}
|
||||
cachedFrameOrigins.Store(&origins)
|
||||
}
|
||||
refreshPaymentOrigin() // 启动时初始化
|
||||
refreshFrameOrigins() // 启动时初始化
|
||||
|
||||
// 应用中间件
|
||||
r.Use(middleware2.RequestLogger())
|
||||
r.Use(middleware2.Logger())
|
||||
r.Use(middleware2.CORS(cfg.CORS))
|
||||
r.Use(middleware2.SecurityHeaders(cfg.Security.CSP, func() string {
|
||||
if p := cachedPaymentOrigin.Load(); p != nil {
|
||||
r.Use(middleware2.SecurityHeaders(cfg.Security.CSP, func() []string {
|
||||
if p := cachedFrameOrigins.Load(); p != nil {
|
||||
return *p
|
||||
}
|
||||
return ""
|
||||
return nil
|
||||
}))
|
||||
|
||||
// Serve embedded frontend with settings injection if available
|
||||
@@ -92,21 +67,21 @@ func SetupRouter(
|
||||
if err != nil {
|
||||
log.Printf("Warning: Failed to create frontend server with settings injection: %v, using legacy mode", err)
|
||||
r.Use(web.ServeEmbeddedFrontend())
|
||||
settingService.SetOnUpdateCallback(refreshPaymentOrigin)
|
||||
settingService.SetOnUpdateCallback(refreshFrameOrigins)
|
||||
} else {
|
||||
// Register combined callback: invalidate HTML cache + refresh payment origin
|
||||
// Register combined callback: invalidate HTML cache + refresh frame origins
|
||||
settingService.SetOnUpdateCallback(func() {
|
||||
frontendServer.InvalidateCache()
|
||||
refreshPaymentOrigin()
|
||||
refreshFrameOrigins()
|
||||
})
|
||||
r.Use(frontendServer.Middleware())
|
||||
}
|
||||
} else {
|
||||
settingService.SetOnUpdateCallback(refreshPaymentOrigin)
|
||||
settingService.SetOnUpdateCallback(refreshFrameOrigins)
|
||||
}
|
||||
|
||||
// 注册路由
|
||||
registerRoutes(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, cfg, redisClient)
|
||||
registerRoutes(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg, redisClient)
|
||||
|
||||
return r
|
||||
}
|
||||
@@ -121,6 +96,7 @@ func registerRoutes(
|
||||
apiKeyService *service.APIKeyService,
|
||||
subscriptionService *service.SubscriptionService,
|
||||
opsService *service.OpsService,
|
||||
settingService *service.SettingService,
|
||||
cfg *config.Config,
|
||||
redisClient *redis.Client,
|
||||
) {
|
||||
@@ -133,6 +109,7 @@ func registerRoutes(
|
||||
// 注册各模块路由
|
||||
routes.RegisterAuthRoutes(v1, h, jwtAuth, redisClient)
|
||||
routes.RegisterUserRoutes(v1, h, jwtAuth)
|
||||
routes.RegisterSoraClientRoutes(v1, h, jwtAuth)
|
||||
routes.RegisterAdminRoutes(v1, h, adminAuth)
|
||||
routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, cfg)
|
||||
routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg)
|
||||
}
|
||||
|
||||
@@ -55,6 +55,9 @@ func RegisterAdminRoutes(
|
||||
// 系统设置
|
||||
registerSettingsRoutes(admin, h)
|
||||
|
||||
// 数据管理
|
||||
registerDataManagementRoutes(admin, h)
|
||||
|
||||
// 运维监控(Ops)
|
||||
registerOpsRoutes(admin, h)
|
||||
|
||||
@@ -72,6 +75,16 @@ func RegisterAdminRoutes(
|
||||
|
||||
// 错误透传规则管理
|
||||
registerErrorPassthroughRoutes(admin, h)
|
||||
|
||||
// API Key 管理
|
||||
registerAdminAPIKeyRoutes(admin, h)
|
||||
}
|
||||
}
|
||||
|
||||
func registerAdminAPIKeyRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
apiKeys := admin.Group("/api-keys")
|
||||
{
|
||||
apiKeys.PUT("/:id", h.Admin.APIKey.UpdateGroup)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -155,6 +168,7 @@ func registerOpsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
ops.GET("/system-logs/health", h.Admin.Ops.GetSystemLogIngestionHealth)
|
||||
|
||||
// Dashboard (vNext - raw path for MVP)
|
||||
ops.GET("/dashboard/snapshot-v2", h.Admin.Ops.GetDashboardSnapshotV2)
|
||||
ops.GET("/dashboard/overview", h.Admin.Ops.GetDashboardOverview)
|
||||
ops.GET("/dashboard/throughput-trend", h.Admin.Ops.GetDashboardThroughputTrend)
|
||||
ops.GET("/dashboard/latency-histogram", h.Admin.Ops.GetDashboardLatencyHistogram)
|
||||
@@ -167,6 +181,7 @@ func registerOpsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
func registerDashboardRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
dashboard := admin.Group("/dashboard")
|
||||
{
|
||||
dashboard.GET("/snapshot-v2", h.Admin.Dashboard.GetSnapshotV2)
|
||||
dashboard.GET("/stats", h.Admin.Dashboard.GetStats)
|
||||
dashboard.GET("/realtime", h.Admin.Dashboard.GetRealtimeMetrics)
|
||||
dashboard.GET("/trend", h.Admin.Dashboard.GetUsageTrend)
|
||||
@@ -232,6 +247,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
accounts.POST("/:id/clear-error", h.Admin.Account.ClearError)
|
||||
accounts.GET("/:id/usage", h.Admin.Account.GetUsage)
|
||||
accounts.GET("/:id/today-stats", h.Admin.Account.GetTodayStats)
|
||||
accounts.POST("/today-stats/batch", h.Admin.Account.GetBatchTodayStats)
|
||||
accounts.POST("/:id/clear-rate-limit", h.Admin.Account.ClearRateLimit)
|
||||
accounts.GET("/:id/temp-unschedulable", h.Admin.Account.GetTempUnschedulable)
|
||||
accounts.DELETE("/:id/temp-unschedulable", h.Admin.Account.ClearTempUnschedulable)
|
||||
@@ -372,6 +388,38 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
// 流超时处理配置
|
||||
adminSettings.GET("/stream-timeout", h.Admin.Setting.GetStreamTimeoutSettings)
|
||||
adminSettings.PUT("/stream-timeout", h.Admin.Setting.UpdateStreamTimeoutSettings)
|
||||
// Sora S3 存储配置
|
||||
adminSettings.GET("/sora-s3", h.Admin.Setting.GetSoraS3Settings)
|
||||
adminSettings.PUT("/sora-s3", h.Admin.Setting.UpdateSoraS3Settings)
|
||||
adminSettings.POST("/sora-s3/test", h.Admin.Setting.TestSoraS3Connection)
|
||||
adminSettings.GET("/sora-s3/profiles", h.Admin.Setting.ListSoraS3Profiles)
|
||||
adminSettings.POST("/sora-s3/profiles", h.Admin.Setting.CreateSoraS3Profile)
|
||||
adminSettings.PUT("/sora-s3/profiles/:profile_id", h.Admin.Setting.UpdateSoraS3Profile)
|
||||
adminSettings.DELETE("/sora-s3/profiles/:profile_id", h.Admin.Setting.DeleteSoraS3Profile)
|
||||
adminSettings.POST("/sora-s3/profiles/:profile_id/activate", h.Admin.Setting.SetActiveSoraS3Profile)
|
||||
}
|
||||
}
|
||||
|
||||
func registerDataManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
dataManagement := admin.Group("/data-management")
|
||||
{
|
||||
dataManagement.GET("/agent/health", h.Admin.DataManagement.GetAgentHealth)
|
||||
dataManagement.GET("/config", h.Admin.DataManagement.GetConfig)
|
||||
dataManagement.PUT("/config", h.Admin.DataManagement.UpdateConfig)
|
||||
dataManagement.GET("/sources/:source_type/profiles", h.Admin.DataManagement.ListSourceProfiles)
|
||||
dataManagement.POST("/sources/:source_type/profiles", h.Admin.DataManagement.CreateSourceProfile)
|
||||
dataManagement.PUT("/sources/:source_type/profiles/:profile_id", h.Admin.DataManagement.UpdateSourceProfile)
|
||||
dataManagement.DELETE("/sources/:source_type/profiles/:profile_id", h.Admin.DataManagement.DeleteSourceProfile)
|
||||
dataManagement.POST("/sources/:source_type/profiles/:profile_id/activate", h.Admin.DataManagement.SetActiveSourceProfile)
|
||||
dataManagement.POST("/s3/test", h.Admin.DataManagement.TestS3)
|
||||
dataManagement.GET("/s3/profiles", h.Admin.DataManagement.ListS3Profiles)
|
||||
dataManagement.POST("/s3/profiles", h.Admin.DataManagement.CreateS3Profile)
|
||||
dataManagement.PUT("/s3/profiles/:profile_id", h.Admin.DataManagement.UpdateS3Profile)
|
||||
dataManagement.DELETE("/s3/profiles/:profile_id", h.Admin.DataManagement.DeleteS3Profile)
|
||||
dataManagement.POST("/s3/profiles/:profile_id/activate", h.Admin.DataManagement.SetActiveS3Profile)
|
||||
dataManagement.POST("/backups", h.Admin.DataManagement.CreateBackupJob)
|
||||
dataManagement.GET("/backups", h.Admin.DataManagement.ListBackupJobs)
|
||||
dataManagement.GET("/backups/:job_id", h.Admin.DataManagement.GetBackupJob)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ func RegisterGatewayRoutes(
|
||||
apiKeyService *service.APIKeyService,
|
||||
subscriptionService *service.SubscriptionService,
|
||||
opsService *service.OpsService,
|
||||
settingService *service.SettingService,
|
||||
cfg *config.Config,
|
||||
) {
|
||||
bodyLimit := middleware.RequestBodyLimit(cfg.Gateway.MaxBodySize)
|
||||
@@ -30,12 +31,17 @@ func RegisterGatewayRoutes(
|
||||
clientRequestID := middleware.ClientRequestID()
|
||||
opsErrorLogger := handler.OpsErrorLoggerMiddleware(opsService)
|
||||
|
||||
// 未分组 Key 拦截中间件(按协议格式区分错误响应)
|
||||
requireGroupAnthropic := middleware.RequireGroupAssignment(settingService, middleware.AnthropicErrorWriter)
|
||||
requireGroupGoogle := middleware.RequireGroupAssignment(settingService, middleware.GoogleErrorWriter)
|
||||
|
||||
// API网关(Claude API兼容)
|
||||
gateway := r.Group("/v1")
|
||||
gateway.Use(bodyLimit)
|
||||
gateway.Use(clientRequestID)
|
||||
gateway.Use(opsErrorLogger)
|
||||
gateway.Use(gin.HandlerFunc(apiKeyAuth))
|
||||
gateway.Use(requireGroupAnthropic)
|
||||
{
|
||||
gateway.POST("/messages", h.Gateway.Messages)
|
||||
gateway.POST("/messages/count_tokens", h.Gateway.CountTokens)
|
||||
@@ -43,6 +49,7 @@ func RegisterGatewayRoutes(
|
||||
gateway.GET("/usage", h.Gateway.Usage)
|
||||
// OpenAI Responses API
|
||||
gateway.POST("/responses", h.OpenAIGateway.Responses)
|
||||
gateway.GET("/responses", h.OpenAIGateway.ResponsesWebSocket)
|
||||
// 明确阻止旧协议入口:OpenAI 仅支持 Responses API,避免客户端误解为会自动路由到其它平台。
|
||||
gateway.POST("/chat/completions", func(c *gin.Context) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
@@ -60,6 +67,7 @@ func RegisterGatewayRoutes(
|
||||
gemini.Use(clientRequestID)
|
||||
gemini.Use(opsErrorLogger)
|
||||
gemini.Use(middleware.APIKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
|
||||
gemini.Use(requireGroupGoogle)
|
||||
{
|
||||
gemini.GET("/models", h.Gateway.GeminiV1BetaListModels)
|
||||
gemini.GET("/models/:model", h.Gateway.GeminiV1BetaGetModel)
|
||||
@@ -68,10 +76,11 @@ func RegisterGatewayRoutes(
|
||||
}
|
||||
|
||||
// OpenAI Responses API(不带v1前缀的别名)
|
||||
r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), h.OpenAIGateway.Responses)
|
||||
r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses)
|
||||
r.GET("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ResponsesWebSocket)
|
||||
|
||||
// Antigravity 模型列表
|
||||
r.GET("/antigravity/models", gin.HandlerFunc(apiKeyAuth), h.Gateway.AntigravityModels)
|
||||
r.GET("/antigravity/models", gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.Gateway.AntigravityModels)
|
||||
|
||||
// Antigravity 专用路由(仅使用 antigravity 账户,不混合调度)
|
||||
antigravityV1 := r.Group("/antigravity/v1")
|
||||
@@ -80,6 +89,7 @@ func RegisterGatewayRoutes(
|
||||
antigravityV1.Use(opsErrorLogger)
|
||||
antigravityV1.Use(middleware.ForcePlatform(service.PlatformAntigravity))
|
||||
antigravityV1.Use(gin.HandlerFunc(apiKeyAuth))
|
||||
antigravityV1.Use(requireGroupAnthropic)
|
||||
{
|
||||
antigravityV1.POST("/messages", h.Gateway.Messages)
|
||||
antigravityV1.POST("/messages/count_tokens", h.Gateway.CountTokens)
|
||||
@@ -93,6 +103,7 @@ func RegisterGatewayRoutes(
|
||||
antigravityV1Beta.Use(opsErrorLogger)
|
||||
antigravityV1Beta.Use(middleware.ForcePlatform(service.PlatformAntigravity))
|
||||
antigravityV1Beta.Use(middleware.APIKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
|
||||
antigravityV1Beta.Use(requireGroupGoogle)
|
||||
{
|
||||
antigravityV1Beta.GET("/models", h.Gateway.GeminiV1BetaListModels)
|
||||
antigravityV1Beta.GET("/models/:model", h.Gateway.GeminiV1BetaGetModel)
|
||||
@@ -106,6 +117,7 @@ func RegisterGatewayRoutes(
|
||||
soraV1.Use(opsErrorLogger)
|
||||
soraV1.Use(middleware.ForcePlatform(service.PlatformSora))
|
||||
soraV1.Use(gin.HandlerFunc(apiKeyAuth))
|
||||
soraV1.Use(requireGroupAnthropic)
|
||||
{
|
||||
soraV1.POST("/chat/completions", h.SoraGateway.ChatCompletions)
|
||||
soraV1.GET("/models", h.Gateway.Models)
|
||||
|
||||
33
backend/internal/server/routes/sora_client.go
Normal file
33
backend/internal/server/routes/sora_client.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package routes
|
||||
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// RegisterSoraClientRoutes 注册 Sora 客户端 API 路由(需要用户认证)。
|
||||
func RegisterSoraClientRoutes(
|
||||
v1 *gin.RouterGroup,
|
||||
h *handler.Handlers,
|
||||
jwtAuth middleware.JWTAuthMiddleware,
|
||||
) {
|
||||
if h.SoraClient == nil {
|
||||
return
|
||||
}
|
||||
|
||||
authenticated := v1.Group("/sora")
|
||||
authenticated.Use(gin.HandlerFunc(jwtAuth))
|
||||
{
|
||||
authenticated.POST("/generate", h.SoraClient.Generate)
|
||||
authenticated.GET("/generations", h.SoraClient.ListGenerations)
|
||||
authenticated.GET("/generations/:id", h.SoraClient.GetGeneration)
|
||||
authenticated.DELETE("/generations/:id", h.SoraClient.DeleteGeneration)
|
||||
authenticated.POST("/generations/:id/cancel", h.SoraClient.CancelGeneration)
|
||||
authenticated.POST("/generations/:id/save", h.SoraClient.SaveToStorage)
|
||||
authenticated.GET("/quota", h.SoraClient.GetQuota)
|
||||
authenticated.GET("/models", h.SoraClient.GetModels)
|
||||
authenticated.GET("/storage-status", h.SoraClient.GetStorageStatus)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user