Merge pull request #1538 from IanShaw027/fix/bug-cleanup-main

fix: 修复多个 UI 和功能问题 - 表格排序搜索、导出逻辑、分页配置和状态筛选
This commit is contained in:
Wesley Liddick
2026-04-11 16:29:55 +08:00
committed by GitHub
117 changed files with 3961 additions and 870 deletions

View File

@@ -10,6 +10,7 @@ import (
"log/slog" "log/slog"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
@@ -359,7 +360,7 @@ func (h *AccountHandler) listAllProxies(ctx context.Context) ([]service.Proxy, e
pageSize := dataPageCap pageSize := dataPageCap
var out []service.Proxy var out []service.Proxy
for { for {
items, total, err := h.adminService.ListProxies(ctx, page, pageSize, "", "", "") items, total, err := h.adminService.ListProxies(ctx, page, pageSize, "", "", "", "created_at", "desc")
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -372,12 +373,12 @@ func (h *AccountHandler) listAllProxies(ctx context.Context) ([]service.Proxy, e
return out, nil return out, nil
} }
func (h *AccountHandler) listAccountsFiltered(ctx context.Context, platform, accountType, status, search string) ([]service.Account, error) { func (h *AccountHandler) listAccountsFiltered(ctx context.Context, platform, accountType, status, search string, groupID int64, privacyMode, sortBy, sortOrder string) ([]service.Account, error) {
page := 1 page := 1
pageSize := dataPageCap pageSize := dataPageCap
var out []service.Account var out []service.Account
for { for {
items, total, err := h.adminService.ListAccounts(ctx, page, pageSize, platform, accountType, status, search, 0, "") items, total, err := h.adminService.ListAccounts(ctx, page, pageSize, platform, accountType, status, search, groupID, privacyMode, sortBy, sortOrder)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -409,11 +410,28 @@ func (h *AccountHandler) resolveExportAccounts(ctx context.Context, ids []int64,
platform := c.Query("platform") platform := c.Query("platform")
accountType := c.Query("type") accountType := c.Query("type")
status := c.Query("status") status := c.Query("status")
privacyMode := strings.TrimSpace(c.Query("privacy_mode"))
search := strings.TrimSpace(c.Query("search")) search := strings.TrimSpace(c.Query("search"))
sortBy := c.DefaultQuery("sort_by", "name")
sortOrder := c.DefaultQuery("sort_order", "asc")
if len(search) > 100 { if len(search) > 100 {
search = search[:100] search = search[:100]
} }
return h.listAccountsFiltered(ctx, platform, accountType, status, search)
groupID := int64(0)
if groupIDStr := c.Query("group"); groupIDStr != "" {
if groupIDStr == accountListGroupUngroupedQueryValue {
groupID = service.AccountListGroupUngrouped
} else {
parsedGroupID, parseErr := strconv.ParseInt(groupIDStr, 10, 64)
if parseErr != nil || parsedGroupID <= 0 {
return nil, infraerrors.BadRequest("INVALID_GROUP_FILTER", "invalid group filter")
}
groupID = parsedGroupID
}
}
return h.listAccountsFiltered(ctx, platform, accountType, status, search, groupID, privacyMode, sortBy, sortOrder)
} }
func (h *AccountHandler) resolveExportProxies(ctx context.Context, accounts []service.Account) ([]service.Proxy, error) { func (h *AccountHandler) resolveExportProxies(ctx context.Context, accounts []service.Account) ([]service.Proxy, error) {

View File

@@ -172,6 +172,51 @@ func TestExportDataWithoutProxies(t *testing.T) {
require.Nil(t, resp.Data.Accounts[0].ProxyKey) require.Nil(t, resp.Data.Accounts[0].ProxyKey)
} }
func TestExportDataPassesAccountFiltersAndSort(t *testing.T) {
router, adminSvc := setupAccountDataRouter()
adminSvc.accounts = []service.Account{
{ID: 1, Name: "acc-1", Status: service.StatusActive},
}
rec := httptest.NewRecorder()
req := httptest.NewRequest(
http.MethodGet,
"/api/v1/admin/accounts/data?platform=openai&type=oauth&status=active&group=12&privacy_mode=blocked&search=keyword&sort_by=priority&sort_order=desc",
nil,
)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, 1, adminSvc.lastListAccounts.calls)
require.Equal(t, "openai", adminSvc.lastListAccounts.platform)
require.Equal(t, "oauth", adminSvc.lastListAccounts.accountType)
require.Equal(t, "active", adminSvc.lastListAccounts.status)
require.Equal(t, int64(12), adminSvc.lastListAccounts.groupID)
require.Equal(t, "blocked", adminSvc.lastListAccounts.privacyMode)
require.Equal(t, "keyword", adminSvc.lastListAccounts.search)
require.Equal(t, "priority", adminSvc.lastListAccounts.sortBy)
require.Equal(t, "desc", adminSvc.lastListAccounts.sortOrder)
}
func TestExportDataSelectedIDsOverrideFilters(t *testing.T) {
router, adminSvc := setupAccountDataRouter()
rec := httptest.NewRecorder()
req := httptest.NewRequest(
http.MethodGet,
"/api/v1/admin/accounts/data?ids=1,2&platform=openai&search=keyword&sort_by=priority&sort_order=desc",
nil,
)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp dataResponse
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, 0, resp.Code)
require.Len(t, resp.Data.Accounts, 2)
require.Equal(t, 0, adminSvc.lastListAccounts.calls)
}
func TestImportDataReusesProxyAndSkipsDefaultGroup(t *testing.T) { func TestImportDataReusesProxyAndSkipsDefaultGroup(t *testing.T) {
router, adminSvc := setupAccountDataRouter() router, adminSvc := setupAccountDataRouter()

View File

@@ -221,6 +221,8 @@ func (h *AccountHandler) List(c *gin.Context) {
status := c.Query("status") status := c.Query("status")
search := c.Query("search") search := c.Query("search")
privacyMode := strings.TrimSpace(c.Query("privacy_mode")) privacyMode := strings.TrimSpace(c.Query("privacy_mode"))
sortBy := c.DefaultQuery("sort_by", "name")
sortOrder := c.DefaultQuery("sort_order", "asc")
// 标准化和验证 search 参数 // 标准化和验证 search 参数
search = strings.TrimSpace(search) search = strings.TrimSpace(search)
if len(search) > 100 { if len(search) > 100 {
@@ -246,7 +248,7 @@ func (h *AccountHandler) List(c *gin.Context) {
} }
} }
accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search, groupID, privacyMode) accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search, groupID, privacyMode, sortBy, sortOrder)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
@@ -2029,7 +2031,7 @@ func (h *AccountHandler) BatchRefreshTier(c *gin.Context) {
accounts := make([]*service.Account, 0) accounts := make([]*service.Account, 0)
if len(req.AccountIDs) == 0 { if len(req.AccountIDs) == 0 {
allAccounts, _, err := h.adminService.ListAccounts(ctx, 1, 10000, "gemini", "oauth", "", "", 0, "") allAccounts, _, err := h.adminService.ListAccounts(ctx, 1, 10000, "gemini", "oauth", "", "", 0, "", "name", "asc")
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return

View File

@@ -31,6 +31,33 @@ type stubAdminService struct {
platform string platform string
groupIDs []int64 groupIDs []int64
} }
lastListAccounts struct {
platform string
accountType string
status string
search string
groupID int64
privacyMode string
sortBy string
sortOrder string
calls int
}
lastListProxies struct {
protocol string
status string
search string
sortBy string
sortOrder string
calls int
}
lastListRedeemCodes struct {
codeType string
status string
search string
sortBy string
sortOrder string
calls int
}
mu sync.Mutex mu sync.Mutex
} }
@@ -99,7 +126,7 @@ func newStubAdminService() *stubAdminService {
} }
} }
func (s *stubAdminService) ListUsers(ctx context.Context, page, pageSize int, filters service.UserListFilters) ([]service.User, int64, error) { func (s *stubAdminService) ListUsers(ctx context.Context, page, pageSize int, filters service.UserListFilters, sortBy, sortOrder string) ([]service.User, int64, error) {
return s.users, int64(len(s.users)), nil return s.users, int64(len(s.users)), nil
} }
@@ -132,7 +159,7 @@ func (s *stubAdminService) UpdateUserBalance(ctx context.Context, userID int64,
return &user, nil return &user, nil
} }
func (s *stubAdminService) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]service.APIKey, int64, error) { func (s *stubAdminService) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int, sortBy, sortOrder string) ([]service.APIKey, int64, error) {
return s.apiKeys, int64(len(s.apiKeys)), nil return s.apiKeys, int64(len(s.apiKeys)), nil
} }
@@ -140,7 +167,7 @@ func (s *stubAdminService) GetUserUsageStats(ctx context.Context, userID int64,
return map[string]any{"user_id": userID}, nil return map[string]any{"user_id": userID}, nil
} }
func (s *stubAdminService) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool) ([]service.Group, int64, error) { func (s *stubAdminService) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool, sortBy, sortOrder string) ([]service.Group, int64, error) {
return s.groups, int64(len(s.groups)), nil return s.groups, int64(len(s.groups)), nil
} }
@@ -187,7 +214,16 @@ func (s *stubAdminService) BatchSetGroupRateMultipliers(_ context.Context, _ int
return nil return nil
} }
func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string) ([]service.Account, int64, error) { func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string, sortBy, sortOrder string) ([]service.Account, int64, error) {
s.lastListAccounts.platform = platform
s.lastListAccounts.accountType = accountType
s.lastListAccounts.status = status
s.lastListAccounts.search = search
s.lastListAccounts.groupID = groupID
s.lastListAccounts.privacyMode = privacyMode
s.lastListAccounts.sortBy = sortBy
s.lastListAccounts.sortOrder = sortOrder
s.lastListAccounts.calls++
return s.accounts, int64(len(s.accounts)), nil return s.accounts, int64(len(s.accounts)), nil
} }
@@ -261,7 +297,13 @@ func (s *stubAdminService) CheckMixedChannelRisk(ctx context.Context, currentAcc
return s.checkMixedErr return s.checkMixedErr
} }
func (s *stubAdminService) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]service.Proxy, int64, error) { func (s *stubAdminService) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string, sortBy, sortOrder string) ([]service.Proxy, int64, error) {
s.lastListProxies.protocol = protocol
s.lastListProxies.status = status
s.lastListProxies.search = search
s.lastListProxies.sortBy = sortBy
s.lastListProxies.sortOrder = sortOrder
s.lastListProxies.calls++
search = strings.TrimSpace(strings.ToLower(search)) search = strings.TrimSpace(strings.ToLower(search))
filtered := make([]service.Proxy, 0, len(s.proxies)) filtered := make([]service.Proxy, 0, len(s.proxies))
for _, proxy := range s.proxies { for _, proxy := range s.proxies {
@@ -283,7 +325,7 @@ func (s *stubAdminService) ListProxies(ctx context.Context, page, pageSize int,
return filtered, int64(len(filtered)), nil return filtered, int64(len(filtered)), nil
} }
func (s *stubAdminService) ListProxiesWithAccountCount(ctx context.Context, page, pageSize int, protocol, status, search string) ([]service.ProxyWithAccountCount, int64, error) { func (s *stubAdminService) ListProxiesWithAccountCount(ctx context.Context, page, pageSize int, protocol, status, search string, sortBy, sortOrder string) ([]service.ProxyWithAccountCount, int64, error) {
return s.proxyCounts, int64(len(s.proxyCounts)), nil return s.proxyCounts, int64(len(s.proxyCounts)), nil
} }
@@ -384,7 +426,13 @@ func (s *stubAdminService) CheckProxyQuality(ctx context.Context, id int64) (*se
}, nil }, nil
} }
func (s *stubAdminService) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]service.RedeemCode, int64, error) { func (s *stubAdminService) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string, sortBy, sortOrder string) ([]service.RedeemCode, int64, error) {
s.lastListRedeemCodes.codeType = codeType
s.lastListRedeemCodes.status = status
s.lastListRedeemCodes.search = search
s.lastListRedeemCodes.sortBy = sortBy
s.lastListRedeemCodes.sortOrder = sortOrder
s.lastListRedeemCodes.calls++
return s.redeems, int64(len(s.redeems)), nil return s.redeems, int64(len(s.redeems)), nil
} }

View File

@@ -52,13 +52,17 @@ func (h *AnnouncementHandler) List(c *gin.Context) {
page, pageSize := response.ParsePagination(c) page, pageSize := response.ParsePagination(c)
status := strings.TrimSpace(c.Query("status")) status := strings.TrimSpace(c.Query("status"))
search := strings.TrimSpace(c.Query("search")) search := strings.TrimSpace(c.Query("search"))
sortBy := c.DefaultQuery("sort_by", "created_at")
sortOrder := c.DefaultQuery("sort_order", "desc")
if len(search) > 200 { if len(search) > 200 {
search = search[:200] search = search[:200]
} }
params := pagination.PaginationParams{ params := pagination.PaginationParams{
Page: page, Page: page,
PageSize: pageSize, PageSize: pageSize,
SortBy: sortBy,
SortOrder: sortOrder,
} }
items, paginationResult, err := h.announcementService.List( items, paginationResult, err := h.announcementService.List(
@@ -227,8 +231,10 @@ func (h *AnnouncementHandler) ListReadStatus(c *gin.Context) {
page, pageSize := response.ParsePagination(c) page, pageSize := response.ParsePagination(c)
params := pagination.PaginationParams{ params := pagination.PaginationParams{
Page: page, Page: page,
PageSize: pageSize, PageSize: pageSize,
SortBy: c.DefaultQuery("sort_by", "email"),
SortOrder: c.DefaultQuery("sort_order", "asc"),
} }
search := strings.TrimSpace(c.Query("search")) search := strings.TrimSpace(c.Query("search"))
if len(search) > 200 { if len(search) > 200 {

View File

@@ -0,0 +1,138 @@
package admin
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type announcementRepoCapture struct {
service.AnnouncementRepository
listParams pagination.PaginationParams
}
func (r *announcementRepoCapture) List(ctx context.Context, params pagination.PaginationParams, filters service.AnnouncementListFilters) ([]service.Announcement, *pagination.PaginationResult, error) {
r.listParams = params
return []service.Announcement{}, &pagination.PaginationResult{
Total: 0,
Page: params.Page,
PageSize: params.PageSize,
Pages: 0,
}, nil
}
func (r *announcementRepoCapture) GetByID(ctx context.Context, id int64) (*service.Announcement, error) {
return &service.Announcement{
ID: id,
Title: "announcement",
Content: "content",
Status: service.AnnouncementStatusActive,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}, nil
}
type announcementUserRepoCapture struct {
service.UserRepository
listParams pagination.PaginationParams
}
func (r *announcementUserRepoCapture) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) {
r.listParams = params
return []service.User{}, &pagination.PaginationResult{
Total: 0,
Page: params.Page,
PageSize: params.PageSize,
Pages: 0,
}, nil
}
type announcementReadRepoCapture struct {
service.AnnouncementReadRepository
}
func (r *announcementReadRepoCapture) GetReadMapByUsers(ctx context.Context, announcementID int64, userIDs []int64) (map[int64]time.Time, error) {
return map[int64]time.Time{}, nil
}
type announcementUserSubRepoCapture struct {
service.UserSubscriptionRepository
}
func newAnnouncementSortTestRouter(announcementRepo *announcementRepoCapture, userRepo *announcementUserRepoCapture) *gin.Engine {
gin.SetMode(gin.TestMode)
svc := service.NewAnnouncementService(
announcementRepo,
&announcementReadRepoCapture{},
userRepo,
&announcementUserSubRepoCapture{},
)
handler := NewAnnouncementHandler(svc)
router := gin.New()
router.GET("/admin/announcements", handler.List)
router.GET("/admin/announcements/:id/read-status", handler.ListReadStatus)
return router
}
func TestAdminAnnouncementListSortParams(t *testing.T) {
announcementRepo := &announcementRepoCapture{}
userRepo := &announcementUserRepoCapture{}
router := newAnnouncementSortTestRouter(announcementRepo, userRepo)
req := httptest.NewRequest(http.MethodGet, "/admin/announcements?sort_by=title&sort_order=ASC", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "title", announcementRepo.listParams.SortBy)
require.Equal(t, "ASC", announcementRepo.listParams.SortOrder)
}
func TestAdminAnnouncementListSortDefaults(t *testing.T) {
announcementRepo := &announcementRepoCapture{}
userRepo := &announcementUserRepoCapture{}
router := newAnnouncementSortTestRouter(announcementRepo, userRepo)
req := httptest.NewRequest(http.MethodGet, "/admin/announcements", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "created_at", announcementRepo.listParams.SortBy)
require.Equal(t, "desc", announcementRepo.listParams.SortOrder)
}
func TestAdminAnnouncementReadStatusSortParams(t *testing.T) {
announcementRepo := &announcementRepoCapture{}
userRepo := &announcementUserRepoCapture{}
router := newAnnouncementSortTestRouter(announcementRepo, userRepo)
req := httptest.NewRequest(http.MethodGet, "/admin/announcements/1/read-status?sort_by=balance&sort_order=DESC", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "balance", userRepo.listParams.SortBy)
require.Equal(t, "DESC", userRepo.listParams.SortOrder)
}
func TestAdminAnnouncementReadStatusSortDefaults(t *testing.T) {
announcementRepo := &announcementRepoCapture{}
userRepo := &announcementUserRepoCapture{}
router := newAnnouncementSortTestRouter(announcementRepo, userRepo)
req := httptest.NewRequest(http.MethodGet, "/admin/announcements/1/read-status", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "email", userRepo.listParams.SortBy)
require.Equal(t, "asc", userRepo.listParams.SortOrder)
}

View File

@@ -245,7 +245,12 @@ func (h *ChannelHandler) List(c *gin.Context) {
search = search[:100] search = search[:100]
} }
channels, pag, err := h.channelService.List(c.Request.Context(), pagination.PaginationParams{Page: page, PageSize: pageSize}, status, search) channels, pag, err := h.channelService.List(c.Request.Context(), pagination.PaginationParams{
Page: page,
PageSize: pageSize,
SortBy: c.DefaultQuery("sort_by", "created_at"),
SortOrder: c.DefaultQuery("sort_order", "desc"),
}, status, search)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return

View File

@@ -162,6 +162,8 @@ func (h *GroupHandler) List(c *gin.Context) {
search = search[:100] search = search[:100]
} }
isExclusiveStr := c.Query("is_exclusive") isExclusiveStr := c.Query("is_exclusive")
sortBy := c.DefaultQuery("sort_by", "sort_order")
sortOrder := c.DefaultQuery("sort_order", "asc")
var isExclusive *bool var isExclusive *bool
if isExclusiveStr != "" { if isExclusiveStr != "" {
@@ -169,7 +171,7 @@ func (h *GroupHandler) List(c *gin.Context) {
isExclusive = &val isExclusive = &val
} }
groups, total, err := h.adminService.ListGroups(c.Request.Context(), page, pageSize, platform, status, search, isExclusive) groups, total, err := h.adminService.ListGroups(c.Request.Context(), page, pageSize, platform, status, search, isExclusive, sortBy, sortOrder)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return

View File

@@ -55,8 +55,10 @@ func (h *PromoHandler) List(c *gin.Context) {
} }
params := pagination.PaginationParams{ params := pagination.PaginationParams{
Page: page, Page: page,
PageSize: pageSize, PageSize: pageSize,
SortBy: c.DefaultQuery("sort_by", "created_at"),
SortOrder: c.DefaultQuery("sort_order", "desc"),
} }
codes, paginationResult, err := h.promoService.List(c.Request.Context(), params, status, search) codes, paginationResult, err := h.promoService.List(c.Request.Context(), params, status, search)

View File

@@ -33,11 +33,13 @@ func (h *ProxyHandler) ExportData(c *gin.Context) {
protocol := c.Query("protocol") protocol := c.Query("protocol")
status := c.Query("status") status := c.Query("status")
search := strings.TrimSpace(c.Query("search")) search := strings.TrimSpace(c.Query("search"))
sortBy := c.DefaultQuery("sort_by", "id")
sortOrder := c.DefaultQuery("sort_order", "desc")
if len(search) > 100 { if len(search) > 100 {
search = search[:100] search = search[:100]
} }
proxies, err = h.listProxiesFiltered(ctx, protocol, status, search) proxies, err = h.listProxiesFiltered(ctx, protocol, status, search, sortBy, sortOrder)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
@@ -89,7 +91,7 @@ func (h *ProxyHandler) ImportData(c *gin.Context) {
ctx := c.Request.Context() ctx := c.Request.Context()
result := DataImportResult{} result := DataImportResult{}
existingProxies, err := h.listProxiesFiltered(ctx, "", "", "") existingProxies, err := h.listProxiesFiltered(ctx, "", "", "", "id", "desc")
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
@@ -220,18 +222,33 @@ func parseProxyIDs(c *gin.Context) ([]int64, error) {
return ids, nil return ids, nil
} }
func (h *ProxyHandler) listProxiesFiltered(ctx context.Context, protocol, status, search string) ([]service.Proxy, error) { func (h *ProxyHandler) listProxiesFiltered(ctx context.Context, protocol, status, search, sortBy, sortOrder string) ([]service.Proxy, error) {
page := 1 page := 1
pageSize := dataPageCap pageSize := dataPageCap
var out []service.Proxy var out []service.Proxy
sortBy = strings.TrimSpace(sortBy)
useAccountCountSort := strings.EqualFold(sortBy, "account_count")
for { for {
items, total, err := h.adminService.ListProxies(ctx, page, pageSize, protocol, status, search) if useAccountCountSort {
if err != nil { items, total, err := h.adminService.ListProxiesWithAccountCount(ctx, page, pageSize, protocol, status, search, sortBy, sortOrder)
return nil, err if err != nil {
} return nil, err
out = append(out, items...) }
if len(out) >= int(total) || len(items) == 0 { for i := range items {
break out = append(out, items[i].Proxy)
}
if len(out) >= int(total) || len(items) == 0 {
break
}
} else {
items, total, err := h.adminService.ListProxies(ctx, page, pageSize, protocol, status, search, sortBy, sortOrder)
if err != nil {
return nil, err
}
out = append(out, items...)
if len(out) >= int(total) || len(items) == 0 {
break
}
} }
page++ page++
} }

View File

@@ -74,6 +74,10 @@ func TestProxyExportDataRespectsFilters(t *testing.T) {
require.Len(t, resp.Data.Proxies, 1) require.Len(t, resp.Data.Proxies, 1)
require.Len(t, resp.Data.Accounts, 0) require.Len(t, resp.Data.Accounts, 0)
require.Equal(t, "https", resp.Data.Proxies[0].Protocol) require.Equal(t, "https", resp.Data.Proxies[0].Protocol)
require.Equal(t, 1, adminSvc.lastListProxies.calls)
require.Equal(t, "https", adminSvc.lastListProxies.protocol)
require.Equal(t, "id", adminSvc.lastListProxies.sortBy)
require.Equal(t, "desc", adminSvc.lastListProxies.sortOrder)
} }
func TestProxyExportDataWithSelectedIDs(t *testing.T) { func TestProxyExportDataWithSelectedIDs(t *testing.T) {
@@ -113,6 +117,96 @@ func TestProxyExportDataWithSelectedIDs(t *testing.T) {
require.Len(t, resp.Data.Proxies, 1) require.Len(t, resp.Data.Proxies, 1)
require.Equal(t, "https", resp.Data.Proxies[0].Protocol) require.Equal(t, "https", resp.Data.Proxies[0].Protocol)
require.Equal(t, "10.0.0.2", resp.Data.Proxies[0].Host) require.Equal(t, "10.0.0.2", resp.Data.Proxies[0].Host)
require.Equal(t, 0, adminSvc.lastListProxies.calls)
}
func TestProxyExportDataPassesSortParams(t *testing.T) {
router, adminSvc := setupProxyDataRouter()
adminSvc.proxies = []service.Proxy{
{
ID: 1,
Name: "proxy-a",
Protocol: "http",
Host: "127.0.0.1",
Port: 8080,
Username: "user",
Password: "pass",
Status: service.StatusActive,
},
}
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/data?protocol=http&status=active&search=proxy&sort_by=name&sort_order=asc", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, 1, adminSvc.lastListProxies.calls)
require.Equal(t, "http", adminSvc.lastListProxies.protocol)
require.Equal(t, "active", adminSvc.lastListProxies.status)
require.Equal(t, "proxy", adminSvc.lastListProxies.search)
require.Equal(t, "name", adminSvc.lastListProxies.sortBy)
require.Equal(t, "asc", adminSvc.lastListProxies.sortOrder)
}
func TestProxyExportDataSortByAccountCountUsesAccountCountListing(t *testing.T) {
router, adminSvc := setupProxyDataRouter()
adminSvc.proxies = []service.Proxy{
{
ID: 1,
Name: "proxy-id-1",
Protocol: "http",
Host: "127.0.0.1",
Port: 8080,
Status: service.StatusActive,
},
{
ID: 2,
Name: "proxy-id-2",
Protocol: "http",
Host: "127.0.0.2",
Port: 8081,
Status: service.StatusActive,
},
}
adminSvc.proxyCounts = []service.ProxyWithAccountCount{
{
Proxy: service.Proxy{
ID: 2,
Name: "proxy-count-high",
Protocol: "http",
Host: "127.0.0.2",
Port: 8081,
Status: service.StatusActive,
},
AccountCount: 9,
},
{
Proxy: service.Proxy{
ID: 1,
Name: "proxy-count-low",
Protocol: "http",
Host: "127.0.0.1",
Port: 8080,
Status: service.StatusActive,
},
AccountCount: 1,
},
}
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/data?sort_by=account_count&sort_order=desc", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp proxyDataResponse
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, 0, resp.Code)
require.Len(t, resp.Data.Proxies, 2)
require.Equal(t, "proxy-count-high", resp.Data.Proxies[0].Name)
require.Equal(t, "proxy-count-low", resp.Data.Proxies[1].Name)
require.Equal(t, 0, adminSvc.lastListProxies.calls)
} }
func TestProxyImportDataReusesAndTriggersLatencyProbe(t *testing.T) { func TestProxyImportDataReusesAndTriggersLatencyProbe(t *testing.T) {

View File

@@ -52,13 +52,15 @@ func (h *ProxyHandler) List(c *gin.Context) {
protocol := c.Query("protocol") protocol := c.Query("protocol")
status := c.Query("status") status := c.Query("status")
search := c.Query("search") search := c.Query("search")
sortBy := c.DefaultQuery("sort_by", "id")
sortOrder := c.DefaultQuery("sort_order", "desc")
// 标准化和验证 search 参数 // 标准化和验证 search 参数
search = strings.TrimSpace(search) search = strings.TrimSpace(search)
if len(search) > 100 { if len(search) > 100 {
search = search[:100] search = search[:100]
} }
proxies, total, err := h.adminService.ListProxiesWithAccountCount(c.Request.Context(), page, pageSize, protocol, status, search) proxies, total, err := h.adminService.ListProxiesWithAccountCount(c.Request.Context(), page, pageSize, protocol, status, search, sortBy, sortOrder)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return

View File

@@ -0,0 +1,49 @@
package admin
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func setupRedeemExportRouter() (*gin.Engine, *stubAdminService) {
gin.SetMode(gin.TestMode)
router := gin.New()
adminSvc := newStubAdminService()
h := NewRedeemHandler(adminSvc, nil)
router.GET("/api/v1/admin/redeem-codes/export", h.Export)
return router, adminSvc
}
func TestRedeemExportPassesSearchAndSort(t *testing.T) {
router, adminSvc := setupRedeemExportRouter()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/redeem-codes/export?type=balance&status=unused&search=ABC&sort_by=value&sort_order=asc", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, 1, adminSvc.lastListRedeemCodes.calls)
require.Equal(t, "balance", adminSvc.lastListRedeemCodes.codeType)
require.Equal(t, "unused", adminSvc.lastListRedeemCodes.status)
require.Equal(t, "ABC", adminSvc.lastListRedeemCodes.search)
require.Equal(t, "value", adminSvc.lastListRedeemCodes.sortBy)
require.Equal(t, "asc", adminSvc.lastListRedeemCodes.sortOrder)
}
func TestRedeemExportSortDefaults(t *testing.T) {
router, adminSvc := setupRedeemExportRouter()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/redeem-codes/export", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, 1, adminSvc.lastListRedeemCodes.calls)
require.Equal(t, "id", adminSvc.lastListRedeemCodes.sortBy)
require.Equal(t, "desc", adminSvc.lastListRedeemCodes.sortOrder)
}

View File

@@ -59,13 +59,15 @@ func (h *RedeemHandler) List(c *gin.Context) {
codeType := c.Query("type") codeType := c.Query("type")
status := c.Query("status") status := c.Query("status")
search := c.Query("search") search := c.Query("search")
sortBy := c.DefaultQuery("sort_by", "id")
sortOrder := c.DefaultQuery("sort_order", "desc")
// 标准化和验证 search 参数 // 标准化和验证 search 参数
search = strings.TrimSpace(search) search = strings.TrimSpace(search)
if len(search) > 100 { if len(search) > 100 {
search = search[:100] search = search[:100]
} }
codes, total, err := h.adminService.ListRedeemCodes(c.Request.Context(), page, pageSize, codeType, status, search) codes, total, err := h.adminService.ListRedeemCodes(c.Request.Context(), page, pageSize, codeType, status, search, sortBy, sortOrder)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
@@ -300,9 +302,15 @@ func (h *RedeemHandler) GetStats(c *gin.Context) {
func (h *RedeemHandler) Export(c *gin.Context) { func (h *RedeemHandler) Export(c *gin.Context) {
codeType := c.Query("type") codeType := c.Query("type")
status := c.Query("status") status := c.Query("status")
search := strings.TrimSpace(c.Query("search"))
sortBy := c.DefaultQuery("sort_by", "id")
sortOrder := c.DefaultQuery("sort_order", "desc")
if len(search) > 100 {
search = search[:100]
}
// Get all codes without pagination (use large page size) // Get all codes without pagination (use large page size)
codes, _, err := h.adminService.ListRedeemCodes(c.Request.Context(), 1, 10000, codeType, status, "") codes, _, err := h.adminService.ListRedeemCodes(c.Request.Context(), 1, 10000, codeType, status, search, sortBy, sortOrder)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return

View File

@@ -137,6 +137,8 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
HideCcsImportButton: settings.HideCcsImportButton, HideCcsImportButton: settings.HideCcsImportButton,
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled, PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL, PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
TableDefaultPageSize: settings.TableDefaultPageSize,
TablePageSizeOptions: settings.TablePageSizeOptions,
CustomMenuItems: dto.ParseCustomMenuItems(settings.CustomMenuItems), CustomMenuItems: dto.ParseCustomMenuItems(settings.CustomMenuItems),
CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints), CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints),
DefaultConcurrency: settings.DefaultConcurrency, DefaultConcurrency: settings.DefaultConcurrency,
@@ -230,6 +232,8 @@ type UpdateSettingsRequest struct {
HideCcsImportButton bool `json:"hide_ccs_import_button"` HideCcsImportButton bool `json:"hide_ccs_import_button"`
PurchaseSubscriptionEnabled *bool `json:"purchase_subscription_enabled"` PurchaseSubscriptionEnabled *bool `json:"purchase_subscription_enabled"`
PurchaseSubscriptionURL *string `json:"purchase_subscription_url"` PurchaseSubscriptionURL *string `json:"purchase_subscription_url"`
TableDefaultPageSize int `json:"table_default_page_size"`
TablePageSizeOptions []int `json:"table_page_size_options"`
CustomMenuItems *[]dto.CustomMenuItem `json:"custom_menu_items"` CustomMenuItems *[]dto.CustomMenuItem `json:"custom_menu_items"`
CustomEndpoints *[]dto.CustomEndpoint `json:"custom_endpoints"` CustomEndpoints *[]dto.CustomEndpoint `json:"custom_endpoints"`
@@ -292,6 +296,13 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
if req.DefaultBalance < 0 { if req.DefaultBalance < 0 {
req.DefaultBalance = 0 req.DefaultBalance = 0
} }
// 通用表格配置:兼容旧客户端未传字段时保留当前值。
if req.TableDefaultPageSize <= 0 {
req.TableDefaultPageSize = previousSettings.TableDefaultPageSize
}
if req.TablePageSizeOptions == nil {
req.TablePageSizeOptions = previousSettings.TablePageSizeOptions
}
req.SMTPHost = strings.TrimSpace(req.SMTPHost) req.SMTPHost = strings.TrimSpace(req.SMTPHost)
req.SMTPUsername = strings.TrimSpace(req.SMTPUsername) req.SMTPUsername = strings.TrimSpace(req.SMTPUsername)
req.SMTPPassword = strings.TrimSpace(req.SMTPPassword) req.SMTPPassword = strings.TrimSpace(req.SMTPPassword)
@@ -757,6 +768,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
HideCcsImportButton: req.HideCcsImportButton, HideCcsImportButton: req.HideCcsImportButton,
PurchaseSubscriptionEnabled: purchaseEnabled, PurchaseSubscriptionEnabled: purchaseEnabled,
PurchaseSubscriptionURL: purchaseURL, PurchaseSubscriptionURL: purchaseURL,
TableDefaultPageSize: req.TableDefaultPageSize,
TablePageSizeOptions: req.TablePageSizeOptions,
CustomMenuItems: customMenuJSON, CustomMenuItems: customMenuJSON,
CustomEndpoints: customEndpointsJSON, CustomEndpoints: customEndpointsJSON,
DefaultConcurrency: req.DefaultConcurrency, DefaultConcurrency: req.DefaultConcurrency,
@@ -894,6 +907,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
HideCcsImportButton: updatedSettings.HideCcsImportButton, HideCcsImportButton: updatedSettings.HideCcsImportButton,
PurchaseSubscriptionEnabled: updatedSettings.PurchaseSubscriptionEnabled, PurchaseSubscriptionEnabled: updatedSettings.PurchaseSubscriptionEnabled,
PurchaseSubscriptionURL: updatedSettings.PurchaseSubscriptionURL, PurchaseSubscriptionURL: updatedSettings.PurchaseSubscriptionURL,
TableDefaultPageSize: updatedSettings.TableDefaultPageSize,
TablePageSizeOptions: updatedSettings.TablePageSizeOptions,
CustomMenuItems: dto.ParseCustomMenuItems(updatedSettings.CustomMenuItems), CustomMenuItems: dto.ParseCustomMenuItems(updatedSettings.CustomMenuItems),
CustomEndpoints: dto.ParseCustomEndpoints(updatedSettings.CustomEndpoints), CustomEndpoints: dto.ParseCustomEndpoints(updatedSettings.CustomEndpoints),
DefaultConcurrency: updatedSettings.DefaultConcurrency, DefaultConcurrency: updatedSettings.DefaultConcurrency,
@@ -1152,6 +1167,12 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.PurchaseSubscriptionURL != after.PurchaseSubscriptionURL { if before.PurchaseSubscriptionURL != after.PurchaseSubscriptionURL {
changed = append(changed, "purchase_subscription_url") changed = append(changed, "purchase_subscription_url")
} }
if before.TableDefaultPageSize != after.TableDefaultPageSize {
changed = append(changed, "table_default_page_size")
}
if !equalIntSlice(before.TablePageSizeOptions, after.TablePageSizeOptions) {
changed = append(changed, "table_page_size_options")
}
if before.CustomMenuItems != after.CustomMenuItems { if before.CustomMenuItems != after.CustomMenuItems {
changed = append(changed, "custom_menu_items") changed = append(changed, "custom_menu_items")
} }
@@ -1208,6 +1229,18 @@ func equalDefaultSubscriptions(a, b []service.DefaultSubscriptionSetting) bool {
return true return true
} }
func equalIntSlice(a, b []int) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}
// TestSMTPRequest 测试SMTP连接请求 // TestSMTPRequest 测试SMTP连接请求
type TestSMTPRequest struct { type TestSMTPRequest struct {
SMTPHost string `json:"smtp_host"` SMTPHost string `json:"smtp_host"`

View File

@@ -165,7 +165,12 @@ func (h *UsageHandler) List(c *gin.Context) {
endTime = &t endTime = &t
} }
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{
Page: page,
PageSize: pageSize,
SortBy: c.DefaultQuery("sort_by", "created_at"),
SortOrder: c.DefaultQuery("sort_order", "desc"),
}
filters := usagestats.UsageLogFilters{ filters := usagestats.UsageLogFilters{
UserID: userID, UserID: userID,
APIKeyID: apiKeyID, APIKeyID: apiKeyID,
@@ -339,7 +344,7 @@ func (h *UsageHandler) SearchUsers(c *gin.Context) {
} }
// Limit to 30 results // Limit to 30 results
users, _, err := h.adminService.ListUsers(c.Request.Context(), 1, 30, service.UserListFilters{Search: keyword}) users, _, err := h.adminService.ListUsers(c.Request.Context(), 1, 30, service.UserListFilters{Search: keyword}, "email", "asc")
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return

View File

@@ -15,11 +15,13 @@ import (
type adminUsageRepoCapture struct { type adminUsageRepoCapture struct {
service.UsageLogRepository service.UsageLogRepository
listParams pagination.PaginationParams
listFilters usagestats.UsageLogFilters listFilters usagestats.UsageLogFilters
statsFilters usagestats.UsageLogFilters statsFilters usagestats.UsageLogFilters
} }
func (s *adminUsageRepoCapture) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) { func (s *adminUsageRepoCapture) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) {
s.listParams = params
s.listFilters = filters s.listFilters = filters
return []service.UsageLog{}, &pagination.PaginationResult{ return []service.UsageLog{}, &pagination.PaginationResult{
Total: 0, Total: 0,

View File

@@ -0,0 +1,35 @@
package admin
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/require"
)
func TestAdminUsageListSortParams(t *testing.T) {
repo := &adminUsageRepoCapture{}
router := newAdminUsageRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/usage?sort_by=model&sort_order=ASC", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "model", repo.listParams.SortBy)
require.Equal(t, "ASC", repo.listParams.SortOrder)
}
func TestAdminUsageListSortDefaults(t *testing.T) {
repo := &adminUsageRepoCapture{}
router := newAdminUsageRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/usage", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "created_at", repo.listParams.SortBy)
require.Equal(t, "desc", repo.listParams.SortOrder)
}

View File

@@ -91,12 +91,14 @@ func (h *UserHandler) List(c *gin.Context) {
GroupName: strings.TrimSpace(c.Query("group_name")), GroupName: strings.TrimSpace(c.Query("group_name")),
Attributes: parseAttributeFilters(c), Attributes: parseAttributeFilters(c),
} }
sortBy := c.DefaultQuery("sort_by", "created_at")
sortOrder := c.DefaultQuery("sort_order", "desc")
if raw, ok := c.GetQuery("include_subscriptions"); ok { if raw, ok := c.GetQuery("include_subscriptions"); ok {
includeSubscriptions := parseBoolQueryWithDefault(raw, true) includeSubscriptions := parseBoolQueryWithDefault(raw, true)
filters.IncludeSubscriptions = &includeSubscriptions filters.IncludeSubscriptions = &includeSubscriptions
} }
users, total, err := h.adminService.ListUsers(c.Request.Context(), page, pageSize, filters) users, total, err := h.adminService.ListUsers(c.Request.Context(), page, pageSize, filters, sortBy, sortOrder)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
@@ -290,8 +292,10 @@ func (h *UserHandler) GetUserAPIKeys(c *gin.Context) {
} }
page, pageSize := response.ParsePagination(c) page, pageSize := response.ParsePagination(c)
sortBy := c.DefaultQuery("sort_by", "created_at")
sortOrder := c.DefaultQuery("sort_order", "desc")
keys, total, err := h.adminService.GetUserAPIKeys(c.Request.Context(), userID, page, pageSize) keys, total, err := h.adminService.GetUserAPIKeys(c.Request.Context(), userID, page, pageSize, sortBy, sortOrder)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return

View File

@@ -72,7 +72,12 @@ func (h *APIKeyHandler) List(c *gin.Context) {
} }
page, pageSize := response.ParsePagination(c) page, pageSize := response.ParsePagination(c)
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{
Page: page,
PageSize: pageSize,
SortBy: c.DefaultQuery("sort_by", "created_at"),
SortOrder: c.DefaultQuery("sort_order", "desc"),
}
// Parse filter parameters // Parse filter parameters
var filters service.APIKeyListFilters var filters service.APIKeyListFilters

View File

@@ -84,6 +84,8 @@ type SystemSettings struct {
HideCcsImportButton bool `json:"hide_ccs_import_button"` HideCcsImportButton bool `json:"hide_ccs_import_button"`
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"` PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
PurchaseSubscriptionURL string `json:"purchase_subscription_url"` PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
TableDefaultPageSize int `json:"table_default_page_size"`
TablePageSizeOptions []int `json:"table_page_size_options"`
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"` CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
CustomEndpoints []CustomEndpoint `json:"custom_endpoints"` CustomEndpoints []CustomEndpoint `json:"custom_endpoints"`
@@ -148,6 +150,8 @@ type PublicSettings struct {
HideCcsImportButton bool `json:"hide_ccs_import_button"` HideCcsImportButton bool `json:"hide_ccs_import_button"`
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"` PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
PurchaseSubscriptionURL string `json:"purchase_subscription_url"` PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
TableDefaultPageSize int `json:"table_default_page_size"`
TablePageSizeOptions []int `json:"table_page_size_options"`
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"` CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
CustomEndpoints []CustomEndpoint `json:"custom_endpoints"` CustomEndpoints []CustomEndpoint `json:"custom_endpoints"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`

View File

@@ -51,6 +51,8 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
HideCcsImportButton: settings.HideCcsImportButton, HideCcsImportButton: settings.HideCcsImportButton,
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled, PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL, PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
TableDefaultPageSize: settings.TableDefaultPageSize,
TablePageSizeOptions: settings.TablePageSizeOptions,
CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems), CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems),
CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints), CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints),
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,

View File

@@ -119,7 +119,12 @@ func (h *UsageHandler) List(c *gin.Context) {
endTime = &t endTime = &t
} }
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{
Page: page,
PageSize: pageSize,
SortBy: c.DefaultQuery("sort_by", "created_at"),
SortOrder: c.DefaultQuery("sort_order", "desc"),
}
filters := usagestats.UsageLogFilters{ filters := usagestats.UsageLogFilters{
UserID: subject.UserID, // Always filter by current user for security UserID: subject.UserID, // Always filter by current user for security
APIKeyID: apiKeyID, APIKeyID: apiKeyID,

View File

@@ -16,10 +16,12 @@ import (
type userUsageRepoCapture struct { type userUsageRepoCapture struct {
service.UsageLogRepository service.UsageLogRepository
listParams pagination.PaginationParams
listFilters usagestats.UsageLogFilters listFilters usagestats.UsageLogFilters
} }
func (s *userUsageRepoCapture) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) { func (s *userUsageRepoCapture) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) {
s.listParams = params
s.listFilters = filters s.listFilters = filters
return []service.UsageLog{}, &pagination.PaginationResult{ return []service.UsageLog{}, &pagination.PaginationResult{
Total: 0, Total: 0,

View File

@@ -0,0 +1,35 @@
package handler
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/require"
)
func TestUserUsageListSortParams(t *testing.T) {
repo := &userUsageRepoCapture{}
router := newUserUsageRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/usage?sort_by=model&sort_order=ASC", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "model", repo.listParams.SortBy)
require.Equal(t, "ASC", repo.listParams.SortOrder)
}
func TestUserUsageListSortDefaults(t *testing.T) {
repo := &userUsageRepoCapture{}
router := newUserUsageRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/usage", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "created_at", repo.listParams.SortBy)
require.Equal(t, "desc", repo.listParams.SortOrder)
}

View File

@@ -1,10 +1,19 @@
// Package pagination provides types and helpers for paginated responses. // Package pagination provides types and helpers for paginated responses.
package pagination package pagination
import "strings"
const (
SortOrderAsc = "asc"
SortOrderDesc = "desc"
)
// PaginationParams 分页参数 // PaginationParams 分页参数
type PaginationParams struct { type PaginationParams struct {
Page int Page int
PageSize int PageSize int
SortBy string
SortOrder string
} }
// PaginationResult 分页结果 // PaginationResult 分页结果
@@ -18,8 +27,9 @@ type PaginationResult struct {
// DefaultPagination 默认分页参数 // DefaultPagination 默认分页参数
func DefaultPagination() PaginationParams { func DefaultPagination() PaginationParams {
return PaginationParams{ return PaginationParams{
Page: 1, Page: 1,
PageSize: 20, PageSize: 20,
SortOrder: SortOrderDesc,
} }
} }
@@ -36,8 +46,32 @@ func (p PaginationParams) Limit() int {
if p.PageSize < 1 { if p.PageSize < 1 {
return 20 return 20
} }
if p.PageSize > 100 { if p.PageSize > 1000 {
return 100 return 1000
} }
return p.PageSize return p.PageSize
} }
// NormalizeSortOrder normalizes sort order to asc/desc and falls back to defaultOrder.
func NormalizeSortOrder(order string, defaultOrder string) string {
switch strings.ToLower(strings.TrimSpace(defaultOrder)) {
case SortOrderAsc:
defaultOrder = SortOrderAsc
default:
defaultOrder = SortOrderDesc
}
switch strings.ToLower(strings.TrimSpace(order)) {
case SortOrderAsc:
return SortOrderAsc
case SortOrderDesc:
return SortOrderDesc
default:
return defaultOrder
}
}
// NormalizedSortOrder returns the normalized sort order using defaultOrder as fallback.
func (p PaginationParams) NormalizedSortOrder(defaultOrder string) string {
return NormalizeSortOrder(p.SortOrder, defaultOrder)
}

View File

@@ -0,0 +1,71 @@
package pagination
import "testing"
func TestNormalizeSortOrder(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
defaultOrder string
want string
}{
{name: "asc", input: "asc", defaultOrder: "desc", want: "asc"},
{name: "uppercase asc", input: "ASC", defaultOrder: "desc", want: "asc"},
{name: "desc", input: "desc", defaultOrder: "asc", want: "desc"},
{name: "trim spaces", input: " desc ", defaultOrder: "asc", want: "desc"},
{name: "invalid falls back", input: "sideways", defaultOrder: "asc", want: "asc"},
{name: "empty falls back", input: "", defaultOrder: "desc", want: "desc"},
{name: "invalid default falls back to desc", input: "", defaultOrder: "wat", want: "desc"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
if got := NormalizeSortOrder(tt.input, tt.defaultOrder); got != tt.want {
t.Fatalf("NormalizeSortOrder(%q, %q) = %q, want %q", tt.input, tt.defaultOrder, got, tt.want)
}
})
}
}
func TestPaginationParamsNormalizedSortOrder(t *testing.T) {
t.Parallel()
params := PaginationParams{SortOrder: "ASC"}
if got := params.NormalizedSortOrder("desc"); got != "asc" {
t.Fatalf("NormalizedSortOrder = %q, want asc", got)
}
params = PaginationParams{SortOrder: "bad"}
if got := params.NormalizedSortOrder("asc"); got != "asc" {
t.Fatalf("NormalizedSortOrder invalid fallback = %q, want asc", got)
}
}
func TestPaginationParamsLimit(t *testing.T) {
t.Parallel()
tests := []struct {
name string
pageSize int
want int
}{
{name: "non-positive falls back to default", pageSize: 0, want: 20},
{name: "negative falls back to default", pageSize: -1, want: 20},
{name: "normal value keeps", pageSize: 50, want: 50},
{name: "max value keeps", pageSize: 1000, want: 1000},
{name: "beyond max clamps to 1000", pageSize: 1500, want: 1000},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
p := PaginationParams{PageSize: tt.pageSize}
if got := p.Limit(); got != tt.want {
t.Fatalf("Limit() for PageSize=%d = %d, want %d", tt.pageSize, got, tt.want)
}
})
}
}

View File

@@ -471,21 +471,58 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati
case service.StatusActive: case service.StatusActive:
q = q.Where( q = q.Where(
dbaccount.StatusEQ(status), dbaccount.StatusEQ(status),
dbaccount.SchedulableEQ(true),
dbaccount.Or( dbaccount.Or(
dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtIsNil(),
dbaccount.RateLimitResetAtLTE(time.Now()), dbaccount.RateLimitResetAtLTE(time.Now()),
), ),
dbpredicate.Account(func(s *entsql.Selector) {
col := s.C("temp_unschedulable_until")
s.Where(entsql.Or(
entsql.IsNull(col),
entsql.LTE(col, entsql.Expr("NOW()")),
))
}),
) )
case "rate_limited": case "rate_limited":
q = q.Where(dbaccount.RateLimitResetAtGT(time.Now())) q = q.Where(
dbaccount.StatusEQ(service.StatusActive),
dbaccount.RateLimitResetAtGT(time.Now()),
dbpredicate.Account(func(s *entsql.Selector) {
col := s.C("temp_unschedulable_until")
s.Where(entsql.Or(
entsql.IsNull(col),
entsql.LTE(col, entsql.Expr("NOW()")),
))
}),
)
case "temp_unschedulable": case "temp_unschedulable":
q = q.Where(dbpredicate.Account(func(s *entsql.Selector) { q = q.Where(
col := s.C("temp_unschedulable_until") dbaccount.StatusEQ(service.StatusActive),
s.Where(entsql.And( dbpredicate.Account(func(s *entsql.Selector) {
entsql.Not(entsql.IsNull(col)), col := s.C("temp_unschedulable_until")
entsql.GT(col, entsql.Expr("NOW()")), s.Where(entsql.And(
)) entsql.Not(entsql.IsNull(col)),
})) entsql.GT(col, entsql.Expr("NOW()")),
))
}),
)
case "unschedulable":
q = q.Where(
dbaccount.StatusEQ(service.StatusActive),
dbaccount.SchedulableEQ(false),
dbaccount.Or(
dbaccount.RateLimitResetAtIsNil(),
dbaccount.RateLimitResetAtLTE(time.Now()),
),
dbpredicate.Account(func(s *entsql.Selector) {
col := s.C("temp_unschedulable_until")
s.Where(entsql.Or(
entsql.IsNull(col),
entsql.LTE(col, entsql.Expr("NOW()")),
))
}),
)
default: default:
q = q.Where(dbaccount.StatusEQ(status)) q = q.Where(dbaccount.StatusEQ(status))
} }
@@ -518,11 +555,14 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati
return nil, nil, err return nil, nil, err
} }
accounts, err := q. accountsQuery := q.
Offset(params.Offset()). Offset(params.Offset()).
Limit(params.Limit()). Limit(params.Limit())
Order(dbent.Desc(dbaccount.FieldID)). for _, order := range accountListOrder(params) {
All(ctx) accountsQuery = accountsQuery.Order(order)
}
accounts, err := accountsQuery.All(ctx)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@@ -534,6 +574,50 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati
return outAccounts, paginationResultFromTotal(int64(total), params), nil return outAccounts, paginationResultFromTotal(int64(total), params), nil
} }
func accountListOrder(params pagination.PaginationParams) []func(*entsql.Selector) {
sortBy := strings.ToLower(strings.TrimSpace(params.SortBy))
sortOrder := params.NormalizedSortOrder(pagination.SortOrderAsc)
field := dbaccount.FieldName
defaultOrder := true
switch sortBy {
case "", "name":
field = dbaccount.FieldName
case "id":
field = dbaccount.FieldID
defaultOrder = false
case "status":
field = dbaccount.FieldStatus
defaultOrder = false
case "schedulable":
field = dbaccount.FieldSchedulable
defaultOrder = false
case "priority":
field = dbaccount.FieldPriority
defaultOrder = false
case "rate_multiplier":
field = dbaccount.FieldRateMultiplier
defaultOrder = false
case "last_used_at":
field = dbaccount.FieldLastUsedAt
defaultOrder = false
case "expires_at":
field = dbaccount.FieldExpiresAt
defaultOrder = false
case "created_at":
field = dbaccount.FieldCreatedAt
defaultOrder = false
}
if sortOrder == pagination.SortOrderDesc {
return []func(*entsql.Selector){dbent.Desc(field), dbent.Desc(dbaccount.FieldID)}
}
if defaultOrder {
return []func(*entsql.Selector){dbent.Asc(dbaccount.FieldName), dbent.Asc(dbaccount.FieldID)}
}
return []func(*entsql.Selector){dbent.Asc(field), dbent.Asc(dbaccount.FieldID)}
}
func (r *accountRepository) ListByGroup(ctx context.Context, groupID int64) ([]service.Account, error) { func (r *accountRepository) ListByGroup(ctx context.Context, groupID int64) ([]service.Account, error) {
accounts, err := r.queryAccountsByGroup(ctx, groupID, accountGroupQueryOptions{ accounts, err := r.queryAccountsByGroup(ctx, groupID, accountGroupQueryOptions{
status: service.StatusActive, status: service.StatusActive,

View File

@@ -256,7 +256,7 @@ func (s *AccountRepoSuite) TestListWithFilters() {
}, },
}, },
{ {
name: "filter_by_status_active_excludes_rate_limited", name: "filter_by_status_active_excludes_runtime_blocked_accounts",
setup: func(client *dbent.Client) { setup: func(client *dbent.Client) {
mustCreateAccount(s.T(), client, &service.Account{Name: "active-normal", Status: service.StatusActive}) mustCreateAccount(s.T(), client, &service.Account{Name: "active-normal", Status: service.StatusActive})
rateLimited := mustCreateAccount(s.T(), client, &service.Account{Name: "active-rate-limited", Status: service.StatusActive}) rateLimited := mustCreateAccount(s.T(), client, &service.Account{Name: "active-rate-limited", Status: service.StatusActive})
@@ -264,6 +264,16 @@ func (s *AccountRepoSuite) TestListWithFilters() {
SetRateLimitResetAt(time.Now().Add(10 * time.Minute)). SetRateLimitResetAt(time.Now().Add(10 * time.Minute)).
Exec(context.Background()) Exec(context.Background())
s.Require().NoError(err) s.Require().NoError(err)
tempUnsched := mustCreateAccount(s.T(), client, &service.Account{Name: "active-temp-unsched", Status: service.StatusActive})
err = client.Account.UpdateOneID(tempUnsched.ID).
SetTempUnschedulableUntil(time.Now().Add(15 * time.Minute)).
Exec(context.Background())
s.Require().NoError(err)
unsched := mustCreateAccount(s.T(), client, &service.Account{Name: "active-unsched", Status: service.StatusActive})
err = client.Account.UpdateOneID(unsched.ID).
SetSchedulable(false).
Exec(context.Background())
s.Require().NoError(err)
}, },
status: service.StatusActive, status: service.StatusActive,
wantCount: 1, wantCount: 1,
@@ -271,6 +281,75 @@ func (s *AccountRepoSuite) TestListWithFilters() {
s.Require().Equal("active-normal", accounts[0].Name) s.Require().Equal("active-normal", accounts[0].Name)
}, },
}, },
{
name: "filter_by_status_unschedulable_excludes_rate_limited_and_temp_unschedulable",
setup: func(client *dbent.Client) {
mustCreateAccount(s.T(), client, &service.Account{Name: "active-normal", Status: service.StatusActive, Schedulable: true})
unsched := mustCreateAccount(s.T(), client, &service.Account{Name: "active-unsched", Status: service.StatusActive})
err := client.Account.UpdateOneID(unsched.ID).
SetSchedulable(false).
Exec(context.Background())
s.Require().NoError(err)
rateLimited := mustCreateAccount(s.T(), client, &service.Account{Name: "active-rate-limited", Status: service.StatusActive})
err = client.Account.UpdateOneID(rateLimited.ID).
SetSchedulable(false).
SetRateLimitResetAt(time.Now().Add(10 * time.Minute)).
Exec(context.Background())
s.Require().NoError(err)
tempUnsched := mustCreateAccount(s.T(), client, &service.Account{Name: "active-temp-unsched", Status: service.StatusActive})
err = client.Account.UpdateOneID(tempUnsched.ID).
SetSchedulable(false).
SetTempUnschedulableUntil(time.Now().Add(15 * time.Minute)).
Exec(context.Background())
s.Require().NoError(err)
},
status: "unschedulable",
wantCount: 1,
validate: func(accounts []service.Account) {
s.Require().Equal("active-unsched", accounts[0].Name)
},
},
{
name: "filter_by_status_rate_limited_excludes_temp_unschedulable",
setup: func(client *dbent.Client) {
rateLimited := mustCreateAccount(s.T(), client, &service.Account{Name: "active-rate-limited", Status: service.StatusActive})
err := client.Account.UpdateOneID(rateLimited.ID).
SetRateLimitResetAt(time.Now().Add(10 * time.Minute)).
Exec(context.Background())
s.Require().NoError(err)
tempUnsched := mustCreateAccount(s.T(), client, &service.Account{Name: "active-temp-unsched", Status: service.StatusActive})
err = client.Account.UpdateOneID(tempUnsched.ID).
SetRateLimitResetAt(time.Now().Add(20 * time.Minute)).
SetTempUnschedulableUntil(time.Now().Add(15 * time.Minute)).
Exec(context.Background())
s.Require().NoError(err)
},
status: "rate_limited",
wantCount: 1,
validate: func(accounts []service.Account) {
s.Require().Equal("active-rate-limited", accounts[0].Name)
},
},
{
name: "filter_by_status_temp_unschedulable_excludes_manually_unschedulable",
setup: func(client *dbent.Client) {
tempUnsched := mustCreateAccount(s.T(), client, &service.Account{Name: "active-temp-unsched", Status: service.StatusActive, Schedulable: true})
err := client.Account.UpdateOneID(tempUnsched.ID).
SetTempUnschedulableUntil(time.Now().Add(15 * time.Minute)).
Exec(context.Background())
s.Require().NoError(err)
unsched := mustCreateAccount(s.T(), client, &service.Account{Name: "active-unsched", Status: service.StatusActive})
err = client.Account.UpdateOneID(unsched.ID).
SetSchedulable(false).
Exec(context.Background())
s.Require().NoError(err)
},
status: "temp_unschedulable",
wantCount: 1,
validate: func(accounts []service.Account) {
s.Require().Equal("active-temp-unsched", accounts[0].Name)
},
},
{ {
name: "filter_by_search", name: "filter_by_search",
setup: func(client *dbent.Client) { setup: func(client *dbent.Client) {

View File

@@ -0,0 +1,35 @@
//go:build integration
package repository
import (
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
)
func (s *AccountRepoSuite) TestList_DefaultSortByNameAsc() {
mustCreateAccount(s.T(), s.client, &service.Account{Name: "z-account"})
mustCreateAccount(s.T(), s.client, &service.Account{Name: "a-account"})
accounts, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err)
s.Require().Len(accounts, 2)
s.Require().Equal("a-account", accounts[0].Name)
s.Require().Equal("z-account", accounts[1].Name)
}
func (s *AccountRepoSuite) TestListWithFilters_SortByPriorityDesc() {
mustCreateAccount(s.T(), s.client, &service.Account{Name: "low-priority", Priority: 10})
mustCreateAccount(s.T(), s.client, &service.Account{Name: "high-priority", Priority: 90})
accounts, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{
Page: 1,
PageSize: 10,
SortBy: "priority",
SortOrder: "desc",
}, "", "", "", "", 0, "")
s.Require().NoError(err)
s.Require().Len(accounts, 2)
s.Require().Equal("high-priority", accounts[0].Name)
s.Require().Equal("low-priority", accounts[1].Name)
}

View File

@@ -2,12 +2,15 @@ package repository
import ( import (
"context" "context"
"strings"
"time" "time"
dbent "github.com/Wei-Shaw/sub2api/ent" dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/announcement" "github.com/Wei-Shaw/sub2api/ent/announcement"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
entsql "entgo.io/ent/dialect/sql"
) )
type announcementRepository struct { type announcementRepository struct {
@@ -128,11 +131,14 @@ func (r *announcementRepository) List(
return nil, nil, err return nil, nil, err
} }
items, err := q. itemsQuery := q.
Offset(params.Offset()). Offset(params.Offset()).
Limit(params.Limit()). Limit(params.Limit())
Order(dbent.Desc(announcement.FieldID)). for _, order := range announcementListOrders(params) {
All(ctx) itemsQuery = itemsQuery.Order(order)
}
items, err := itemsQuery.All(ctx)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@@ -141,6 +147,56 @@ func (r *announcementRepository) List(
return out, paginationResultFromTotal(int64(total), params), nil return out, paginationResultFromTotal(int64(total), params), nil
} }
func announcementListOrder(params pagination.PaginationParams) (string, string) {
sortBy := strings.ToLower(strings.TrimSpace(params.SortBy))
sortOrder := params.NormalizedSortOrder(pagination.SortOrderDesc)
switch sortBy {
case "title":
return announcement.FieldTitle, sortOrder
case "status":
return announcement.FieldStatus, sortOrder
case "notify_mode":
return announcement.FieldNotifyMode, sortOrder
case "starts_at":
return announcement.FieldStartsAt, sortOrder
case "ends_at":
return announcement.FieldEndsAt, sortOrder
case "id":
return announcement.FieldID, sortOrder
case "", "created_at":
return announcement.FieldCreatedAt, sortOrder
default:
return announcement.FieldCreatedAt, pagination.SortOrderDesc
}
}
func announcementListOrders(params pagination.PaginationParams) []func(*entsql.Selector) {
field, sortOrder := announcementListOrder(params)
if sortOrder == pagination.SortOrderAsc {
if field == announcement.FieldID {
return []func(*entsql.Selector){
dbent.Asc(field),
}
}
return []func(*entsql.Selector){
dbent.Asc(field),
dbent.Asc(announcement.FieldID),
}
}
if field == announcement.FieldID {
return []func(*entsql.Selector){
dbent.Desc(field),
}
}
return []func(*entsql.Selector){
dbent.Desc(field),
dbent.Desc(announcement.FieldID),
}
}
func (r *announcementRepository) ListActive(ctx context.Context, now time.Time) ([]service.Announcement, error) { func (r *announcementRepository) ListActive(ctx context.Context, now time.Time) ([]service.Announcement, error) {
q := r.client.Announcement.Query(). q := r.client.Announcement.Query().
Where( Where(

View File

@@ -0,0 +1,63 @@
package repository
import (
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
func TestAnnouncementListOrder(t *testing.T) {
t.Parallel()
tests := []struct {
name string
params pagination.PaginationParams
wantBy string
want string
}{
{
name: "default created_at desc",
params: pagination.PaginationParams{},
wantBy: "created_at",
want: "desc",
},
{
name: "title asc",
params: pagination.PaginationParams{
SortBy: "title",
SortOrder: "ASC",
},
wantBy: "title",
want: "asc",
},
{
name: "status desc",
params: pagination.PaginationParams{
SortBy: "status",
SortOrder: "desc",
},
wantBy: "status",
want: "desc",
},
{
name: "invalid falls back",
params: pagination.PaginationParams{
SortBy: "sideways",
SortOrder: "wat",
},
wantBy: "created_at",
want: "desc",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
gotBy, gotOrder := announcementListOrder(tt.params)
if gotBy != tt.wantBy || gotOrder != tt.want {
t.Fatalf("announcementListOrder(%+v) = (%q, %q), want (%q, %q)", tt.params, gotBy, gotOrder, tt.wantBy, tt.want)
}
})
}
}

View File

@@ -4,6 +4,7 @@ import (
"context" "context"
"database/sql" "database/sql"
"fmt" "fmt"
"strings"
"time" "time"
dbent "github.com/Wei-Shaw/sub2api/ent" dbent "github.com/Wei-Shaw/sub2api/ent"
@@ -14,6 +15,8 @@ import (
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
entsql "entgo.io/ent/dialect/sql"
) )
type apiKeyRepository struct { type apiKeyRepository struct {
@@ -164,6 +167,7 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
group.FieldSupportedModelScopes, group.FieldSupportedModelScopes,
group.FieldAllowMessagesDispatch, group.FieldAllowMessagesDispatch,
group.FieldDefaultMappedModel, group.FieldDefaultMappedModel,
group.FieldMessagesDispatchModelConfig,
) )
}). }).
Only(ctx) Only(ctx)
@@ -309,12 +313,15 @@ func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, param
return nil, nil, err return nil, nil, err
} }
keys, err := q. keysQuery := q.
WithGroup(). WithGroup().
Offset(params.Offset()). Offset(params.Offset()).
Limit(params.Limit()). Limit(params.Limit())
Order(dbent.Desc(apikey.FieldID)). for _, order := range apiKeyListOrder(params) {
All(ctx) keysQuery = keysQuery.Order(order)
}
keys, err := keysQuery.All(ctx)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@@ -359,12 +366,15 @@ func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, par
return nil, nil, err return nil, nil, err
} }
keys, err := q. keysQuery := q.
WithUser(). WithUser().
Offset(params.Offset()). Offset(params.Offset()).
Limit(params.Limit()). Limit(params.Limit())
Order(dbent.Desc(apikey.FieldID)). for _, order := range apiKeyListOrder(params) {
All(ctx) keysQuery = keysQuery.Order(order)
}
keys, err := keysQuery.All(ctx)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@@ -377,6 +387,32 @@ func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, par
return outKeys, paginationResultFromTotal(int64(total), params), nil return outKeys, paginationResultFromTotal(int64(total), params), nil
} }
func apiKeyListOrder(params pagination.PaginationParams) []func(*entsql.Selector) {
sortBy := strings.ToLower(strings.TrimSpace(params.SortBy))
sortOrder := params.NormalizedSortOrder(pagination.SortOrderDesc)
var field string
switch sortBy {
case "name":
field = apikey.FieldName
case "status":
field = apikey.FieldStatus
case "expires_at":
field = apikey.FieldExpiresAt
case "last_used_at":
field = apikey.FieldLastUsedAt
case "created_at":
field = apikey.FieldCreatedAt
default:
field = apikey.FieldID
}
if sortOrder == pagination.SortOrderAsc {
return []func(*entsql.Selector){dbent.Asc(field), dbent.Asc(apikey.FieldID)}
}
return []func(*entsql.Selector){dbent.Desc(field), dbent.Desc(apikey.FieldID)}
}
// SearchAPIKeys searches API keys by user ID and/or keyword (name) // SearchAPIKeys searches API keys by user ID and/or keyword (name)
func (r *apiKeyRepository) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.APIKey, error) { func (r *apiKeyRepository) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.APIKey, error) {
q := r.activeQuery() q := r.activeQuery()
@@ -654,6 +690,7 @@ func groupEntityToService(g *dbent.Group) *service.Group {
RequireOAuthOnly: g.RequireOauthOnly, RequireOAuthOnly: g.RequireOauthOnly,
RequirePrivacySet: g.RequirePrivacySet, RequirePrivacySet: g.RequirePrivacySet,
DefaultMappedModel: g.DefaultMappedModel, DefaultMappedModel: g.DefaultMappedModel,
MessagesDispatchModelConfig: g.MessagesDispatchModelConfig,
CreatedAt: g.CreatedAt, CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt, UpdatedAt: g.UpdatedAt,
} }

View File

@@ -86,6 +86,45 @@ func (s *APIKeyRepoSuite) TestGetByKey_NotFound() {
s.Require().Error(err, "expected error for non-existent key") s.Require().Error(err, "expected error for non-existent key")
} }
func (s *APIKeyRepoSuite) TestGetByKeyForAuth_PreservesMessagesDispatchModelConfig() {
user := s.mustCreateUser("getbykey-auth-dispatch@test.com")
group, err := s.client.Group.Create().
SetName("g-auth-dispatch").
SetPlatform(service.PlatformOpenAI).
SetStatus(service.StatusActive).
SetSubscriptionType(service.SubscriptionTypeStandard).
SetRateMultiplier(1).
SetAllowMessagesDispatch(true).
SetDefaultMappedModel("gpt-5.4").
SetMessagesDispatchModelConfig(service.OpenAIMessagesDispatchModelConfig{
OpusMappedModel: "gpt-5.4-nano",
SonnetMappedModel: "gpt-5.3-codex",
HaikuMappedModel: "gpt-5.4-mini",
ExactModelMappings: map[string]string{
"claude-sonnet-4.5": "gpt-5.4-nano",
},
}).
Save(s.ctx)
s.Require().NoError(err)
key := &service.APIKey{
UserID: user.ID,
Key: "sk-getbykey-auth-dispatch",
Name: "Dispatch Key",
GroupID: &group.ID,
Status: service.StatusActive,
}
s.Require().NoError(s.repo.Create(s.ctx, key))
got, err := s.repo.GetByKeyForAuth(s.ctx, key.Key)
s.Require().NoError(err)
s.Require().NotNil(got.Group)
s.Require().True(got.Group.AllowMessagesDispatch)
s.Require().Equal("gpt-5.4", got.Group.DefaultMappedModel)
s.Require().Equal("gpt-5.4-nano", got.Group.MessagesDispatchModelConfig.OpusMappedModel)
s.Require().Equal("gpt-5.4-nano", got.Group.MessagesDispatchModelConfig.ExactModelMappings["claude-sonnet-4.5"])
}
// --- Update --- // --- Update ---
func (s *APIKeyRepoSuite) TestUpdate() { func (s *APIKeyRepoSuite) TestUpdate() {

View File

@@ -0,0 +1,74 @@
package repository
import (
"context"
"testing"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
func TestGroupEntityToService_PreservesMessagesDispatchModelConfig(t *testing.T) {
group := &dbent.Group{
ID: 1,
Name: "openai-dispatch",
Platform: service.PlatformOpenAI,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
RateMultiplier: 1,
AllowMessagesDispatch: true,
DefaultMappedModel: "gpt-5.4",
MessagesDispatchModelConfig: service.OpenAIMessagesDispatchModelConfig{
OpusMappedModel: "gpt-5.4-nano",
SonnetMappedModel: "gpt-5.3-codex",
HaikuMappedModel: "gpt-5.4-mini",
ExactModelMappings: map[string]string{
"claude-sonnet-4.5": "gpt-5.4-nano",
},
},
}
got := groupEntityToService(group)
require.NotNil(t, got)
require.Equal(t, group.MessagesDispatchModelConfig, got.MessagesDispatchModelConfig)
}
func TestAPIKeyRepository_GetByKeyForAuth_PreservesMessagesDispatchModelConfig_SQLite(t *testing.T) {
repo, client := newAPIKeyRepoSQLite(t)
ctx := context.Background()
user := mustCreateAPIKeyRepoUser(t, ctx, client, "getbykey-auth-dispatch-unit@test.com")
group, err := client.Group.Create().
SetName("g-auth-dispatch-unit").
SetPlatform(service.PlatformOpenAI).
SetStatus(service.StatusActive).
SetSubscriptionType(service.SubscriptionTypeStandard).
SetRateMultiplier(1).
SetAllowMessagesDispatch(true).
SetDefaultMappedModel("gpt-5.4").
SetMessagesDispatchModelConfig(service.OpenAIMessagesDispatchModelConfig{
OpusMappedModel: "gpt-5.4-nano",
SonnetMappedModel: "gpt-5.3-codex",
HaikuMappedModel: "gpt-5.4-mini",
ExactModelMappings: map[string]string{
"claude-sonnet-4.5": "gpt-5.4-nano",
},
}).
Save(ctx)
require.NoError(t, err)
key := &service.APIKey{
UserID: user.ID,
Key: "sk-getbykey-auth-dispatch-unit",
Name: "Dispatch Key Unit",
GroupID: &group.ID,
Status: service.StatusActive,
}
require.NoError(t, repo.Create(ctx, key))
got, err := repo.GetByKeyForAuth(ctx, key.Key)
require.NoError(t, err)
require.NotNil(t, got.Group)
require.Equal(t, group.MessagesDispatchModelConfig, got.Group.MessagesDispatchModelConfig)
}

View File

@@ -0,0 +1,25 @@
//go:build integration
package repository
import (
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
)
func (s *APIKeyRepoSuite) TestListByUserID_SortByNameAsc() {
user := s.mustCreateUser("sort-name@example.com")
s.mustCreateApiKey(user.ID, "sk-z", "z-key", nil)
s.mustCreateApiKey(user.ID, "sk-a", "a-key", nil)
keys, _, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{
Page: 1,
PageSize: 10,
SortBy: "name",
SortOrder: "asc",
}, service.APIKeyListFilters{})
s.Require().NoError(err)
s.Require().Len(keys, 2)
s.Require().Equal("a-key", keys[0].Name)
s.Require().Equal("z-key", keys[1].Name)
}

View File

@@ -188,8 +188,8 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati
// 查询 channel 列表 // 查询 channel 列表
dataQuery := fmt.Sprintf( dataQuery := fmt.Sprintf(
`SELECT c.id, c.name, c.description, c.status, c.model_mapping, c.billing_model_source, c.restrict_models, c.created_at, c.updated_at `SELECT c.id, c.name, c.description, c.status, c.model_mapping, c.billing_model_source, c.restrict_models, c.created_at, c.updated_at
FROM channels c WHERE %s ORDER BY c.id ASC LIMIT $%d OFFSET $%d`, FROM channels c WHERE %s ORDER BY %s LIMIT $%d OFFSET $%d`,
whereClause, argIdx, argIdx+1, whereClause, channelListOrderBy(params), argIdx, argIdx+1,
) )
args = append(args, pageSize, offset) args = append(args, pageSize, offset)
@@ -246,6 +246,31 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati
return channels, paginationResult, nil return channels, paginationResult, nil
} }
func channelListOrderBy(params pagination.PaginationParams) string {
sortBy := strings.ToLower(strings.TrimSpace(params.SortBy))
sortOrder := strings.ToUpper(params.NormalizedSortOrder(pagination.SortOrderAsc))
var column string
switch sortBy {
case "":
column = "c.id"
sortOrder = "ASC"
case "id":
column = "c.id"
case "name":
column = "c.name"
case "status":
column = "c.status"
case "created_at":
column = "c.created_at"
default:
column = "c.id"
sortOrder = "ASC"
}
return fmt.Sprintf("%s %s, c.id %s", column, sortOrder, sortOrder)
}
func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, error) { func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, error) {
rows, err := r.db.QueryContext(ctx, rows, err := r.db.QueryContext(ctx,
`SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, created_at, updated_at FROM channels ORDER BY id`, `SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, created_at, updated_at FROM channels ORDER BY id`,

View File

@@ -8,6 +8,7 @@ import (
"fmt" "fmt"
"testing" "testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/lib/pq" "github.com/lib/pq"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@@ -225,3 +226,12 @@ func TestIsUniqueViolation(t *testing.T) {
}) })
} }
} }
func TestChannelListOrderBy_AllowsDescendingIDSort(t *testing.T) {
params := pagination.PaginationParams{
SortBy: "id",
SortOrder: "desc",
}
require.Equal(t, "c.id DESC, c.id DESC", channelListOrderBy(params))
}

View File

@@ -5,6 +5,7 @@ import (
"database/sql" "database/sql"
"errors" "errors"
"fmt" "fmt"
"sort"
"strings" "strings"
dbent "github.com/Wei-Shaw/sub2api/ent" dbent "github.com/Wei-Shaw/sub2api/ent"
@@ -14,6 +15,8 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq" "github.com/lib/pq"
entsql "entgo.io/ent/dialect/sql"
) )
type sqlExecutor interface { type sqlExecutor interface {
@@ -40,6 +43,7 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
SetDescription(groupIn.Description). SetDescription(groupIn.Description).
SetPlatform(groupIn.Platform). SetPlatform(groupIn.Platform).
SetRateMultiplier(groupIn.RateMultiplier). SetRateMultiplier(groupIn.RateMultiplier).
SetSortOrder(groupIn.SortOrder).
SetIsExclusive(groupIn.IsExclusive). SetIsExclusive(groupIn.IsExclusive).
SetStatus(groupIn.Status). SetStatus(groupIn.Status).
SetSubscriptionType(groupIn.SubscriptionType). SetSubscriptionType(groupIn.SubscriptionType).
@@ -233,11 +237,18 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination
return nil, nil, err return nil, nil, err
} }
groups, err := q. if strings.EqualFold(strings.TrimSpace(params.SortBy), "account_count") {
return r.listWithAccountCountSort(ctx, q, params, total)
}
groupsQuery := q.
Offset(params.Offset()). Offset(params.Offset()).
Limit(params.Limit()). Limit(params.Limit())
Order(dbent.Asc(group.FieldSortOrder), dbent.Asc(group.FieldID)). for _, order := range groupListOrder(params) {
All(ctx) groupsQuery = groupsQuery.Order(order)
}
groups, err := groupsQuery.All(ctx)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@@ -263,6 +274,104 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination
return outGroups, paginationResultFromTotal(int64(total), params), nil return outGroups, paginationResultFromTotal(int64(total), params), nil
} }
func (r *groupRepository) listWithAccountCountSort(ctx context.Context, q *dbent.GroupQuery, params pagination.PaginationParams, total int) ([]service.Group, *pagination.PaginationResult, error) {
groups, err := q.
Order(dbent.Asc(group.FieldSortOrder), dbent.Asc(group.FieldID)).
All(ctx)
if err != nil {
return nil, nil, err
}
groupIDs := make([]int64, 0, len(groups))
outGroups := make([]service.Group, 0, len(groups))
for i := range groups {
g := groupEntityToService(groups[i])
outGroups = append(outGroups, *g)
groupIDs = append(groupIDs, g.ID)
}
counts, err := r.loadAccountCounts(ctx, groupIDs)
if err != nil {
return nil, nil, err
}
for i := range outGroups {
c := counts[outGroups[i].ID]
outGroups[i].AccountCount = c.Total
outGroups[i].ActiveAccountCount = c.Active
outGroups[i].RateLimitedAccountCount = c.RateLimited
}
sortOrder := params.NormalizedSortOrder(pagination.SortOrderDesc)
sort.SliceStable(outGroups, func(i, j int) bool {
if outGroups[i].AccountCount == outGroups[j].AccountCount {
if outGroups[i].SortOrder == outGroups[j].SortOrder {
return outGroups[i].ID < outGroups[j].ID
}
return outGroups[i].SortOrder < outGroups[j].SortOrder
}
if sortOrder == pagination.SortOrderAsc {
return outGroups[i].AccountCount < outGroups[j].AccountCount
}
return outGroups[i].AccountCount > outGroups[j].AccountCount
})
return paginateSlice(outGroups, params), paginationResultFromTotal(int64(total), params), nil
}
func groupListOrder(params pagination.PaginationParams) []func(*entsql.Selector) {
sortBy := strings.ToLower(strings.TrimSpace(params.SortBy))
sortOrder := params.NormalizedSortOrder(pagination.SortOrderAsc)
var field string
tieField := group.FieldID
defaultOrder := true
switch sortBy {
case "", "sort_order":
field = group.FieldSortOrder
case "name":
field = group.FieldName
defaultOrder = false
case "platform":
field = group.FieldPlatform
defaultOrder = false
case "billing_type", "subscription_type":
field = group.FieldSubscriptionType
defaultOrder = false
case "rate_multiplier":
field = group.FieldRateMultiplier
defaultOrder = false
case "is_exclusive":
field = group.FieldIsExclusive
defaultOrder = false
case "status":
field = group.FieldStatus
defaultOrder = false
case "created_at":
field = group.FieldCreatedAt
defaultOrder = false
case "id":
field = group.FieldID
defaultOrder = false
tieField = ""
default:
field = group.FieldSortOrder
}
if sortOrder == pagination.SortOrderDesc && sortBy != "" {
if tieField == "" {
return []func(*entsql.Selector){dbent.Desc(field)}
}
return []func(*entsql.Selector){dbent.Desc(field), dbent.Desc(tieField)}
}
if defaultOrder {
return []func(*entsql.Selector){dbent.Asc(group.FieldSortOrder), dbent.Asc(group.FieldID)}
}
if tieField == "" {
return []func(*entsql.Selector){dbent.Asc(field)}
}
return []func(*entsql.Selector){dbent.Asc(field), dbent.Asc(tieField)}
}
func (r *groupRepository) ListActive(ctx context.Context) ([]service.Group, error) { func (r *groupRepository) ListActive(ctx context.Context) ([]service.Group, error) {
groups, err := r.client.Group.Query(). groups, err := r.client.Group.Query().
Where(group.StatusEQ(service.StatusActive)). Where(group.StatusEQ(service.StatusActive)).

View File

@@ -113,6 +113,33 @@ func (s *GroupRepoSuite) TestUpdate() {
s.Require().Equal("updated", got.Name) s.Require().Equal("updated", got.Name)
} }
func (s *GroupRepoSuite) TestGetByID_PreservesMessagesDispatchModelConfig() {
group := &service.Group{
Name: "openai-dispatch",
Platform: service.PlatformOpenAI,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
AllowMessagesDispatch: true,
DefaultMappedModel: "gpt-5.4",
MessagesDispatchModelConfig: service.OpenAIMessagesDispatchModelConfig{
OpusMappedModel: "gpt-5.4",
SonnetMappedModel: "gpt-5.3-codex",
HaikuMappedModel: "gpt-5.4-mini",
ExactModelMappings: map[string]string{
"claude-sonnet-4.5": "gpt-5.4-nano",
},
},
}
s.Require().NoError(s.repo.Create(s.ctx, group))
got, err := s.repo.GetByID(s.ctx, group.ID)
s.Require().NoError(err)
s.Require().Equal(group.MessagesDispatchModelConfig, got.MessagesDispatchModelConfig)
}
func (s *GroupRepoSuite) TestDelete() { func (s *GroupRepoSuite) TestDelete() {
group := &service.Group{ group := &service.Group{
Name: "to-delete", Name: "to-delete",

View File

@@ -0,0 +1,50 @@
//go:build integration
package repository
import (
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
)
func (s *GroupRepoSuite) TestList_DefaultSortBySortOrderAsc() {
g1 := &service.Group{Name: "g1", Platform: service.PlatformAnthropic, RateMultiplier: 1, Status: service.StatusActive, SubscriptionType: service.SubscriptionTypeStandard, SortOrder: 20}
g2 := &service.Group{Name: "g2", Platform: service.PlatformAnthropic, RateMultiplier: 1, Status: service.StatusActive, SubscriptionType: service.SubscriptionTypeStandard, SortOrder: 10}
s.Require().NoError(s.repo.Create(s.ctx, g1))
s.Require().NoError(s.repo.Create(s.ctx, g2))
groups, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 100})
s.Require().NoError(err)
s.Require().GreaterOrEqual(len(groups), 2)
indexByID := make(map[int64]int, len(groups))
for i, g := range groups {
indexByID[g.ID] = i
}
s.Require().Contains(indexByID, g1.ID)
s.Require().Contains(indexByID, g2.ID)
// g2 has SortOrder=10, g1 has SortOrder=20; ascending means g2 comes first
s.Require().Less(indexByID[g2.ID], indexByID[g1.ID])
}
func (s *GroupRepoSuite) TestList_SortBySortOrderDesc() {
g1 := &service.Group{Name: "g1", Platform: service.PlatformAnthropic, RateMultiplier: 1, Status: service.StatusActive, SubscriptionType: service.SubscriptionTypeStandard, SortOrder: 40}
g2 := &service.Group{Name: "g2", Platform: service.PlatformAnthropic, RateMultiplier: 1, Status: service.StatusActive, SubscriptionType: service.SubscriptionTypeStandard, SortOrder: 50}
s.Require().NoError(s.repo.Create(s.ctx, g1))
s.Require().NoError(s.repo.Create(s.ctx, g2))
groups, _, err := s.repo.List(s.ctx, pagination.PaginationParams{
Page: 1,
PageSize: 10,
SortBy: "sort_order",
SortOrder: "desc",
})
s.Require().NoError(err)
s.Require().GreaterOrEqual(len(groups), 2)
indexByID := make(map[int64]int, len(groups))
for i, group := range groups {
indexByID[group.ID] = i
}
s.Require().Contains(indexByID, g1.ID)
s.Require().Contains(indexByID, g2.ID)
s.Require().Less(indexByID[g2.ID], indexByID[g1.ID])
}

View File

@@ -14,3 +14,22 @@ func paginationResultFromTotal(total int64, params pagination.PaginationParams)
Pages: pages, Pages: pages,
} }
} }
func paginateSlice[T any](items []T, params pagination.PaginationParams) []T {
if len(items) == 0 {
return []T{}
}
offset := params.Offset()
if offset >= len(items) {
return []T{}
}
limit := params.Limit()
end := offset + limit
if end > len(items) {
end = len(items)
}
return items[offset:end]
}

View File

@@ -2,12 +2,15 @@ package repository
import ( import (
"context" "context"
"strings"
dbent "github.com/Wei-Shaw/sub2api/ent" dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/promocode" "github.com/Wei-Shaw/sub2api/ent/promocode"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage" "github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
entsql "entgo.io/ent/dialect/sql"
) )
type promoCodeRepository struct { type promoCodeRepository struct {
@@ -137,11 +140,14 @@ func (r *promoCodeRepository) ListWithFilters(ctx context.Context, params pagina
return nil, nil, err return nil, nil, err
} }
codes, err := q. codesQuery := q.
Offset(params.Offset()). Offset(params.Offset()).
Limit(params.Limit()). Limit(params.Limit())
Order(dbent.Desc(promocode.FieldID)). for _, order := range promoCodeListOrder(params) {
All(ctx) codesQuery = codesQuery.Order(order)
}
codes, err := codesQuery.All(ctx)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@@ -151,6 +157,32 @@ func (r *promoCodeRepository) ListWithFilters(ctx context.Context, params pagina
return outCodes, paginationResultFromTotal(int64(total), params), nil return outCodes, paginationResultFromTotal(int64(total), params), nil
} }
func promoCodeListOrder(params pagination.PaginationParams) []func(*entsql.Selector) {
sortBy := strings.ToLower(strings.TrimSpace(params.SortBy))
sortOrder := params.NormalizedSortOrder(pagination.SortOrderDesc)
var field string
switch sortBy {
case "bonus_amount":
field = promocode.FieldBonusAmount
case "status":
field = promocode.FieldStatus
case "expires_at":
field = promocode.FieldExpiresAt
case "created_at":
field = promocode.FieldCreatedAt
case "code":
field = promocode.FieldCode
default:
field = promocode.FieldID
}
if sortOrder == pagination.SortOrderAsc {
return []func(*entsql.Selector){dbent.Asc(field), dbent.Asc(promocode.FieldID)}
}
return []func(*entsql.Selector){dbent.Desc(field), dbent.Desc(promocode.FieldID)}
}
func (r *promoCodeRepository) CreateUsage(ctx context.Context, usage *service.PromoCodeUsage) error { func (r *promoCodeRepository) CreateUsage(ctx context.Context, usage *service.PromoCodeUsage) error {
client := clientFromContext(ctx, r.client) client := clientFromContext(ctx, r.client)
created, err := client.PromoCodeUsage.Create(). created, err := client.PromoCodeUsage.Create().

View File

@@ -3,12 +3,16 @@ package repository
import ( import (
"context" "context"
"database/sql" "database/sql"
"sort"
"strings"
dbent "github.com/Wei-Shaw/sub2api/ent" dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/proxy" "github.com/Wei-Shaw/sub2api/ent/proxy"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
entsql "entgo.io/ent/dialect/sql"
) )
type sqlQuerier interface { type sqlQuerier interface {
@@ -135,11 +139,14 @@ func (r *proxyRepository) ListWithFilters(ctx context.Context, params pagination
return nil, nil, err return nil, nil, err
} }
proxies, err := q. proxiesQuery := q.
Offset(params.Offset()). Offset(params.Offset()).
Limit(params.Limit()). Limit(params.Limit())
Order(dbent.Desc(proxy.FieldID)). for _, order := range proxyListOrder(params) {
All(ctx) proxiesQuery = proxiesQuery.Order(order)
}
proxies, err := proxiesQuery.All(ctx)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@@ -170,22 +177,58 @@ func (r *proxyRepository) ListWithFiltersAndAccountCount(ctx context.Context, pa
return nil, nil, err return nil, nil, err
} }
proxies, err := q. if strings.EqualFold(strings.TrimSpace(params.SortBy), "account_count") {
return r.listWithAccountCountSort(ctx, q, params, total)
}
proxiesQuery := q.
Offset(params.Offset()). Offset(params.Offset()).
Limit(params.Limit()). Limit(params.Limit())
for _, order := range proxyListOrder(params) {
proxiesQuery = proxiesQuery.Order(order)
}
proxies, err := proxiesQuery.All(ctx)
if err != nil {
return nil, nil, err
}
return r.buildProxyWithAccountCountResult(ctx, proxies, params, int64(total))
}
func (r *proxyRepository) listWithAccountCountSort(ctx context.Context, q *dbent.ProxyQuery, params pagination.PaginationParams, total int) ([]service.ProxyWithAccountCount, *pagination.PaginationResult, error) {
proxies, err := q.
Order(dbent.Desc(proxy.FieldID)). Order(dbent.Desc(proxy.FieldID)).
All(ctx) All(ctx)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
// Get account counts result, _, err := r.buildProxyWithAccountCountResult(ctx, proxies, params, int64(total))
if err != nil {
return nil, nil, err
}
sortOrder := params.NormalizedSortOrder(pagination.SortOrderDesc)
sort.SliceStable(result, func(i, j int) bool {
if result[i].AccountCount == result[j].AccountCount {
return result[i].ID > result[j].ID
}
if sortOrder == pagination.SortOrderAsc {
return result[i].AccountCount < result[j].AccountCount
}
return result[i].AccountCount > result[j].AccountCount
})
return paginateSlice(result, params), paginationResultFromTotal(int64(total), params), nil
}
func (r *proxyRepository) buildProxyWithAccountCountResult(ctx context.Context, proxies []*dbent.Proxy, params pagination.PaginationParams, total int64) ([]service.ProxyWithAccountCount, *pagination.PaginationResult, error) {
counts, err := r.GetAccountCountsForProxies(ctx) counts, err := r.GetAccountCountsForProxies(ctx)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
// Build result with account counts
result := make([]service.ProxyWithAccountCount, 0, len(proxies)) result := make([]service.ProxyWithAccountCount, 0, len(proxies))
for i := range proxies { for i := range proxies {
proxyOut := proxyEntityToService(proxies[i]) proxyOut := proxyEntityToService(proxies[i])
@@ -198,7 +241,31 @@ func (r *proxyRepository) ListWithFiltersAndAccountCount(ctx context.Context, pa
}) })
} }
return result, paginationResultFromTotal(int64(total), params), nil return result, paginationResultFromTotal(total, params), nil
}
func proxyListOrder(params pagination.PaginationParams) []func(*entsql.Selector) {
sortBy := strings.ToLower(strings.TrimSpace(params.SortBy))
sortOrder := params.NormalizedSortOrder(pagination.SortOrderDesc)
var field string
switch sortBy {
case "name":
field = proxy.FieldName
case "protocol":
field = proxy.FieldProtocol
case "status":
field = proxy.FieldStatus
case "created_at":
field = proxy.FieldCreatedAt
default:
field = proxy.FieldID
}
if sortOrder == pagination.SortOrderAsc {
return []func(*entsql.Selector){dbent.Asc(field), dbent.Asc(proxy.FieldID)}
}
return []func(*entsql.Selector){dbent.Desc(field), dbent.Desc(proxy.FieldID)}
} }
func (r *proxyRepository) ListActive(ctx context.Context) ([]service.Proxy, error) { func (r *proxyRepository) ListActive(ctx context.Context) ([]service.Proxy, error) {

View File

@@ -0,0 +1,28 @@
//go:build integration
package repository
import (
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
)
func (s *ProxyRepoSuite) TestListWithFiltersAndAccountCount_SortByAccountCountDesc() {
p1 := s.mustCreateProxy(&service.Proxy{Name: "p1", Protocol: "http", Host: "127.0.0.1", Port: 8080, Status: service.StatusActive})
p2 := s.mustCreateProxy(&service.Proxy{Name: "p2", Protocol: "http", Host: "127.0.0.1", Port: 8081, Status: service.StatusActive})
s.mustInsertAccount("a1", &p1.ID)
s.mustInsertAccount("a2", &p1.ID)
s.mustInsertAccount("a3", &p2.ID)
proxies, _, err := s.repo.ListWithFiltersAndAccountCount(s.ctx, pagination.PaginationParams{
Page: 1,
PageSize: 10,
SortBy: "account_count",
SortOrder: "desc",
}, "", "", "")
s.Require().NoError(err)
s.Require().Len(proxies, 2)
s.Require().Equal(p1.ID, proxies[0].ID)
s.Require().Equal(int64(2), proxies[0].AccountCount)
s.Require().Equal(p2.ID, proxies[1].ID)
}

View File

@@ -2,6 +2,7 @@ package repository
import ( import (
"context" "context"
"strings"
"time" "time"
dbent "github.com/Wei-Shaw/sub2api/ent" dbent "github.com/Wei-Shaw/sub2api/ent"
@@ -9,6 +10,8 @@ import (
"github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
entsql "entgo.io/ent/dialect/sql"
) )
type redeemCodeRepository struct { type redeemCodeRepository struct {
@@ -120,13 +123,16 @@ func (r *redeemCodeRepository) ListWithFilters(ctx context.Context, params pagin
return nil, nil, err return nil, nil, err
} }
codes, err := q. codesQuery := q.
WithUser(). WithUser().
WithGroup(). WithGroup().
Offset(params.Offset()). Offset(params.Offset()).
Limit(params.Limit()). Limit(params.Limit())
Order(dbent.Desc(redeemcode.FieldID)). for _, order := range redeemCodeListOrder(params) {
All(ctx) codesQuery = codesQuery.Order(order)
}
codes, err := codesQuery.All(ctx)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@@ -136,6 +142,34 @@ func (r *redeemCodeRepository) ListWithFilters(ctx context.Context, params pagin
return outCodes, paginationResultFromTotal(int64(total), params), nil return outCodes, paginationResultFromTotal(int64(total), params), nil
} }
func redeemCodeListOrder(params pagination.PaginationParams) []func(*entsql.Selector) {
sortBy := strings.ToLower(strings.TrimSpace(params.SortBy))
sortOrder := params.NormalizedSortOrder(pagination.SortOrderDesc)
var field string
switch sortBy {
case "type":
field = redeemcode.FieldType
case "value":
field = redeemcode.FieldValue
case "status":
field = redeemcode.FieldStatus
case "used_at":
field = redeemcode.FieldUsedAt
case "created_at":
field = redeemcode.FieldCreatedAt
case "code":
field = redeemcode.FieldCode
default:
field = redeemcode.FieldID
}
if sortOrder == pagination.SortOrderAsc {
return []func(*entsql.Selector){dbent.Asc(field), dbent.Asc(redeemcode.FieldID)}
}
return []func(*entsql.Selector){dbent.Desc(field), dbent.Desc(redeemcode.FieldID)}
}
func (r *redeemCodeRepository) Update(ctx context.Context, code *service.RedeemCode) error { func (r *redeemCodeRepository) Update(ctx context.Context, code *service.RedeemCode) error {
up := r.client.RedeemCode.UpdateOneID(code.ID). up := r.client.RedeemCode.UpdateOneID(code.ID).
SetCode(code.Code). SetCode(code.Code).

View File

@@ -0,0 +1,24 @@
//go:build integration
package repository
import (
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
)
func (s *RedeemCodeRepoSuite) TestListWithFilters_SortByValueAsc() {
s.Require().NoError(s.repo.Create(s.ctx, &service.RedeemCode{Code: "VALUE-20", Type: service.RedeemTypeBalance, Value: 20, Status: service.StatusUnused}))
s.Require().NoError(s.repo.Create(s.ctx, &service.RedeemCode{Code: "VALUE-10", Type: service.RedeemTypeBalance, Value: 10, Status: service.StatusUnused}))
codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{
Page: 1,
PageSize: 10,
SortBy: "value",
SortOrder: "asc",
}, "", "", "")
s.Require().NoError(err)
s.Require().Len(codes, 2)
s.Require().Equal("VALUE-10", codes[0].Code)
s.Require().Equal("VALUE-20", codes[1].Code)
}

View File

@@ -3771,7 +3771,7 @@ func (r *usageLogRepository) listUsageLogsWithPagination(ctx context.Context, wh
limitPos := len(args) + 1 limitPos := len(args) + 1
offsetPos := len(args) + 2 offsetPos := len(args) + 2
listArgs := append(append([]any{}, args...), params.Limit(), params.Offset()) listArgs := append(append([]any{}, args...), params.Limit(), params.Offset())
query := fmt.Sprintf("SELECT %s FROM usage_logs %s ORDER BY id DESC LIMIT $%d OFFSET $%d", usageLogSelectColumns, whereClause, limitPos, offsetPos) query := fmt.Sprintf("SELECT %s FROM usage_logs %s ORDER BY %s LIMIT $%d OFFSET $%d", usageLogSelectColumns, whereClause, usageLogOrderBy(params), limitPos, offsetPos)
logs, err := r.queryUsageLogs(ctx, query, listArgs...) logs, err := r.queryUsageLogs(ctx, query, listArgs...)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
@@ -3786,7 +3786,7 @@ func (r *usageLogRepository) listUsageLogsWithFastPagination(ctx context.Context
limitPos := len(args) + 1 limitPos := len(args) + 1
offsetPos := len(args) + 2 offsetPos := len(args) + 2
listArgs := append(append([]any{}, args...), limit+1, offset) listArgs := append(append([]any{}, args...), limit+1, offset)
query := fmt.Sprintf("SELECT %s FROM usage_logs %s ORDER BY id DESC LIMIT $%d OFFSET $%d", usageLogSelectColumns, whereClause, limitPos, offsetPos) query := fmt.Sprintf("SELECT %s FROM usage_logs %s ORDER BY %s LIMIT $%d OFFSET $%d", usageLogSelectColumns, whereClause, usageLogOrderBy(params), limitPos, offsetPos)
logs, err := r.queryUsageLogs(ctx, query, listArgs...) logs, err := r.queryUsageLogs(ctx, query, listArgs...)
if err != nil { if err != nil {
@@ -3808,6 +3808,26 @@ func (r *usageLogRepository) listUsageLogsWithFastPagination(ctx context.Context
return logs, paginationResultFromTotal(total, params), nil return logs, paginationResultFromTotal(total, params), nil
} }
func usageLogOrderBy(params pagination.PaginationParams) string {
sortBy := strings.ToLower(strings.TrimSpace(params.SortBy))
sortOrder := strings.ToUpper(params.NormalizedSortOrder(pagination.SortOrderDesc))
var column string
switch sortBy {
case "model":
column = "COALESCE(NULLIF(TRIM(requested_model), ''), model)"
case "created_at":
column = "created_at"
default:
column = "id"
}
if column == "id" {
return fmt.Sprintf("id %s", sortOrder)
}
return fmt.Sprintf("%s %s, id %s", column, sortOrder, sortOrder)
}
func (r *usageLogRepository) queryUsageLogs(ctx context.Context, query string, args ...any) (logs []service.UsageLog, err error) { func (r *usageLogRepository) queryUsageLogs(ctx context.Context, query string, args ...any) (logs []service.UsageLog, err error) {
rows, err := r.sql.QueryContext(ctx, query, args...) rows, err := r.sql.QueryContext(ctx, query, args...)
if err != nil { if err != nil {

View File

@@ -330,6 +330,15 @@ func TestUsageLogRepositoryGetStatsWithFiltersRequestTypePriority(t *testing.T)
"total_account_cost", "total_account_cost",
"avg_duration_ms", "avg_duration_ms",
}).AddRow(int64(1), int64(2), int64(3), int64(4), 1.2, 1.0, 1.2, 20.0)) }).AddRow(int64(1), int64(2), int64(3), int64(4), 1.2, 1.0, 1.2, 20.0))
mock.ExpectQuery("SELECT COALESCE\\(NULLIF\\(TRIM\\(inbound_endpoint\\), ''\\), 'unknown'\\) AS endpoint").
WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg(), requestType).
WillReturnRows(sqlmock.NewRows([]string{"endpoint", "requests", "total_tokens", "cost", "actual_cost"}))
mock.ExpectQuery("SELECT COALESCE\\(NULLIF\\(TRIM\\(upstream_endpoint\\), ''\\), 'unknown'\\) AS endpoint").
WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg(), requestType).
WillReturnRows(sqlmock.NewRows([]string{"endpoint", "requests", "total_tokens", "cost", "actual_cost"}))
mock.ExpectQuery("SELECT CONCAT\\(").
WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg(), requestType).
WillReturnRows(sqlmock.NewRows([]string{"endpoint", "requests", "total_tokens", "cost", "actual_cost"}))
stats, err := repo.GetStatsWithFilters(context.Background(), filters) stats, err := repo.GetStatsWithFilters(context.Background(), filters)
require.NoError(t, err) require.NoError(t, err)

View File

@@ -0,0 +1,61 @@
//go:build integration
package repository
import (
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/google/uuid"
)
func (s *UsageLogRepoSuite) TestListWithFilters_SortByModelAsc() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "usage-sort@example.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-usage-sort", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "usage-sort-account"})
first := &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: uuid.New().String(),
Model: "z-model",
RequestedModel: "z-model",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: time.Now(),
}
_, err := s.repo.Create(s.ctx, first)
s.Require().NoError(err)
second := &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: uuid.New().String(),
Model: "a-model",
RequestedModel: "a-model",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: time.Now().Add(time.Second),
}
_, err = s.repo.Create(s.ctx, second)
s.Require().NoError(err)
logs, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{
Page: 1,
PageSize: 10,
SortBy: "model",
SortOrder: "asc",
}, usagestats.UsageLogFilters{UserID: user.ID})
s.Require().NoError(err)
s.Require().Len(logs, 2)
s.Require().Equal("a-model", logs[0].RequestedModel)
s.Require().Equal("z-model", logs[1].RequestedModel)
}

View File

@@ -17,6 +17,8 @@ import (
"github.com/Wei-Shaw/sub2api/ent/usersubscription" "github.com/Wei-Shaw/sub2api/ent/usersubscription"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
entsql "entgo.io/ent/dialect/sql"
) )
type userRepository struct { type userRepository struct {
@@ -224,11 +226,14 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
return nil, nil, err return nil, nil, err
} }
users, err := q. usersQuery := q.
Offset(params.Offset()). Offset(params.Offset()).
Limit(params.Limit()). Limit(params.Limit())
Order(dbent.Desc(dbuser.FieldID)). for _, order := range userListOrder(params) {
All(ctx) usersQuery = usersQuery.Order(order)
}
users, err := usersQuery.All(ctx)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@@ -281,6 +286,50 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
return outUsers, paginationResultFromTotal(int64(total), params), nil return outUsers, paginationResultFromTotal(int64(total), params), nil
} }
func userListOrder(params pagination.PaginationParams) []func(*entsql.Selector) {
sortBy := strings.ToLower(strings.TrimSpace(params.SortBy))
sortOrder := params.NormalizedSortOrder(pagination.SortOrderDesc)
var field string
defaultField := true
switch sortBy {
case "email":
field = dbuser.FieldEmail
defaultField = false
case "username":
field = dbuser.FieldUsername
defaultField = false
case "role":
field = dbuser.FieldRole
defaultField = false
case "balance":
field = dbuser.FieldBalance
defaultField = false
case "concurrency":
field = dbuser.FieldConcurrency
defaultField = false
case "status":
field = dbuser.FieldStatus
defaultField = false
case "created_at":
field = dbuser.FieldCreatedAt
defaultField = false
default:
field = dbuser.FieldID
}
if sortOrder == pagination.SortOrderAsc {
if defaultField && field == dbuser.FieldID {
return []func(*entsql.Selector){dbent.Asc(dbuser.FieldID)}
}
return []func(*entsql.Selector){dbent.Asc(field), dbent.Asc(dbuser.FieldID)}
}
if defaultField && field == dbuser.FieldID {
return []func(*entsql.Selector){dbent.Desc(dbuser.FieldID)}
}
return []func(*entsql.Selector){dbent.Desc(field), dbent.Desc(dbuser.FieldID)}
}
// filterUsersByAttributes returns user IDs that match ALL the given attribute filters // filterUsersByAttributes returns user IDs that match ALL the given attribute filters
func (r *userRepository) filterUsersByAttributes(ctx context.Context, attrs map[int64]string) ([]int64, error) { func (r *userRepository) filterUsersByAttributes(ctx context.Context, attrs map[int64]string) ([]int64, error) {
if len(attrs) == 0 { if len(attrs) == 0 {

View File

@@ -0,0 +1,39 @@
//go:build integration
package repository
import (
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
)
func (s *UserRepoSuite) TestListWithFilters_SortByEmailAsc() {
s.mustCreateUser(&service.User{Email: "z-last@example.com", Username: "z-user"})
s.mustCreateUser(&service.User{Email: "a-first@example.com", Username: "a-user"})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{
Page: 1,
PageSize: 10,
SortBy: "email",
SortOrder: "asc",
}, service.UserListFilters{})
s.Require().NoError(err)
s.Require().Len(users, 2)
s.Require().Equal("a-first@example.com", users[0].Email)
s.Require().Equal("z-last@example.com", users[1].Email)
}
func (s *UserRepoSuite) TestList_DefaultSortByNewestFirst() {
first := s.mustCreateUser(&service.User{Email: "first@example.com"})
second := s.mustCreateUser(&service.User{Email: "second@example.com"})
users, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err)
s.Require().Len(users, 2)
s.Require().Equal(second.ID, users[0].ID)
s.Require().Equal(first.ID, users[1].ID)
}
func TestUserRepoSortSuiteSmoke(_ *testing.T) {}

View File

@@ -491,8 +491,10 @@ func TestAPIContracts(t *testing.T) {
service.SettingKeyContactInfo: "support", service.SettingKeyContactInfo: "support",
service.SettingKeyDocURL: "https://docs.example.com", service.SettingKeyDocURL: "https://docs.example.com",
service.SettingKeyDefaultConcurrency: "5", service.SettingKeyDefaultConcurrency: "5",
service.SettingKeyDefaultBalance: "1.25", service.SettingKeyDefaultBalance: "1.25",
service.SettingKeyTableDefaultPageSize: "20",
service.SettingKeyTablePageSizeOptions: "[10,20,50,100]",
service.SettingKeyOpsMonitoringEnabled: "false", service.SettingKeyOpsMonitoringEnabled: "false",
service.SettingKeyOpsRealtimeMonitoringEnabled: "true", service.SettingKeyOpsRealtimeMonitoringEnabled: "true",
@@ -576,6 +578,8 @@ func TestAPIContracts(t *testing.T) {
"hide_ccs_import_button": false, "hide_ccs_import_button": false,
"purchase_subscription_enabled": false, "purchase_subscription_enabled": false,
"purchase_subscription_url": "", "purchase_subscription_url": "",
"table_default_page_size": 20,
"table_page_size_options": [10, 20, 50, 100],
"min_claude_code_version": "", "min_claude_code_version": "",
"max_claude_code_version": "", "max_claude_code_version": "",
"allow_ungrouped_key_scheduling": false, "allow_ungrouped_key_scheduling": false,

View File

@@ -21,13 +21,13 @@ import (
// AdminService interface defines admin management operations // AdminService interface defines admin management operations
type AdminService interface { type AdminService interface {
// User management // User management
ListUsers(ctx context.Context, page, pageSize int, filters UserListFilters) ([]User, int64, error) ListUsers(ctx context.Context, page, pageSize int, filters UserListFilters, sortBy, sortOrder string) ([]User, int64, error)
GetUser(ctx context.Context, id int64) (*User, error) GetUser(ctx context.Context, id int64) (*User, error)
CreateUser(ctx context.Context, input *CreateUserInput) (*User, error) CreateUser(ctx context.Context, input *CreateUserInput) (*User, error)
UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error) UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error)
DeleteUser(ctx context.Context, id int64) error DeleteUser(ctx context.Context, id int64) error
UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error) UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error)
GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]APIKey, int64, error) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int, sortBy, sortOrder string) ([]APIKey, int64, error)
GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error) GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error)
// GetUserBalanceHistory returns paginated balance/concurrency change records for a user. // GetUserBalanceHistory returns paginated balance/concurrency change records for a user.
// codeType is optional - pass empty string to return all types. // codeType is optional - pass empty string to return all types.
@@ -35,7 +35,7 @@ type AdminService interface {
GetUserBalanceHistory(ctx context.Context, userID int64, page, pageSize int, codeType string) ([]RedeemCode, int64, float64, error) GetUserBalanceHistory(ctx context.Context, userID int64, page, pageSize int, codeType string) ([]RedeemCode, int64, float64, error)
// Group management // Group management
ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool) ([]Group, int64, error) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool, sortBy, sortOrder string) ([]Group, int64, error)
GetAllGroups(ctx context.Context) ([]Group, error) GetAllGroups(ctx context.Context) ([]Group, error)
GetAllGroupsByPlatform(ctx context.Context, platform string) ([]Group, error) GetAllGroupsByPlatform(ctx context.Context, platform string) ([]Group, error)
GetGroup(ctx context.Context, id int64) (*Group, error) GetGroup(ctx context.Context, id int64) (*Group, error)
@@ -55,7 +55,7 @@ type AdminService interface {
ReplaceUserGroup(ctx context.Context, userID, oldGroupID, newGroupID int64) (*ReplaceUserGroupResult, error) ReplaceUserGroup(ctx context.Context, userID, oldGroupID, newGroupID int64) (*ReplaceUserGroupResult, error)
// Account management // Account management
ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, int64, error) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string, sortBy, sortOrder string) ([]Account, int64, error)
GetAccount(ctx context.Context, id int64) (*Account, error) GetAccount(ctx context.Context, id int64) (*Account, error)
GetAccountsByIDs(ctx context.Context, ids []int64) ([]*Account, error) GetAccountsByIDs(ctx context.Context, ids []int64) ([]*Account, error)
CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error) CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error)
@@ -77,8 +77,8 @@ type AdminService interface {
CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error
// Proxy management // Proxy management
ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]Proxy, int64, error) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string, sortBy, sortOrder string) ([]Proxy, int64, error)
ListProxiesWithAccountCount(ctx context.Context, page, pageSize int, protocol, status, search string) ([]ProxyWithAccountCount, int64, error) ListProxiesWithAccountCount(ctx context.Context, page, pageSize int, protocol, status, search string, sortBy, sortOrder string) ([]ProxyWithAccountCount, int64, error)
GetAllProxies(ctx context.Context) ([]Proxy, error) GetAllProxies(ctx context.Context) ([]Proxy, error)
GetAllProxiesWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error) GetAllProxiesWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error)
GetProxy(ctx context.Context, id int64) (*Proxy, error) GetProxy(ctx context.Context, id int64) (*Proxy, error)
@@ -93,7 +93,7 @@ type AdminService interface {
CheckProxyQuality(ctx context.Context, id int64) (*ProxyQualityCheckResult, error) CheckProxyQuality(ctx context.Context, id int64) (*ProxyQualityCheckResult, error)
// Redeem code management // Redeem code management
ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]RedeemCode, int64, error) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string, sortBy, sortOrder string) ([]RedeemCode, int64, error)
GetRedeemCode(ctx context.Context, id int64) (*RedeemCode, error) GetRedeemCode(ctx context.Context, id int64) (*RedeemCode, error)
GenerateRedeemCodes(ctx context.Context, input *GenerateRedeemCodesInput) ([]RedeemCode, error) GenerateRedeemCodes(ctx context.Context, input *GenerateRedeemCodesInput) ([]RedeemCode, error)
DeleteRedeemCode(ctx context.Context, id int64) error DeleteRedeemCode(ctx context.Context, id int64) error
@@ -485,8 +485,8 @@ func NewAdminService(
} }
// User management implementations // User management implementations
func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, filters UserListFilters) ([]User, int64, error) { func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, filters UserListFilters, sortBy, sortOrder string) ([]User, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize, SortBy: sortBy, SortOrder: sortOrder}
users, result, err := s.userRepo.ListWithFilters(ctx, params, filters) users, result, err := s.userRepo.ListWithFilters(ctx, params, filters)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
@@ -753,8 +753,8 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
return user, nil return user, nil
} }
func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]APIKey, int64, error) { func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int, sortBy, sortOrder string) ([]APIKey, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize, SortBy: sortBy, SortOrder: sortOrder}
keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params, APIKeyListFilters{}) keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params, APIKeyListFilters{})
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
@@ -789,8 +789,8 @@ func (s *adminServiceImpl) GetUserBalanceHistory(ctx context.Context, userID int
} }
// Group management implementations // Group management implementations
func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool) ([]Group, int64, error) { func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool, sortBy, sortOrder string) ([]Group, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize, SortBy: sortBy, SortOrder: sortOrder}
groups, result, err := s.groupRepo.ListWithFilters(ctx, params, platform, status, search, isExclusive) groups, result, err := s.groupRepo.ListWithFilters(ctx, params, platform, status, search, isExclusive)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
@@ -1464,8 +1464,8 @@ func (s *adminServiceImpl) ReplaceUserGroup(ctx context.Context, userID, oldGrou
} }
// Account management implementations // Account management implementations
func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, int64, error) { func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string, sortBy, sortOrder string) ([]Account, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize, SortBy: sortBy, SortOrder: sortOrder}
accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search, groupID, privacyMode) accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search, groupID, privacyMode)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
@@ -1893,8 +1893,8 @@ func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64,
} }
// Proxy management implementations // Proxy management implementations
func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]Proxy, int64, error) { func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string, sortBy, sortOrder string) ([]Proxy, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize, SortBy: sortBy, SortOrder: sortOrder}
proxies, result, err := s.proxyRepo.ListWithFilters(ctx, params, protocol, status, search) proxies, result, err := s.proxyRepo.ListWithFilters(ctx, params, protocol, status, search)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
@@ -1902,8 +1902,8 @@ func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int,
return proxies, result.Total, nil return proxies, result.Total, nil
} }
func (s *adminServiceImpl) ListProxiesWithAccountCount(ctx context.Context, page, pageSize int, protocol, status, search string) ([]ProxyWithAccountCount, int64, error) { func (s *adminServiceImpl) ListProxiesWithAccountCount(ctx context.Context, page, pageSize int, protocol, status, search string, sortBy, sortOrder string) ([]ProxyWithAccountCount, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize, SortBy: sortBy, SortOrder: sortOrder}
proxies, result, err := s.proxyRepo.ListWithFiltersAndAccountCount(ctx, params, protocol, status, search) proxies, result, err := s.proxyRepo.ListWithFiltersAndAccountCount(ctx, params, protocol, status, search)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
@@ -2040,8 +2040,8 @@ func (s *adminServiceImpl) CheckProxyExists(ctx context.Context, host string, po
} }
// Redeem code management implementations // Redeem code management implementations
func (s *adminServiceImpl) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]RedeemCode, int64, error) { func (s *adminServiceImpl) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string, sortBy, sortOrder string) ([]RedeemCode, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize, SortBy: sortBy, SortOrder: sortOrder}
codes, result, err := s.redeemCodeRepo.ListWithFilters(ctx, params, codeType, status, search) codes, result, err := s.redeemCodeRepo.ListWithFilters(ctx, params, codeType, status, search)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err

View File

@@ -125,6 +125,22 @@ func (s *groupRepoStubForAdmin) UpdateSortOrders(_ context.Context, _ []GroupSor
return nil return nil
} }
func TestAdminService_ListGroups_PassesSortParams(t *testing.T) {
repo := &groupRepoStubForAdmin{
listWithFiltersGroups: []Group{{ID: 1, Name: "g1"}},
}
svc := &adminServiceImpl{groupRepo: repo}
_, _, err := svc.ListGroups(context.Background(), 3, 25, PlatformOpenAI, StatusActive, "needle", nil, "account_count", "ASC")
require.NoError(t, err)
require.Equal(t, pagination.PaginationParams{
Page: 3,
PageSize: 25,
SortBy: "account_count",
SortOrder: "ASC",
}, repo.listWithFiltersParams)
}
// TestAdminService_CreateGroup_WithImagePricing 测试创建分组时 ImagePrice 字段正确传递 // TestAdminService_CreateGroup_WithImagePricing 测试创建分组时 ImagePrice 字段正确传递
func TestAdminService_CreateGroup_WithImagePricing(t *testing.T) { func TestAdminService_CreateGroup_WithImagePricing(t *testing.T) {
repo := &groupRepoStubForAdmin{} repo := &groupRepoStubForAdmin{}
@@ -373,7 +389,7 @@ func TestAdminService_ListGroups_WithSearch(t *testing.T) {
} }
svc := &adminServiceImpl{groupRepo: repo} svc := &adminServiceImpl{groupRepo: repo}
groups, total, err := svc.ListGroups(context.Background(), 1, 20, "", "", "alpha", nil) groups, total, err := svc.ListGroups(context.Background(), 1, 20, "", "", "alpha", nil, "", "")
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, int64(1), total) require.Equal(t, int64(1), total)
require.Equal(t, []Group{{ID: 1, Name: "alpha"}}, groups) require.Equal(t, []Group{{ID: 1, Name: "alpha"}}, groups)
@@ -391,7 +407,7 @@ func TestAdminService_ListGroups_WithSearch(t *testing.T) {
} }
svc := &adminServiceImpl{groupRepo: repo} svc := &adminServiceImpl{groupRepo: repo}
groups, total, err := svc.ListGroups(context.Background(), 2, 10, "", "", "", nil) groups, total, err := svc.ListGroups(context.Background(), 2, 10, "", "", "", nil, "", "")
require.NoError(t, err) require.NoError(t, err)
require.Empty(t, groups) require.Empty(t, groups)
require.Equal(t, int64(0), total) require.Equal(t, int64(0), total)
@@ -410,7 +426,7 @@ func TestAdminService_ListGroups_WithSearch(t *testing.T) {
} }
svc := &adminServiceImpl{groupRepo: repo} svc := &adminServiceImpl{groupRepo: repo}
groups, total, err := svc.ListGroups(context.Background(), 3, 50, PlatformAntigravity, StatusActive, "beta", &isExclusive) groups, total, err := svc.ListGroups(context.Background(), 3, 50, PlatformAntigravity, StatusActive, "beta", &isExclusive, "", "")
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, int64(42), total) require.Equal(t, int64(42), total)
require.Equal(t, []Group{{ID: 2, Name: "beta"}}, groups) require.Equal(t, []Group{{ID: 2, Name: "beta"}}, groups)

View File

@@ -13,11 +13,13 @@ import (
type userRepoStubForListUsers struct { type userRepoStubForListUsers struct {
userRepoStub userRepoStub
users []User users []User
err error err error
listWithFiltersParams pagination.PaginationParams
} }
func (s *userRepoStubForListUsers) ListWithFilters(_ context.Context, params pagination.PaginationParams, _ UserListFilters) ([]User, *pagination.PaginationResult, error) { func (s *userRepoStubForListUsers) ListWithFilters(_ context.Context, params pagination.PaginationParams, _ UserListFilters) ([]User, *pagination.PaginationResult, error) {
s.listWithFiltersParams = params
if s.err != nil { if s.err != nil {
return nil, nil, s.err return nil, nil, s.err
} }
@@ -103,7 +105,7 @@ func TestAdminService_ListUsers_BatchRateFallbackToSingle(t *testing.T) {
userGroupRateRepo: rateRepo, userGroupRateRepo: rateRepo,
} }
users, total, err := svc.ListUsers(context.Background(), 1, 20, UserListFilters{}) users, total, err := svc.ListUsers(context.Background(), 1, 20, UserListFilters{}, "", "")
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, int64(2), total) require.Equal(t, int64(2), total)
require.Len(t, users, 2) require.Len(t, users, 2)
@@ -112,3 +114,19 @@ func TestAdminService_ListUsers_BatchRateFallbackToSingle(t *testing.T) {
require.Equal(t, 1.1, users[0].GroupRates[11]) require.Equal(t, 1.1, users[0].GroupRates[11])
require.Equal(t, 2.2, users[1].GroupRates[22]) require.Equal(t, 2.2, users[1].GroupRates[22])
} }
func TestAdminService_ListUsers_PassesSortParams(t *testing.T) {
userRepo := &userRepoStubForListUsers{
users: []User{{ID: 1, Email: "a@example.com"}},
}
svc := &adminServiceImpl{userRepo: userRepo}
_, _, err := svc.ListUsers(context.Background(), 2, 50, UserListFilters{}, "email", "ASC")
require.NoError(t, err)
require.Equal(t, pagination.PaginationParams{
Page: 2,
PageSize: 50,
SortBy: "email",
SortOrder: "ASC",
}, userRepo.listWithFiltersParams)
}

View File

@@ -170,13 +170,13 @@ func TestAdminService_ListAccounts_WithSearch(t *testing.T) {
} }
svc := &adminServiceImpl{accountRepo: repo} svc := &adminServiceImpl{accountRepo: repo}
accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformGemini, AccountTypeOAuth, StatusActive, "acc", 0, "") accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformGemini, AccountTypeOAuth, StatusActive, "acc", 0, "", "name", "ASC")
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, int64(10), total) require.Equal(t, int64(10), total)
require.Equal(t, []Account{{ID: 1, Name: "acc"}}, accounts) require.Equal(t, []Account{{ID: 1, Name: "acc"}}, accounts)
require.Equal(t, 1, repo.listWithFiltersCalls) require.Equal(t, 1, repo.listWithFiltersCalls)
require.Equal(t, pagination.PaginationParams{Page: 1, PageSize: 20}, repo.listWithFiltersParams) require.Equal(t, pagination.PaginationParams{Page: 1, PageSize: 20, SortBy: "name", SortOrder: "ASC"}, repo.listWithFiltersParams)
require.Equal(t, PlatformGemini, repo.listWithFiltersPlatform) require.Equal(t, PlatformGemini, repo.listWithFiltersPlatform)
require.Equal(t, AccountTypeOAuth, repo.listWithFiltersType) require.Equal(t, AccountTypeOAuth, repo.listWithFiltersType)
require.Equal(t, StatusActive, repo.listWithFiltersStatus) require.Equal(t, StatusActive, repo.listWithFiltersStatus)
@@ -192,7 +192,7 @@ func TestAdminService_ListAccounts_WithPrivacyMode(t *testing.T) {
} }
svc := &adminServiceImpl{accountRepo: repo} svc := &adminServiceImpl{accountRepo: repo}
accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformOpenAI, AccountTypeOAuth, StatusActive, "acc2", 0, PrivacyModeCFBlocked) accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformOpenAI, AccountTypeOAuth, StatusActive, "acc2", 0, PrivacyModeCFBlocked, "", "")
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, int64(1), total) require.Equal(t, int64(1), total)
require.Equal(t, []Account{{ID: 2, Name: "acc2"}}, accounts) require.Equal(t, []Account{{ID: 2, Name: "acc2"}}, accounts)
@@ -208,13 +208,13 @@ func TestAdminService_ListProxies_WithSearch(t *testing.T) {
} }
svc := &adminServiceImpl{proxyRepo: repo} svc := &adminServiceImpl{proxyRepo: repo}
proxies, total, err := svc.ListProxies(context.Background(), 3, 50, "http", StatusActive, "p1") proxies, total, err := svc.ListProxies(context.Background(), 3, 50, "http", StatusActive, "p1", "name", "ASC")
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, int64(7), total) require.Equal(t, int64(7), total)
require.Equal(t, []Proxy{{ID: 2, Name: "p1"}}, proxies) require.Equal(t, []Proxy{{ID: 2, Name: "p1"}}, proxies)
require.Equal(t, 1, repo.listWithFiltersCalls) require.Equal(t, 1, repo.listWithFiltersCalls)
require.Equal(t, pagination.PaginationParams{Page: 3, PageSize: 50}, repo.listWithFiltersParams) require.Equal(t, pagination.PaginationParams{Page: 3, PageSize: 50, SortBy: "name", SortOrder: "ASC"}, repo.listWithFiltersParams)
require.Equal(t, "http", repo.listWithFiltersProtocol) require.Equal(t, "http", repo.listWithFiltersProtocol)
require.Equal(t, StatusActive, repo.listWithFiltersStatus) require.Equal(t, StatusActive, repo.listWithFiltersStatus)
require.Equal(t, "p1", repo.listWithFiltersSearch) require.Equal(t, "p1", repo.listWithFiltersSearch)
@@ -229,13 +229,13 @@ func TestAdminService_ListProxiesWithAccountCount_WithSearch(t *testing.T) {
} }
svc := &adminServiceImpl{proxyRepo: repo} svc := &adminServiceImpl{proxyRepo: repo}
proxies, total, err := svc.ListProxiesWithAccountCount(context.Background(), 2, 10, "socks5", StatusDisabled, "p2") proxies, total, err := svc.ListProxiesWithAccountCount(context.Background(), 2, 10, "socks5", StatusDisabled, "p2", "account_count", "DESC")
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, int64(9), total) require.Equal(t, int64(9), total)
require.Equal(t, []ProxyWithAccountCount{{Proxy: Proxy{ID: 3, Name: "p2"}, AccountCount: 5}}, proxies) require.Equal(t, []ProxyWithAccountCount{{Proxy: Proxy{ID: 3, Name: "p2"}, AccountCount: 5}}, proxies)
require.Equal(t, 1, repo.listWithFiltersAndAccountCountCalls) require.Equal(t, 1, repo.listWithFiltersAndAccountCountCalls)
require.Equal(t, pagination.PaginationParams{Page: 2, PageSize: 10}, repo.listWithFiltersAndAccountCountParams) require.Equal(t, pagination.PaginationParams{Page: 2, PageSize: 10, SortBy: "account_count", SortOrder: "DESC"}, repo.listWithFiltersAndAccountCountParams)
require.Equal(t, "socks5", repo.listWithFiltersAndAccountCountProtocol) require.Equal(t, "socks5", repo.listWithFiltersAndAccountCountProtocol)
require.Equal(t, StatusDisabled, repo.listWithFiltersAndAccountCountStatus) require.Equal(t, StatusDisabled, repo.listWithFiltersAndAccountCountStatus)
require.Equal(t, "p2", repo.listWithFiltersAndAccountCountSearch) require.Equal(t, "p2", repo.listWithFiltersAndAccountCountSearch)
@@ -250,13 +250,13 @@ func TestAdminService_ListRedeemCodes_WithSearch(t *testing.T) {
} }
svc := &adminServiceImpl{redeemCodeRepo: repo} svc := &adminServiceImpl{redeemCodeRepo: repo}
codes, total, err := svc.ListRedeemCodes(context.Background(), 1, 20, RedeemTypeBalance, StatusUnused, "ABC") codes, total, err := svc.ListRedeemCodes(context.Background(), 1, 20, RedeemTypeBalance, StatusUnused, "ABC", "value", "ASC")
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, int64(3), total) require.Equal(t, int64(3), total)
require.Equal(t, []RedeemCode{{ID: 4, Code: "ABC"}}, codes) require.Equal(t, []RedeemCode{{ID: 4, Code: "ABC"}}, codes)
require.Equal(t, 1, repo.listWithFiltersCalls) require.Equal(t, 1, repo.listWithFiltersCalls)
require.Equal(t, pagination.PaginationParams{Page: 1, PageSize: 20}, repo.listWithFiltersParams) require.Equal(t, pagination.PaginationParams{Page: 1, PageSize: 20, SortBy: "value", SortOrder: "ASC"}, repo.listWithFiltersParams)
require.Equal(t, RedeemTypeBalance, repo.listWithFiltersType) require.Equal(t, RedeemTypeBalance, repo.listWithFiltersType)
require.Equal(t, StatusUnused, repo.listWithFiltersStatus) require.Equal(t, StatusUnused, repo.listWithFiltersStatus)
require.Equal(t, "ABC", repo.listWithFiltersSearch) require.Equal(t, "ABC", repo.listWithFiltersSearch)

View File

@@ -4,6 +4,7 @@ import "time"
// APIKeyAuthSnapshot API Key 认证缓存快照(仅包含认证所需字段) // APIKeyAuthSnapshot API Key 认证缓存快照(仅包含认证所需字段)
type APIKeyAuthSnapshot struct { type APIKeyAuthSnapshot struct {
Version int `json:"version"`
APIKeyID int64 `json:"api_key_id"` APIKeyID int64 `json:"api_key_id"`
UserID int64 `json:"user_id"` UserID int64 `json:"user_id"`
GroupID *int64 `json:"group_id,omitempty"` GroupID *int64 `json:"group_id,omitempty"`
@@ -63,8 +64,9 @@ type APIKeyAuthGroupSnapshot struct {
SupportedModelScopes []string `json:"supported_model_scopes,omitempty"` SupportedModelScopes []string `json:"supported_model_scopes,omitempty"`
// OpenAI Messages 调度配置(仅 openai 平台使用) // OpenAI Messages 调度配置(仅 openai 平台使用)
AllowMessagesDispatch bool `json:"allow_messages_dispatch"` AllowMessagesDispatch bool `json:"allow_messages_dispatch"`
DefaultMappedModel string `json:"default_mapped_model,omitempty"` DefaultMappedModel string `json:"default_mapped_model,omitempty"`
MessagesDispatchModelConfig OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config,omitempty"`
} }
// APIKeyAuthCacheEntry 缓存条目,支持负缓存 // APIKeyAuthCacheEntry 缓存条目,支持负缓存

View File

@@ -13,6 +13,8 @@ import (
"github.com/dgraph-io/ristretto" "github.com/dgraph-io/ristretto"
) )
const apiKeyAuthSnapshotVersion = 3
type apiKeyAuthCacheConfig struct { type apiKeyAuthCacheConfig struct {
l1Size int l1Size int
l1TTL time.Duration l1TTL time.Duration
@@ -192,6 +194,9 @@ func (s *APIKeyService) applyAuthCacheEntry(key string, entry *APIKeyAuthCacheEn
if entry.Snapshot == nil { if entry.Snapshot == nil {
return nil, false, nil return nil, false, nil
} }
if entry.Snapshot.Version != apiKeyAuthSnapshotVersion {
return nil, false, nil
}
return s.snapshotToAPIKey(key, entry.Snapshot), true, nil return s.snapshotToAPIKey(key, entry.Snapshot), true, nil
} }
@@ -200,6 +205,7 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
return nil return nil
} }
snapshot := &APIKeyAuthSnapshot{ snapshot := &APIKeyAuthSnapshot{
Version: apiKeyAuthSnapshotVersion,
APIKeyID: apiKey.ID, APIKeyID: apiKey.ID,
UserID: apiKey.UserID, UserID: apiKey.UserID,
GroupID: apiKey.GroupID, GroupID: apiKey.GroupID,
@@ -243,6 +249,7 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
SupportedModelScopes: apiKey.Group.SupportedModelScopes, SupportedModelScopes: apiKey.Group.SupportedModelScopes,
AllowMessagesDispatch: apiKey.Group.AllowMessagesDispatch, AllowMessagesDispatch: apiKey.Group.AllowMessagesDispatch,
DefaultMappedModel: apiKey.Group.DefaultMappedModel, DefaultMappedModel: apiKey.Group.DefaultMappedModel,
MessagesDispatchModelConfig: apiKey.Group.MessagesDispatchModelConfig,
} }
} }
return snapshot return snapshot
@@ -298,6 +305,7 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
SupportedModelScopes: snapshot.Group.SupportedModelScopes, SupportedModelScopes: snapshot.Group.SupportedModelScopes,
AllowMessagesDispatch: snapshot.Group.AllowMessagesDispatch, AllowMessagesDispatch: snapshot.Group.AllowMessagesDispatch,
DefaultMappedModel: snapshot.Group.DefaultMappedModel, DefaultMappedModel: snapshot.Group.DefaultMappedModel,
MessagesDispatchModelConfig: snapshot.Group.MessagesDispatchModelConfig,
} }
} }
s.compileAPIKeyIPRules(apiKey) s.compileAPIKeyIPRules(apiKey)

View File

@@ -188,6 +188,7 @@ func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) {
groupID := int64(9) groupID := int64(9)
cacheEntry := &APIKeyAuthCacheEntry{ cacheEntry := &APIKeyAuthCacheEntry{
Snapshot: &APIKeyAuthSnapshot{ Snapshot: &APIKeyAuthSnapshot{
Version: apiKeyAuthSnapshotVersion,
APIKeyID: 1, APIKeyID: 1,
UserID: 2, UserID: 2,
GroupID: &groupID, GroupID: &groupID,
@@ -226,6 +227,129 @@ func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) {
require.Equal(t, map[string][]int64{"claude-opus-*": {1, 2}}, apiKey.Group.ModelRouting) require.Equal(t, map[string][]int64{"claude-opus-*": {1, 2}}, apiKey.Group.ModelRouting)
} }
func TestAPIKeyService_SnapshotRoundTrip_PreservesMessagesDispatchModelConfig(t *testing.T) {
svc := NewAPIKeyService(nil, nil, nil, nil, nil, nil, &config.Config{})
groupID := int64(9)
apiKey := &APIKey{
ID: 1,
UserID: 2,
GroupID: &groupID,
Key: "k-roundtrip",
Status: StatusActive,
User: &User{
ID: 2,
Status: StatusActive,
Role: RoleUser,
Balance: 10,
Concurrency: 3,
},
Group: &Group{
ID: groupID,
Name: "openai",
Platform: PlatformOpenAI,
Status: StatusActive,
SubscriptionType: SubscriptionTypeStandard,
RateMultiplier: 1,
AllowMessagesDispatch: true,
DefaultMappedModel: "gpt-5.4",
MessagesDispatchModelConfig: OpenAIMessagesDispatchModelConfig{
OpusMappedModel: "gpt-5.4-nano",
SonnetMappedModel: "gpt-5.3-codex",
HaikuMappedModel: "gpt-5.4-mini",
ExactModelMappings: map[string]string{
"claude-sonnet-4.5": "gpt-5.4-nano",
},
},
},
}
snapshot := svc.snapshotFromAPIKey(apiKey)
roundTrip := svc.snapshotToAPIKey(apiKey.Key, snapshot)
require.NotNil(t, roundTrip)
require.NotNil(t, roundTrip.Group)
require.Equal(t, apiKey.Group.MessagesDispatchModelConfig, roundTrip.Group.MessagesDispatchModelConfig)
}
func TestAPIKeyService_GetByKey_IgnoresLegacyAuthCacheSnapshotWithoutMessagesDispatchConfig(t *testing.T) {
cache := &authCacheStub{}
var repoCalls int32
repo := &authRepoStub{
getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
atomic.AddInt32(&repoCalls, 1)
groupID := int64(9)
return &APIKey{
ID: 1,
UserID: 2,
GroupID: &groupID,
Status: StatusActive,
User: &User{
ID: 2,
Status: StatusActive,
Role: RoleUser,
Balance: 10,
Concurrency: 3,
},
Group: &Group{
ID: groupID,
Name: "openai",
Platform: PlatformOpenAI,
Status: StatusActive,
Hydrated: true,
SubscriptionType: SubscriptionTypeStandard,
RateMultiplier: 1,
AllowMessagesDispatch: true,
DefaultMappedModel: "gpt-5.4",
MessagesDispatchModelConfig: OpenAIMessagesDispatchModelConfig{
OpusMappedModel: "gpt-5.4-nano",
},
},
}, nil
},
}
cfg := &config.Config{
APIKeyAuth: config.APIKeyAuthCacheConfig{
L2TTLSeconds: 60,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
groupID := int64(9)
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
return &APIKeyAuthCacheEntry{
Snapshot: &APIKeyAuthSnapshot{
APIKeyID: 1,
UserID: 2,
GroupID: &groupID,
Status: StatusActive,
User: APIKeyAuthUserSnapshot{
ID: 2,
Status: StatusActive,
Role: RoleUser,
Balance: 10,
Concurrency: 3,
},
Group: &APIKeyAuthGroupSnapshot{
ID: groupID,
Name: "openai",
Platform: PlatformOpenAI,
Status: StatusActive,
SubscriptionType: SubscriptionTypeStandard,
RateMultiplier: 1,
AllowMessagesDispatch: true,
DefaultMappedModel: "gpt-5.4",
},
},
}, nil
}
apiKey, err := svc.GetByKey(context.Background(), "k-legacy")
require.NoError(t, err)
require.Equal(t, int32(1), atomic.LoadInt32(&repoCalls))
require.NotNil(t, apiKey.Group)
require.Equal(t, "gpt-5.4-nano", apiKey.Group.MessagesDispatchModelConfig.OpusMappedModel)
}
func TestAPIKeyService_GetByKey_NegativeCache(t *testing.T) { func TestAPIKeyService_GetByKey_NegativeCache(t *testing.T) {
cache := &authCacheStub{} cache := &authCacheStub{}
repo := &authRepoStub{ repo := &authRepoStub{

View File

@@ -143,6 +143,8 @@ const (
SettingKeyHideCcsImportButton = "hide_ccs_import_button" // 是否隐藏 API Keys 页面的导入 CCS 按钮 SettingKeyHideCcsImportButton = "hide_ccs_import_button" // 是否隐藏 API Keys 页面的导入 CCS 按钮
SettingKeyPurchaseSubscriptionEnabled = "purchase_subscription_enabled" // 是否展示"购买订阅"页面入口 SettingKeyPurchaseSubscriptionEnabled = "purchase_subscription_enabled" // 是否展示"购买订阅"页面入口
SettingKeyPurchaseSubscriptionURL = "purchase_subscription_url" // "购买订阅"页面 URL作为 iframe src SettingKeyPurchaseSubscriptionURL = "purchase_subscription_url" // "购买订阅"页面 URL作为 iframe src
SettingKeyTableDefaultPageSize = "table_default_page_size" // 表格默认每页条数
SettingKeyTablePageSizeOptions = "table_page_size_options" // 表格可选每页条数JSON 数组)
SettingKeyCustomMenuItems = "custom_menu_items" // 自定义菜单项JSON 数组) SettingKeyCustomMenuItems = "custom_menu_items" // 自定义菜单项JSON 数组)
SettingKeyCustomEndpoints = "custom_endpoints" // 自定义端点列表JSON 数组) SettingKeyCustomEndpoints = "custom_endpoints" // 自定义端点列表JSON 数组)

View File

@@ -492,7 +492,7 @@ func TestAdminService_ListAccounts_ExhaustedCodexExtraReturnsRateLimitedAccount(
} }
svc := &adminServiceImpl{accountRepo: repo} svc := &adminServiceImpl{accountRepo: repo}
accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformOpenAI, AccountTypeOAuth, "", "", 0, "") accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformOpenAI, AccountTypeOAuth, "", "", 0, "", "", "")
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, int64(1), total) require.Equal(t, int64(1), total)
require.Len(t, accounts, 1) require.Len(t, accounts, 1)

View File

@@ -9,6 +9,7 @@ import (
"fmt" "fmt"
"log/slog" "log/slog"
"net/url" "net/url"
"sort"
"strconv" "strconv"
"strings" "strings"
"sync/atomic" "sync/atomic"
@@ -161,6 +162,8 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SettingKeyHideCcsImportButton, SettingKeyHideCcsImportButton,
SettingKeyPurchaseSubscriptionEnabled, SettingKeyPurchaseSubscriptionEnabled,
SettingKeyPurchaseSubscriptionURL, SettingKeyPurchaseSubscriptionURL,
SettingKeyTableDefaultPageSize,
SettingKeyTablePageSizeOptions,
SettingKeyCustomMenuItems, SettingKeyCustomMenuItems,
SettingKeyCustomEndpoints, SettingKeyCustomEndpoints,
SettingKeyLinuxDoConnectEnabled, SettingKeyLinuxDoConnectEnabled,
@@ -200,6 +203,10 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
registrationEmailSuffixWhitelist := ParseRegistrationEmailSuffixWhitelist( registrationEmailSuffixWhitelist := ParseRegistrationEmailSuffixWhitelist(
settings[SettingKeyRegistrationEmailSuffixWhitelist], settings[SettingKeyRegistrationEmailSuffixWhitelist],
) )
tableDefaultPageSize, tablePageSizeOptions := parseTablePreferences(
settings[SettingKeyTableDefaultPageSize],
settings[SettingKeyTablePageSizeOptions],
)
return &PublicSettings{ return &PublicSettings{
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true", RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
@@ -221,6 +228,8 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true", HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true",
PurchaseSubscriptionEnabled: settings[SettingKeyPurchaseSubscriptionEnabled] == "true", PurchaseSubscriptionEnabled: settings[SettingKeyPurchaseSubscriptionEnabled] == "true",
PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]), PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]),
TableDefaultPageSize: tableDefaultPageSize,
TablePageSizeOptions: tablePageSizeOptions,
CustomMenuItems: settings[SettingKeyCustomMenuItems], CustomMenuItems: settings[SettingKeyCustomMenuItems],
CustomEndpoints: settings[SettingKeyCustomEndpoints], CustomEndpoints: settings[SettingKeyCustomEndpoints],
LinuxDoOAuthEnabled: linuxDoEnabled, LinuxDoOAuthEnabled: linuxDoEnabled,
@@ -270,6 +279,8 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
HideCcsImportButton bool `json:"hide_ccs_import_button"` HideCcsImportButton bool `json:"hide_ccs_import_button"`
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"` PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
PurchaseSubscriptionURL string `json:"purchase_subscription_url,omitempty"` PurchaseSubscriptionURL string `json:"purchase_subscription_url,omitempty"`
TableDefaultPageSize int `json:"table_default_page_size"`
TablePageSizeOptions []int `json:"table_page_size_options"`
CustomMenuItems json.RawMessage `json:"custom_menu_items"` CustomMenuItems json.RawMessage `json:"custom_menu_items"`
CustomEndpoints json.RawMessage `json:"custom_endpoints"` CustomEndpoints json.RawMessage `json:"custom_endpoints"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
@@ -297,6 +308,8 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
HideCcsImportButton: settings.HideCcsImportButton, HideCcsImportButton: settings.HideCcsImportButton,
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled, PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL, PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
TableDefaultPageSize: settings.TableDefaultPageSize,
TablePageSizeOptions: settings.TablePageSizeOptions,
CustomMenuItems: filterUserVisibleMenuItems(settings.CustomMenuItems), CustomMenuItems: filterUserVisibleMenuItems(settings.CustomMenuItems),
CustomEndpoints: safeRawJSONArray(settings.CustomEndpoints), CustomEndpoints: safeRawJSONArray(settings.CustomEndpoints),
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
@@ -522,6 +535,16 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeyHideCcsImportButton] = strconv.FormatBool(settings.HideCcsImportButton) updates[SettingKeyHideCcsImportButton] = strconv.FormatBool(settings.HideCcsImportButton)
updates[SettingKeyPurchaseSubscriptionEnabled] = strconv.FormatBool(settings.PurchaseSubscriptionEnabled) updates[SettingKeyPurchaseSubscriptionEnabled] = strconv.FormatBool(settings.PurchaseSubscriptionEnabled)
updates[SettingKeyPurchaseSubscriptionURL] = strings.TrimSpace(settings.PurchaseSubscriptionURL) updates[SettingKeyPurchaseSubscriptionURL] = strings.TrimSpace(settings.PurchaseSubscriptionURL)
tableDefaultPageSize, tablePageSizeOptions := normalizeTablePreferences(
settings.TableDefaultPageSize,
settings.TablePageSizeOptions,
)
updates[SettingKeyTableDefaultPageSize] = strconv.Itoa(tableDefaultPageSize)
tablePageSizeOptionsJSON, err := json.Marshal(tablePageSizeOptions)
if err != nil {
return fmt.Errorf("marshal table page size options: %w", err)
}
updates[SettingKeyTablePageSizeOptions] = string(tablePageSizeOptionsJSON)
updates[SettingKeyCustomMenuItems] = settings.CustomMenuItems updates[SettingKeyCustomMenuItems] = settings.CustomMenuItems
updates[SettingKeyCustomEndpoints] = settings.CustomEndpoints updates[SettingKeyCustomEndpoints] = settings.CustomEndpoints
@@ -875,6 +898,8 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
SettingKeySiteLogo: "", SettingKeySiteLogo: "",
SettingKeyPurchaseSubscriptionEnabled: "false", SettingKeyPurchaseSubscriptionEnabled: "false",
SettingKeyPurchaseSubscriptionURL: "", SettingKeyPurchaseSubscriptionURL: "",
SettingKeyTableDefaultPageSize: "20",
SettingKeyTablePageSizeOptions: "[10,20,50,100]",
SettingKeyCustomMenuItems: "[]", SettingKeyCustomMenuItems: "[]",
SettingKeyCustomEndpoints: "[]", SettingKeyCustomEndpoints: "[]",
SettingKeyOIDCConnectEnabled: "false", SettingKeyOIDCConnectEnabled: "false",
@@ -946,6 +971,10 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
CustomEndpoints: settings[SettingKeyCustomEndpoints], CustomEndpoints: settings[SettingKeyCustomEndpoints],
BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true", BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true",
} }
result.TableDefaultPageSize, result.TablePageSizeOptions = parseTablePreferences(
settings[SettingKeyTableDefaultPageSize],
settings[SettingKeyTablePageSizeOptions],
)
// 解析整数类型 // 解析整数类型
if port, err := strconv.Atoi(settings[SettingKeySMTPPort]); err == nil { if port, err := strconv.Atoi(settings[SettingKeySMTPPort]); err == nil {
@@ -1221,6 +1250,50 @@ func parseDefaultSubscriptions(raw string) []DefaultSubscriptionSetting {
return normalized return normalized
} }
func parseTablePreferences(defaultPageSizeRaw, optionsRaw string) (int, []int) {
defaultPageSize := 20
if v, err := strconv.Atoi(strings.TrimSpace(defaultPageSizeRaw)); err == nil {
defaultPageSize = v
}
var options []int
if strings.TrimSpace(optionsRaw) != "" {
_ = json.Unmarshal([]byte(optionsRaw), &options)
}
return normalizeTablePreferences(defaultPageSize, options)
}
func normalizeTablePreferences(defaultPageSize int, options []int) (int, []int) {
const minPageSize = 5
const maxPageSize = 1000
const fallbackPageSize = 20
seen := make(map[int]struct{}, len(options))
normalizedOptions := make([]int, 0, len(options))
for _, option := range options {
if option < minPageSize || option > maxPageSize {
continue
}
if _, ok := seen[option]; ok {
continue
}
seen[option] = struct{}{}
normalizedOptions = append(normalizedOptions, option)
}
sort.Ints(normalizedOptions)
if defaultPageSize < minPageSize || defaultPageSize > maxPageSize {
defaultPageSize = fallbackPageSize
}
if len(normalizedOptions) == 0 {
normalizedOptions = []int{10, 20, 50}
}
return defaultPageSize, normalizedOptions
}
// getStringOrDefault 获取字符串值或默认值 // getStringOrDefault 获取字符串值或默认值
func (s *SettingService) getStringOrDefault(settings map[string]string, key, defaultValue string) string { func (s *SettingService) getStringOrDefault(settings map[string]string, key, defaultValue string) string {
if value, ok := settings[key]; ok && value != "" { if value, ok := settings[key]; ok && value != "" {

View File

@@ -62,3 +62,18 @@ func TestSettingService_GetPublicSettings_ExposesRegistrationEmailSuffixWhitelis
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, []string{"@example.com", "@foo.bar"}, settings.RegistrationEmailSuffixWhitelist) require.Equal(t, []string{"@example.com", "@foo.bar"}, settings.RegistrationEmailSuffixWhitelist)
} }
func TestSettingService_GetPublicSettings_ExposesTablePreferences(t *testing.T) {
repo := &settingPublicRepoStub{
values: map[string]string{
SettingKeyTableDefaultPageSize: "50",
SettingKeyTablePageSizeOptions: "[20,50,100]",
},
}
svc := NewSettingService(repo, &config.Config{})
settings, err := svc.GetPublicSettings(context.Background())
require.NoError(t, err)
require.Equal(t, 50, settings.TableDefaultPageSize)
require.Equal(t, []int{20, 50, 100}, settings.TablePageSizeOptions)
}

View File

@@ -202,3 +202,24 @@ func TestParseDefaultSubscriptions_NormalizesValues(t *testing.T) {
{GroupID: 12, ValidityDays: MaxValidityDays}, {GroupID: 12, ValidityDays: MaxValidityDays},
}, got) }, got)
} }
func TestSettingService_UpdateSettings_TablePreferences(t *testing.T) {
repo := &settingUpdateRepoStub{}
svc := NewSettingService(repo, &config.Config{})
err := svc.UpdateSettings(context.Background(), &SystemSettings{
TableDefaultPageSize: 50,
TablePageSizeOptions: []int{20, 50, 100},
})
require.NoError(t, err)
require.Equal(t, "50", repo.updates[SettingKeyTableDefaultPageSize])
require.Equal(t, "[20,50,100]", repo.updates[SettingKeyTablePageSizeOptions])
err = svc.UpdateSettings(context.Background(), &SystemSettings{
TableDefaultPageSize: 1000,
TablePageSizeOptions: []int{20, 100},
})
require.NoError(t, err)
require.Equal(t, "1000", repo.updates[SettingKeyTableDefaultPageSize])
require.Equal(t, "[20,100]", repo.updates[SettingKeyTablePageSizeOptions])
}

View File

@@ -66,6 +66,8 @@ type SystemSettings struct {
HideCcsImportButton bool HideCcsImportButton bool
PurchaseSubscriptionEnabled bool PurchaseSubscriptionEnabled bool
PurchaseSubscriptionURL string PurchaseSubscriptionURL string
TableDefaultPageSize int
TablePageSizeOptions []int
CustomMenuItems string // JSON array of custom menu items CustomMenuItems string // JSON array of custom menu items
CustomEndpoints string // JSON array of custom endpoints CustomEndpoints string // JSON array of custom endpoints
@@ -132,6 +134,8 @@ type PublicSettings struct {
PurchaseSubscriptionEnabled bool PurchaseSubscriptionEnabled bool
PurchaseSubscriptionURL string PurchaseSubscriptionURL string
TableDefaultPageSize int
TablePageSizeOptions []int
CustomMenuItems string // JSON array of custom menu items CustomMenuItems string // JSON array of custom menu items
CustomEndpoints string // JSON array of custom endpoints CustomEndpoints string // JSON array of custom endpoints

View File

@@ -18,7 +18,7 @@
"@lobehub/icons": "^4.0.2", "@lobehub/icons": "^4.0.2",
"@tanstack/vue-virtual": "^3.13.23", "@tanstack/vue-virtual": "^3.13.23",
"@vueuse/core": "^10.7.0", "@vueuse/core": "^10.7.0",
"axios": "^1.13.5", "axios": "^1.15.0",
"chart.js": "^4.4.1", "chart.js": "^4.4.1",
"dompurify": "^3.3.1", "dompurify": "^3.3.1",
"driver.js": "^1.4.0", "driver.js": "^1.4.0",

1015
frontend/pnpm-lock.yaml generated

File diff suppressed because it is too large Load Diff

View File

@@ -38,6 +38,8 @@ export async function list(
search?: string search?: string
privacy_mode?: string privacy_mode?: string
lite?: string lite?: string
sort_by?: string
sort_order?: 'asc' | 'desc'
}, },
options?: { options?: {
signal?: AbortSignal signal?: AbortSignal
@@ -71,6 +73,8 @@ export async function listWithEtag(
search?: string search?: string
privacy_mode?: string privacy_mode?: string
lite?: string lite?: string
sort_by?: string
sort_order?: 'asc' | 'desc'
}, },
options?: { options?: {
signal?: AbortSignal signal?: AbortSignal
@@ -500,7 +504,11 @@ export async function exportData(options?: {
platform?: string platform?: string
type?: string type?: string
status?: string status?: string
group?: string
privacy_mode?: string
search?: string search?: string
sort_by?: string
sort_order?: 'asc' | 'desc'
} }
includeProxies?: boolean includeProxies?: boolean
}): Promise<AdminDataPayload> { }): Promise<AdminDataPayload> {
@@ -508,11 +516,15 @@ export async function exportData(options?: {
if (options?.ids && options.ids.length > 0) { if (options?.ids && options.ids.length > 0) {
params.ids = options.ids.join(',') params.ids = options.ids.join(',')
} else if (options?.filters) { } else if (options?.filters) {
const { platform, type, status, search } = options.filters const { platform, type, status, group, privacy_mode, search, sort_by, sort_order } = options.filters
if (platform) params.platform = platform if (platform) params.platform = platform
if (type) params.type = type if (type) params.type = type
if (status) params.status = status if (status) params.status = status
if (group) params.group = group
if (privacy_mode) params.privacy_mode = privacy_mode
if (search) params.search = search if (search) params.search = search
if (sort_by) params.sort_by = sort_by
if (sort_order) params.sort_order = sort_order
} }
if (options?.includeProxies === false) { if (options?.includeProxies === false) {
params.include_proxies = 'false' params.include_proxies = 'false'

View File

@@ -17,10 +17,16 @@ export async function list(
filters?: { filters?: {
status?: string status?: string
search?: string search?: string
sort_by?: string
sort_order?: 'asc' | 'desc'
},
options?: {
signal?: AbortSignal
} }
): Promise<BasePaginationResponse<Announcement>> { ): Promise<BasePaginationResponse<Announcement>> {
const { data } = await apiClient.get<BasePaginationResponse<Announcement>>('/admin/announcements', { const { data } = await apiClient.get<BasePaginationResponse<Announcement>>('/admin/announcements', {
params: { page, page_size: pageSize, ...filters } params: { page, page_size: pageSize, ...filters },
signal: options?.signal
}) })
return data return data
} }
@@ -49,11 +55,21 @@ export async function getReadStatus(
id: number, id: number,
page: number = 1, page: number = 1,
pageSize: number = 20, pageSize: number = 20,
search: string = '' filters?: {
search?: string
sort_by?: string
sort_order?: 'asc' | 'desc'
},
options?: {
signal?: AbortSignal
}
): Promise<BasePaginationResponse<AnnouncementUserReadStatus>> { ): Promise<BasePaginationResponse<AnnouncementUserReadStatus>> {
const { data } = await apiClient.get<BasePaginationResponse<AnnouncementUserReadStatus>>( const { data } = await apiClient.get<BasePaginationResponse<AnnouncementUserReadStatus>>(
`/admin/announcements/${id}/read-status`, `/admin/announcements/${id}/read-status`,
{ params: { page, page_size: pageSize, search } } {
params: { page, page_size: pageSize, ...filters },
signal: options?.signal
}
) )
return data return data
} }
@@ -68,4 +84,3 @@ const announcementsAPI = {
} }
export default announcementsAPI export default announcementsAPI

View File

@@ -83,6 +83,8 @@ export async function list(
filters?: { filters?: {
status?: string status?: string
search?: string search?: string
sort_by?: string
sort_order?: 'asc' | 'desc'
}, },
options?: { signal?: AbortSignal } options?: { signal?: AbortSignal }
): Promise<PaginatedResponse<Channel>> { ): Promise<PaginatedResponse<Channel>> {

View File

@@ -27,6 +27,8 @@ export async function list(
status?: 'active' | 'inactive' status?: 'active' | 'inactive'
is_exclusive?: boolean is_exclusive?: boolean
search?: string search?: string
sort_by?: string
sort_order?: 'asc' | 'desc'
}, },
options?: { options?: {
signal?: AbortSignal signal?: AbortSignal

View File

@@ -17,10 +17,16 @@ export async function list(
filters?: { filters?: {
status?: string status?: string
search?: string search?: string
sort_by?: string
sort_order?: 'asc' | 'desc'
},
options?: {
signal?: AbortSignal
} }
): Promise<BasePaginationResponse<PromoCode>> { ): Promise<BasePaginationResponse<PromoCode>> {
const { data } = await apiClient.get<BasePaginationResponse<PromoCode>>('/admin/promo-codes', { const { data } = await apiClient.get<BasePaginationResponse<PromoCode>>('/admin/promo-codes', {
params: { page, page_size: pageSize, ...filters } params: { page, page_size: pageSize, ...filters },
signal: options?.signal
}) })
return data return data
} }

View File

@@ -29,6 +29,8 @@ export async function list(
protocol?: string protocol?: string
status?: 'active' | 'inactive' status?: 'active' | 'inactive'
search?: string search?: string
sort_by?: string
sort_order?: 'asc' | 'desc'
}, },
options?: { options?: {
signal?: AbortSignal signal?: AbortSignal
@@ -227,16 +229,20 @@ export async function exportData(options?: {
protocol?: string protocol?: string
status?: 'active' | 'inactive' status?: 'active' | 'inactive'
search?: string search?: string
sort_by?: string
sort_order?: 'asc' | 'desc'
} }
}): Promise<AdminDataPayload> { }): Promise<AdminDataPayload> {
const params: Record<string, string> = {} const params: Record<string, string> = {}
if (options?.ids && options.ids.length > 0) { if (options?.ids && options.ids.length > 0) {
params.ids = options.ids.join(',') params.ids = options.ids.join(',')
} else if (options?.filters) { } else if (options?.filters) {
const { protocol, status, search } = options.filters const { protocol, status, search, sort_by, sort_order } = options.filters
if (protocol) params.protocol = protocol if (protocol) params.protocol = protocol
if (status) params.status = status if (status) params.status = status
if (search) params.search = search if (search) params.search = search
if (sort_by) params.sort_by = sort_by
if (sort_order) params.sort_order = sort_order
} }
const { data } = await apiClient.get<AdminDataPayload>('/admin/proxies/data', { params }) const { data } = await apiClient.get<AdminDataPayload>('/admin/proxies/data', { params })
return data return data

View File

@@ -25,6 +25,8 @@ export async function list(
type?: RedeemCodeType type?: RedeemCodeType
status?: 'active' | 'used' | 'expired' | 'unused' status?: 'active' | 'used' | 'expired' | 'unused'
search?: string search?: string
sort_by?: string
sort_order?: 'asc' | 'desc'
}, },
options?: { options?: {
signal?: AbortSignal signal?: AbortSignal
@@ -151,7 +153,10 @@ export async function getStats(): Promise<{
*/ */
export async function exportCodes(filters?: { export async function exportCodes(filters?: {
type?: RedeemCodeType type?: RedeemCodeType
status?: 'active' | 'used' | 'expired' status?: 'used' | 'expired' | 'unused'
search?: string
sort_by?: string
sort_order?: 'asc' | 'desc'
}): Promise<Blob> { }): Promise<Blob> {
const response = await apiClient.get('/admin/redeem-codes/export', { const response = await apiClient.get('/admin/redeem-codes/export', {
params: filters, params: filters,

View File

@@ -40,6 +40,8 @@ export interface SystemSettings {
hide_ccs_import_button: boolean hide_ccs_import_button: boolean
purchase_subscription_enabled: boolean purchase_subscription_enabled: boolean
purchase_subscription_url: string purchase_subscription_url: string
table_default_page_size: number
table_page_size_options: number[]
backend_mode_enabled: boolean backend_mode_enabled: boolean
custom_menu_items: CustomMenuItem[] custom_menu_items: CustomMenuItem[]
custom_endpoints: CustomEndpoint[] custom_endpoints: CustomEndpoint[]
@@ -138,6 +140,8 @@ export interface UpdateSettingsRequest {
hide_ccs_import_button?: boolean hide_ccs_import_button?: boolean
purchase_subscription_enabled?: boolean purchase_subscription_enabled?: boolean
purchase_subscription_url?: string purchase_subscription_url?: string
table_default_page_size?: number
table_page_size_options?: number[]
backend_mode_enabled?: boolean backend_mode_enabled?: boolean
custom_menu_items?: CustomMenuItem[] custom_menu_items?: CustomMenuItem[]
custom_endpoints?: CustomEndpoint[] custom_endpoints?: CustomEndpoint[]

View File

@@ -81,6 +81,8 @@ export interface AdminUsageQueryParams extends UsageQueryParams {
user_id?: number user_id?: number
exact_total?: boolean exact_total?: boolean
billing_mode?: string billing_mode?: string
sort_by?: string
sort_order?: 'asc' | 'desc'
} }
// ==================== API Functions ==================== // ==================== API Functions ====================

View File

@@ -24,6 +24,8 @@ export async function list(
group_name?: string // fuzzy filter by allowed group name group_name?: string // fuzzy filter by allowed group name
attributes?: Record<number, string> // attributeId -> value attributes?: Record<number, string> // attributeId -> value
include_subscriptions?: boolean include_subscriptions?: boolean
sort_by?: string
sort_order?: 'asc' | 'desc'
}, },
options?: { options?: {
signal?: AbortSignal signal?: AbortSignal
@@ -37,7 +39,9 @@ export async function list(
role: filters?.role, role: filters?.role,
search: filters?.search, search: filters?.search,
group_name: filters?.group_name, group_name: filters?.group_name,
include_subscriptions: filters?.include_subscriptions include_subscriptions: filters?.include_subscriptions,
sort_by: filters?.sort_by,
sort_order: filters?.sort_order
} }
// Add attribute filters as attr[id]=value // Add attribute filters as attr[id]=value

View File

@@ -17,7 +17,13 @@ import type { ApiKey, CreateApiKeyRequest, UpdateApiKeyRequest, PaginatedRespons
export async function list( export async function list(
page: number = 1, page: number = 1,
pageSize: number = 10, pageSize: number = 10,
filters?: { search?: string; status?: string; group_id?: number | string }, filters?: {
search?: string
status?: string
group_id?: number | string
sort_by?: string
sort_order?: 'asc' | 'desc'
},
options?: { options?: {
signal?: AbortSignal signal?: AbortSignal
} }

View File

@@ -91,7 +91,7 @@ export async function list(
* @returns Paginated list of usage logs * @returns Paginated list of usage logs
*/ */
export async function query( export async function query(
params: UsageQueryParams, params: UsageQueryParams & { sort_by?: string; sort_order?: 'asc' | 'desc' },
config: { signal?: AbortSignal } = {} config: { signal?: AbortSignal } = {}
): Promise<PaginatedResponse<UsageLog>> { ): Promise<PaginatedResponse<UsageLog>> {
const { data } = await apiClient.get<PaginatedResponse<UsageLog>>('/usage', { const { data } = await apiClient.get<PaginatedResponse<UsageLog>>('/usage', {

View File

@@ -27,7 +27,7 @@ const updatePrivacyMode = (value: string | number | boolean | null) => { emit('u
const updateGroup = (value: string | number | boolean | null) => { emit('update:filters', { ...props.filters, group: value }) } const updateGroup = (value: string | number | boolean | null) => { emit('update:filters', { ...props.filters, group: value }) }
const pOpts = computed(() => [{ value: '', label: t('admin.accounts.allPlatforms') }, { value: 'anthropic', label: 'Anthropic' }, { value: 'openai', label: 'OpenAI' }, { value: 'gemini', label: 'Gemini' }, { value: 'antigravity', label: 'Antigravity' }]) const pOpts = computed(() => [{ value: '', label: t('admin.accounts.allPlatforms') }, { value: 'anthropic', label: 'Anthropic' }, { value: 'openai', label: 'OpenAI' }, { value: 'gemini', label: 'Gemini' }, { value: 'antigravity', label: 'Antigravity' }])
const tOpts = computed(() => [{ value: '', label: t('admin.accounts.allTypes') }, { value: 'oauth', label: t('admin.accounts.oauthType') }, { value: 'setup-token', label: t('admin.accounts.setupToken') }, { value: 'apikey', label: t('admin.accounts.apiKey') }, { value: 'bedrock', label: 'AWS Bedrock' }]) const tOpts = computed(() => [{ value: '', label: t('admin.accounts.allTypes') }, { value: 'oauth', label: t('admin.accounts.oauthType') }, { value: 'setup-token', label: t('admin.accounts.setupToken') }, { value: 'apikey', label: t('admin.accounts.apiKey') }, { value: 'bedrock', label: 'AWS Bedrock' }])
const sOpts = computed(() => [{ value: '', label: t('admin.accounts.allStatus') }, { value: 'active', label: t('admin.accounts.status.active') }, { value: 'inactive', label: t('admin.accounts.status.inactive') }, { value: 'error', label: t('admin.accounts.status.error') }, { value: 'rate_limited', label: t('admin.accounts.status.rateLimited') }, { value: 'temp_unschedulable', label: t('admin.accounts.status.tempUnschedulable') }]) const sOpts = computed(() => [{ value: '', label: t('admin.accounts.allStatus') }, { value: 'active', label: t('admin.accounts.status.active') }, { value: 'inactive', label: t('admin.accounts.status.inactive') }, { value: 'error', label: t('admin.accounts.status.error') }, { value: 'rate_limited', label: t('admin.accounts.status.rateLimited') }, { value: 'temp_unschedulable', label: t('admin.accounts.status.tempUnschedulable') }, { value: 'unschedulable', label: t('admin.accounts.status.unschedulable') }])
const privacyOpts = computed(() => [ const privacyOpts = computed(() => [
{ value: '', label: t('admin.accounts.allPrivacyModes') }, { value: '', label: t('admin.accounts.allPrivacyModes') },
{ value: '__unset__', label: t('admin.accounts.privacyUnset') }, { value: '__unset__', label: t('admin.accounts.privacyUnset') },

View File

@@ -1,56 +0,0 @@
import { describe, expect, it, vi } from 'vitest'
import { mount } from '@vue/test-utils'
import AccountTableFilters from '../AccountTableFilters.vue'
vi.mock('vue-i18n', async () => {
const actual = await vi.importActual<typeof import('vue-i18n')>('vue-i18n')
return {
...actual,
useI18n: () => ({
t: (key: string) => key
})
}
})
describe('AccountTableFilters', () => {
it('renders privacy mode options and emits privacy_mode updates', async () => {
const wrapper = mount(AccountTableFilters, {
props: {
searchQuery: '',
filters: {
platform: '',
type: '',
status: '',
group: '',
privacy_mode: ''
},
groups: []
},
global: {
stubs: {
SearchInput: {
template: '<div />'
},
Select: {
props: ['modelValue', 'options'],
emits: ['update:modelValue', 'change'],
template: '<div class="select-stub" :data-options="JSON.stringify(options)" />'
}
}
}
})
const selects = wrapper.findAll('.select-stub')
expect(selects).toHaveLength(5)
const privacyOptions = JSON.parse(selects[3].attributes('data-options'))
expect(privacyOptions).toEqual([
{ value: '', label: 'admin.accounts.allPrivacyModes' },
{ value: '__unset__', label: 'admin.accounts.privacyUnset' },
{ value: 'training_off', label: 'Privacy' },
{ value: 'training_set_cf_blocked', label: 'CF' },
{ value: 'training_set_failed', label: 'Fail' }
])
})
})

View File

@@ -21,7 +21,15 @@
</button> </button>
</div> </div>
<DataTable :columns="columns" :data="items" :loading="loading"> <DataTable
:columns="columns"
:data="items"
:loading="loading"
:server-side-sort="true"
default-sort-key="email"
default-sort-order="asc"
@sort="handleSort"
>
<template #cell-email="{ value }"> <template #cell-email="{ value }">
<span class="font-medium text-gray-900 dark:text-white">{{ value }}</span> <span class="font-medium text-gray-900 dark:text-white">{{ value }}</span>
</template> </template>
@@ -62,7 +70,7 @@
</template> </template>
<script setup lang="ts"> <script setup lang="ts">
import { computed, onMounted, reactive, ref, watch } from 'vue' import { computed, onUnmounted, reactive, ref, watch } from 'vue'
import { useI18n } from 'vue-i18n' import { useI18n } from 'vue-i18n'
import { useAppStore } from '@/stores/app' import { useAppStore } from '@/stores/app'
import { adminAPI } from '@/api/admin' import { adminAPI } from '@/api/admin'
@@ -98,23 +106,54 @@ const pagination = reactive({
pages: 0 pages: 0
}) })
const sortState = reactive({
sort_by: 'email',
sort_order: 'asc' as 'asc' | 'desc'
})
const items = ref<AnnouncementUserReadStatus[]>([]) const items = ref<AnnouncementUserReadStatus[]>([])
const columns = computed<Column[]>(() => [ const columns = computed<Column[]>(() => [
{ key: 'email', label: t('common.email') }, { key: 'email', label: t('common.email'), sortable: true },
{ key: 'username', label: t('admin.users.columns.username') }, { key: 'username', label: t('admin.users.columns.username'), sortable: true },
{ key: 'balance', label: t('common.balance') }, { key: 'balance', label: t('common.balance'), sortable: true },
{ key: 'eligible', label: t('admin.announcements.eligible') }, { key: 'eligible', label: t('admin.announcements.eligible') },
{ key: 'read_at', label: t('admin.announcements.readAt') } { key: 'read_at', label: t('admin.announcements.readAt') }
]) ])
let currentController: AbortController | null = null let currentController: AbortController | null = null
let searchDebounceTimer: number | null = null
function resetDialogState() {
loading.value = false
search.value = ''
items.value = []
pagination.page = 1
pagination.total = 0
pagination.pages = 0
sortState.sort_by = 'email'
sortState.sort_order = 'asc'
}
function cancelPendingLoad(resetState = false) {
if (searchDebounceTimer) {
window.clearTimeout(searchDebounceTimer)
searchDebounceTimer = null
}
currentController?.abort()
currentController = null
if (resetState) {
resetDialogState()
}
}
async function load() { async function load() {
if (!props.show || !props.announcementId) return if (!props.show || !props.announcementId) return
if (currentController) currentController.abort() currentController?.abort()
currentController = new AbortController() const requestController = new AbortController()
currentController = requestController
const { signal } = requestController
try { try {
loading.value = true loading.value = true
@@ -122,20 +161,37 @@ async function load() {
props.announcementId, props.announcementId,
pagination.page, pagination.page,
pagination.page_size, pagination.page_size,
search.value {
search: search.value,
sort_by: sortState.sort_by,
sort_order: sortState.sort_order
},
{ signal }
) )
if (signal.aborted || currentController !== requestController) return
items.value = res.items items.value = res.items
pagination.total = res.total pagination.total = res.total
pagination.pages = res.pages pagination.pages = res.pages
pagination.page = res.page pagination.page = res.page
pagination.page_size = res.page_size pagination.page_size = res.page_size
} catch (error: any) { } catch (error: any) {
if (currentController.signal.aborted || error?.name === 'AbortError') return if (
signal.aborted ||
currentController !== requestController ||
error?.name === 'AbortError' ||
error?.code === 'ERR_CANCELED'
) {
return
}
console.error('Failed to load read status:', error) console.error('Failed to load read status:', error)
appStore.showError(error.response?.data?.detail || t('admin.announcements.failedToLoadReadStatus')) appStore.showError(error.response?.data?.detail || t('admin.announcements.failedToLoadReadStatus'))
} finally { } finally {
loading.value = false if (currentController === requestController) {
loading.value = false
currentController = null
}
} }
} }
@@ -150,7 +206,13 @@ function handlePageSizeChange(pageSize: number) {
load() load()
} }
let searchDebounceTimer: number | null = null function handleSort(key: string, order: 'asc' | 'desc') {
sortState.sort_by = key
sortState.sort_order = order
pagination.page = 1
load()
}
function handleSearch() { function handleSearch() {
if (searchDebounceTimer) window.clearTimeout(searchDebounceTimer) if (searchDebounceTimer) window.clearTimeout(searchDebounceTimer)
searchDebounceTimer = window.setTimeout(() => { searchDebounceTimer = window.setTimeout(() => {
@@ -160,13 +222,17 @@ function handleSearch() {
} }
function handleClose() { function handleClose() {
cancelPendingLoad(true)
emit('close') emit('close')
} }
watch( watch(
() => props.show, () => props.show,
(v) => { (v) => {
if (!v) return if (!v) {
cancelPendingLoad(true)
return
}
pagination.page = 1 pagination.page = 1
load() load()
} }
@@ -181,7 +247,7 @@ watch(
} }
) )
onMounted(() => { onUnmounted(() => {
// noop cancelPendingLoad()
}) })
</script> </script>

View File

@@ -0,0 +1,95 @@
import { describe, it, expect, vi, beforeEach } from 'vitest'
import { flushPromises, mount } from '@vue/test-utils'
import AnnouncementReadStatusDialog from '../AnnouncementReadStatusDialog.vue'
const { getReadStatus, showError } = vi.hoisted(() => ({
getReadStatus: vi.fn(),
showError: vi.fn(),
}))
vi.mock('@/api/admin', () => ({
adminAPI: {
announcements: {
getReadStatus,
},
},
}))
vi.mock('@/stores/app', () => ({
useAppStore: () => ({
showError,
}),
}))
vi.mock('vue-i18n', async () => {
const actual = await vi.importActual<typeof import('vue-i18n')>('vue-i18n')
return {
...actual,
useI18n: () => ({
t: (key: string) => key,
}),
}
})
vi.mock('@/composables/usePersistedPageSize', () => ({
getPersistedPageSize: () => 20,
}))
const BaseDialogStub = {
props: ['show', 'title', 'width'],
emits: ['close'],
template: '<div><slot /><slot name="footer" /></div>',
}
describe('AnnouncementReadStatusDialog', () => {
beforeEach(() => {
getReadStatus.mockReset()
showError.mockReset()
vi.useFakeTimers()
})
it('closes by aborting active requests and clearing debounced reloads', async () => {
let activeSignal: AbortSignal | undefined
getReadStatus.mockImplementation(async (...args: any[]) => {
activeSignal = args[4]?.signal
return new Promise(() => {})
})
const wrapper = mount(AnnouncementReadStatusDialog, {
props: {
show: false,
announcementId: 1,
},
global: {
stubs: {
BaseDialog: BaseDialogStub,
DataTable: true,
Pagination: true,
Icon: true,
},
},
})
await wrapper.setProps({ show: true })
await flushPromises()
expect(getReadStatus).toHaveBeenCalledTimes(1)
expect(activeSignal?.aborted).toBe(false)
const setupState = (wrapper.vm as any).$?.setupState
setupState.search = 'alice'
setupState.handleSearch()
setupState.handleClose()
await flushPromises()
expect(activeSignal?.aborted).toBe(true)
expect(wrapper.emitted('close')).toHaveLength(1)
vi.advanceTimersByTime(350)
await flushPromises()
expect(getReadStatus).toHaveBeenCalledTimes(1)
})
})

View File

@@ -196,7 +196,6 @@
:total="localEntries.length" :total="localEntries.length"
:page="currentPage" :page="currentPage"
:page-size="pageSize" :page-size="pageSize"
:page-size-options="[10, 20, 50]"
@update:page="currentPage = $event" @update:page="currentPage = $event"
@update:pageSize="handlePageSizeChange" @update:pageSize="handlePageSizeChange"
/> />

View File

@@ -1,7 +1,15 @@
<template> <template>
<div class="card overflow-hidden"> <div class="card overflow-hidden">
<div class="overflow-auto"> <div class="overflow-auto">
<DataTable :columns="columns" :data="data" :loading="loading"> <DataTable
:columns="columns"
:data="data"
:loading="loading"
:server-side-sort="serverSideSort"
:default-sort-key="defaultSortKey"
:default-sort-order="defaultSortOrder"
@sort="(key, order) => $emit('sort', key, order)"
>
<template #cell-user="{ row }"> <template #cell-user="{ row }">
<div class="text-sm"> <div class="text-sm">
<button <button
@@ -334,9 +342,27 @@ import DataTable from '@/components/common/DataTable.vue'
import EmptyState from '@/components/common/EmptyState.vue' import EmptyState from '@/components/common/EmptyState.vue'
import Icon from '@/components/icons/Icon.vue' import Icon from '@/components/icons/Icon.vue'
import type { AdminUsageLog } from '@/types' import type { AdminUsageLog } from '@/types'
import type { Column } from '@/components/common/types'
defineProps(['data', 'loading', 'columns']) interface Props {
defineEmits(['userClick']) data: AdminUsageLog[]
loading?: boolean
columns: Column[]
serverSideSort?: boolean
defaultSortKey?: string
defaultSortOrder?: 'asc' | 'desc'
}
withDefaults(defineProps<Props>(), {
loading: false,
serverSideSort: false,
defaultSortKey: '',
defaultSortOrder: 'asc'
})
defineEmits<{
userClick: [userID: number, email?: string]
sort: [key: string, order: 'asc' | 'desc']
}>()
const { t } = useI18n() const { t } = useI18n()
// Tooltip state - cost // Tooltip state - cost

View File

@@ -68,7 +68,7 @@
'is-scrollable': isScrollable 'is-scrollable': isScrollable
}" }"
> >
<table class="min-w-full divide-y divide-gray-200 dark:divide-dark-700"> <table class="w-full min-w-max divide-y divide-gray-200 dark:divide-dark-700">
<thead class="table-header bg-gray-50 dark:bg-dark-800"> <thead class="table-header bg-gray-50 dark:bg-dark-800">
<tr> <tr>
<th <th
@@ -797,3 +797,62 @@ tbody tr:hover .sticky-col {
background: linear-gradient(to left, rgba(0, 0, 0, 0.2), transparent); background: linear-gradient(to left, rgba(0, 0, 0, 0.2), transparent);
} }
</style> </style>
<style>
/* ==========================================================================
终极悬浮滚动条防丢器 (Sledgehammer Override)
绕过 style.css 中 `* { scrollbar-color: transparent }` 的全局悬停隐身诅咒!
========================================================================== */
/* 1. 废除全局针对所有元素的 scrollbar-width 设定,拿回 Chrome/Safari 下 Webkit 滚动条规则的控制权! */
.table-wrapper {
scrollbar-width: auto !important; /* 阻止 Chrome 121 退化到原生 Mac 闪隐滚动条 */
}
/* 2. 重写 Webkit 滚动层,全部加上 !important 强制覆盖透明悬停陷阱 */
.table-wrapper::-webkit-scrollbar {
height: 12px !important;
width: 12px !important;
display: block !important;
background-color: transparent !important;
}
.table-wrapper::-webkit-scrollbar-track {
background-color: rgba(0, 0, 0, 0.03) !important;
border-radius: 6px !important;
margin: 0 4px !important;
}
.dark .table-wrapper::-webkit-scrollbar-track {
background-color: rgba(255, 255, 255, 0.05) !important;
}
/* 常驻、不透明的滑块,无视鼠标是否 hover 都在那! */
.table-wrapper::-webkit-scrollbar-thumb {
background-color: rgba(107, 114, 128, 0.75) !important;
border-radius: 6px !important;
border: 2px solid transparent !important;
background-clip: padding-box !important;
-webkit-appearance: none !important;
}
.table-wrapper::-webkit-scrollbar-thumb:hover {
background-color: rgba(75, 85, 99, 0.9) !important;
}
.dark .table-wrapper::-webkit-scrollbar-thumb {
background-color: rgba(156, 163, 175, 0.75) !important;
}
.dark .table-wrapper::-webkit-scrollbar-thumb:hover {
background-color: rgba(209, 213, 219, 0.9) !important;
}
/* 3. 仅给真正的 Firefox 留的后路 */
@supports (-moz-appearance:none) {
.table-wrapper {
scrollbar-width: thin !important;
scrollbar-color: rgba(156, 163, 175, 0.5) rgba(0, 0, 0, 0.03) !important;
}
.dark .table-wrapper {
scrollbar-color: rgba(75, 85, 99, 0.5) rgba(255, 255, 255, 0.05) !important;
}
}
</style>

View File

@@ -122,7 +122,7 @@ import { computed, ref } from 'vue'
import { useI18n } from 'vue-i18n' import { useI18n } from 'vue-i18n'
import Icon from '@/components/icons/Icon.vue' import Icon from '@/components/icons/Icon.vue'
import Select from './Select.vue' import Select from './Select.vue'
import { setPersistedPageSize } from '@/composables/usePersistedPageSize' import { getConfiguredTablePageSizeOptions, normalizeTablePageSize } from '@/utils/tablePreferences'
const { t } = useI18n() const { t } = useI18n()
@@ -141,7 +141,7 @@ interface Emits {
} }
const props = withDefaults(defineProps<Props>(), { const props = withDefaults(defineProps<Props>(), {
pageSizeOptions: () => [10, 20, 50, 100], pageSizeOptions: () => getConfiguredTablePageSizeOptions(),
showPageSizeSelector: true, showPageSizeSelector: true,
showJump: false showJump: false
}) })
@@ -161,7 +161,14 @@ const toItem = computed(() => {
}) })
const pageSizeSelectOptions = computed(() => { const pageSizeSelectOptions = computed(() => {
return props.pageSizeOptions.map((size) => ({ const options = Array.from(
new Set([
...getConfiguredTablePageSizeOptions(),
normalizeTablePageSize(props.pageSize)
])
).sort((a, b) => a - b)
return options.map((size) => ({
value: size, value: size,
label: String(size) label: String(size)
})) }))
@@ -216,8 +223,7 @@ const goToPage = (newPage: number) => {
const handlePageSizeChange = (value: string | number | boolean | null) => { const handlePageSizeChange = (value: string | number | boolean | null) => {
if (value === null || typeof value === 'boolean') return if (value === null || typeof value === 'boolean') return
const newPageSize = typeof value === 'string' ? parseInt(value) : value const newPageSize = normalizeTablePageSize(typeof value === 'string' ? parseInt(value, 10) : value)
setPersistedPageSize(newPageSize)
emit('update:pageSize', newPageSize) emit('update:pageSize', newPageSize)
} }

View File

@@ -669,11 +669,14 @@ onMounted(() => {
opacity: 0; opacity: 0;
} }
/* Custom SVG icon in sidebar: inherit color, constrain size */ /* Custom SVG icon in sidebar: constrain size without overriding uploaded SVG colors */
.sidebar-svg-icon {
color: currentColor;
}
.sidebar-svg-icon :deep(svg) { .sidebar-svg-icon :deep(svg) {
display: block;
width: 1.25rem; width: 1.25rem;
height: 1.25rem; height: 1.25rem;
stroke: currentColor;
fill: none;
} }
</style> </style>

View File

@@ -0,0 +1,18 @@
import { readFileSync } from 'node:fs'
import { dirname, resolve } from 'node:path'
import { fileURLToPath } from 'node:url'
import { describe, expect, it } from 'vitest'
const componentPath = resolve(dirname(fileURLToPath(import.meta.url)), '../AppSidebar.vue')
const componentSource = readFileSync(componentPath, 'utf8')
describe('AppSidebar custom SVG styles', () => {
it('does not override uploaded SVG fill or stroke colors', () => {
expect(componentSource).toContain('.sidebar-svg-icon {')
expect(componentSource).toContain('color: currentColor;')
expect(componentSource).toContain('display: block;')
expect(componentSource).not.toContain('stroke: currentColor;')
expect(componentSource).not.toContain('fill: none;')
})
})

View File

@@ -0,0 +1,21 @@
import { afterEach, describe, expect, it } from 'vitest'
import { getPersistedPageSize } from '@/composables/usePersistedPageSize'
describe('usePersistedPageSize', () => {
afterEach(() => {
localStorage.clear()
delete window.__APP_CONFIG__
})
it('uses the system table default instead of stale localStorage state', () => {
window.__APP_CONFIG__ = {
table_default_page_size: 1000,
table_page_size_options: [20, 50, 1000]
} as any
localStorage.setItem('table-page-size', '50')
localStorage.setItem('table-page-size-source', 'user')
expect(getPersistedPageSize()).toBe(1000)
})
})

View File

@@ -1,27 +1,9 @@
const STORAGE_KEY = 'table-page-size' import { getConfiguredTableDefaultPageSize, normalizeTablePageSize } from '@/utils/tablePreferences'
const DEFAULT_PAGE_SIZE = 20
/** /**
* 从 localStorage 读取/写入 pageSize * 读取当前系统配置的表格默认每页条数。
* 全局共享一个 key所有表格统一偏好 * 不再使用本地持久化缓存,所有页面统一以通用表格设置为准。
*/ */
export function getPersistedPageSize(fallback = DEFAULT_PAGE_SIZE): number { export function getPersistedPageSize(fallback = getConfiguredTableDefaultPageSize()): number {
try { return normalizeTablePageSize(getConfiguredTableDefaultPageSize() || fallback)
const stored = localStorage.getItem(STORAGE_KEY)
if (stored) {
const parsed = Number(stored)
if (Number.isFinite(parsed) && parsed > 0) return parsed
}
} catch {
// localStorage 不可用(隐私模式等)
}
return fallback
}
export function setPersistedPageSize(size: number): void {
try {
localStorage.setItem(STORAGE_KEY, String(size))
} catch {
// 静默失败
}
} }

View File

@@ -1,7 +1,7 @@
import { ref, reactive, onUnmounted, toRaw } from 'vue' import { ref, reactive, onUnmounted, toRaw } from 'vue'
import { useDebounceFn } from '@vueuse/core' import { useDebounceFn } from '@vueuse/core'
import type { BasePaginationResponse, FetchOptions } from '@/types' import type { BasePaginationResponse, FetchOptions } from '@/types'
import { getPersistedPageSize, setPersistedPageSize } from './usePersistedPageSize' import { getPersistedPageSize } from './usePersistedPageSize'
interface PaginationState { interface PaginationState {
page: number page: number
@@ -88,7 +88,6 @@ export function useTableLoader<T, P extends Record<string, any>>(options: TableL
const handlePageSizeChange = (size: number) => { const handlePageSizeChange = (size: number) => {
pagination.page_size = size pagination.page_size = size
pagination.page = 1 pagination.page = 1
setPersistedPageSize(size)
load() load()
} }

View File

@@ -2051,6 +2051,7 @@ export default {
rateLimited: 'Rate Limited', rateLimited: 'Rate Limited',
overloaded: 'Overloaded', overloaded: 'Overloaded',
tempUnschedulable: 'Temp Unschedulable', tempUnschedulable: 'Temp Unschedulable',
unschedulable: 'Unschedulable',
rateLimitedUntil: 'Rate limited and removed from scheduling. Auto resumes at {time}', rateLimitedUntil: 'Rate limited and removed from scheduling. Auto resumes at {time}',
rateLimitedAutoResume: 'Auto resumes in {time}', rateLimitedAutoResume: 'Auto resumes in {time}',
modelRateLimitedUntil: '{model} rate limited until {time}', modelRateLimitedUntil: '{model} rate limited until {time}',
@@ -4353,6 +4354,15 @@ export default {
apiBaseUrlPlaceholder: 'https://api.example.com', apiBaseUrlPlaceholder: 'https://api.example.com',
apiBaseUrlHint: apiBaseUrlHint:
'Used for "Use Key" and "Import to CC Switch" features. Leave empty to use current site URL.', 'Used for "Use Key" and "Import to CC Switch" features. Leave empty to use current site URL.',
tablePreferencesTitle: 'Global Table Preferences',
tablePreferencesDescription: 'Configure default pagination behavior for shared table components',
tableDefaultPageSize: 'Default Rows Per Page',
tableDefaultPageSizeHint: 'Must be an integer between 5 and 1000',
tablePageSizeOptions: 'Rows Per Page Options',
tablePageSizeOptionsPlaceholder: '10, 20, 50, 100',
tablePageSizeOptionsHint: 'Use commas to separate integers between 5 and 1000; values are deduplicated and sorted on save',
tableDefaultPageSizeRangeError: 'Default rows per page must be between {min} and {max}',
tablePageSizeOptionsFormatError: 'Invalid options format. Enter comma-separated integers between {min} and {max}',
customEndpoints: { customEndpoints: {
title: 'Custom Endpoints', title: 'Custom Endpoints',
description: 'Add additional API endpoint URLs for users to quickly copy on the API Keys page', description: 'Add additional API endpoint URLs for users to quickly copy on the API Keys page',

View File

@@ -2234,6 +2234,7 @@ export default {
rateLimited: '限流中', rateLimited: '限流中',
overloaded: '过载中', overloaded: '过载中',
tempUnschedulable: '临时不可调度', tempUnschedulable: '临时不可调度',
unschedulable: '不可调度',
rateLimitedUntil: '限流中,当前不参与调度,预计 {time} 自动恢复', rateLimitedUntil: '限流中,当前不参与调度,预计 {time} 自动恢复',
rateLimitedAutoResume: '{time} 自动恢复', rateLimitedAutoResume: '{time} 自动恢复',
modelRateLimitedUntil: '{model} 限流至 {time}', modelRateLimitedUntil: '{model} 限流至 {time}',
@@ -4514,6 +4515,15 @@ export default {
apiBaseUrl: 'API 端点地址', apiBaseUrl: 'API 端点地址',
apiBaseUrlHint: '用于"使用密钥"和"导入到 CC Switch"功能,留空则使用当前站点地址', apiBaseUrlHint: '用于"使用密钥"和"导入到 CC Switch"功能,留空则使用当前站点地址',
apiBaseUrlPlaceholder: 'https://api.example.com', apiBaseUrlPlaceholder: 'https://api.example.com',
tablePreferencesTitle: '通用表格设置',
tablePreferencesDescription: '设置后台与用户侧表格组件的默认分页行为',
tableDefaultPageSize: '默认每页条数',
tableDefaultPageSizeHint: '必须为 5-1000 之间的整数',
tablePageSizeOptions: '可选每页条数列表',
tablePageSizeOptionsPlaceholder: '10, 20, 50, 100',
tablePageSizeOptionsHint: '使用英文逗号分隔,取值范围 5-1000保存时会自动去重并排序',
tableDefaultPageSizeRangeError: '默认每页条数必须在 {min}-{max} 之间',
tablePageSizeOptionsFormatError: '可选每页条数格式无效,请输入 {min}-{max} 之间的整数并用英文逗号分隔',
customEndpoints: { customEndpoints: {
title: '自定义端点', title: '自定义端点',
description: '添加额外的 API 端点地址用户可在「API Keys」页面快速复制', description: '添加额外的 API 端点地址用户可在「API Keys」页面快速复制',

View File

@@ -1,6 +1,7 @@
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest' import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'
import { setActivePinia, createPinia } from 'pinia' import { setActivePinia, createPinia } from 'pinia'
import { useAppStore } from '@/stores/app' import { useAppStore } from '@/stores/app'
import { getPublicSettings } from '@/api/auth'
// Mock API 模块 // Mock API 模块
vi.mock('@/api/admin/system', () => ({ vi.mock('@/api/admin/system', () => ({
@@ -15,12 +16,14 @@ describe('useAppStore', () => {
beforeEach(() => { beforeEach(() => {
setActivePinia(createPinia()) setActivePinia(createPinia())
vi.useFakeTimers() vi.useFakeTimers()
localStorage.clear()
// 清除 window.__APP_CONFIG__ // 清除 window.__APP_CONFIG__
delete (window as any).__APP_CONFIG__ delete (window as any).__APP_CONFIG__
}) })
afterEach(() => { afterEach(() => {
vi.useRealTimers() vi.useRealTimers()
localStorage.clear()
}) })
// --- Toast 消息管理 --- // --- Toast 消息管理 ---
@@ -291,5 +294,43 @@ describe('useAppStore', () => {
expect(store.publicSettingsLoaded).toBe(false) expect(store.publicSettingsLoaded).toBe(false)
expect(store.cachedPublicSettings).toBeNull() expect(store.cachedPublicSettings).toBeNull()
}) })
it('fetchPublicSettings(force) 会同步更新运行时注入配置', async () => {
vi.mocked(getPublicSettings).mockResolvedValue({
registration_enabled: false,
email_verify_enabled: false,
registration_email_suffix_whitelist: [],
promo_code_enabled: true,
password_reset_enabled: false,
invitation_code_enabled: false,
turnstile_enabled: false,
turnstile_site_key: '',
site_name: 'Updated Site',
site_logo: '',
site_subtitle: '',
api_base_url: '',
contact_info: '',
doc_url: '',
home_content: '',
hide_ccs_import_button: false,
purchase_subscription_enabled: false,
purchase_subscription_url: '',
table_default_page_size: 1000,
table_page_size_options: [20, 100, 1000],
custom_menu_items: [],
custom_endpoints: [],
linuxdo_oauth_enabled: false,
backend_mode_enabled: false,
version: '1.0.0'
})
const store = useAppStore()
await store.fetchPublicSettings(true)
expect((window as any).__APP_CONFIG__.table_default_page_size).toBe(1000)
expect((window as any).__APP_CONFIG__.table_page_size_options).toEqual([20, 100, 1000])
expect(localStorage.getItem('table-page-size')).toBeNull()
expect(localStorage.getItem('table-page-size-source')).toBeNull()
})
}) })
}) })

View File

@@ -284,6 +284,9 @@ export const useAppStore = defineStore('app', () => {
* Apply settings to store state (internal helper to avoid code duplication) * Apply settings to store state (internal helper to avoid code duplication)
*/ */
function applySettings(config: PublicSettings): void { function applySettings(config: PublicSettings): void {
if (typeof window !== 'undefined') {
window.__APP_CONFIG__ = { ...config }
}
cachedPublicSettings.value = config cachedPublicSettings.value = config
siteName.value = config.site_name || 'Sub2API' siteName.value = config.site_name || 'Sub2API'
siteLogo.value = config.site_logo || '' siteLogo.value = config.site_logo || ''
@@ -329,6 +332,8 @@ export const useAppStore = defineStore('app', () => {
hide_ccs_import_button: false, hide_ccs_import_button: false,
purchase_subscription_enabled: false, purchase_subscription_enabled: false,
purchase_subscription_url: '', purchase_subscription_url: '',
table_default_page_size: 20,
table_page_size_options: [10, 20, 50, 100],
custom_menu_items: [], custom_menu_items: [],
custom_endpoints: [], custom_endpoints: [],
linuxdo_oauth_enabled: false, linuxdo_oauth_enabled: false,

View File

@@ -16,20 +16,22 @@
@apply min-h-screen; @apply min-h-screen;
} }
/* 自定义滚动条 - 默认隐藏,悬停或滚动时显示 */ /* 自定义滚动条 - 仅针对 Firefox避免 Chrome 取消 webkit 的全局定制 */
* { @supports (-moz-appearance:none) {
scrollbar-width: thin; * {
scrollbar-color: transparent transparent; scrollbar-width: thin;
} scrollbar-color: transparent transparent;
}
*:hover, *:hover,
*:focus-within { *:focus-within {
scrollbar-color: rgba(156, 163, 175, 0.5) transparent; scrollbar-color: rgba(156, 163, 175, 0.5) transparent;
} }
.dark *:hover, .dark *:hover,
.dark *:focus-within { .dark *:focus-within {
scrollbar-color: rgba(75, 85, 99, 0.5) transparent; scrollbar-color: rgba(75, 85, 99, 0.5) transparent;
}
} }
::-webkit-scrollbar { ::-webkit-scrollbar {
@@ -58,36 +60,7 @@
@apply bg-primary-500/20 text-primary-900 dark:text-primary-100; @apply bg-primary-500/20 text-primary-900 dark:text-primary-100;
} }
/*
* 表格滚动容器:始终显示滚动条,不跟随全局悬停策略。
*
* 浏览器兼容性说明:
* - Chrome 121+ 原生支持 scrollbar-color / scrollbar-width。
* 一旦元素匹配了这两个标准属性,::-webkit-scrollbar-* 被完全忽略。
* 全局 * { scrollbar-width: thin } 使所有元素都走标准属性,
* 因此 Chrome 121+ 只看 scrollbar-color。
* - Chrome < 121 不认识标准属性,只看 ::-webkit-scrollbar-*
* 所以保留 ::-webkit-scrollbar-thumb 作为回退。
* - Firefox 始终只看 scrollbar-color / scrollbar-width。
*/
.table-wrapper {
scrollbar-width: auto;
scrollbar-color: rgba(156, 163, 175, 0.7) transparent;
}
.dark .table-wrapper {
scrollbar-color: rgba(75, 85, 99, 0.7) transparent;
}
/* 旧版 Chrome (< 121) 兼容回退 */
.table-wrapper::-webkit-scrollbar {
width: 10px;
height: 10px;
}
.table-wrapper::-webkit-scrollbar-thumb {
@apply rounded-full bg-gray-400/70;
}
.dark .table-wrapper::-webkit-scrollbar-thumb {
@apply rounded-full bg-gray-500/70;
}
} }
@layer components { @layer components {

View File

@@ -106,6 +106,8 @@ export interface PublicSettings {
hide_ccs_import_button: boolean hide_ccs_import_button: boolean
purchase_subscription_enabled: boolean purchase_subscription_enabled: boolean
purchase_subscription_url: string purchase_subscription_url: string
table_default_page_size: number
table_page_size_options: number[]
custom_menu_items: CustomMenuItem[] custom_menu_items: CustomMenuItem[]
custom_endpoints: CustomEndpoint[] custom_endpoints: CustomEndpoint[]
linuxdo_oauth_enabled: boolean linuxdo_oauth_enabled: boolean
@@ -1363,6 +1365,8 @@ export interface UsageQueryParams {
billing_type?: number | null billing_type?: number | null
start_date?: string start_date?: string
end_date?: string end_date?: string
sort_by?: string
sort_order?: 'asc' | 'desc'
} }
// ==================== Account Usage Statistics ==================== // ==================== Account Usage Statistics ====================

View File

@@ -0,0 +1,74 @@
import { afterEach, describe, expect, it } from 'vitest'
import {
DEFAULT_TABLE_PAGE_SIZE,
DEFAULT_TABLE_PAGE_SIZE_OPTIONS,
getConfiguredTableDefaultPageSize,
getConfiguredTablePageSizeOptions,
normalizeTablePageSize
} from '@/utils/tablePreferences'
describe('tablePreferences', () => {
afterEach(() => {
delete window.__APP_CONFIG__
})
it('returns built-in defaults when app config is missing', () => {
expect(getConfiguredTableDefaultPageSize()).toBe(DEFAULT_TABLE_PAGE_SIZE)
expect(getConfiguredTablePageSizeOptions()).toEqual(DEFAULT_TABLE_PAGE_SIZE_OPTIONS)
})
it('uses configured defaults when app config is valid', () => {
window.__APP_CONFIG__ = {
table_default_page_size: 50,
table_page_size_options: [20, 50, 100]
} as any
expect(getConfiguredTableDefaultPageSize()).toBe(50)
expect(getConfiguredTablePageSizeOptions()).toEqual([20, 50, 100])
})
it('allows default page size outside selectable options', () => {
window.__APP_CONFIG__ = {
table_default_page_size: 1000,
table_page_size_options: [20, 50, 100]
} as any
expect(getConfiguredTableDefaultPageSize()).toBe(1000)
expect(getConfiguredTablePageSizeOptions()).toEqual([20, 50, 100])
expect(normalizeTablePageSize(1000)).toBe(100)
expect(normalizeTablePageSize(35)).toBe(50)
})
it('normalizes invalid options without rewriting the configured default itself', () => {
window.__APP_CONFIG__ = {
table_default_page_size: 35,
table_page_size_options: [1001, 50, 10, 10, 2, 0]
} as any
expect(getConfiguredTableDefaultPageSize()).toBe(35)
expect(getConfiguredTablePageSizeOptions()).toEqual([10, 50])
expect(normalizeTablePageSize(undefined)).toBe(50)
})
it('normalizes page size against configured options by rounding up', () => {
window.__APP_CONFIG__ = {
table_default_page_size: 20,
table_page_size_options: [20, 50, 1000]
} as any
expect(normalizeTablePageSize(20)).toBe(20)
expect(normalizeTablePageSize(30)).toBe(50)
expect(normalizeTablePageSize(100)).toBe(1000)
expect(normalizeTablePageSize(1500)).toBe(1000)
expect(normalizeTablePageSize(undefined)).toBe(20)
})
it('keeps built-in selectable defaults at 10, 20, 50, 100', () => {
window.__APP_CONFIG__ = {
table_default_page_size: 1000
} as any
expect(getConfiguredTablePageSizeOptions()).toEqual([10, 20, 50, 100])
})
})

Some files were not shown because too many files have changed in this diff Show More