diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index f767bbea..dab35577 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -65,7 +65,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { userGroupRateRepository := repository.NewUserGroupRateRepository(db) billingCacheService := service.ProvideBillingCacheService(billingCache, userRepository, userSubscriptionRepository, apiKeyRepository, userRPMCache, userGroupRateRepository, configConfig) apiKeyCache := repository.NewAPIKeyCache(redisClient) - apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, userGroupRateRepository, apiKeyCache, configConfig) + apiKeyService := service.ProvideAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, userGroupRateRepository, apiKeyCache, configConfig, billingCacheService) apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService) promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator) subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig) diff --git a/backend/internal/handler/admin/admin_service_stub_test.go b/backend/internal/handler/admin/admin_service_stub_test.go index 2fe29fa3..b187b47f 100644 --- a/backend/internal/handler/admin/admin_service_stub_test.go +++ b/backend/internal/handler/admin/admin_service_stub_test.go @@ -565,6 +565,22 @@ func (s *stubAdminService) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID i return nil, service.ErrAPIKeyNotFound } +func (s *stubAdminService) AdminResetAPIKeyRateLimitUsage(ctx context.Context, keyID int64) (*service.APIKey, error) { + for i := range s.apiKeys { + if s.apiKeys[i].ID == keyID { + s.apiKeys[i].Usage5h = 0 + s.apiKeys[i].Usage1d = 0 + s.apiKeys[i].Usage7d = 0 + s.apiKeys[i].Window5hStart = nil + s.apiKeys[i].Window1dStart = nil + s.apiKeys[i].Window7dStart = nil + k := s.apiKeys[i] + return &k, nil + } + } + return nil, service.ErrAPIKeyNotFound +} + func (s *stubAdminService) ResetAccountQuota(ctx context.Context, id int64) error { return nil } diff --git a/backend/internal/handler/admin/apikey_handler.go b/backend/internal/handler/admin/apikey_handler.go index 8dd245a4..5e405bdd 100644 --- a/backend/internal/handler/admin/apikey_handler.go +++ b/backend/internal/handler/admin/apikey_handler.go @@ -22,12 +22,13 @@ func NewAdminAPIKeyHandler(adminService service.AdminService) *AdminAPIKeyHandle } } -// AdminUpdateAPIKeyGroupRequest represents the request to update an API key's group +// AdminUpdateAPIKeyGroupRequest represents the request to update an API key. type AdminUpdateAPIKeyGroupRequest struct { - GroupID *int64 `json:"group_id"` // nil=不修改, 0=解绑, >0=绑定到目标分组 + GroupID *int64 `json:"group_id"` // nil=不修改, 0=解绑, >0=绑定到目标分组 + ResetRateLimitUsage *bool `json:"reset_rate_limit_usage"` // true=重置 5h/1d/7d 限速用量 } -// UpdateGroup handles updating an API key's group binding +// UpdateGroup handles updating an API key's admin-managed fields. // PUT /api/v1/admin/api-keys/:id func (h *AdminAPIKeyHandler) UpdateGroup(c *gin.Context) { keyID, err := strconv.ParseInt(c.Param("id"), 10, 64) @@ -42,11 +43,23 @@ func (h *AdminAPIKeyHandler) UpdateGroup(c *gin.Context) { return } + var resetKey *service.APIKey + if req.ResetRateLimitUsage != nil && *req.ResetRateLimitUsage { + resetKey, err = h.adminService.AdminResetAPIKeyRateLimitUsage(c.Request.Context(), keyID) + if err != nil { + response.ErrorFrom(c, err) + return + } + } + result, err := h.adminService.AdminUpdateAPIKeyGroupID(c.Request.Context(), keyID, req.GroupID) if err != nil { response.ErrorFrom(c, err) return } + if resetKey != nil && req.GroupID == nil { + result.APIKey = resetKey + } resp := struct { APIKey *dto.APIKey `json:"api_key"` diff --git a/backend/internal/handler/admin/apikey_handler_test.go b/backend/internal/handler/admin/apikey_handler_test.go index bf128b18..6ac6d52f 100644 --- a/backend/internal/handler/admin/apikey_handler_test.go +++ b/backend/internal/handler/admin/apikey_handler_test.go @@ -8,6 +8,7 @@ import ( "net/http" "net/http/httptest" "testing" + "time" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/service" @@ -117,6 +118,45 @@ func TestAdminAPIKeyHandler_UpdateGroup_Unbind(t *testing.T) { require.Nil(t, resp.Data.APIKey.GroupID) } +func TestAdminAPIKeyHandler_ResetRateLimitUsage(t *testing.T) { + svc := newStubAdminService() + now := time.Now() + svc.apiKeys[0].Usage5h = 1.2 + svc.apiKeys[0].Usage1d = 3.4 + svc.apiKeys[0].Usage7d = 5.6 + svc.apiKeys[0].Window5hStart = &now + svc.apiKeys[0].Window1dStart = &now + svc.apiKeys[0].Window7dStart = &now + router := setupAPIKeyHandler(svc) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(`{"reset_rate_limit_usage":true}`)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + + var resp struct { + Data struct { + APIKey struct { + Usage5h float64 `json:"usage_5h"` + Usage1d float64 `json:"usage_1d"` + Usage7d float64 `json:"usage_7d"` + Window5hStart *time.Time `json:"window_5h_start"` + Window1dStart *time.Time `json:"window_1d_start"` + Window7dStart *time.Time `json:"window_7d_start"` + } `json:"api_key"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + require.Zero(t, resp.Data.APIKey.Usage5h) + require.Zero(t, resp.Data.APIKey.Usage1d) + require.Zero(t, resp.Data.APIKey.Usage7d) + require.Nil(t, resp.Data.APIKey.Window5hStart) + require.Nil(t, resp.Data.APIKey.Window1dStart) + require.Nil(t, resp.Data.APIKey.Window7dStart) +} + func TestAdminAPIKeyHandler_UpdateGroup_ServiceError(t *testing.T) { svc := &failingUpdateGroupService{ stubAdminService: newStubAdminService(), diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index 434f1f38..cb0c5339 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -58,6 +58,7 @@ type AdminService interface { // API Key management (admin) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID int64, groupID *int64) (*AdminUpdateAPIKeyGroupIDResult, error) + AdminResetAPIKeyRateLimitUsage(ctx context.Context, keyID int64) (*APIKey, error) // ReplaceUserGroup 替换用户的专属分组:授予新分组权限、迁移 Key、移除旧分组权限 ReplaceUserGroup(ctx context.Context, userID, oldGroupID, newGroupID int64) (*ReplaceUserGroupResult, error) @@ -1961,6 +1962,30 @@ func (s *adminServiceImpl) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID i return result, nil } +// AdminResetAPIKeyRateLimitUsage resets all API key rate-limit usage windows. +func (s *adminServiceImpl) AdminResetAPIKeyRateLimitUsage(ctx context.Context, keyID int64) (*APIKey, error) { + apiKey, err := s.apiKeyRepo.GetByID(ctx, keyID) + if err != nil { + return nil, err + } + apiKey.Usage5h = 0 + apiKey.Usage1d = 0 + apiKey.Usage7d = 0 + apiKey.Window5hStart = nil + apiKey.Window1dStart = nil + apiKey.Window7dStart = nil + if err := s.apiKeyRepo.Update(ctx, apiKey); err != nil { + return nil, fmt.Errorf("reset api key rate limit usage: %w", err) + } + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByKey(ctx, apiKey.Key) + } + if s.billingCacheService != nil { + _ = s.billingCacheService.InvalidateAPIKeyRateLimit(ctx, apiKey.ID) + } + return apiKey, nil +} + // ReplaceUserGroup 替换用户的专属分组 func (s *adminServiceImpl) ReplaceUserGroup(ctx context.Context, userID, oldGroupID, newGroupID int64) (*ReplaceUserGroupResult, error) { if oldGroupID == newGroupID { diff --git a/backend/internal/service/billing_cache_service.go b/backend/internal/service/billing_cache_service.go index 4e695eb9..050db55b 100644 --- a/backend/internal/service/billing_cache_service.go +++ b/backend/internal/service/billing_cache_service.go @@ -508,6 +508,18 @@ func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID return nil } +// InvalidateAPIKeyRateLimit invalidates the Redis rate-limit usage cache for an API key. +func (s *BillingCacheService) InvalidateAPIKeyRateLimit(ctx context.Context, keyID int64) error { + if s.cache == nil { + return nil + } + if err := s.cache.InvalidateAPIKeyRateLimit(ctx, keyID); err != nil { + logger.LegacyPrintf("service.billing_cache", "Warning: invalidate api key rate limit cache failed for key %d: %v", keyID, err) + return err + } + return nil +} + // ============================================ // API Key 限速缓存方法 // ============================================ diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index b1d9aaba..8b50e478 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -404,12 +404,28 @@ func ProvideBillingCacheService( return NewBillingCacheService(cache, userRepo, subRepo, apiKeyRepo, rpmCache, rateRepo, cfg) } +// ProvideAPIKeyService wires APIKeyService and connects rate-limit cache invalidation. +func ProvideAPIKeyService( + apiKeyRepo APIKeyRepository, + userRepo UserRepository, + groupRepo GroupRepository, + userSubRepo UserSubscriptionRepository, + userGroupRateRepo UserGroupRateRepository, + cache APIKeyCache, + cfg *config.Config, + billingCacheService *BillingCacheService, +) *APIKeyService { + svc := NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, userGroupRateRepo, cache, cfg) + svc.SetRateLimitCacheInvalidator(billingCacheService) + return svc +} + // ProviderSet is the Wire provider set for all services var ProviderSet = wire.NewSet( // Core services NewAuthService, NewUserService, - NewAPIKeyService, + ProvideAPIKeyService, ProvideAPIKeyAuthCacheInvalidator, NewGroupService, NewAccountService,