diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index f632bff3..48f15b5c 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -110,7 +110,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig) concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig) adminUserHandler := admin.NewUserHandler(adminService, concurrencyService) - groupHandler := admin.NewGroupHandler(adminService) claudeOAuthClient := repository.NewClaudeOAuthClient() oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient) openAIOAuthClient := repository.NewOpenAIOAuthClient() @@ -143,6 +142,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig) sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig) rpmCache := repository.NewRPMCache(redisClient) + groupCapacityService := service.NewGroupCapacityService(accountRepository, groupRepository, concurrencyService, sessionLimitCache, rpmCache) + groupHandler := admin.NewGroupHandler(adminService, dashboardService, groupCapacityService) accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, rpmCache, compositeTokenCacheInvalidator) adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService) dataManagementService := service.NewDataManagementService() diff --git a/backend/internal/handler/admin/admin_basic_handlers_test.go b/backend/internal/handler/admin/admin_basic_handlers_test.go index 4de10d3e..cba3ae21 100644 --- a/backend/internal/handler/admin/admin_basic_handlers_test.go +++ b/backend/internal/handler/admin/admin_basic_handlers_test.go @@ -17,7 +17,7 @@ func setupAdminRouter() (*gin.Engine, *stubAdminService) { adminSvc := newStubAdminService() userHandler := NewUserHandler(adminSvc, nil) - groupHandler := NewGroupHandler(adminSvc) + groupHandler := NewGroupHandler(adminSvc, nil, nil) proxyHandler := NewProxyHandler(adminSvc) redeemHandler := NewRedeemHandler(adminSvc, nil) diff --git a/backend/internal/handler/admin/group_handler.go b/backend/internal/handler/admin/group_handler.go index 4ffe64ee..459fd949 100644 --- a/backend/internal/handler/admin/group_handler.go +++ b/backend/internal/handler/admin/group_handler.go @@ -9,6 +9,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" @@ -16,7 +17,9 @@ import ( // GroupHandler handles admin group management type GroupHandler struct { - adminService service.AdminService + adminService service.AdminService + dashboardService *service.DashboardService + groupCapacityService *service.GroupCapacityService } type optionalLimitField struct { @@ -69,9 +72,11 @@ func (f optionalLimitField) ToServiceInput() *float64 { } // NewGroupHandler creates a new admin group handler -func NewGroupHandler(adminService service.AdminService) *GroupHandler { +func NewGroupHandler(adminService service.AdminService, dashboardService *service.DashboardService, groupCapacityService *service.GroupCapacityService) *GroupHandler { return &GroupHandler{ - adminService: adminService, + adminService: adminService, + dashboardService: dashboardService, + groupCapacityService: groupCapacityService, } } @@ -363,6 +368,33 @@ func (h *GroupHandler) GetStats(c *gin.Context) { _ = groupID // TODO: implement actual stats } +// GetUsageSummary returns today's and cumulative cost for all groups. +// GET /api/v1/admin/groups/usage-summary?timezone=Asia/Shanghai +func (h *GroupHandler) GetUsageSummary(c *gin.Context) { + userTZ := c.Query("timezone") + now := timezone.NowInUserLocation(userTZ) + todayStart := timezone.StartOfDayInUserLocation(now, userTZ) + + results, err := h.dashboardService.GetGroupUsageSummary(c.Request.Context(), todayStart) + if err != nil { + response.Error(c, 500, "Failed to get group usage summary") + return + } + + response.Success(c, results) +} + +// GetCapacitySummary returns aggregated capacity (concurrency/sessions/RPM) for all active groups. +// GET /api/v1/admin/groups/capacity-summary +func (h *GroupHandler) GetCapacitySummary(c *gin.Context) { + results, err := h.groupCapacityService.GetAllGroupCapacity(c.Request.Context()) + if err != nil { + response.Error(c, 500, "Failed to get group capacity summary") + return + } + response.Success(c, results) +} + // GetGroupAPIKeys handles getting API keys in a group // GET /api/v1/admin/groups/:id/api-keys func (h *GroupHandler) GetGroupAPIKeys(c *gin.Context) { diff --git a/backend/internal/handler/admin/subscription_handler.go b/backend/internal/handler/admin/subscription_handler.go index 342964b6..611666de 100644 --- a/backend/internal/handler/admin/subscription_handler.go +++ b/backend/internal/handler/admin/subscription_handler.go @@ -77,12 +77,13 @@ func (h *SubscriptionHandler) List(c *gin.Context) { } } status := c.Query("status") + platform := c.Query("platform") // Parse sorting parameters sortBy := c.DefaultQuery("sort_by", "created_at") sortOrder := c.DefaultQuery("sort_order", "desc") - subscriptions, pagination, err := h.subscriptionService.List(c.Request.Context(), page, pageSize, userID, groupID, status, sortBy, sortOrder) + subscriptions, pagination, err := h.subscriptionService.List(c.Request.Context(), page, pageSize, userID, groupID, status, platform, sortBy, sortOrder) if err != nil { response.ErrorFrom(c, err) return diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go index cc25f7c3..d1d867ee 100644 --- a/backend/internal/handler/dto/mappers.go +++ b/backend/internal/handler/dto/mappers.go @@ -135,14 +135,16 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup { return nil } out := &AdminGroup{ - Group: groupFromServiceBase(g), - ModelRouting: g.ModelRouting, - ModelRoutingEnabled: g.ModelRoutingEnabled, - MCPXMLInject: g.MCPXMLInject, - DefaultMappedModel: g.DefaultMappedModel, - SupportedModelScopes: g.SupportedModelScopes, - AccountCount: g.AccountCount, - SortOrder: g.SortOrder, + Group: groupFromServiceBase(g), + ModelRouting: g.ModelRouting, + ModelRoutingEnabled: g.ModelRoutingEnabled, + MCPXMLInject: g.MCPXMLInject, + DefaultMappedModel: g.DefaultMappedModel, + SupportedModelScopes: g.SupportedModelScopes, + AccountCount: g.AccountCount, + ActiveAccountCount: g.ActiveAccountCount, + RateLimitedAccountCount: g.RateLimitedAccountCount, + SortOrder: g.SortOrder, } if len(g.AccountGroups) > 0 { out.AccountGroups = make([]AccountGroup, 0, len(g.AccountGroups)) diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go index fa360804..7b3443be 100644 --- a/backend/internal/handler/dto/types.go +++ b/backend/internal/handler/dto/types.go @@ -122,9 +122,11 @@ type AdminGroup struct { DefaultMappedModel string `json:"default_mapped_model"` // 支持的模型系列(仅 antigravity 平台使用) - SupportedModelScopes []string `json:"supported_model_scopes"` - AccountGroups []AccountGroup `json:"account_groups,omitempty"` - AccountCount int64 `json:"account_count,omitempty"` + SupportedModelScopes []string `json:"supported_model_scopes"` + AccountGroups []AccountGroup `json:"account_groups,omitempty"` + AccountCount int64 `json:"account_count,omitempty"` + ActiveAccountCount int64 `json:"active_account_count,omitempty"` + RateLimitedAccountCount int64 `json:"rate_limited_account_count,omitempty"` // 分组排序 SortOrder int `json:"sort_order"` diff --git a/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go index 6bcc0003..b9dbe0ce 100644 --- a/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go +++ b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go @@ -76,7 +76,7 @@ func (f *fakeGroupRepo) ListActiveByPlatform(context.Context, string) ([]service return nil, nil } func (f *fakeGroupRepo) ExistsByName(context.Context, string) (bool, error) { return false, nil } -func (f *fakeGroupRepo) GetAccountCount(context.Context, int64) (int64, error) { return 0, nil } +func (f *fakeGroupRepo) GetAccountCount(context.Context, int64) (int64, int64, error) { return 0, 0, nil } func (f *fakeGroupRepo) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) { return 0, nil } diff --git a/backend/internal/handler/gateway_helper_hotpath_test.go b/backend/internal/handler/gateway_helper_hotpath_test.go index 9e904107..4a677199 100644 --- a/backend/internal/handler/gateway_helper_hotpath_test.go +++ b/backend/internal/handler/gateway_helper_hotpath_test.go @@ -136,7 +136,7 @@ func validClaudeCodeBodyJSON() []byte { return []byte(`{ "model":"claude-3-5-sonnet-20241022", "system":[{"text":"You are Claude Code, Anthropic's official CLI for Claude."}], - "metadata":{"user_id":"user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123"} + "metadata":{"user_id":"user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"} }`) } @@ -190,7 +190,7 @@ func TestSetClaudeCodeClientContext_ReuseParsedRequestAndContextCache(t *testing System: []any{ map[string]any{"text": "You are Claude Code, Anthropic's official CLI for Claude."}, }, - MetadataUserID: "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123", + MetadataUserID: "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", } // body 非法 JSON,如果函数复用 parsedReq 成功则仍应判定为 Claude Code。 @@ -209,7 +209,7 @@ func TestSetClaudeCodeClientContext_ReuseParsedRequestAndContextCache(t *testing "system": []any{ map[string]any{"text": "You are Claude Code, Anthropic's official CLI for Claude."}, }, - "metadata": map[string]any{"user_id": "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123"}, + "metadata": map[string]any{"user_id": "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"}, }) SetClaudeCodeClientContext(c, []byte(`{invalid`), nil) diff --git a/backend/internal/handler/sora_gateway_handler_test.go b/backend/internal/handler/sora_gateway_handler_test.go index 06b09437..5c631132 100644 --- a/backend/internal/handler/sora_gateway_handler_test.go +++ b/backend/internal/handler/sora_gateway_handler_test.go @@ -273,8 +273,8 @@ func (r *stubGroupRepo) ListActiveByPlatform(ctx context.Context, platform strin func (r *stubGroupRepo) ExistsByName(ctx context.Context, name string) (bool, error) { return false, nil } -func (r *stubGroupRepo) GetAccountCount(ctx context.Context, groupID int64) (int64, error) { - return 0, nil +func (r *stubGroupRepo) GetAccountCount(ctx context.Context, groupID int64) (int64, int64, error) { + return 0, 0, nil } func (r *stubGroupRepo) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) { return 0, nil @@ -348,6 +348,9 @@ func (s *stubUsageLogRepo) GetGroupStatsWithFilters(ctx context.Context, startTi func (s *stubUsageLogRepo) GetUserBreakdownStats(ctx context.Context, startTime, endTime time.Time, dim usagestats.UserBreakdownDimension, limit int) ([]usagestats.UserBreakdownItem, error) { return nil, nil } +func (s *stubUsageLogRepo) GetAllGroupUsageSummary(ctx context.Context, todayStart time.Time) ([]usagestats.GroupUsageSummary, error) { + return nil, nil +} func (s *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) { return nil, nil } diff --git a/backend/internal/pkg/usagestats/usage_log_types.go b/backend/internal/pkg/usagestats/usage_log_types.go index de3ad378..44cddb6a 100644 --- a/backend/internal/pkg/usagestats/usage_log_types.go +++ b/backend/internal/pkg/usagestats/usage_log_types.go @@ -112,6 +112,13 @@ type EndpointStat struct { ActualCost float64 `json:"actual_cost"` // 实际扣除 } +// GroupUsageSummary represents today's and cumulative cost for a single group. +type GroupUsageSummary struct { + GroupID int64 `json:"group_id"` + TodayCost float64 `json:"today_cost"` + TotalCost float64 `json:"total_cost"` +} + // GroupStat represents usage statistics for a single group type GroupStat struct { GroupID int64 `json:"group_id"` diff --git a/backend/internal/repository/group_repo.go b/backend/internal/repository/group_repo.go index c195f1f1..674c655b 100644 --- a/backend/internal/repository/group_repo.go +++ b/backend/internal/repository/group_repo.go @@ -88,8 +88,9 @@ func (r *groupRepository) GetByID(ctx context.Context, id int64) (*service.Group if err != nil { return nil, err } - count, _ := r.GetAccountCount(ctx, out.ID) - out.AccountCount = count + total, active, _ := r.GetAccountCount(ctx, out.ID) + out.AccountCount = total + out.ActiveAccountCount = active return out, nil } @@ -256,7 +257,10 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination counts, err := r.loadAccountCounts(ctx, groupIDs) if err == nil { for i := range outGroups { - outGroups[i].AccountCount = counts[outGroups[i].ID] + c := counts[outGroups[i].ID] + outGroups[i].AccountCount = c.Total + outGroups[i].ActiveAccountCount = c.Active + outGroups[i].RateLimitedAccountCount = c.RateLimited } } @@ -283,7 +287,10 @@ func (r *groupRepository) ListActive(ctx context.Context) ([]service.Group, erro counts, err := r.loadAccountCounts(ctx, groupIDs) if err == nil { for i := range outGroups { - outGroups[i].AccountCount = counts[outGroups[i].ID] + c := counts[outGroups[i].ID] + outGroups[i].AccountCount = c.Total + outGroups[i].ActiveAccountCount = c.Active + outGroups[i].RateLimitedAccountCount = c.RateLimited } } @@ -310,7 +317,10 @@ func (r *groupRepository) ListActiveByPlatform(ctx context.Context, platform str counts, err := r.loadAccountCounts(ctx, groupIDs) if err == nil { for i := range outGroups { - outGroups[i].AccountCount = counts[outGroups[i].ID] + c := counts[outGroups[i].ID] + outGroups[i].AccountCount = c.Total + outGroups[i].ActiveAccountCount = c.Active + outGroups[i].RateLimitedAccountCount = c.RateLimited } } @@ -369,12 +379,20 @@ func (r *groupRepository) ExistsByIDs(ctx context.Context, ids []int64) (map[int return result, nil } -func (r *groupRepository) GetAccountCount(ctx context.Context, groupID int64) (int64, error) { - var count int64 - if err := scanSingleRow(ctx, r.sql, "SELECT COUNT(*) FROM account_groups WHERE group_id = $1", []any{groupID}, &count); err != nil { - return 0, err - } - return count, nil +func (r *groupRepository) GetAccountCount(ctx context.Context, groupID int64) (total int64, active int64, err error) { + var rateLimited int64 + err = scanSingleRow(ctx, r.sql, + `SELECT COUNT(*), + COUNT(*) FILTER (WHERE a.status = 'active' AND a.schedulable = true), + COUNT(*) FILTER (WHERE a.status = 'active' AND ( + a.rate_limit_reset_at > NOW() OR + a.overload_until > NOW() OR + a.temp_unschedulable_until > NOW() + )) + FROM account_groups ag JOIN accounts a ON a.id = ag.account_id + WHERE ag.group_id = $1`, + []any{groupID}, &total, &active, &rateLimited) + return } func (r *groupRepository) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) { @@ -500,15 +518,32 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64, return affectedUserIDs, nil } -func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int64) (counts map[int64]int64, err error) { - counts = make(map[int64]int64, len(groupIDs)) +type groupAccountCounts struct { + Total int64 + Active int64 + RateLimited int64 +} + +func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int64) (counts map[int64]groupAccountCounts, err error) { + counts = make(map[int64]groupAccountCounts, len(groupIDs)) if len(groupIDs) == 0 { return counts, nil } rows, err := r.sql.QueryContext( ctx, - "SELECT group_id, COUNT(*) FROM account_groups WHERE group_id = ANY($1) GROUP BY group_id", + `SELECT ag.group_id, + COUNT(*) AS total, + COUNT(*) FILTER (WHERE a.status = 'active' AND a.schedulable = true) AS active, + COUNT(*) FILTER (WHERE a.status = 'active' AND ( + a.rate_limit_reset_at > NOW() OR + a.overload_until > NOW() OR + a.temp_unschedulable_until > NOW() + )) AS rate_limited + FROM account_groups ag + JOIN accounts a ON a.id = ag.account_id + WHERE ag.group_id = ANY($1) + GROUP BY ag.group_id`, pq.Array(groupIDs), ) if err != nil { @@ -523,11 +558,11 @@ func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int6 for rows.Next() { var groupID int64 - var count int64 - if err = rows.Scan(&groupID, &count); err != nil { + var c groupAccountCounts + if err = rows.Scan(&groupID, &c.Total, &c.Active, &c.RateLimited); err != nil { return nil, err } - counts[groupID] = count + counts[groupID] = c } if err = rows.Err(); err != nil { return nil, err diff --git a/backend/internal/repository/group_repo_integration_test.go b/backend/internal/repository/group_repo_integration_test.go index 4a849a46..eccf5cea 100644 --- a/backend/internal/repository/group_repo_integration_test.go +++ b/backend/internal/repository/group_repo_integration_test.go @@ -603,7 +603,7 @@ func (s *GroupRepoSuite) TestGetAccountCount() { _, err = s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", a2, group.ID, 2) s.Require().NoError(err) - count, err := s.repo.GetAccountCount(s.ctx, group.ID) + count, _, err := s.repo.GetAccountCount(s.ctx, group.ID) s.Require().NoError(err, "GetAccountCount") s.Require().Equal(int64(2), count) } @@ -619,7 +619,7 @@ func (s *GroupRepoSuite) TestGetAccountCount_Empty() { } s.Require().NoError(s.repo.Create(s.ctx, group)) - count, err := s.repo.GetAccountCount(s.ctx, group.ID) + count, _, err := s.repo.GetAccountCount(s.ctx, group.ID) s.Require().NoError(err) s.Require().Zero(count) } @@ -651,7 +651,7 @@ func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID() { s.Require().NoError(err, "DeleteAccountGroupsByGroupID") s.Require().Equal(int64(1), affected, "expected 1 affected row") - count, err := s.repo.GetAccountCount(s.ctx, g.ID) + count, _, err := s.repo.GetAccountCount(s.ctx, g.ID) s.Require().NoError(err, "GetAccountCount") s.Require().Equal(int64(0), count, "expected 0 account groups") } @@ -692,7 +692,7 @@ func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID_MultipleAccounts() { s.Require().NoError(err) s.Require().Equal(int64(3), affected) - count, _ := s.repo.GetAccountCount(s.ctx, g.ID) + count, _, _ := s.repo.GetAccountCount(s.ctx, g.ID) s.Require().Zero(count) } diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go index 61a54267..ca454606 100644 --- a/backend/internal/repository/usage_log_repo.go +++ b/backend/internal/repository/usage_log_repo.go @@ -3089,6 +3089,41 @@ func (r *usageLogRepository) GetUserBreakdownStats(ctx context.Context, startTim return results, nil } +// GetAllGroupUsageSummary returns today's and cumulative actual_cost for every group. +// todayStart is the start-of-day in the caller's timezone (UTC-based). +// TODO(perf): This query scans ALL usage_logs rows for total_cost aggregation. +// When usage_logs exceeds ~1M rows, consider adding a short-lived cache (30s) +// or a materialized view / pre-aggregation table for cumulative costs. +func (r *usageLogRepository) GetAllGroupUsageSummary(ctx context.Context, todayStart time.Time) ([]usagestats.GroupUsageSummary, error) { + query := ` + SELECT + g.id AS group_id, + COALESCE(SUM(ul.actual_cost), 0) AS total_cost, + COALESCE(SUM(CASE WHEN ul.created_at >= $1 THEN ul.actual_cost ELSE 0 END), 0) AS today_cost + FROM groups g + LEFT JOIN usage_logs ul ON ul.group_id = g.id + GROUP BY g.id + ` + + rows, err := r.sql.QueryContext(ctx, query, todayStart) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + var results []usagestats.GroupUsageSummary + for rows.Next() { + var row usagestats.GroupUsageSummary + if err := rows.Scan(&row.GroupID, &row.TotalCost, &row.TodayCost); err != nil { + return nil, err + } + results = append(results, row) + } + if err := rows.Err(); err != nil { + return nil, err + } + return results, nil +} + // resolveModelDimensionExpression maps model source type to a safe SQL expression. func resolveModelDimensionExpression(modelType string) string { switch usagestats.NormalizeModelSource(modelType) { diff --git a/backend/internal/repository/user_subscription_repo.go b/backend/internal/repository/user_subscription_repo.go index 5a649846..e3f64a5f 100644 --- a/backend/internal/repository/user_subscription_repo.go +++ b/backend/internal/repository/user_subscription_repo.go @@ -5,6 +5,7 @@ import ( "time" dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/usersubscription" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/service" @@ -190,7 +191,7 @@ func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID return userSubscriptionEntitiesToService(subs), paginationResultFromTotal(int64(total), params), nil } -func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) { +func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, platform, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) { client := clientFromContext(ctx, r.client) q := client.UserSubscription.Query() if userID != nil { @@ -199,6 +200,9 @@ func (r *userSubscriptionRepository) List(ctx context.Context, params pagination if groupID != nil { q = q.Where(usersubscription.GroupIDEQ(*groupID)) } + if platform != "" { + q = q.Where(usersubscription.HasGroupWith(group.PlatformEQ(platform))) + } // Status filtering with real-time expiration check now := time.Now() diff --git a/backend/internal/repository/user_subscription_repo_integration_test.go b/backend/internal/repository/user_subscription_repo_integration_test.go index 60a5a378..a74860e3 100644 --- a/backend/internal/repository/user_subscription_repo_integration_test.go +++ b/backend/internal/repository/user_subscription_repo_integration_test.go @@ -271,7 +271,7 @@ func (s *UserSubscriptionRepoSuite) TestList_NoFilters() { group := s.mustCreateGroup("g-list") s.mustCreateSubscription(user.ID, group.ID, nil) - subs, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, "", "", "") + subs, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, "", "", "", "") s.Require().NoError(err, "List") s.Require().Len(subs, 1) s.Require().Equal(int64(1), page.Total) @@ -285,7 +285,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByUserID() { s.mustCreateSubscription(user1.ID, group.ID, nil) s.mustCreateSubscription(user2.ID, group.ID, nil) - subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, &user1.ID, nil, "", "", "") + subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, &user1.ID, nil, "", "", "", "") s.Require().NoError(err) s.Require().Len(subs, 1) s.Require().Equal(user1.ID, subs[0].UserID) @@ -299,7 +299,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByGroupID() { s.mustCreateSubscription(user.ID, g1.ID, nil) s.mustCreateSubscription(user.ID, g2.ID, nil) - subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, &g1.ID, "", "", "") + subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, &g1.ID, "", "", "", "") s.Require().NoError(err) s.Require().Len(subs, 1) s.Require().Equal(g1.ID, subs[0].GroupID) @@ -320,7 +320,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByStatus() { c.SetExpiresAt(time.Now().Add(-24 * time.Hour)) }) - subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, service.SubscriptionStatusExpired, "", "") + subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, service.SubscriptionStatusExpired, "", "", "") s.Require().NoError(err) s.Require().Len(subs, 1) s.Require().Equal(service.SubscriptionStatusExpired, subs[0].Status) diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index 309dcf4e..4ae5c272 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -924,8 +924,8 @@ func (stubGroupRepo) ExistsByName(ctx context.Context, name string) (bool, error return false, errors.New("not implemented") } -func (stubGroupRepo) GetAccountCount(ctx context.Context, groupID int64) (int64, error) { - return 0, errors.New("not implemented") +func (stubGroupRepo) GetAccountCount(ctx context.Context, groupID int64) (int64, int64, error) { + return 0, 0, errors.New("not implemented") } func (stubGroupRepo) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) { @@ -1289,7 +1289,7 @@ func (r *stubUserSubscriptionRepo) ListActiveByUserID(ctx context.Context, userI func (stubUserSubscriptionRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) { return nil, nil, errors.New("not implemented") } -func (stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) { +func (stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, platform, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) { return nil, nil, errors.New("not implemented") } func (stubUserSubscriptionRepo) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) { @@ -1786,6 +1786,9 @@ func (r *stubUsageLogRepo) GetAccountUsageStats(ctx context.Context, accountID i func (r *stubUsageLogRepo) GetStatsWithFilters(ctx context.Context, filters usagestats.UsageLogFilters) (*usagestats.UsageStats, error) { return nil, errors.New("not implemented") } +func (r *stubUsageLogRepo) GetAllGroupUsageSummary(ctx context.Context, todayStart time.Time) ([]usagestats.GroupUsageSummary, error) { + return nil, errors.New("not implemented") +} type stubSettingRepo struct { all map[string]string diff --git a/backend/internal/server/middleware/api_key_auth_google_test.go b/backend/internal/server/middleware/api_key_auth_google_test.go index 49db5f19..9f9bba13 100644 --- a/backend/internal/server/middleware/api_key_auth_google_test.go +++ b/backend/internal/server/middleware/api_key_auth_google_test.go @@ -135,7 +135,7 @@ func (f fakeGoogleSubscriptionRepo) ListActiveByUserID(ctx context.Context, user func (f fakeGoogleSubscriptionRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) { return nil, nil, errors.New("not implemented") } -func (f fakeGoogleSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) { +func (f fakeGoogleSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, platform, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) { return nil, nil, errors.New("not implemented") } func (f fakeGoogleSubscriptionRepo) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) { diff --git a/backend/internal/server/middleware/api_key_auth_test.go b/backend/internal/server/middleware/api_key_auth_test.go index 22befa2a..a633ffdd 100644 --- a/backend/internal/server/middleware/api_key_auth_test.go +++ b/backend/internal/server/middleware/api_key_auth_test.go @@ -646,7 +646,7 @@ func (r *stubUserSubscriptionRepo) ListByGroupID(ctx context.Context, groupID in return nil, nil, errors.New("not implemented") } -func (r *stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) { +func (r *stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, platform, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) { return nil, nil, errors.New("not implemented") } diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index 67d7cb45..89faf6dc 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -227,6 +227,8 @@ func registerGroupRoutes(admin *gin.RouterGroup, h *handler.Handlers) { { groups.GET("", h.Admin.Group.List) groups.GET("/all", h.Admin.Group.GetAll) + groups.GET("/usage-summary", h.Admin.Group.GetUsageSummary) + groups.GET("/capacity-summary", h.Admin.Group.GetCapacitySummary) groups.PUT("/sort-order", h.Admin.Group.UpdateSortOrder) groups.GET("/:id", h.Admin.Group.GetByID) groups.POST("", h.Admin.Group.Create) diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go index 482d22b1..d30b670d 100644 --- a/backend/internal/service/account_test_service.go +++ b/backend/internal/service/account_test_service.go @@ -113,15 +113,18 @@ func (s *AccountTestService) validateUpstreamBaseURL(raw string) (string, error) return normalized, nil } -// generateSessionString generates a Claude Code style session string +// generateSessionString generates a Claude Code style session string. +// The output format is determined by the UA version in claude.DefaultHeaders, +// ensuring consistency between the user_id format and the UA sent to upstream. func generateSessionString() (string, error) { - bytes := make([]byte, 32) - if _, err := rand.Read(bytes); err != nil { + b := make([]byte, 32) + if _, err := rand.Read(b); err != nil { return "", err } - hex64 := hex.EncodeToString(bytes) + hex64 := hex.EncodeToString(b) sessionUUID := uuid.New().String() - return fmt.Sprintf("user_%s_account__session_%s", hex64, sessionUUID), nil + uaVersion := ExtractCLIVersion(claude.DefaultHeaders["User-Agent"]) + return FormatMetadataUserID(hex64, "", sessionUUID, uaVersion), nil } // createTestPayload creates a Claude Code style test request payload diff --git a/backend/internal/service/account_usage_service.go b/backend/internal/service/account_usage_service.go index 4a05c64a..74142700 100644 --- a/backend/internal/service/account_usage_service.go +++ b/backend/internal/service/account_usage_service.go @@ -49,6 +49,7 @@ type UsageLogRepository interface { GetUpstreamEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) GetUserBreakdownStats(ctx context.Context, startTime, endTime time.Time, dim usagestats.UserBreakdownDimension, limit int) ([]usagestats.UserBreakdownItem, error) + GetAllGroupUsageSummary(ctx context.Context, todayStart time.Time) ([]usagestats.GroupUsageSummary, error) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error) GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (*usagestats.UserSpendingRankingResponse, error) diff --git a/backend/internal/service/admin_service_apikey_test.go b/backend/internal/service/admin_service_apikey_test.go index 88d2f492..7588c16d 100644 --- a/backend/internal/service/admin_service_apikey_test.go +++ b/backend/internal/service/admin_service_apikey_test.go @@ -194,7 +194,7 @@ func (s *groupRepoStubForGroupUpdate) ListActiveByPlatform(context.Context, stri func (s *groupRepoStubForGroupUpdate) ExistsByName(context.Context, string) (bool, error) { panic("unexpected") } -func (s *groupRepoStubForGroupUpdate) GetAccountCount(context.Context, int64) (int64, error) { +func (s *groupRepoStubForGroupUpdate) GetAccountCount(context.Context, int64) (int64, int64, error) { panic("unexpected") } func (s *groupRepoStubForGroupUpdate) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) { diff --git a/backend/internal/service/admin_service_delete_test.go b/backend/internal/service/admin_service_delete_test.go index 2e0f7d90..662b4771 100644 --- a/backend/internal/service/admin_service_delete_test.go +++ b/backend/internal/service/admin_service_delete_test.go @@ -160,7 +160,7 @@ func (s *groupRepoStub) ExistsByName(ctx context.Context, name string) (bool, er panic("unexpected ExistsByName call") } -func (s *groupRepoStub) GetAccountCount(ctx context.Context, groupID int64) (int64, error) { +func (s *groupRepoStub) GetAccountCount(ctx context.Context, groupID int64) (int64, int64, error) { panic("unexpected GetAccountCount call") } diff --git a/backend/internal/service/admin_service_group_test.go b/backend/internal/service/admin_service_group_test.go index ef77a980..536be0b5 100644 --- a/backend/internal/service/admin_service_group_test.go +++ b/backend/internal/service/admin_service_group_test.go @@ -100,7 +100,7 @@ func (s *groupRepoStubForAdmin) ExistsByName(_ context.Context, _ string) (bool, panic("unexpected ExistsByName call") } -func (s *groupRepoStubForAdmin) GetAccountCount(_ context.Context, _ int64) (int64, error) { +func (s *groupRepoStubForAdmin) GetAccountCount(_ context.Context, _ int64) (int64, int64, error) { panic("unexpected GetAccountCount call") } @@ -383,7 +383,7 @@ func (s *groupRepoStubForFallbackCycle) ExistsByName(_ context.Context, _ string panic("unexpected ExistsByName call") } -func (s *groupRepoStubForFallbackCycle) GetAccountCount(_ context.Context, _ int64) (int64, error) { +func (s *groupRepoStubForFallbackCycle) GetAccountCount(_ context.Context, _ int64) (int64, int64, error) { panic("unexpected GetAccountCount call") } @@ -458,7 +458,7 @@ func (s *groupRepoStubForInvalidRequestFallback) ExistsByName(_ context.Context, panic("unexpected ExistsByName call") } -func (s *groupRepoStubForInvalidRequestFallback) GetAccountCount(_ context.Context, _ int64) (int64, error) { +func (s *groupRepoStubForInvalidRequestFallback) GetAccountCount(_ context.Context, _ int64) (int64, int64, error) { panic("unexpected GetAccountCount call") } diff --git a/backend/internal/service/claude_code_validator.go b/backend/internal/service/claude_code_validator.go index f71098b1..4e8ced67 100644 --- a/backend/internal/service/claude_code_validator.go +++ b/backend/internal/service/claude_code_validator.go @@ -21,9 +21,6 @@ var ( // 带捕获组的版本提取正则 claudeCodeUAVersionPattern = regexp.MustCompile(`(?i)^claude-cli/(\d+\.\d+\.\d+)`) - // metadata.user_id 格式: user_{64位hex}_account__session_{uuid} - userIDPattern = regexp.MustCompile(`^user_[a-fA-F0-9]{64}_account__session_[\w-]+$`) - // System prompt 相似度阈值(默认 0.5,和 claude-relay-service 一致) systemPromptThreshold = 0.5 ) @@ -124,7 +121,7 @@ func (v *ClaudeCodeValidator) Validate(r *http.Request, body map[string]any) boo return false } - if !userIDPattern.MatchString(userID) { + if ParseMetadataUserID(userID) == nil { return false } @@ -278,11 +275,7 @@ func SetClaudeCodeClient(ctx context.Context, isClaudeCode bool) context.Context // ExtractVersion 从 User-Agent 中提取 Claude Code 版本号 // 返回 "2.1.22" 形式的版本号,如果不匹配返回空字符串 func (v *ClaudeCodeValidator) ExtractVersion(ua string) string { - matches := claudeCodeUAVersionPattern.FindStringSubmatch(ua) - if len(matches) >= 2 { - return matches[1] - } - return "" + return ExtractCLIVersion(ua) } // SetClaudeCodeVersion 将 Claude Code 版本号设置到 context 中 diff --git a/backend/internal/service/dashboard_service.go b/backend/internal/service/dashboard_service.go index 1c960fdf..3e059e30 100644 --- a/backend/internal/service/dashboard_service.go +++ b/backend/internal/service/dashboard_service.go @@ -169,6 +169,15 @@ func (s *DashboardService) GetGroupStatsWithFilters(ctx context.Context, startTi return stats, nil } +// GetGroupUsageSummary returns today's and cumulative cost for all groups. +func (s *DashboardService) GetGroupUsageSummary(ctx context.Context, todayStart time.Time) ([]usagestats.GroupUsageSummary, error) { + results, err := s.usageRepo.GetAllGroupUsageSummary(ctx, todayStart) + if err != nil { + return nil, fmt.Errorf("get group usage summary: %w", err) + } + return results, nil +} + func (s *DashboardService) getCachedDashboardStats(ctx context.Context) (*usagestats.DashboardStats, bool, error) { data, err := s.cache.GetDashboardStats(ctx) if err != nil { diff --git a/backend/internal/service/gateway_multiplatform_test.go b/backend/internal/service/gateway_multiplatform_test.go index ea8fa784..718cd42a 100644 --- a/backend/internal/service/gateway_multiplatform_test.go +++ b/backend/internal/service/gateway_multiplatform_test.go @@ -278,8 +278,8 @@ func (m *mockGroupRepoForGateway) ListActiveByPlatform(ctx context.Context, plat func (m *mockGroupRepoForGateway) ExistsByName(ctx context.Context, name string) (bool, error) { return false, nil } -func (m *mockGroupRepoForGateway) GetAccountCount(ctx context.Context, groupID int64) (int64, error) { - return 0, nil +func (m *mockGroupRepoForGateway) GetAccountCount(ctx context.Context, groupID int64) (int64, int64, error) { + return 0, 0, nil } func (m *mockGroupRepoForGateway) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) { return 0, nil diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 4544ec82..1f6df629 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -326,7 +326,6 @@ func isClaudeCodeCredentialScopeError(msg string) bool { // Some upstream APIs return non-standard "data:" without space (should be "data: "). var ( sseDataRe = regexp.MustCompile(`^data:\s*`) - sessionIDRegex = regexp.MustCompile(`session_([a-f0-9-]{36})`) claudeCliUserAgentRe = regexp.MustCompile(`^claude-cli/\d+\.\d+\.\d+`) // claudeCodePromptPrefixes 用于检测 Claude Code 系统提示词的前缀列表 @@ -645,8 +644,8 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string { // 1. 最高优先级:从 metadata.user_id 提取 session_xxx if parsed.MetadataUserID != "" { - if match := sessionIDRegex.FindStringSubmatch(parsed.MetadataUserID); len(match) > 1 { - return match[1] + if uid := ParseMetadataUserID(parsed.MetadataUserID); uid != nil && uid.SessionID != "" { + return uid.SessionID } } @@ -1027,13 +1026,13 @@ func (s *GatewayService) buildOAuthMetadataUserID(parsed *ParsedRequest, account sessionID = generateSessionUUID(seed) } - // Prefer the newer format that includes account_uuid (if present), - // otherwise fall back to the legacy Claude Code format. - accountUUID := strings.TrimSpace(account.GetExtraString("account_uuid")) - if accountUUID != "" { - return fmt.Sprintf("user_%s_account_%s_session_%s", userID, accountUUID, sessionID) + // 根据指纹 UA 版本选择输出格式 + var uaVersion string + if fp != nil { + uaVersion = ExtractCLIVersion(fp.UserAgent) } - return fmt.Sprintf("user_%s_account__session_%s", userID, sessionID) + accountUUID := strings.TrimSpace(account.GetExtraString("account_uuid")) + return FormatMetadataUserID(userID, accountUUID, sessionID, uaVersion) } // GenerateSessionUUID creates a deterministic UUID4 from a seed string. @@ -5567,7 +5566,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex // 如果启用了会话ID伪装,会在重写后替换 session 部分为固定值 accountUUID := account.GetExtraString("account_uuid") if accountUUID != "" && fp.ClientID != "" { - if newBody, err := s.identityService.RewriteUserIDWithMasking(ctx, body, account, accountUUID, fp.ClientID); err == nil && len(newBody) > 0 { + if newBody, err := s.identityService.RewriteUserIDWithMasking(ctx, body, account, accountUUID, fp.ClientID, fp.UserAgent); err == nil && len(newBody) > 0 { body = newBody } } @@ -8197,7 +8196,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con if err == nil { accountUUID := account.GetExtraString("account_uuid") if accountUUID != "" && fp.ClientID != "" { - if newBody, err := s.identityService.RewriteUserIDWithMasking(ctx, body, account, accountUUID, fp.ClientID); err == nil && len(newBody) > 0 { + if newBody, err := s.identityService.RewriteUserIDWithMasking(ctx, body, account, accountUUID, fp.ClientID, fp.UserAgent); err == nil && len(newBody) > 0 { body = newBody } } diff --git a/backend/internal/service/gemini_multiplatform_test.go b/backend/internal/service/gemini_multiplatform_test.go index b0b804eb..a78c56e7 100644 --- a/backend/internal/service/gemini_multiplatform_test.go +++ b/backend/internal/service/gemini_multiplatform_test.go @@ -230,8 +230,8 @@ func (m *mockGroupRepoForGemini) ListActiveByPlatform(ctx context.Context, platf func (m *mockGroupRepoForGemini) ExistsByName(ctx context.Context, name string) (bool, error) { return false, nil } -func (m *mockGroupRepoForGemini) GetAccountCount(ctx context.Context, groupID int64) (int64, error) { - return 0, nil +func (m *mockGroupRepoForGemini) GetAccountCount(ctx context.Context, groupID int64) (int64, int64, error) { + return 0, 0, nil } func (m *mockGroupRepoForGemini) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) { return 0, nil diff --git a/backend/internal/service/generate_session_hash_test.go b/backend/internal/service/generate_session_hash_test.go index 8aa358a5..f91fb4c9 100644 --- a/backend/internal/service/generate_session_hash_test.go +++ b/backend/internal/service/generate_session_hash_test.go @@ -24,7 +24,7 @@ func TestGenerateSessionHash_MetadataHasHighestPriority(t *testing.T) { svc := &GatewayService{} parsed := &ParsedRequest{ - MetadataUserID: "session_123e4567-e89b-12d3-a456-426614174000", + MetadataUserID: "user_a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2_account__session_123e4567-e89b-12d3-a456-426614174000", System: "You are a helpful assistant.", HasSystem: true, Messages: []any{ @@ -196,7 +196,7 @@ func TestGenerateSessionHash_MetadataOverridesSessionContext(t *testing.T) { svc := &GatewayService{} parsed := &ParsedRequest{ - MetadataUserID: "session_123e4567-e89b-12d3-a456-426614174000", + MetadataUserID: "user_a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2_account__session_123e4567-e89b-12d3-a456-426614174000", Messages: []any{ map[string]any{"role": "user", "content": "hello"}, }, @@ -212,6 +212,22 @@ func TestGenerateSessionHash_MetadataOverridesSessionContext(t *testing.T) { "metadata session_id should take priority over SessionContext") } +func TestGenerateSessionHash_MetadataJSON_HasHighestPriority(t *testing.T) { + svc := &GatewayService{} + + parsed := &ParsedRequest{ + MetadataUserID: `{"device_id":"a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2","account_uuid":"","session_id":"c72554f2-1234-5678-abcd-123456789abc"}`, + System: "You are a helpful assistant.", + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + } + + hash := svc.GenerateSessionHash(parsed) + require.Equal(t, "c72554f2-1234-5678-abcd-123456789abc", hash, "JSON format metadata session_id should have highest priority") +} + func TestGenerateSessionHash_NilSessionContextBackwardCompatible(t *testing.T) { svc := &GatewayService{} diff --git a/backend/internal/service/group.go b/backend/internal/service/group.go index 537b5a3b..e17032e0 100644 --- a/backend/internal/service/group.go +++ b/backend/internal/service/group.go @@ -64,8 +64,10 @@ type Group struct { CreatedAt time.Time UpdatedAt time.Time - AccountGroups []AccountGroup - AccountCount int64 + AccountGroups []AccountGroup + AccountCount int64 + ActiveAccountCount int64 + RateLimitedAccountCount int64 } func (g *Group) IsActive() bool { diff --git a/backend/internal/service/group_capacity_service.go b/backend/internal/service/group_capacity_service.go new file mode 100644 index 00000000..459084dc --- /dev/null +++ b/backend/internal/service/group_capacity_service.go @@ -0,0 +1,131 @@ +package service + +import ( + "context" + "time" +) + +// GroupCapacitySummary holds aggregated capacity for a single group. +type GroupCapacitySummary struct { + GroupID int64 `json:"group_id"` + ConcurrencyUsed int `json:"concurrency_used"` + ConcurrencyMax int `json:"concurrency_max"` + SessionsUsed int `json:"sessions_used"` + SessionsMax int `json:"sessions_max"` + RPMUsed int `json:"rpm_used"` + RPMMax int `json:"rpm_max"` +} + +// GroupCapacityService aggregates per-group capacity from runtime data. +type GroupCapacityService struct { + accountRepo AccountRepository + groupRepo GroupRepository + concurrencyService *ConcurrencyService + sessionLimitCache SessionLimitCache + rpmCache RPMCache +} + +// NewGroupCapacityService creates a new GroupCapacityService. +func NewGroupCapacityService( + accountRepo AccountRepository, + groupRepo GroupRepository, + concurrencyService *ConcurrencyService, + sessionLimitCache SessionLimitCache, + rpmCache RPMCache, +) *GroupCapacityService { + return &GroupCapacityService{ + accountRepo: accountRepo, + groupRepo: groupRepo, + concurrencyService: concurrencyService, + sessionLimitCache: sessionLimitCache, + rpmCache: rpmCache, + } +} + +// GetAllGroupCapacity returns capacity summary for all active groups. +func (s *GroupCapacityService) GetAllGroupCapacity(ctx context.Context) ([]GroupCapacitySummary, error) { + groups, err := s.groupRepo.ListActive(ctx) + if err != nil { + return nil, err + } + + results := make([]GroupCapacitySummary, 0, len(groups)) + for i := range groups { + cap, err := s.getGroupCapacity(ctx, groups[i].ID) + if err != nil { + // Skip groups with errors, return partial results + continue + } + cap.GroupID = groups[i].ID + results = append(results, cap) + } + return results, nil +} + +func (s *GroupCapacityService) getGroupCapacity(ctx context.Context, groupID int64) (GroupCapacitySummary, error) { + accounts, err := s.accountRepo.ListSchedulableByGroupID(ctx, groupID) + if err != nil { + return GroupCapacitySummary{}, err + } + if len(accounts) == 0 { + return GroupCapacitySummary{}, nil + } + + // Collect account IDs and config values + accountIDs := make([]int64, 0, len(accounts)) + sessionTimeouts := make(map[int64]time.Duration) + var concurrencyMax, sessionsMax, rpmMax int + + for i := range accounts { + acc := &accounts[i] + accountIDs = append(accountIDs, acc.ID) + concurrencyMax += acc.Concurrency + + if ms := acc.GetMaxSessions(); ms > 0 { + sessionsMax += ms + timeout := time.Duration(acc.GetSessionIdleTimeoutMinutes()) * time.Minute + if timeout <= 0 { + timeout = 5 * time.Minute + } + sessionTimeouts[acc.ID] = timeout + } + + if rpm := acc.GetBaseRPM(); rpm > 0 { + rpmMax += rpm + } + } + + // Batch query runtime data from Redis + concurrencyMap, _ := s.concurrencyService.GetAccountConcurrencyBatch(ctx, accountIDs) + + var sessionsMap map[int64]int + if sessionsMax > 0 && s.sessionLimitCache != nil { + sessionsMap, _ = s.sessionLimitCache.GetActiveSessionCountBatch(ctx, accountIDs, sessionTimeouts) + } + + var rpmMap map[int64]int + if rpmMax > 0 && s.rpmCache != nil { + rpmMap, _ = s.rpmCache.GetRPMBatch(ctx, accountIDs) + } + + // Aggregate + var concurrencyUsed, sessionsUsed, rpmUsed int + for _, id := range accountIDs { + concurrencyUsed += concurrencyMap[id] + if sessionsMap != nil { + sessionsUsed += sessionsMap[id] + } + if rpmMap != nil { + rpmUsed += rpmMap[id] + } + } + + return GroupCapacitySummary{ + ConcurrencyUsed: concurrencyUsed, + ConcurrencyMax: concurrencyMax, + SessionsUsed: sessionsUsed, + SessionsMax: sessionsMax, + RPMUsed: rpmUsed, + RPMMax: rpmMax, + }, nil +} diff --git a/backend/internal/service/group_service.go b/backend/internal/service/group_service.go index 22a67eda..87174e03 100644 --- a/backend/internal/service/group_service.go +++ b/backend/internal/service/group_service.go @@ -27,7 +27,7 @@ type GroupRepository interface { ListActiveByPlatform(ctx context.Context, platform string) ([]Group, error) ExistsByName(ctx context.Context, name string) (bool, error) - GetAccountCount(ctx context.Context, groupID int64) (int64, error) + GetAccountCount(ctx context.Context, groupID int64) (total int64, active int64, err error) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) // GetAccountIDsByGroupIDs 获取多个分组的所有账号 ID(去重) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) @@ -202,7 +202,7 @@ func (s *GroupService) GetStats(ctx context.Context, id int64) (map[string]any, } // 获取账号数量 - accountCount, err := s.groupRepo.GetAccountCount(ctx, id) + accountCount, _, err := s.groupRepo.GetAccountCount(ctx, id) if err != nil { return nil, fmt.Errorf("get account count: %w", err) } diff --git a/backend/internal/service/identity_service.go b/backend/internal/service/identity_service.go index f6a94d15..8d464a8b 100644 --- a/backend/internal/service/identity_service.go +++ b/backend/internal/service/identity_service.go @@ -19,10 +19,6 @@ import ( // 预编译正则表达式(避免每次调用重新编译) var ( - // 匹配 user_id 格式: - // 旧格式: user_{64位hex}_account__session_{uuid} (account 后无 UUID) - // 新格式: user_{64位hex}_account_{uuid}_session_{uuid} (account 后有 UUID) - userIDRegex = regexp.MustCompile(`^user_[a-f0-9]{64}_account_([a-f0-9-]*)_session_([a-f0-9-]{36})$`) // 匹配 User-Agent 版本号: xxx/x.y.z userAgentVersionRegex = regexp.MustCompile(`/(\d+)\.(\d+)\.(\d+)`) ) @@ -209,12 +205,12 @@ func (s *IdentityService) ApplyFingerprint(req *http.Request, fp *Fingerprint) { } // RewriteUserID 重写body中的metadata.user_id -// 输入格式:user_{clientId}_account__session_{sessionUUID} -// 输出格式:user_{cachedClientID}_account_{accountUUID}_session_{newHash} +// 支持旧拼接格式和新 JSON 格式的 user_id 解析, +// 根据 fingerprintUA 版本选择输出格式。 // // 重要:此函数使用 json.RawMessage 保留其他字段的原始字节, // 避免重新序列化导致 thinking 块等内容被修改。 -func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUID, cachedClientID string) ([]byte, error) { +func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUID, cachedClientID, fingerprintUA string) ([]byte, error) { if len(body) == 0 || accountUUID == "" || cachedClientID == "" { return body, nil } @@ -241,24 +237,21 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI return body, nil } - // 匹配格式: - // 旧格式: user_{64位hex}_account__session_{uuid} - // 新格式: user_{64位hex}_account_{uuid}_session_{uuid} - matches := userIDRegex.FindStringSubmatch(userID) - if matches == nil { + // 解析 user_id(兼容旧拼接格式和新 JSON 格式) + parsed := ParseMetadataUserID(userID) + if parsed == nil { return body, nil } - // matches[1] = account UUID (可能为空), matches[2] = session UUID - sessionTail := matches[2] // 原始session UUID + sessionTail := parsed.SessionID // 原始session UUID // 生成新的session hash: SHA256(accountID::sessionTail) -> UUID格式 seed := fmt.Sprintf("%d::%s", accountID, sessionTail) newSessionHash := generateUUIDFromSeed(seed) - // 构建新的user_id - // 格式: user_{cachedClientID}_account_{account_uuid}_session_{newSessionHash} - newUserID := fmt.Sprintf("user_%s_account_%s_session_%s", cachedClientID, accountUUID, newSessionHash) + // 根据客户端版本选择输出格式 + version := ExtractCLIVersion(fingerprintUA) + newUserID := FormatMetadataUserID(cachedClientID, accountUUID, newSessionHash, version) metadata["user_id"] = newUserID @@ -278,9 +271,9 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI // // 重要:此函数使用 json.RawMessage 保留其他字段的原始字节, // 避免重新序列化导致 thinking 块等内容被修改。 -func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []byte, account *Account, accountUUID, cachedClientID string) ([]byte, error) { +func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []byte, account *Account, accountUUID, cachedClientID, fingerprintUA string) ([]byte, error) { // 先执行常规的 RewriteUserID 逻辑 - newBody, err := s.RewriteUserID(body, account.ID, accountUUID, cachedClientID) + newBody, err := s.RewriteUserID(body, account.ID, accountUUID, cachedClientID, fingerprintUA) if err != nil { return newBody, err } @@ -312,10 +305,9 @@ func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []b return newBody, nil } - // 查找 _session_ 的位置,替换其后的内容 - const sessionMarker = "_session_" - idx := strings.LastIndex(userID, sessionMarker) - if idx == -1 { + // 解析已重写的 user_id + uidParsed := ParseMetadataUserID(userID) + if uidParsed == nil { return newBody, nil } @@ -337,8 +329,9 @@ func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []b logger.LegacyPrintf("service.identity", "Warning: failed to set masked session ID for account %d: %v", account.ID, err) } - // 替换 session 部分:保留 _session_ 之前的内容,替换之后的内容 - newUserID := userID[:idx+len(sessionMarker)] + maskedSessionID + // 用 FormatMetadataUserID 重建(保持与 RewriteUserID 相同的格式) + version := ExtractCLIVersion(fingerprintUA) + newUserID := FormatMetadataUserID(uidParsed.DeviceID, uidParsed.AccountUUID, maskedSessionID, version) slog.Debug("session_id_masking_applied", "account_id", account.ID, diff --git a/backend/internal/service/metadata_userid.go b/backend/internal/service/metadata_userid.go new file mode 100644 index 00000000..ee1ef64a --- /dev/null +++ b/backend/internal/service/metadata_userid.go @@ -0,0 +1,104 @@ +package service + +import ( + "encoding/json" + "regexp" + "strings" +) + +// NewMetadataFormatMinVersion is the minimum Claude Code version that uses +// JSON-formatted metadata.user_id instead of the legacy concatenated string. +const NewMetadataFormatMinVersion = "2.1.78" + +// ParsedUserID represents the components extracted from a metadata.user_id value. +type ParsedUserID struct { + DeviceID string // 64-char hex (or arbitrary client id) + AccountUUID string // may be empty + SessionID string // UUID + IsNewFormat bool // true if the original was JSON format +} + +// legacyUserIDRegex matches the legacy user_id format: +// +// user_{64hex}_account_{optional_uuid}_session_{uuid} +var legacyUserIDRegex = regexp.MustCompile(`^user_([a-fA-F0-9]{64})_account_([a-fA-F0-9-]*)_session_([a-fA-F0-9-]{36})$`) + +// jsonUserID is the JSON structure for the new metadata.user_id format. +type jsonUserID struct { + DeviceID string `json:"device_id"` + AccountUUID string `json:"account_uuid"` + SessionID string `json:"session_id"` +} + +// ParseMetadataUserID parses a metadata.user_id string in either format. +// Returns nil if the input cannot be parsed. +func ParseMetadataUserID(raw string) *ParsedUserID { + raw = strings.TrimSpace(raw) + if raw == "" { + return nil + } + + // Try JSON format first (starts with '{') + if raw[0] == '{' { + var j jsonUserID + if err := json.Unmarshal([]byte(raw), &j); err != nil { + return nil + } + if j.DeviceID == "" || j.SessionID == "" { + return nil + } + return &ParsedUserID{ + DeviceID: j.DeviceID, + AccountUUID: j.AccountUUID, + SessionID: j.SessionID, + IsNewFormat: true, + } + } + + // Try legacy format + matches := legacyUserIDRegex.FindStringSubmatch(raw) + if matches == nil { + return nil + } + return &ParsedUserID{ + DeviceID: matches[1], + AccountUUID: matches[2], + SessionID: matches[3], + IsNewFormat: false, + } +} + +// FormatMetadataUserID builds a metadata.user_id string in the format +// appropriate for the given CLI version. Components are the rewritten values +// (not necessarily the originals). +func FormatMetadataUserID(deviceID, accountUUID, sessionID, uaVersion string) string { + if IsNewMetadataFormatVersion(uaVersion) { + b, _ := json.Marshal(jsonUserID{ + DeviceID: deviceID, + AccountUUID: accountUUID, + SessionID: sessionID, + }) + return string(b) + } + // Legacy format + return "user_" + deviceID + "_account_" + accountUUID + "_session_" + sessionID +} + +// IsNewMetadataFormatVersion returns true if the given CLI version uses the +// new JSON metadata.user_id format (>= 2.1.78). +func IsNewMetadataFormatVersion(version string) bool { + if version == "" { + return false + } + return CompareVersions(version, NewMetadataFormatMinVersion) >= 0 +} + +// ExtractCLIVersion extracts the Claude Code version from a User-Agent string. +// Returns "" if the UA doesn't match the expected pattern. +func ExtractCLIVersion(ua string) string { + matches := claudeCodeUAVersionPattern.FindStringSubmatch(ua) + if len(matches) >= 2 { + return matches[1] + } + return "" +} diff --git a/backend/internal/service/metadata_userid_test.go b/backend/internal/service/metadata_userid_test.go new file mode 100644 index 00000000..40ad7087 --- /dev/null +++ b/backend/internal/service/metadata_userid_test.go @@ -0,0 +1,183 @@ +//go:build unit + +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +// ============ ParseMetadataUserID Tests ============ + +func TestParseMetadataUserID_LegacyFormat_WithoutAccountUUID(t *testing.T) { + raw := "user_a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2_account__session_123e4567-e89b-12d3-a456-426614174000" + parsed := ParseMetadataUserID(raw) + require.NotNil(t, parsed) + require.Equal(t, "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2", parsed.DeviceID) + require.Equal(t, "", parsed.AccountUUID) + require.Equal(t, "123e4567-e89b-12d3-a456-426614174000", parsed.SessionID) + require.False(t, parsed.IsNewFormat) +} + +func TestParseMetadataUserID_LegacyFormat_WithAccountUUID(t *testing.T) { + raw := "user_a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2_account_550e8400-e29b-41d4-a716-446655440000_session_123e4567-e89b-12d3-a456-426614174000" + parsed := ParseMetadataUserID(raw) + require.NotNil(t, parsed) + require.Equal(t, "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2", parsed.DeviceID) + require.Equal(t, "550e8400-e29b-41d4-a716-446655440000", parsed.AccountUUID) + require.Equal(t, "123e4567-e89b-12d3-a456-426614174000", parsed.SessionID) + require.False(t, parsed.IsNewFormat) +} + +func TestParseMetadataUserID_JSONFormat_WithoutAccountUUID(t *testing.T) { + raw := `{"device_id":"d61f76d0aabbccdd00112233445566778899aabbccddeeff0011223344556677","account_uuid":"","session_id":"c72554f2-1234-5678-abcd-123456789abc"}` + parsed := ParseMetadataUserID(raw) + require.NotNil(t, parsed) + require.Equal(t, "d61f76d0aabbccdd00112233445566778899aabbccddeeff0011223344556677", parsed.DeviceID) + require.Equal(t, "", parsed.AccountUUID) + require.Equal(t, "c72554f2-1234-5678-abcd-123456789abc", parsed.SessionID) + require.True(t, parsed.IsNewFormat) +} + +func TestParseMetadataUserID_JSONFormat_WithAccountUUID(t *testing.T) { + raw := `{"device_id":"d61f76d0aabbccdd00112233445566778899aabbccddeeff0011223344556677","account_uuid":"550e8400-e29b-41d4-a716-446655440000","session_id":"c72554f2-1234-5678-abcd-123456789abc"}` + parsed := ParseMetadataUserID(raw) + require.NotNil(t, parsed) + require.Equal(t, "d61f76d0aabbccdd00112233445566778899aabbccddeeff0011223344556677", parsed.DeviceID) + require.Equal(t, "550e8400-e29b-41d4-a716-446655440000", parsed.AccountUUID) + require.Equal(t, "c72554f2-1234-5678-abcd-123456789abc", parsed.SessionID) + require.True(t, parsed.IsNewFormat) +} + +func TestParseMetadataUserID_InvalidInputs(t *testing.T) { + tests := []struct { + name string + raw string + }{ + {"empty string", ""}, + {"whitespace only", " "}, + {"random text", "not-a-valid-user-id"}, + {"partial legacy format", "session_123e4567-e89b-12d3-a456-426614174000"}, + {"invalid JSON", `{"device_id":}`}, + {"JSON missing device_id", `{"account_uuid":"","session_id":"c72554f2-1234-5678-abcd-123456789abc"}`}, + {"JSON missing session_id", `{"device_id":"d61f76d0aabbccdd00112233445566778899aabbccddeeff0011223344556677","account_uuid":""}`}, + {"JSON empty device_id", `{"device_id":"","account_uuid":"","session_id":"c72554f2-1234-5678-abcd-123456789abc"}`}, + {"JSON empty session_id", `{"device_id":"d61f76d0aabbccdd00112233445566778899aabbccddeeff0011223344556677","account_uuid":"","session_id":""}`}, + {"legacy format short hex", "user_a1b2c3d4_account__session_123e4567-e89b-12d3-a456-426614174000"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Nil(t, ParseMetadataUserID(tt.raw), "should return nil for: %s", tt.raw) + }) + } +} + +func TestParseMetadataUserID_HexCaseInsensitive(t *testing.T) { + // Legacy format should accept both upper and lower case hex + rawUpper := "user_A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2_account__session_123e4567-e89b-12d3-a456-426614174000" + parsed := ParseMetadataUserID(rawUpper) + require.NotNil(t, parsed, "legacy format should accept uppercase hex") + require.Equal(t, "A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2", parsed.DeviceID) +} + +// ============ FormatMetadataUserID Tests ============ + +func TestFormatMetadataUserID_LegacyVersion(t *testing.T) { + result := FormatMetadataUserID("deadbeef"+"00112233445566778899aabbccddeeff0011223344556677", "acc-uuid", "sess-uuid", "2.1.77") + require.Equal(t, "user_deadbeef00112233445566778899aabbccddeeff0011223344556677_account_acc-uuid_session_sess-uuid", result) +} + +func TestFormatMetadataUserID_NewVersion(t *testing.T) { + result := FormatMetadataUserID("deadbeef"+"00112233445566778899aabbccddeeff0011223344556677", "acc-uuid", "sess-uuid", "2.1.78") + require.Equal(t, `{"device_id":"deadbeef00112233445566778899aabbccddeeff0011223344556677","account_uuid":"acc-uuid","session_id":"sess-uuid"}`, result) +} + +func TestFormatMetadataUserID_EmptyVersion_Legacy(t *testing.T) { + result := FormatMetadataUserID("deadbeef"+"00112233445566778899aabbccddeeff0011223344556677", "", "sess-uuid", "") + require.Equal(t, "user_deadbeef00112233445566778899aabbccddeeff0011223344556677_account__session_sess-uuid", result) +} + +func TestFormatMetadataUserID_EmptyAccountUUID(t *testing.T) { + // Legacy format with empty account UUID → double underscore + result := FormatMetadataUserID("deadbeef"+"00112233445566778899aabbccddeeff0011223344556677", "", "sess-uuid", "2.1.22") + require.Contains(t, result, "_account__session_") + + // New format with empty account UUID → empty string in JSON + result = FormatMetadataUserID("deadbeef"+"00112233445566778899aabbccddeeff0011223344556677", "", "sess-uuid", "2.1.78") + require.Contains(t, result, `"account_uuid":""`) +} + +// ============ IsNewMetadataFormatVersion Tests ============ + +func TestIsNewMetadataFormatVersion(t *testing.T) { + tests := []struct { + version string + want bool + }{ + {"", false}, + {"2.1.77", false}, + {"2.1.78", true}, + {"2.1.79", true}, + {"2.2.0", true}, + {"3.0.0", true}, + {"2.0.100", false}, + {"1.9.99", false}, + } + for _, tt := range tests { + t.Run(tt.version, func(t *testing.T) { + require.Equal(t, tt.want, IsNewMetadataFormatVersion(tt.version)) + }) + } +} + +// ============ Round-trip Tests ============ + +func TestParseFormat_RoundTrip_Legacy(t *testing.T) { + deviceID := "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2" + accountUUID := "550e8400-e29b-41d4-a716-446655440000" + sessionID := "123e4567-e89b-12d3-a456-426614174000" + + formatted := FormatMetadataUserID(deviceID, accountUUID, sessionID, "2.1.22") + parsed := ParseMetadataUserID(formatted) + require.NotNil(t, parsed) + require.Equal(t, deviceID, parsed.DeviceID) + require.Equal(t, accountUUID, parsed.AccountUUID) + require.Equal(t, sessionID, parsed.SessionID) + require.False(t, parsed.IsNewFormat) +} + +func TestParseFormat_RoundTrip_JSON(t *testing.T) { + deviceID := "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2" + accountUUID := "550e8400-e29b-41d4-a716-446655440000" + sessionID := "123e4567-e89b-12d3-a456-426614174000" + + formatted := FormatMetadataUserID(deviceID, accountUUID, sessionID, "2.1.78") + parsed := ParseMetadataUserID(formatted) + require.NotNil(t, parsed) + require.Equal(t, deviceID, parsed.DeviceID) + require.Equal(t, accountUUID, parsed.AccountUUID) + require.Equal(t, sessionID, parsed.SessionID) + require.True(t, parsed.IsNewFormat) +} + +func TestParseFormat_RoundTrip_EmptyAccountUUID(t *testing.T) { + deviceID := "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2" + sessionID := "123e4567-e89b-12d3-a456-426614174000" + + // Legacy round-trip with empty account UUID + formatted := FormatMetadataUserID(deviceID, "", sessionID, "2.1.22") + parsed := ParseMetadataUserID(formatted) + require.NotNil(t, parsed) + require.Equal(t, deviceID, parsed.DeviceID) + require.Equal(t, "", parsed.AccountUUID) + require.Equal(t, sessionID, parsed.SessionID) + + // JSON round-trip with empty account UUID + formatted = FormatMetadataUserID(deviceID, "", sessionID, "2.1.78") + parsed = ParseMetadataUserID(formatted) + require.NotNil(t, parsed) + require.Equal(t, deviceID, parsed.DeviceID) + require.Equal(t, "", parsed.AccountUUID) + require.Equal(t, sessionID, parsed.SessionID) +} diff --git a/backend/internal/service/sora_quota_service_test.go b/backend/internal/service/sora_quota_service_test.go index 040e427d..da8efe77 100644 --- a/backend/internal/service/sora_quota_service_test.go +++ b/backend/internal/service/sora_quota_service_test.go @@ -52,8 +52,8 @@ func (r *stubGroupRepoForQuota) ListActiveByPlatform(context.Context, string) ([ func (r *stubGroupRepoForQuota) ExistsByName(context.Context, string) (bool, error) { return false, nil } -func (r *stubGroupRepoForQuota) GetAccountCount(context.Context, int64) (int64, error) { - return 0, nil +func (r *stubGroupRepoForQuota) GetAccountCount(context.Context, int64) (int64, int64, error) { + return 0, 0, nil } func (r *stubGroupRepoForQuota) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) { return 0, nil diff --git a/backend/internal/service/subscription_assign_idempotency_test.go b/backend/internal/service/subscription_assign_idempotency_test.go index 0defafba..40bab206 100644 --- a/backend/internal/service/subscription_assign_idempotency_test.go +++ b/backend/internal/service/subscription_assign_idempotency_test.go @@ -40,7 +40,7 @@ func (groupRepoNoop) ListActiveByPlatform(context.Context, string) ([]Group, err func (groupRepoNoop) ExistsByName(context.Context, string) (bool, error) { panic("unexpected ExistsByName call") } -func (groupRepoNoop) GetAccountCount(context.Context, int64) (int64, error) { +func (groupRepoNoop) GetAccountCount(context.Context, int64) (int64, int64, error) { panic("unexpected GetAccountCount call") } func (groupRepoNoop) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) { @@ -92,7 +92,7 @@ func (userSubRepoNoop) ListActiveByUserID(context.Context, int64) ([]UserSubscri func (userSubRepoNoop) ListByGroupID(context.Context, int64, pagination.PaginationParams) ([]UserSubscription, *pagination.PaginationResult, error) { panic("unexpected ListByGroupID call") } -func (userSubRepoNoop) List(context.Context, pagination.PaginationParams, *int64, *int64, string, string, string) ([]UserSubscription, *pagination.PaginationResult, error) { +func (userSubRepoNoop) List(context.Context, pagination.PaginationParams, *int64, *int64, string, string, string, string) ([]UserSubscription, *pagination.PaginationResult, error) { panic("unexpected List call") } func (userSubRepoNoop) ExistsByUserIDAndGroupID(context.Context, int64, int64) (bool, error) { diff --git a/backend/internal/service/subscription_service.go b/backend/internal/service/subscription_service.go index af548509..f0a5540e 100644 --- a/backend/internal/service/subscription_service.go +++ b/backend/internal/service/subscription_service.go @@ -634,9 +634,9 @@ func (s *SubscriptionService) ListGroupSubscriptions(ctx context.Context, groupI } // List 获取所有订阅(分页,支持筛选和排序) -func (s *SubscriptionService) List(ctx context.Context, page, pageSize int, userID, groupID *int64, status, sortBy, sortOrder string) ([]UserSubscription, *pagination.PaginationResult, error) { +func (s *SubscriptionService) List(ctx context.Context, page, pageSize int, userID, groupID *int64, status, platform, sortBy, sortOrder string) ([]UserSubscription, *pagination.PaginationResult, error) { params := pagination.PaginationParams{Page: page, PageSize: pageSize} - subs, pag, err := s.userSubRepo.List(ctx, params, userID, groupID, status, sortBy, sortOrder) + subs, pag, err := s.userSubRepo.List(ctx, params, userID, groupID, status, platform, sortBy, sortOrder) if err != nil { return nil, nil, err } diff --git a/backend/internal/service/user_subscription_port.go b/backend/internal/service/user_subscription_port.go index 2dfc8d02..4484fae8 100644 --- a/backend/internal/service/user_subscription_port.go +++ b/backend/internal/service/user_subscription_port.go @@ -18,7 +18,7 @@ type UserSubscriptionRepository interface { ListByUserID(ctx context.Context, userID int64) ([]UserSubscription, error) ListActiveByUserID(ctx context.Context, userID int64) ([]UserSubscription, error) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]UserSubscription, *pagination.PaginationResult, error) - List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]UserSubscription, *pagination.PaginationResult, error) + List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, platform, sortBy, sortOrder string) ([]UserSubscription, *pagination.PaginationResult, error) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) ExtendExpiry(ctx context.Context, subscriptionID int64, newExpiresAt time.Time) error diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index 7da72630..a4c667be 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -486,4 +486,5 @@ var ProviderSet = wire.NewSet( ProvideIdempotencyCleanupService, ProvideScheduledTestService, ProvideScheduledTestRunnerService, + NewGroupCapacityService, ) diff --git a/frontend/src/api/admin/groups.ts b/frontend/src/api/admin/groups.ts index 7c2658fa..5885dc6a 100644 --- a/frontend/src/api/admin/groups.ts +++ b/frontend/src/api/admin/groups.ts @@ -218,6 +218,34 @@ export async function batchSetGroupRateMultipliers( return data } +/** + * Get usage summary (today + cumulative cost) for all groups + * @param timezone - IANA timezone string (e.g. "Asia/Shanghai") + * @returns Array of group usage summaries + */ +export async function getUsageSummary( + timezone?: string +): Promise<{ group_id: number; today_cost: number; total_cost: number }[]> { + const { data } = await apiClient.get< + { group_id: number; today_cost: number; total_cost: number }[] + >('/admin/groups/usage-summary', { + params: timezone ? { timezone } : undefined + }) + return data +} + +/** + * Get capacity summary (concurrency/sessions/RPM) for all active groups + */ +export async function getCapacitySummary(): Promise< + { group_id: number; concurrency_used: number; concurrency_max: number; sessions_used: number; sessions_max: number; rpm_used: number; rpm_max: number }[] +> { + const { data } = await apiClient.get< + { group_id: number; concurrency_used: number; concurrency_max: number; sessions_used: number; sessions_max: number; rpm_used: number; rpm_max: number }[] + >('/admin/groups/capacity-summary') + return data +} + export const groupsAPI = { list, getAll, @@ -232,7 +260,9 @@ export const groupsAPI = { getGroupRateMultipliers, clearGroupRateMultipliers, batchSetGroupRateMultipliers, - updateSortOrder + updateSortOrder, + getUsageSummary, + getCapacitySummary } export default groupsAPI diff --git a/frontend/src/api/admin/subscriptions.ts b/frontend/src/api/admin/subscriptions.ts index 7557e3ad..611f67c2 100644 --- a/frontend/src/api/admin/subscriptions.ts +++ b/frontend/src/api/admin/subscriptions.ts @@ -27,6 +27,7 @@ export async function list( status?: 'active' | 'expired' | 'revoked' user_id?: number group_id?: number + platform?: string sort_by?: string sort_order?: 'asc' | 'desc' }, diff --git a/frontend/src/components/account/AccountUsageCell.vue b/frontend/src/components/account/AccountUsageCell.vue index e548be8c..131d82b2 100644 --- a/frontend/src/components/account/AccountUsageCell.vue +++ b/frontend/src/components/account/AccountUsageCell.vue @@ -82,6 +82,7 @@ :utilization="usageInfo.five_hour.utilization" :resets-at="usageInfo.five_hour.resets_at" :window-stats="usageInfo.five_hour.window_stats" + :show-now-when-idle="true" color="indigo" /> diff --git a/frontend/src/components/account/EditAccountModal.vue b/frontend/src/components/account/EditAccountModal.vue index c2f2f7d2..5f3da1b7 100644 --- a/frontend/src/components/account/EditAccountModal.vue +++ b/frontend/src/components/account/EditAccountModal.vue @@ -1980,271 +1980,281 @@ const normalizePoolModeRetryCount = (value: number) => { return normalized } -watch( - () => props.account, - (newAccount) => { - if (newAccount) { - antigravityMixedChannelConfirmed.value = false - showMixedChannelWarning.value = false - mixedChannelWarningDetails.value = null - mixedChannelWarningRawMessage.value = '' - mixedChannelWarningAction.value = null - form.name = newAccount.name - form.notes = newAccount.notes || '' - form.proxy_id = newAccount.proxy_id - form.concurrency = newAccount.concurrency - form.load_factor = newAccount.load_factor ?? null - form.priority = newAccount.priority - form.rate_multiplier = newAccount.rate_multiplier ?? 1 - form.status = (newAccount.status === 'active' || newAccount.status === 'inactive' || newAccount.status === 'error') - ? newAccount.status - : 'active' - form.group_ids = newAccount.group_ids || [] - form.expires_at = newAccount.expires_at ?? null +const syncFormFromAccount = (newAccount: Account | null) => { + if (!newAccount) { + return + } + antigravityMixedChannelConfirmed.value = false + showMixedChannelWarning.value = false + mixedChannelWarningDetails.value = null + mixedChannelWarningRawMessage.value = '' + mixedChannelWarningAction.value = null + form.name = newAccount.name + form.notes = newAccount.notes || '' + form.proxy_id = newAccount.proxy_id + form.concurrency = newAccount.concurrency + form.load_factor = newAccount.load_factor ?? null + form.priority = newAccount.priority + form.rate_multiplier = newAccount.rate_multiplier ?? 1 + form.status = (newAccount.status === 'active' || newAccount.status === 'inactive' || newAccount.status === 'error') + ? newAccount.status + : 'active' + form.group_ids = newAccount.group_ids || [] + form.expires_at = newAccount.expires_at ?? null - // Load intercept warmup requests setting (applies to all account types) - const credentials = newAccount.credentials as Record | undefined - interceptWarmupRequests.value = credentials?.intercept_warmup_requests === true - autoPauseOnExpired.value = newAccount.auto_pause_on_expired === true + // Load intercept warmup requests setting (applies to all account types) + const credentials = newAccount.credentials as Record | undefined + interceptWarmupRequests.value = credentials?.intercept_warmup_requests === true + autoPauseOnExpired.value = newAccount.auto_pause_on_expired === true - // Load mixed scheduling setting (only for antigravity accounts) - mixedScheduling.value = false - allowOverages.value = false - const extra = newAccount.extra as Record | undefined - mixedScheduling.value = extra?.mixed_scheduling === true - allowOverages.value = extra?.allow_overages === true + // Load mixed scheduling setting (only for antigravity accounts) + mixedScheduling.value = false + allowOverages.value = false + const extra = newAccount.extra as Record | undefined + mixedScheduling.value = extra?.mixed_scheduling === true + allowOverages.value = extra?.allow_overages === true - // Load OpenAI passthrough toggle (OpenAI OAuth/API Key) - openaiPassthroughEnabled.value = false - openaiOAuthResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF - openaiAPIKeyResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF - codexCLIOnlyEnabled.value = false - anthropicPassthroughEnabled.value = false - if (newAccount.platform === 'openai' && (newAccount.type === 'oauth' || newAccount.type === 'apikey')) { - openaiPassthroughEnabled.value = extra?.openai_passthrough === true || extra?.openai_oauth_passthrough === true - openaiOAuthResponsesWebSocketV2Mode.value = resolveOpenAIWSModeFromExtra(extra, { - modeKey: 'openai_oauth_responses_websockets_v2_mode', - enabledKey: 'openai_oauth_responses_websockets_v2_enabled', - fallbackEnabledKeys: ['responses_websockets_v2_enabled', 'openai_ws_enabled'], - defaultMode: OPENAI_WS_MODE_OFF - }) - openaiAPIKeyResponsesWebSocketV2Mode.value = resolveOpenAIWSModeFromExtra(extra, { - modeKey: 'openai_apikey_responses_websockets_v2_mode', - enabledKey: 'openai_apikey_responses_websockets_v2_enabled', - fallbackEnabledKeys: ['responses_websockets_v2_enabled', 'openai_ws_enabled'], - defaultMode: OPENAI_WS_MODE_OFF - }) - if (newAccount.type === 'oauth') { - codexCLIOnlyEnabled.value = extra?.codex_cli_only === true - } - } - if (newAccount.platform === 'anthropic' && newAccount.type === 'apikey') { - anthropicPassthroughEnabled.value = extra?.anthropic_passthrough === true - } + // Load OpenAI passthrough toggle (OpenAI OAuth/API Key) + openaiPassthroughEnabled.value = false + openaiOAuthResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF + openaiAPIKeyResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF + codexCLIOnlyEnabled.value = false + anthropicPassthroughEnabled.value = false + if (newAccount.platform === 'openai' && (newAccount.type === 'oauth' || newAccount.type === 'apikey')) { + openaiPassthroughEnabled.value = extra?.openai_passthrough === true || extra?.openai_oauth_passthrough === true + openaiOAuthResponsesWebSocketV2Mode.value = resolveOpenAIWSModeFromExtra(extra, { + modeKey: 'openai_oauth_responses_websockets_v2_mode', + enabledKey: 'openai_oauth_responses_websockets_v2_enabled', + fallbackEnabledKeys: ['responses_websockets_v2_enabled', 'openai_ws_enabled'], + defaultMode: OPENAI_WS_MODE_OFF + }) + openaiAPIKeyResponsesWebSocketV2Mode.value = resolveOpenAIWSModeFromExtra(extra, { + modeKey: 'openai_apikey_responses_websockets_v2_mode', + enabledKey: 'openai_apikey_responses_websockets_v2_enabled', + fallbackEnabledKeys: ['responses_websockets_v2_enabled', 'openai_ws_enabled'], + defaultMode: OPENAI_WS_MODE_OFF + }) + if (newAccount.type === 'oauth') { + codexCLIOnlyEnabled.value = extra?.codex_cli_only === true + } + } + if (newAccount.platform === 'anthropic' && newAccount.type === 'apikey') { + anthropicPassthroughEnabled.value = extra?.anthropic_passthrough === true + } - // Load quota limit for apikey/bedrock accounts (bedrock quota is also loaded in its own branch above) - if (newAccount.type === 'apikey' || newAccount.type === 'bedrock') { - const quotaVal = extra?.quota_limit as number | undefined - editQuotaLimit.value = (quotaVal && quotaVal > 0) ? quotaVal : null - const dailyVal = extra?.quota_daily_limit as number | undefined - editQuotaDailyLimit.value = (dailyVal && dailyVal > 0) ? dailyVal : null - const weeklyVal = extra?.quota_weekly_limit as number | undefined - editQuotaWeeklyLimit.value = (weeklyVal && weeklyVal > 0) ? weeklyVal : null - // Load quota reset mode config - editDailyResetMode.value = (extra?.quota_daily_reset_mode as 'rolling' | 'fixed') || null - editDailyResetHour.value = (extra?.quota_daily_reset_hour as number) ?? null - editWeeklyResetMode.value = (extra?.quota_weekly_reset_mode as 'rolling' | 'fixed') || null - editWeeklyResetDay.value = (extra?.quota_weekly_reset_day as number) ?? null - editWeeklyResetHour.value = (extra?.quota_weekly_reset_hour as number) ?? null - editResetTimezone.value = (extra?.quota_reset_timezone as string) || null + // Load quota limit for apikey/bedrock accounts (bedrock quota is also loaded in its own branch above) + if (newAccount.type === 'apikey' || newAccount.type === 'bedrock') { + const quotaVal = extra?.quota_limit as number | undefined + editQuotaLimit.value = (quotaVal && quotaVal > 0) ? quotaVal : null + const dailyVal = extra?.quota_daily_limit as number | undefined + editQuotaDailyLimit.value = (dailyVal && dailyVal > 0) ? dailyVal : null + const weeklyVal = extra?.quota_weekly_limit as number | undefined + editQuotaWeeklyLimit.value = (weeklyVal && weeklyVal > 0) ? weeklyVal : null + // Load quota reset mode config + editDailyResetMode.value = (extra?.quota_daily_reset_mode as 'rolling' | 'fixed') || null + editDailyResetHour.value = (extra?.quota_daily_reset_hour as number) ?? null + editWeeklyResetMode.value = (extra?.quota_weekly_reset_mode as 'rolling' | 'fixed') || null + editWeeklyResetDay.value = (extra?.quota_weekly_reset_day as number) ?? null + editWeeklyResetHour.value = (extra?.quota_weekly_reset_hour as number) ?? null + editResetTimezone.value = (extra?.quota_reset_timezone as string) || null + } else { + editQuotaLimit.value = null + editQuotaDailyLimit.value = null + editQuotaWeeklyLimit.value = null + editDailyResetMode.value = null + editDailyResetHour.value = null + editWeeklyResetMode.value = null + editWeeklyResetDay.value = null + editWeeklyResetHour.value = null + editResetTimezone.value = null + } + + // Load antigravity model mapping (Antigravity 只支持映射模式) + if (newAccount.platform === 'antigravity') { + const credentials = newAccount.credentials as Record | undefined + + // Antigravity 始终使用映射模式 + antigravityModelRestrictionMode.value = 'mapping' + antigravityWhitelistModels.value = [] + + // 从 model_mapping 读取映射配置 + const rawAgMapping = credentials?.model_mapping as Record | undefined + if (rawAgMapping && typeof rawAgMapping === 'object') { + const entries = Object.entries(rawAgMapping) + // 无论是白名单样式(key===value)还是真正的映射,都统一转换为映射列表 + antigravityModelMappings.value = entries.map(([from, to]) => ({ from, to })) + } else { + // 兼容旧数据:从 model_whitelist 读取,转换为映射格式 + const rawWhitelist = credentials?.model_whitelist + if (Array.isArray(rawWhitelist) && rawWhitelist.length > 0) { + antigravityModelMappings.value = rawWhitelist + .map((v) => String(v).trim()) + .filter((v) => v.length > 0) + .map((m) => ({ from: m, to: m })) } else { - editQuotaLimit.value = null - editQuotaDailyLimit.value = null - editQuotaWeeklyLimit.value = null - editDailyResetMode.value = null - editDailyResetHour.value = null - editWeeklyResetMode.value = null - editWeeklyResetDay.value = null - editWeeklyResetHour.value = null - editResetTimezone.value = null - } - - // Load antigravity model mapping (Antigravity 只支持映射模式) - if (newAccount.platform === 'antigravity') { - const credentials = newAccount.credentials as Record | undefined - - // Antigravity 始终使用映射模式 - antigravityModelRestrictionMode.value = 'mapping' - antigravityWhitelistModels.value = [] - - // 从 model_mapping 读取映射配置 - const rawAgMapping = credentials?.model_mapping as Record | undefined - if (rawAgMapping && typeof rawAgMapping === 'object') { - const entries = Object.entries(rawAgMapping) - // 无论是白名单样式(key===value)还是真正的映射,都统一转换为映射列表 - antigravityModelMappings.value = entries.map(([from, to]) => ({ from, to })) - } else { - // 兼容旧数据:从 model_whitelist 读取,转换为映射格式 - const rawWhitelist = credentials?.model_whitelist - if (Array.isArray(rawWhitelist) && rawWhitelist.length > 0) { - antigravityModelMappings.value = rawWhitelist - .map((v) => String(v).trim()) - .filter((v) => v.length > 0) - .map((m) => ({ from: m, to: m })) - } else { - antigravityModelMappings.value = [] - } - } - } else { - antigravityModelRestrictionMode.value = 'mapping' - antigravityWhitelistModels.value = [] antigravityModelMappings.value = [] } + } + } else { + antigravityModelRestrictionMode.value = 'mapping' + antigravityWhitelistModels.value = [] + antigravityModelMappings.value = [] + } - // Load quota control settings (Anthropic OAuth/SetupToken only) - loadQuotaControlSettings(newAccount) + // Load quota control settings (Anthropic OAuth/SetupToken only) + loadQuotaControlSettings(newAccount) - loadTempUnschedRules(credentials) + loadTempUnschedRules(credentials) - // Initialize API Key fields for apikey type - if (newAccount.type === 'apikey' && newAccount.credentials) { - const credentials = newAccount.credentials as Record - const platformDefaultUrl = - newAccount.platform === 'openai' || newAccount.platform === 'sora' - ? 'https://api.openai.com' - : newAccount.platform === 'gemini' - ? 'https://generativelanguage.googleapis.com' - : 'https://api.anthropic.com' - editBaseUrl.value = (credentials.base_url as string) || platformDefaultUrl + // Initialize API Key fields for apikey type + if (newAccount.type === 'apikey' && newAccount.credentials) { + const credentials = newAccount.credentials as Record + const platformDefaultUrl = + newAccount.platform === 'openai' || newAccount.platform === 'sora' + ? 'https://api.openai.com' + : newAccount.platform === 'gemini' + ? 'https://generativelanguage.googleapis.com' + : 'https://api.anthropic.com' + editBaseUrl.value = (credentials.base_url as string) || platformDefaultUrl - // Load model mappings and detect mode - const existingMappings = credentials.model_mapping as Record | undefined - if (existingMappings && typeof existingMappings === 'object') { - const entries = Object.entries(existingMappings) + // Load model mappings and detect mode + const existingMappings = credentials.model_mapping as Record | undefined + if (existingMappings && typeof existingMappings === 'object') { + const entries = Object.entries(existingMappings) - // Detect if this is whitelist mode (all from === to) or mapping mode - const isWhitelistMode = entries.length > 0 && entries.every(([from, to]) => from === to) + // Detect if this is whitelist mode (all from === to) or mapping mode + const isWhitelistMode = entries.length > 0 && entries.every(([from, to]) => from === to) - if (isWhitelistMode) { - // Whitelist mode: populate allowedModels - modelRestrictionMode.value = 'whitelist' - allowedModels.value = entries.map(([from]) => from) - modelMappings.value = [] - } else { - // Mapping mode: populate modelMappings - modelRestrictionMode.value = 'mapping' - modelMappings.value = entries.map(([from, to]) => ({ from, to })) - allowedModels.value = [] - } - } else { - // No mappings: default to whitelist mode with empty selection (allow all) - modelRestrictionMode.value = 'whitelist' - modelMappings.value = [] - allowedModels.value = [] - } - - // Load pool mode - poolModeEnabled.value = credentials.pool_mode === true - poolModeRetryCount.value = normalizePoolModeRetryCount( - Number(credentials.pool_mode_retry_count ?? DEFAULT_POOL_MODE_RETRY_COUNT) - ) - - // Load custom error codes - customErrorCodesEnabled.value = credentials.custom_error_codes_enabled === true - const existingErrorCodes = credentials.custom_error_codes as number[] | undefined - if (existingErrorCodes && Array.isArray(existingErrorCodes)) { - selectedErrorCodes.value = [...existingErrorCodes] - } else { - selectedErrorCodes.value = [] - } - } else if (newAccount.type === 'bedrock' && newAccount.credentials) { - const bedrockCreds = newAccount.credentials as Record - const authMode = (bedrockCreds.auth_mode as string) || 'sigv4' - editBedrockRegion.value = (bedrockCreds.aws_region as string) || '' - editBedrockForceGlobal.value = (bedrockCreds.aws_force_global as string) === 'true' - - if (authMode === 'apikey') { - editBedrockApiKeyValue.value = '' - } else { - editBedrockAccessKeyId.value = (bedrockCreds.aws_access_key_id as string) || '' - editBedrockSecretAccessKey.value = '' - editBedrockSessionToken.value = '' - } - - // Load pool mode for bedrock - poolModeEnabled.value = bedrockCreds.pool_mode === true - const retryCount = bedrockCreds.pool_mode_retry_count - poolModeRetryCount.value = (typeof retryCount === 'number' && retryCount >= 0) ? retryCount : DEFAULT_POOL_MODE_RETRY_COUNT - - // Load quota limits for bedrock - const bedrockExtra = (newAccount.extra as Record) || {} - editQuotaLimit.value = typeof bedrockExtra.quota_limit === 'number' ? bedrockExtra.quota_limit : null - editQuotaDailyLimit.value = typeof bedrockExtra.quota_daily_limit === 'number' ? bedrockExtra.quota_daily_limit : null - editQuotaWeeklyLimit.value = typeof bedrockExtra.quota_weekly_limit === 'number' ? bedrockExtra.quota_weekly_limit : null - - // Load model mappings for bedrock - const existingMappings = bedrockCreds.model_mapping as Record | undefined - if (existingMappings && typeof existingMappings === 'object') { - const entries = Object.entries(existingMappings) - const isWhitelistMode = entries.length > 0 && entries.every(([from, to]) => from === to) - if (isWhitelistMode) { - modelRestrictionMode.value = 'whitelist' - allowedModels.value = entries.map(([from]) => from) - modelMappings.value = [] - } else { - modelRestrictionMode.value = 'mapping' - modelMappings.value = entries.map(([from, to]) => ({ from, to })) - allowedModels.value = [] - } - } else { - modelRestrictionMode.value = 'whitelist' - modelMappings.value = [] - allowedModels.value = [] - } - } else if (newAccount.type === 'upstream' && newAccount.credentials) { - const credentials = newAccount.credentials as Record - editBaseUrl.value = (credentials.base_url as string) || '' + if (isWhitelistMode) { + // Whitelist mode: populate allowedModels + modelRestrictionMode.value = 'whitelist' + allowedModels.value = entries.map(([from]) => from) + modelMappings.value = [] } else { - const platformDefaultUrl = - newAccount.platform === 'openai' || newAccount.platform === 'sora' - ? 'https://api.openai.com' - : newAccount.platform === 'gemini' - ? 'https://generativelanguage.googleapis.com' - : 'https://api.anthropic.com' - editBaseUrl.value = platformDefaultUrl + // Mapping mode: populate modelMappings + modelRestrictionMode.value = 'mapping' + modelMappings.value = entries.map(([from, to]) => ({ from, to })) + allowedModels.value = [] + } + } else { + // No mappings: default to whitelist mode with empty selection (allow all) + modelRestrictionMode.value = 'whitelist' + modelMappings.value = [] + allowedModels.value = [] + } - // Load model mappings for OpenAI OAuth accounts - if (newAccount.platform === 'openai' && newAccount.credentials) { - const oauthCredentials = newAccount.credentials as Record - const existingMappings = oauthCredentials.model_mapping as Record | undefined - if (existingMappings && typeof existingMappings === 'object') { - const entries = Object.entries(existingMappings) - const isWhitelistMode = entries.length > 0 && entries.every(([from, to]) => from === to) - if (isWhitelistMode) { - modelRestrictionMode.value = 'whitelist' - allowedModels.value = entries.map(([from]) => from) - modelMappings.value = [] - } else { - modelRestrictionMode.value = 'mapping' - modelMappings.value = entries.map(([from, to]) => ({ from, to })) - allowedModels.value = [] - } - } else { - modelRestrictionMode.value = 'whitelist' - modelMappings.value = [] - allowedModels.value = [] - } - } else { + // Load pool mode + poolModeEnabled.value = credentials.pool_mode === true + poolModeRetryCount.value = normalizePoolModeRetryCount( + Number(credentials.pool_mode_retry_count ?? DEFAULT_POOL_MODE_RETRY_COUNT) + ) + + // Load custom error codes + customErrorCodesEnabled.value = credentials.custom_error_codes_enabled === true + const existingErrorCodes = credentials.custom_error_codes as number[] | undefined + if (existingErrorCodes && Array.isArray(existingErrorCodes)) { + selectedErrorCodes.value = [...existingErrorCodes] + } else { + selectedErrorCodes.value = [] + } + } else if (newAccount.type === 'bedrock' && newAccount.credentials) { + const bedrockCreds = newAccount.credentials as Record + const authMode = (bedrockCreds.auth_mode as string) || 'sigv4' + editBedrockRegion.value = (bedrockCreds.aws_region as string) || '' + editBedrockForceGlobal.value = (bedrockCreds.aws_force_global as string) === 'true' + + if (authMode === 'apikey') { + editBedrockApiKeyValue.value = '' + } else { + editBedrockAccessKeyId.value = (bedrockCreds.aws_access_key_id as string) || '' + editBedrockSecretAccessKey.value = '' + editBedrockSessionToken.value = '' + } + + // Load pool mode for bedrock + poolModeEnabled.value = bedrockCreds.pool_mode === true + const retryCount = bedrockCreds.pool_mode_retry_count + poolModeRetryCount.value = (typeof retryCount === 'number' && retryCount >= 0) ? retryCount : DEFAULT_POOL_MODE_RETRY_COUNT + + // Load quota limits for bedrock + const bedrockExtra = (newAccount.extra as Record) || {} + editQuotaLimit.value = typeof bedrockExtra.quota_limit === 'number' ? bedrockExtra.quota_limit : null + editQuotaDailyLimit.value = typeof bedrockExtra.quota_daily_limit === 'number' ? bedrockExtra.quota_daily_limit : null + editQuotaWeeklyLimit.value = typeof bedrockExtra.quota_weekly_limit === 'number' ? bedrockExtra.quota_weekly_limit : null + + // Load model mappings for bedrock + const existingMappings = bedrockCreds.model_mapping as Record | undefined + if (existingMappings && typeof existingMappings === 'object') { + const entries = Object.entries(existingMappings) + const isWhitelistMode = entries.length > 0 && entries.every(([from, to]) => from === to) + if (isWhitelistMode) { + modelRestrictionMode.value = 'whitelist' + allowedModels.value = entries.map(([from]) => from) + modelMappings.value = [] + } else { + modelRestrictionMode.value = 'mapping' + modelMappings.value = entries.map(([from, to]) => ({ from, to })) + allowedModels.value = [] + } + } else { + modelRestrictionMode.value = 'whitelist' + modelMappings.value = [] + allowedModels.value = [] + } + } else if (newAccount.type === 'upstream' && newAccount.credentials) { + const credentials = newAccount.credentials as Record + editBaseUrl.value = (credentials.base_url as string) || '' + } else { + const platformDefaultUrl = + newAccount.platform === 'openai' || newAccount.platform === 'sora' + ? 'https://api.openai.com' + : newAccount.platform === 'gemini' + ? 'https://generativelanguage.googleapis.com' + : 'https://api.anthropic.com' + editBaseUrl.value = platformDefaultUrl + + // Load model mappings for OpenAI OAuth accounts + if (newAccount.platform === 'openai' && newAccount.credentials) { + const oauthCredentials = newAccount.credentials as Record + const existingMappings = oauthCredentials.model_mapping as Record | undefined + if (existingMappings && typeof existingMappings === 'object') { + const entries = Object.entries(existingMappings) + const isWhitelistMode = entries.length > 0 && entries.every(([from, to]) => from === to) + if (isWhitelistMode) { modelRestrictionMode.value = 'whitelist' + allowedModels.value = entries.map(([from]) => from) modelMappings.value = [] + } else { + modelRestrictionMode.value = 'mapping' + modelMappings.value = entries.map(([from, to]) => ({ from, to })) allowedModels.value = [] } - poolModeEnabled.value = false - poolModeRetryCount.value = DEFAULT_POOL_MODE_RETRY_COUNT - customErrorCodesEnabled.value = false - selectedErrorCodes.value = [] + } else { + modelRestrictionMode.value = 'whitelist' + modelMappings.value = [] + allowedModels.value = [] } - editApiKey.value = '' + } else { + modelRestrictionMode.value = 'whitelist' + modelMappings.value = [] + allowedModels.value = [] + } + poolModeEnabled.value = false + poolModeRetryCount.value = DEFAULT_POOL_MODE_RETRY_COUNT + customErrorCodesEnabled.value = false + selectedErrorCodes.value = [] + } + editApiKey.value = '' +} + +watch( + [() => props.show, () => props.account], + ([show, newAccount], [wasShow, previousAccount]) => { + if (!show || !newAccount) { + return + } + if (!wasShow || newAccount !== previousAccount) { + syncFormFromAccount(newAccount) } }, { immediate: true } diff --git a/frontend/src/components/account/UsageProgressBar.vue b/frontend/src/components/account/UsageProgressBar.vue index 506071fa..52f0ecbb 100644 --- a/frontend/src/components/account/UsageProgressBar.vue +++ b/frontend/src/components/account/UsageProgressBar.vue @@ -48,7 +48,7 @@ - + {{ formatResetTime }} @@ -68,6 +68,7 @@ const props = defineProps<{ resetsAt?: string | null color: 'indigo' | 'emerald' | 'purple' | 'amber' windowStats?: WindowStats | null + showNowWhenIdle?: boolean }>() const { t } = useI18n() @@ -139,9 +140,20 @@ const displayPercent = computed(() => { return percent > 999 ? '>999%' : `${percent}%` }) +const shouldShowResetTime = computed(() => { + if (props.resetsAt) return true + return Boolean(props.showNowWhenIdle && props.utilization <= 0) +}) + // Format reset time const formatResetTime = computed(() => { + // For rolling windows, when utilization is 0%, treat as immediately available. + if (props.showNowWhenIdle && props.utilization <= 0) { + return '现在' + } + if (!props.resetsAt) return '-' + const date = new Date(props.resetsAt) const diffMs = date.getTime() - now.value.getTime() diff --git a/frontend/src/components/account/__tests__/EditAccountModal.spec.ts b/frontend/src/components/account/__tests__/EditAccountModal.spec.ts new file mode 100644 index 00000000..e3260168 --- /dev/null +++ b/frontend/src/components/account/__tests__/EditAccountModal.spec.ts @@ -0,0 +1,159 @@ +import { describe, expect, it, vi } from 'vitest' +import { defineComponent } from 'vue' +import { mount } from '@vue/test-utils' + +const { updateAccountMock, checkMixedChannelRiskMock } = vi.hoisted(() => ({ + updateAccountMock: vi.fn(), + checkMixedChannelRiskMock: vi.fn() +})) + +vi.mock('@/stores/app', () => ({ + useAppStore: () => ({ + showError: vi.fn(), + showSuccess: vi.fn(), + showInfo: vi.fn() + }) +})) + +vi.mock('@/stores/auth', () => ({ + useAuthStore: () => ({ + isSimpleMode: true + }) +})) + +vi.mock('@/api/admin', () => ({ + adminAPI: { + accounts: { + update: updateAccountMock, + checkMixedChannelRisk: checkMixedChannelRiskMock + } + } +})) + +vi.mock('@/api/admin/accounts', () => ({ + getAntigravityDefaultModelMapping: vi.fn() +})) + +vi.mock('vue-i18n', async () => { + const actual = await vi.importActual('vue-i18n') + return { + ...actual, + useI18n: () => ({ + t: (key: string) => key + }) + } +}) + +import EditAccountModal from '../EditAccountModal.vue' + +const BaseDialogStub = defineComponent({ + name: 'BaseDialog', + props: { + show: { + type: Boolean, + default: false + } + }, + template: '
' +}) + +const ModelWhitelistSelectorStub = defineComponent({ + name: 'ModelWhitelistSelector', + props: { + modelValue: { + type: Array, + default: () => [] + } + }, + emits: ['update:modelValue'], + template: ` +
+ + + {{ Array.isArray(modelValue) ? modelValue.join(',') : '' }} + +
+ ` +}) + +function buildAccount() { + return { + id: 1, + name: 'OpenAI Key', + notes: '', + platform: 'openai', + type: 'apikey', + credentials: { + api_key: 'sk-test', + base_url: 'https://api.openai.com', + model_mapping: { + 'gpt-5.2': 'gpt-5.2' + } + }, + extra: {}, + proxy_id: null, + concurrency: 1, + priority: 1, + rate_multiplier: 1, + status: 'active', + group_ids: [], + expires_at: null, + auto_pause_on_expired: false + } as any +} + +function mountModal(account = buildAccount()) { + return mount(EditAccountModal, { + props: { + show: true, + account, + proxies: [], + groups: [] + }, + global: { + stubs: { + BaseDialog: BaseDialogStub, + Select: true, + Icon: true, + ProxySelector: true, + GroupSelector: true, + ModelWhitelistSelector: ModelWhitelistSelectorStub + } + } + }) +} + +describe('EditAccountModal', () => { + it('reopening the same account rehydrates the OpenAI whitelist from props', async () => { + const account = buildAccount() + updateAccountMock.mockReset() + checkMixedChannelRiskMock.mockReset() + checkMixedChannelRiskMock.mockResolvedValue({ has_risk: false }) + updateAccountMock.mockResolvedValue(account) + + const wrapper = mountModal(account) + + expect(wrapper.get('[data-testid="model-whitelist-value"]').text()).toBe('gpt-5.2') + + await wrapper.get('[data-testid="rewrite-to-snapshot"]').trigger('click') + expect(wrapper.get('[data-testid="model-whitelist-value"]').text()).toBe('gpt-5.2-2025-12-11') + + await wrapper.setProps({ show: false }) + await wrapper.setProps({ show: true }) + + expect(wrapper.get('[data-testid="model-whitelist-value"]').text()).toBe('gpt-5.2') + + await wrapper.get('form#edit-account-form').trigger('submit.prevent') + + expect(updateAccountMock).toHaveBeenCalledTimes(1) + expect(updateAccountMock.mock.calls[0]?.[1]?.credentials?.model_mapping).toEqual({ + 'gpt-5.2': 'gpt-5.2' + }) + }) +}) diff --git a/frontend/src/components/account/__tests__/UsageProgressBar.spec.ts b/frontend/src/components/account/__tests__/UsageProgressBar.spec.ts new file mode 100644 index 00000000..9def052c --- /dev/null +++ b/frontend/src/components/account/__tests__/UsageProgressBar.spec.ts @@ -0,0 +1,69 @@ +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' +import { mount } from '@vue/test-utils' +import UsageProgressBar from '../UsageProgressBar.vue' + +vi.mock('vue-i18n', async () => { + const actual = await vi.importActual('vue-i18n') + return { + ...actual, + useI18n: () => ({ + t: (key: string) => key + }) + } +}) + +describe('UsageProgressBar', () => { + beforeEach(() => { + vi.useFakeTimers() + vi.setSystemTime(new Date('2026-03-17T00:00:00Z')) + }) + + afterEach(() => { + vi.useRealTimers() + }) + + it('showNowWhenIdle=true 且利用率为 0 时显示“现在”', () => { + const wrapper = mount(UsageProgressBar, { + props: { + label: '5h', + utilization: 0, + resetsAt: '2026-03-17T02:30:00Z', + showNowWhenIdle: true, + color: 'indigo' + } + }) + + expect(wrapper.text()).toContain('现在') + expect(wrapper.text()).not.toContain('2h 30m') + }) + + it('showNowWhenIdle=true 但利用率大于 0 时显示倒计时', () => { + const wrapper = mount(UsageProgressBar, { + props: { + label: '7d', + utilization: 12, + resetsAt: '2026-03-17T02:30:00Z', + showNowWhenIdle: true, + color: 'emerald' + } + }) + + expect(wrapper.text()).toContain('2h 30m') + expect(wrapper.text()).not.toContain('现在') + }) + + it('showNowWhenIdle=false 时保持原有倒计时行为', () => { + const wrapper = mount(UsageProgressBar, { + props: { + label: '1d', + utilization: 0, + resetsAt: '2026-03-17T02:30:00Z', + showNowWhenIdle: false, + color: 'indigo' + } + }) + + expect(wrapper.text()).toContain('2h 30m') + expect(wrapper.text()).not.toContain('现在') + }) +}) diff --git a/frontend/src/components/common/GroupCapacityBadge.vue b/frontend/src/components/common/GroupCapacityBadge.vue new file mode 100644 index 00000000..a8580b54 --- /dev/null +++ b/frontend/src/components/common/GroupCapacityBadge.vue @@ -0,0 +1,84 @@ + + + diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index 95d903a0..13a6c1b1 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -1508,6 +1508,8 @@ export default { rateMultiplier: 'Rate Multiplier', type: 'Type', accounts: 'Accounts', + capacity: 'Capacity', + usage: 'Usage', status: 'Status', actions: 'Actions', billingType: 'Billing Type', @@ -1516,6 +1518,12 @@ export default { userNotes: 'Notes', userStatus: 'Status' }, + usageToday: 'Today', + usageTotal: 'Total', + accountsAvailable: 'Avail:', + accountsRateLimited: 'Limited:', + accountsTotal: 'Total:', + accountsUnit: '', rateAndAccounts: '{rate}x rate · {count} accounts', accountsCount: '{count} accounts', form: { @@ -1697,6 +1705,7 @@ export default { revokeSubscription: 'Revoke Subscription', allStatus: 'All Status', allGroups: 'All Groups', + allPlatforms: 'All Platforms', daily: 'Daily', weekly: 'Weekly', monthly: 'Monthly', @@ -1762,7 +1771,37 @@ export default { pleaseSelectGroup: 'Please select a group', validityDaysRequired: 'Please enter a valid number of days (at least 1)', revokeConfirm: - "Are you sure you want to revoke the subscription for '{user}'? This action cannot be undone." + "Are you sure you want to revoke the subscription for '{user}'? This action cannot be undone.", + guide: { + title: 'Subscription Management Guide', + subtitle: 'Subscription mode lets you assign time-based usage quotas to users, with daily/weekly/monthly limits. Follow these steps to get started.', + showGuide: 'Usage Guide', + step1: { + title: 'Create a Subscription Group', + line1: 'Go to "Group Management" page, click "Create Group"', + line2: 'Set billing type to "Subscription", configure daily/weekly/monthly quota limits', + line3: 'Save the group and ensure its status is "Active"', + link: 'Go to Group Management' + }, + step2: { + title: 'Assign Subscription to User', + line1: 'Click the "Assign Subscription" button in the top right', + line2: 'Search for a user by email and select them', + line3: 'Choose a subscription group, set validity days, then click "Assign"' + }, + step3: { + title: 'Manage Existing Subscriptions' + }, + actions: { + adjust: 'Adjust', + adjustDesc: 'Extend or shorten the subscription validity period', + resetQuota: 'Reset Quota', + resetQuotaDesc: 'Reset daily/weekly/monthly usage to zero', + revoke: 'Revoke', + revokeDesc: 'Immediately terminate the subscription (irreversible)' + }, + tip: 'Tip: Only groups with billing type "Subscription" and status "Active" appear in the group dropdown. If no options are available, create one in Group Management first.' + } }, // Accounts diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index 4a1cd058..3cfc7953 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -1564,6 +1564,8 @@ export default { priority: '优先级', apiKeys: 'API 密钥数', accounts: '账号数', + capacity: '容量', + usage: '用量', status: '状态', actions: '操作', billingType: '计费类型', @@ -1572,6 +1574,12 @@ export default { userNotes: '备注', userStatus: '状态' }, + usageToday: '今日', + usageTotal: '累计', + accountsAvailable: '可用:', + accountsRateLimited: '限流:', + accountsTotal: '总量:', + accountsUnit: '个账号', form: { name: '名称', description: '描述', @@ -1777,6 +1785,7 @@ export default { revokeSubscription: '撤销订阅', allStatus: '全部状态', allGroups: '全部分组', + allPlatforms: '全部平台', daily: '每日', weekly: '每周', monthly: '每月', @@ -1841,7 +1850,37 @@ export default { pleaseSelectUser: '请选择用户', pleaseSelectGroup: '请选择分组', validityDaysRequired: '请输入有效的天数(至少1天)', - revokeConfirm: "确定要撤销 '{user}' 的订阅吗?此操作无法撤销。" + revokeConfirm: "确定要撤销 '{user}' 的订阅吗?此操作无法撤销。", + guide: { + title: '订阅管理教程', + subtitle: '订阅模式允许你按时间周期为用户分配使用额度,支持日/周/月配额限制。按照以下步骤即可完成配置。', + showGuide: '使用指南', + step1: { + title: '创建订阅分组', + line1: '前往「分组管理」页面,点击「创建分组」', + line2: '将计费类型设为「订阅」,配置日/周/月额度限制', + line3: '保存分组,确保状态为「正常」', + link: '前往分组管理' + }, + step2: { + title: '分配订阅给用户', + line1: '点击本页右上角「分配订阅」按钮', + line2: '在弹窗中搜索用户邮箱并选择目标用户', + line3: '选择订阅分组、设置有效期天数,点击「分配」' + }, + step3: { + title: '管理已有订阅' + }, + actions: { + adjust: '调整', + adjustDesc: '延长或缩短订阅有效期', + resetQuota: '重置配额', + resetQuotaDesc: '将日/周/月用量归零,重新开始计算', + revoke: '撤销', + revokeDesc: '立即终止该用户的订阅,不可恢复' + }, + tip: '提示:订阅分组下拉列表中只会显示计费类型为「订阅」且状态为「正常」的分组。如果没有可选项,请先到分组管理中创建。' + } }, // Accounts Management diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index b8dd695e..88d6e994 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -411,6 +411,8 @@ export interface AdminGroup extends Group { // 分组下账号数量(仅管理员可见) account_count?: number + active_account_count?: number + rate_limited_account_count?: number // OpenAI Messages 调度配置(仅 openai 平台使用) default_mapped_model?: string diff --git a/frontend/src/views/admin/GroupsView.vue b/frontend/src/views/admin/GroupsView.vue index f8ee39e9..ddd7e672 100644 --- a/frontend/src/views/admin/GroupsView.vue +++ b/frontend/src/views/admin/GroupsView.vue @@ -158,12 +158,51 @@
-