Merge pull request #1106 from geminiwen/feat/subscription-platform-filter

feat: add platform type filter to subscription management
This commit is contained in:
Wesley Liddick
2026-03-18 11:32:35 +08:00
committed by GitHub
13 changed files with 40 additions and 13 deletions

View File

@@ -5,6 +5,7 @@ import (
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -190,7 +191,7 @@ func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID
return userSubscriptionEntitiesToService(subs), paginationResultFromTotal(int64(total), params), nil
}
func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, platform, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
client := clientFromContext(ctx, r.client)
q := client.UserSubscription.Query()
if userID != nil {
@@ -199,6 +200,9 @@ func (r *userSubscriptionRepository) List(ctx context.Context, params pagination
if groupID != nil {
q = q.Where(usersubscription.GroupIDEQ(*groupID))
}
if platform != "" {
q = q.Where(usersubscription.HasGroupWith(group.PlatformEQ(platform)))
}
// Status filtering with real-time expiration check
now := time.Now()

View File

@@ -271,7 +271,7 @@ func (s *UserSubscriptionRepoSuite) TestList_NoFilters() {
group := s.mustCreateGroup("g-list")
s.mustCreateSubscription(user.ID, group.ID, nil)
subs, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, "", "", "")
subs, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, "", "", "", "")
s.Require().NoError(err, "List")
s.Require().Len(subs, 1)
s.Require().Equal(int64(1), page.Total)
@@ -285,7 +285,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByUserID() {
s.mustCreateSubscription(user1.ID, group.ID, nil)
s.mustCreateSubscription(user2.ID, group.ID, nil)
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, &user1.ID, nil, "", "", "")
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, &user1.ID, nil, "", "", "", "")
s.Require().NoError(err)
s.Require().Len(subs, 1)
s.Require().Equal(user1.ID, subs[0].UserID)
@@ -299,7 +299,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByGroupID() {
s.mustCreateSubscription(user.ID, g1.ID, nil)
s.mustCreateSubscription(user.ID, g2.ID, nil)
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, &g1.ID, "", "", "")
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, &g1.ID, "", "", "", "")
s.Require().NoError(err)
s.Require().Len(subs, 1)
s.Require().Equal(g1.ID, subs[0].GroupID)
@@ -320,7 +320,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByStatus() {
c.SetExpiresAt(time.Now().Add(-24 * time.Hour))
})
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, service.SubscriptionStatusExpired, "", "")
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, service.SubscriptionStatusExpired, "", "", "")
s.Require().NoError(err)
s.Require().Len(subs, 1)
s.Require().Equal(service.SubscriptionStatusExpired, subs[0].Status)