Files
sub2api/backend/internal/service/admin_service.go

1142 lines
34 KiB
Go
Raw Normal View History

2025-12-18 13:50:39 +08:00
package service
import (
"context"
"errors"
"fmt"
"log"
"strings"
2025-12-18 13:50:39 +08:00
"time"
2025-12-24 21:07:21 +08:00
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
2025-12-18 13:50:39 +08:00
)
// AdminService interface defines admin management operations
type AdminService interface {
// User management
ListUsers(ctx context.Context, page, pageSize int, filters UserListFilters) ([]User, int64, error)
GetUser(ctx context.Context, id int64) (*User, error)
CreateUser(ctx context.Context, input *CreateUserInput) (*User, error)
UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error)
2025-12-18 13:50:39 +08:00
DeleteUser(ctx context.Context, id int64) error
UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error)
GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]APIKey, int64, error)
2025-12-20 16:19:40 +08:00
GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error)
2025-12-18 13:50:39 +08:00
// Group management
ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]Group, int64, error)
GetAllGroups(ctx context.Context) ([]Group, error)
GetAllGroupsByPlatform(ctx context.Context, platform string) ([]Group, error)
GetGroup(ctx context.Context, id int64) (*Group, error)
CreateGroup(ctx context.Context, input *CreateGroupInput) (*Group, error)
UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error)
2025-12-18 13:50:39 +08:00
DeleteGroup(ctx context.Context, id int64) error
GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]APIKey, int64, error)
2025-12-18 13:50:39 +08:00
// Account management
ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error)
GetAccount(ctx context.Context, id int64) (*Account, error)
GetAccountsByIDs(ctx context.Context, ids []int64) ([]*Account, error)
CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error)
UpdateAccount(ctx context.Context, id int64, input *UpdateAccountInput) (*Account, error)
2025-12-18 13:50:39 +08:00
DeleteAccount(ctx context.Context, id int64) error
RefreshAccountCredentials(ctx context.Context, id int64) (*Account, error)
ClearAccountError(ctx context.Context, id int64) (*Account, error)
SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error)
BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error)
2025-12-18 13:50:39 +08:00
// Proxy management
ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]Proxy, int64, error)
GetAllProxies(ctx context.Context) ([]Proxy, error)
GetAllProxiesWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error)
GetProxy(ctx context.Context, id int64) (*Proxy, error)
CreateProxy(ctx context.Context, input *CreateProxyInput) (*Proxy, error)
UpdateProxy(ctx context.Context, id int64, input *UpdateProxyInput) (*Proxy, error)
2025-12-18 13:50:39 +08:00
DeleteProxy(ctx context.Context, id int64) error
GetProxyAccounts(ctx context.Context, proxyID int64, page, pageSize int) ([]Account, int64, error)
2025-12-18 13:50:39 +08:00
CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error)
TestProxy(ctx context.Context, id int64) (*ProxyTestResult, error)
// Redeem code management
ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]RedeemCode, int64, error)
GetRedeemCode(ctx context.Context, id int64) (*RedeemCode, error)
GenerateRedeemCodes(ctx context.Context, input *GenerateRedeemCodesInput) ([]RedeemCode, error)
2025-12-18 13:50:39 +08:00
DeleteRedeemCode(ctx context.Context, id int64) error
BatchDeleteRedeemCodes(ctx context.Context, ids []int64) (int64, error)
ExpireRedeemCode(ctx context.Context, id int64) (*RedeemCode, error)
2025-12-18 13:50:39 +08:00
}
// CreateUserInput represents input for creating a new user via admin operations.
2025-12-18 13:50:39 +08:00
type CreateUserInput struct {
Email string
Password string
Username string
Notes string
2025-12-18 13:50:39 +08:00
Balance float64
Concurrency int
AllowedGroups []int64
}
type UpdateUserInput struct {
Email string
Password string
Username *string
Notes *string
2025-12-18 13:50:39 +08:00
Balance *float64 // 使用指针区分"未提供"和"设置为0"
Concurrency *int // 使用指针区分"未提供"和"设置为0"
Status string
AllowedGroups *[]int64 // 使用指针区分"未提供"和"设置为空数组"
}
type CreateGroupInput struct {
Name string
Description string
Platform string
RateMultiplier float64
IsExclusive bool
SubscriptionType string // standard/subscription
DailyLimitUSD *float64 // 日限额 (USD)
WeeklyLimitUSD *float64 // 周限额 (USD)
MonthlyLimitUSD *float64 // 月限额 (USD)
}
type UpdateGroupInput struct {
Name string
Description string
Platform string
RateMultiplier *float64 // 使用指针以支持设置为0
IsExclusive *bool
Status string
SubscriptionType string // standard/subscription
DailyLimitUSD *float64 // 日限额 (USD)
WeeklyLimitUSD *float64 // 周限额 (USD)
MonthlyLimitUSD *float64 // 月限额 (USD)
}
type CreateAccountInput struct {
Name string
Platform string
Type string
2025-12-20 16:19:40 +08:00
Credentials map[string]any
Extra map[string]any
2025-12-18 13:50:39 +08:00
ProxyID *int64
Concurrency int
Priority int
GroupIDs []int64
// SkipMixedChannelCheck skips the mixed channel risk check when binding groups.
// This should only be set when the caller has explicitly confirmed the risk.
SkipMixedChannelCheck bool
2025-12-18 13:50:39 +08:00
}
type UpdateAccountInput struct {
Name string
Type string // Account type: oauth, setup-token, apikey
Credentials map[string]any
Extra map[string]any
ProxyID *int64
Concurrency *int // 使用指针区分"未提供"和"设置为0"
Priority *int // 使用指针区分"未提供"和"设置为0"
Status string
GroupIDs *[]int64
SkipMixedChannelCheck bool // 跳过混合渠道检查(用户已确认风险)
2025-12-18 13:50:39 +08:00
}
// BulkUpdateAccountsInput describes the payload for bulk updating accounts.
type BulkUpdateAccountsInput struct {
AccountIDs []int64
Name string
ProxyID *int64
Concurrency *int
Priority *int
Status string
GroupIDs *[]int64
Credentials map[string]any
Extra map[string]any
// SkipMixedChannelCheck skips the mixed channel risk check when binding groups.
// This should only be set when the caller has explicitly confirmed the risk.
SkipMixedChannelCheck bool
}
// BulkUpdateAccountResult captures the result for a single account update.
type BulkUpdateAccountResult struct {
AccountID int64 `json:"account_id"`
Success bool `json:"success"`
Error string `json:"error,omitempty"`
}
// BulkUpdateAccountsResult is the aggregated response for bulk updates.
type BulkUpdateAccountsResult struct {
Success int `json:"success"`
Failed int `json:"failed"`
Results []BulkUpdateAccountResult `json:"results"`
}
2025-12-18 13:50:39 +08:00
type CreateProxyInput struct {
Name string
Protocol string
Host string
Port int
Username string
Password string
}
type UpdateProxyInput struct {
Name string
Protocol string
Host string
Port int
Username string
Password string
Status string
}
type GenerateRedeemCodesInput struct {
Count int
Type string
Value float64
GroupID *int64 // 订阅类型专用关联的分组ID
ValidityDays int // 订阅类型专用:有效天数
}
// ProxyTestResult represents the result of testing a proxy
type ProxyTestResult struct {
Success bool `json:"success"`
Message string `json:"message"`
LatencyMs int64 `json:"latency_ms,omitempty"`
IPAddress string `json:"ip_address,omitempty"`
City string `json:"city,omitempty"`
Region string `json:"region,omitempty"`
Country string `json:"country,omitempty"`
}
2025-12-20 11:56:11 +08:00
// ProxyExitInfo represents proxy exit information from ipinfo.io
type ProxyExitInfo struct {
IP string
City string
Region string
Country string
}
// ProxyExitInfoProber tests proxy connectivity and retrieves exit information
type ProxyExitInfoProber interface {
ProbeProxy(ctx context.Context, proxyURL string) (*ProxyExitInfo, int64, error)
}
2025-12-18 13:50:39 +08:00
// adminServiceImpl implements AdminService
type adminServiceImpl struct {
2025-12-25 17:15:01 +08:00
userRepo UserRepository
groupRepo GroupRepository
accountRepo AccountRepository
proxyRepo ProxyRepository
apiKeyRepo APIKeyRepository
2025-12-25 17:15:01 +08:00
redeemCodeRepo RedeemCodeRepository
2025-12-18 13:50:39 +08:00
billingCacheService *BillingCacheService
2025-12-20 11:56:11 +08:00
proxyProber ProxyExitInfoProber
2025-12-18 13:50:39 +08:00
}
// NewAdminService creates a new AdminService
func NewAdminService(
2025-12-25 17:15:01 +08:00
userRepo UserRepository,
groupRepo GroupRepository,
accountRepo AccountRepository,
proxyRepo ProxyRepository,
apiKeyRepo APIKeyRepository,
2025-12-25 17:15:01 +08:00
redeemCodeRepo RedeemCodeRepository,
billingCacheService *BillingCacheService,
2025-12-20 11:56:11 +08:00
proxyProber ProxyExitInfoProber,
) AdminService {
2025-12-18 13:50:39 +08:00
return &adminServiceImpl{
userRepo: userRepo,
groupRepo: groupRepo,
accountRepo: accountRepo,
proxyRepo: proxyRepo,
apiKeyRepo: apiKeyRepo,
redeemCodeRepo: redeemCodeRepo,
billingCacheService: billingCacheService,
2025-12-20 11:56:11 +08:00
proxyProber: proxyProber,
2025-12-18 13:50:39 +08:00
}
}
// User management implementations
func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, filters UserListFilters) ([]User, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
users, result, err := s.userRepo.ListWithFilters(ctx, params, filters)
2025-12-18 13:50:39 +08:00
if err != nil {
return nil, 0, err
}
return users, result.Total, nil
}
func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error) {
2025-12-18 13:50:39 +08:00
return s.userRepo.GetByID(ctx, id)
}
func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInput) (*User, error) {
user := &User{
Email: input.Email,
Username: input.Username,
Notes: input.Notes,
Role: RoleUser, // Always create as regular user, never admin
Balance: input.Balance,
Concurrency: input.Concurrency,
Status: StatusActive,
AllowedGroups: input.AllowedGroups,
2025-12-18 13:50:39 +08:00
}
if err := user.SetPassword(input.Password); err != nil {
return nil, err
}
if err := s.userRepo.Create(ctx, user); err != nil {
return nil, err
}
return user, nil
}
func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error) {
2025-12-18 13:50:39 +08:00
user, err := s.userRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
// Protect admin users: cannot disable admin accounts
if user.Role == "admin" && input.Status == "disabled" {
return nil, errors.New("cannot disable admin user")
}
oldConcurrency := user.Concurrency
if input.Email != "" {
user.Email = input.Email
}
if input.Password != "" {
if err := user.SetPassword(input.Password); err != nil {
return nil, err
}
}
if input.Username != nil {
user.Username = *input.Username
}
if input.Notes != nil {
user.Notes = *input.Notes
}
2025-12-18 13:50:39 +08:00
if input.Status != "" {
user.Status = input.Status
}
if input.Concurrency != nil {
user.Concurrency = *input.Concurrency
}
if input.AllowedGroups != nil {
user.AllowedGroups = *input.AllowedGroups
}
if err := s.userRepo.Update(ctx, user); err != nil {
return nil, err
}
concurrencyDiff := user.Concurrency - oldConcurrency
if concurrencyDiff != 0 {
code, err := GenerateRedeemCode()
if err != nil {
log.Printf("failed to generate adjustment redeem code: %v", err)
return user, nil
}
adjustmentRecord := &RedeemCode{
Code: code,
Type: AdjustmentTypeAdminConcurrency,
2025-12-18 13:50:39 +08:00
Value: float64(concurrencyDiff),
Status: StatusUsed,
2025-12-18 13:50:39 +08:00
UsedBy: &user.ID,
}
now := time.Now()
adjustmentRecord.UsedAt = &now
if err := s.redeemCodeRepo.Create(ctx, adjustmentRecord); err != nil {
log.Printf("failed to create concurrency adjustment redeem code: %v", err)
2025-12-18 13:50:39 +08:00
}
}
return user, nil
}
func (s *adminServiceImpl) DeleteUser(ctx context.Context, id int64) error {
// Protect admin users: cannot delete admin accounts
user, err := s.userRepo.GetByID(ctx, id)
if err != nil {
return err
}
if user.Role == "admin" {
return errors.New("cannot delete admin user")
}
if err := s.userRepo.Delete(ctx, id); err != nil {
log.Printf("delete user failed: user_id=%d err=%v", id, err)
return err
}
return nil
2025-12-18 13:50:39 +08:00
}
func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error) {
2025-12-18 13:50:39 +08:00
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return nil, err
}
oldBalance := user.Balance
2025-12-18 13:50:39 +08:00
switch operation {
case "set":
user.Balance = balance
case "add":
user.Balance += balance
case "subtract":
user.Balance -= balance
}
if user.Balance < 0 {
return nil, fmt.Errorf("balance cannot be negative, current balance: %.2f, requested operation would result in: %.2f", oldBalance, user.Balance)
}
2025-12-18 13:50:39 +08:00
if err := s.userRepo.Update(ctx, user); err != nil {
return nil, err
}
if s.billingCacheService != nil {
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := s.billingCacheService.InvalidateUserBalance(cacheCtx, userID); err != nil {
log.Printf("invalidate user balance cache failed: user_id=%d err=%v", userID, err)
}
2025-12-18 13:50:39 +08:00
}()
}
balanceDiff := user.Balance - oldBalance
if balanceDiff != 0 {
code, err := GenerateRedeemCode()
if err != nil {
log.Printf("failed to generate adjustment redeem code: %v", err)
return user, nil
}
adjustmentRecord := &RedeemCode{
Code: code,
Type: AdjustmentTypeAdminBalance,
Value: balanceDiff,
Status: StatusUsed,
UsedBy: &user.ID,
Notes: notes,
}
now := time.Now()
adjustmentRecord.UsedAt = &now
if err := s.redeemCodeRepo.Create(ctx, adjustmentRecord); err != nil {
log.Printf("failed to create balance adjustment redeem code: %v", err)
}
}
2025-12-18 13:50:39 +08:00
return user, nil
}
func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]APIKey, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
2025-12-18 13:50:39 +08:00
keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
if err != nil {
return nil, 0, err
}
return keys, result.Total, nil
}
2025-12-20 16:19:40 +08:00
func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error) {
2025-12-18 13:50:39 +08:00
// Return mock data for now
2025-12-20 16:19:40 +08:00
return map[string]any{
2025-12-18 13:50:39 +08:00
"period": period,
"total_requests": 0,
"total_cost": 0.0,
"total_tokens": 0,
"avg_duration_ms": 0,
}, nil
}
// Group management implementations
func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]Group, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
2025-12-18 13:50:39 +08:00
groups, result, err := s.groupRepo.ListWithFilters(ctx, params, platform, status, isExclusive)
if err != nil {
return nil, 0, err
}
return groups, result.Total, nil
}
func (s *adminServiceImpl) GetAllGroups(ctx context.Context) ([]Group, error) {
2025-12-18 13:50:39 +08:00
return s.groupRepo.ListActive(ctx)
}
func (s *adminServiceImpl) GetAllGroupsByPlatform(ctx context.Context, platform string) ([]Group, error) {
2025-12-18 13:50:39 +08:00
return s.groupRepo.ListActiveByPlatform(ctx, platform)
}
func (s *adminServiceImpl) GetGroup(ctx context.Context, id int64) (*Group, error) {
2025-12-18 13:50:39 +08:00
return s.groupRepo.GetByID(ctx, id)
}
func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupInput) (*Group, error) {
2025-12-18 13:50:39 +08:00
platform := input.Platform
if platform == "" {
platform = PlatformAnthropic
2025-12-18 13:50:39 +08:00
}
subscriptionType := input.SubscriptionType
if subscriptionType == "" {
subscriptionType = SubscriptionTypeStandard
2025-12-18 13:50:39 +08:00
}
// 限额字段0 和 nil 都表示"无限制"
dailyLimit := normalizeLimit(input.DailyLimitUSD)
weeklyLimit := normalizeLimit(input.WeeklyLimitUSD)
monthlyLimit := normalizeLimit(input.MonthlyLimitUSD)
group := &Group{
2025-12-18 13:50:39 +08:00
Name: input.Name,
Description: input.Description,
Platform: platform,
RateMultiplier: input.RateMultiplier,
IsExclusive: input.IsExclusive,
Status: StatusActive,
2025-12-18 13:50:39 +08:00
SubscriptionType: subscriptionType,
DailyLimitUSD: dailyLimit,
WeeklyLimitUSD: weeklyLimit,
MonthlyLimitUSD: monthlyLimit,
2025-12-18 13:50:39 +08:00
}
if err := s.groupRepo.Create(ctx, group); err != nil {
return nil, err
}
return group, nil
}
// normalizeLimit 将 0 或负数转换为 nil表示无限制
func normalizeLimit(limit *float64) *float64 {
if limit == nil || *limit <= 0 {
return nil
}
return limit
}
func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) {
2025-12-18 13:50:39 +08:00
group, err := s.groupRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
if input.Name != "" {
group.Name = input.Name
}
if input.Description != "" {
group.Description = input.Description
}
if input.Platform != "" {
group.Platform = input.Platform
}
if input.RateMultiplier != nil {
group.RateMultiplier = *input.RateMultiplier
}
if input.IsExclusive != nil {
group.IsExclusive = *input.IsExclusive
}
if input.Status != "" {
group.Status = input.Status
}
// 订阅相关字段
if input.SubscriptionType != "" {
group.SubscriptionType = input.SubscriptionType
}
// 限额字段0 和 nil 都表示"无限制",正数表示具体限额
2025-12-18 13:50:39 +08:00
if input.DailyLimitUSD != nil {
group.DailyLimitUSD = normalizeLimit(input.DailyLimitUSD)
2025-12-18 13:50:39 +08:00
}
if input.WeeklyLimitUSD != nil {
group.WeeklyLimitUSD = normalizeLimit(input.WeeklyLimitUSD)
2025-12-18 13:50:39 +08:00
}
if input.MonthlyLimitUSD != nil {
group.MonthlyLimitUSD = normalizeLimit(input.MonthlyLimitUSD)
2025-12-18 13:50:39 +08:00
}
if err := s.groupRepo.Update(ctx, group); err != nil {
return nil, err
}
return group, nil
}
func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
2025-12-25 20:52:47 +08:00
affectedUserIDs, err := s.groupRepo.DeleteCascade(ctx, id)
2025-12-18 13:50:39 +08:00
if err != nil {
return err
}
// 事务成功后,异步失效受影响用户的订阅缓存
if len(affectedUserIDs) > 0 && s.billingCacheService != nil {
groupID := id
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
for _, userID := range affectedUserIDs {
if err := s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID); err != nil {
log.Printf("invalidate subscription cache failed: user_id=%d group_id=%d err=%v", userID, groupID, err)
}
2025-12-18 13:50:39 +08:00
}
}()
}
return nil
}
func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]APIKey, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
2025-12-18 13:50:39 +08:00
keys, result, err := s.apiKeyRepo.ListByGroupID(ctx, groupID, params)
if err != nil {
return nil, 0, err
}
return keys, result.Total, nil
}
// Account management implementations
func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
2025-12-18 13:50:39 +08:00
accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search)
if err != nil {
return nil, 0, err
}
return accounts, result.Total, nil
}
func (s *adminServiceImpl) GetAccount(ctx context.Context, id int64) (*Account, error) {
2025-12-18 13:50:39 +08:00
return s.accountRepo.GetByID(ctx, id)
}
func (s *adminServiceImpl) GetAccountsByIDs(ctx context.Context, ids []int64) ([]*Account, error) {
if len(ids) == 0 {
return []*Account{}, nil
}
accounts, err := s.accountRepo.GetByIDs(ctx, ids)
if err != nil {
return nil, fmt.Errorf("failed to get accounts by IDs: %w", err)
}
return accounts, nil
}
func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error) {
2025-12-18 13:50:39 +08:00
// 绑定分组
groupIDs := input.GroupIDs
// 如果没有指定分组,自动绑定对应平台的默认分组
if len(groupIDs) == 0 {
defaultGroupName := input.Platform + "-default"
groups, err := s.groupRepo.ListActiveByPlatform(ctx, input.Platform)
if err == nil {
for _, g := range groups {
if g.Name == defaultGroupName {
groupIDs = []int64{g.ID}
break
}
}
}
}
// 检查混合渠道风险(除非用户已确认)
if len(groupIDs) > 0 && !input.SkipMixedChannelCheck {
if err := s.checkMixedChannelRisk(ctx, 0, input.Platform, groupIDs); err != nil {
return nil, err
}
}
account := &Account{
Name: input.Name,
Platform: input.Platform,
Type: input.Type,
Credentials: input.Credentials,
Extra: input.Extra,
ProxyID: input.ProxyID,
Concurrency: input.Concurrency,
Priority: input.Priority,
Status: StatusActive,
Schedulable: true,
}
if err := s.accountRepo.Create(ctx, account); err != nil {
return nil, err
}
// 绑定分组
if len(groupIDs) > 0 {
if err := s.accountRepo.BindGroups(ctx, account.ID, groupIDs); err != nil {
2025-12-18 13:50:39 +08:00
return nil, err
}
}
2025-12-18 13:50:39 +08:00
return account, nil
}
func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *UpdateAccountInput) (*Account, error) {
2025-12-18 13:50:39 +08:00
account, err := s.accountRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
if input.Name != "" {
account.Name = input.Name
}
if input.Type != "" {
account.Type = input.Type
}
if len(input.Credentials) > 0 {
account.Credentials = input.Credentials
2025-12-18 13:50:39 +08:00
}
if len(input.Extra) > 0 {
account.Extra = input.Extra
2025-12-18 13:50:39 +08:00
}
if input.ProxyID != nil {
account.ProxyID = input.ProxyID
account.Proxy = nil // 清除关联对象,防止 GORM Save 时根据 Proxy.ID 覆盖 ProxyID
2025-12-18 13:50:39 +08:00
}
// 只在指针非 nil 时更新 Concurrency支持设置为 0
if input.Concurrency != nil {
account.Concurrency = *input.Concurrency
}
// 只在指针非 nil 时更新 Priority支持设置为 0
if input.Priority != nil {
account.Priority = *input.Priority
}
if input.Status != "" {
account.Status = input.Status
}
// 先验证分组是否存在(在任何写操作之前)
if input.GroupIDs != nil {
for _, groupID := range *input.GroupIDs {
if _, err := s.groupRepo.GetByID(ctx, groupID); err != nil {
return nil, fmt.Errorf("get group: %w", err)
}
}
// 检查混合渠道风险(除非用户已确认)
if !input.SkipMixedChannelCheck {
if err := s.checkMixedChannelRisk(ctx, account.ID, account.Platform, *input.GroupIDs); err != nil {
return nil, err
}
}
}
2025-12-18 13:50:39 +08:00
if err := s.accountRepo.Update(ctx, account); err != nil {
return nil, err
}
// 绑定分组
2025-12-18 13:50:39 +08:00
if input.GroupIDs != nil {
if err := s.accountRepo.BindGroups(ctx, account.ID, *input.GroupIDs); err != nil {
return nil, err
}
}
// 重新查询以确保返回完整数据(包括正确的 Proxy 关联对象)
return s.accountRepo.GetByID(ctx, id)
2025-12-18 13:50:39 +08:00
}
// BulkUpdateAccounts updates multiple accounts in one request.
// It merges credentials/extra keys instead of overwriting the whole object.
func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error) {
result := &BulkUpdateAccountsResult{
Results: make([]BulkUpdateAccountResult, 0, len(input.AccountIDs)),
}
if len(input.AccountIDs) == 0 {
return result, nil
}
// Preload account platforms for mixed channel risk checks if group bindings are requested.
platformByID := map[int64]string{}
if input.GroupIDs != nil && !input.SkipMixedChannelCheck {
accounts, err := s.accountRepo.GetByIDs(ctx, input.AccountIDs)
if err != nil {
return nil, err
}
for _, account := range accounts {
if account != nil {
platformByID[account.ID] = account.Platform
}
}
}
// Prepare bulk updates for columns and JSONB fields.
2025-12-25 17:15:01 +08:00
repoUpdates := AccountBulkUpdate{
Credentials: input.Credentials,
Extra: input.Extra,
}
if input.Name != "" {
repoUpdates.Name = &input.Name
}
if input.ProxyID != nil {
repoUpdates.ProxyID = input.ProxyID
}
if input.Concurrency != nil {
repoUpdates.Concurrency = input.Concurrency
}
if input.Priority != nil {
repoUpdates.Priority = input.Priority
}
if input.Status != "" {
repoUpdates.Status = &input.Status
}
// Run bulk update for column/jsonb fields first.
if _, err := s.accountRepo.BulkUpdate(ctx, input.AccountIDs, repoUpdates); err != nil {
return nil, err
}
// Handle group bindings per account (requires individual operations).
for _, accountID := range input.AccountIDs {
entry := BulkUpdateAccountResult{AccountID: accountID}
if input.GroupIDs != nil {
// 检查混合渠道风险(除非用户已确认)
if !input.SkipMixedChannelCheck {
platform := platformByID[accountID]
if platform == "" {
account, err := s.accountRepo.GetByID(ctx, accountID)
if err != nil {
entry.Success = false
entry.Error = err.Error()
result.Failed++
result.Results = append(result.Results, entry)
continue
}
platform = account.Platform
}
if err := s.checkMixedChannelRisk(ctx, accountID, platform, *input.GroupIDs); err != nil {
entry.Success = false
entry.Error = err.Error()
result.Failed++
result.Results = append(result.Results, entry)
continue
}
}
if err := s.accountRepo.BindGroups(ctx, accountID, *input.GroupIDs); err != nil {
entry.Success = false
entry.Error = err.Error()
result.Failed++
result.Results = append(result.Results, entry)
continue
}
}
entry.Success = true
result.Success++
result.Results = append(result.Results, entry)
}
return result, nil
}
2025-12-18 13:50:39 +08:00
func (s *adminServiceImpl) DeleteAccount(ctx context.Context, id int64) error {
return s.accountRepo.Delete(ctx, id)
}
func (s *adminServiceImpl) RefreshAccountCredentials(ctx context.Context, id int64) (*Account, error) {
2025-12-18 13:50:39 +08:00
account, err := s.accountRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
// TODO: Implement refresh logic
return account, nil
}
func (s *adminServiceImpl) ClearAccountError(ctx context.Context, id int64) (*Account, error) {
2025-12-18 13:50:39 +08:00
account, err := s.accountRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
account.Status = StatusActive
2025-12-18 13:50:39 +08:00
account.ErrorMessage = ""
if err := s.accountRepo.Update(ctx, account); err != nil {
return nil, err
}
return account, nil
}
func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error) {
2025-12-18 13:50:39 +08:00
if err := s.accountRepo.SetSchedulable(ctx, id, schedulable); err != nil {
return nil, err
}
return s.accountRepo.GetByID(ctx, id)
}
// Proxy management implementations
func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]Proxy, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
2025-12-18 13:50:39 +08:00
proxies, result, err := s.proxyRepo.ListWithFilters(ctx, params, protocol, status, search)
if err != nil {
return nil, 0, err
}
return proxies, result.Total, nil
}
func (s *adminServiceImpl) GetAllProxies(ctx context.Context) ([]Proxy, error) {
2025-12-18 13:50:39 +08:00
return s.proxyRepo.ListActive(ctx)
}
func (s *adminServiceImpl) GetAllProxiesWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error) {
2025-12-18 13:50:39 +08:00
return s.proxyRepo.ListActiveWithAccountCount(ctx)
}
func (s *adminServiceImpl) GetProxy(ctx context.Context, id int64) (*Proxy, error) {
2025-12-18 13:50:39 +08:00
return s.proxyRepo.GetByID(ctx, id)
}
func (s *adminServiceImpl) CreateProxy(ctx context.Context, input *CreateProxyInput) (*Proxy, error) {
proxy := &Proxy{
2025-12-18 13:50:39 +08:00
Name: input.Name,
Protocol: input.Protocol,
Host: input.Host,
Port: input.Port,
Username: input.Username,
Password: input.Password,
Status: StatusActive,
2025-12-18 13:50:39 +08:00
}
if err := s.proxyRepo.Create(ctx, proxy); err != nil {
return nil, err
}
return proxy, nil
}
func (s *adminServiceImpl) UpdateProxy(ctx context.Context, id int64, input *UpdateProxyInput) (*Proxy, error) {
2025-12-18 13:50:39 +08:00
proxy, err := s.proxyRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
if input.Name != "" {
proxy.Name = input.Name
}
if input.Protocol != "" {
proxy.Protocol = input.Protocol
}
if input.Host != "" {
proxy.Host = input.Host
}
if input.Port != 0 {
proxy.Port = input.Port
}
if input.Username != "" {
proxy.Username = input.Username
}
if input.Password != "" {
proxy.Password = input.Password
}
if input.Status != "" {
proxy.Status = input.Status
}
if err := s.proxyRepo.Update(ctx, proxy); err != nil {
return nil, err
}
return proxy, nil
}
func (s *adminServiceImpl) DeleteProxy(ctx context.Context, id int64) error {
return s.proxyRepo.Delete(ctx, id)
}
func (s *adminServiceImpl) GetProxyAccounts(ctx context.Context, proxyID int64, page, pageSize int) ([]Account, int64, error) {
2025-12-18 13:50:39 +08:00
// Return mock data for now - would need a dedicated repository method
return []Account{}, 0, nil
2025-12-18 13:50:39 +08:00
}
func (s *adminServiceImpl) CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error) {
return s.proxyRepo.ExistsByHostPortAuth(ctx, host, port, username, password)
}
// Redeem code management implementations
func (s *adminServiceImpl) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]RedeemCode, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
2025-12-18 13:50:39 +08:00
codes, result, err := s.redeemCodeRepo.ListWithFilters(ctx, params, codeType, status, search)
if err != nil {
return nil, 0, err
}
return codes, result.Total, nil
}
func (s *adminServiceImpl) GetRedeemCode(ctx context.Context, id int64) (*RedeemCode, error) {
2025-12-18 13:50:39 +08:00
return s.redeemCodeRepo.GetByID(ctx, id)
}
func (s *adminServiceImpl) GenerateRedeemCodes(ctx context.Context, input *GenerateRedeemCodesInput) ([]RedeemCode, error) {
2025-12-18 13:50:39 +08:00
// 如果是订阅类型,验证必须有 GroupID
if input.Type == RedeemTypeSubscription {
2025-12-18 13:50:39 +08:00
if input.GroupID == nil {
return nil, errors.New("group_id is required for subscription type")
}
// 验证分组存在且为订阅类型
group, err := s.groupRepo.GetByID(ctx, *input.GroupID)
if err != nil {
return nil, fmt.Errorf("group not found: %w", err)
}
if !group.IsSubscriptionType() {
return nil, errors.New("group must be subscription type")
}
}
codes := make([]RedeemCode, 0, input.Count)
2025-12-18 13:50:39 +08:00
for i := 0; i < input.Count; i++ {
codeValue, err := GenerateRedeemCode()
if err != nil {
return nil, err
}
code := RedeemCode{
Code: codeValue,
2025-12-18 13:50:39 +08:00
Type: input.Type,
Value: input.Value,
Status: StatusUnused,
2025-12-18 13:50:39 +08:00
}
// 订阅类型专用字段
if input.Type == RedeemTypeSubscription {
2025-12-18 13:50:39 +08:00
code.GroupID = input.GroupID
code.ValidityDays = input.ValidityDays
if code.ValidityDays <= 0 {
code.ValidityDays = 30 // 默认30天
}
}
if err := s.redeemCodeRepo.Create(ctx, &code); err != nil {
return nil, err
}
codes = append(codes, code)
}
return codes, nil
}
func (s *adminServiceImpl) DeleteRedeemCode(ctx context.Context, id int64) error {
return s.redeemCodeRepo.Delete(ctx, id)
}
func (s *adminServiceImpl) BatchDeleteRedeemCodes(ctx context.Context, ids []int64) (int64, error) {
var deleted int64
for _, id := range ids {
if err := s.redeemCodeRepo.Delete(ctx, id); err == nil {
deleted++
}
}
return deleted, nil
}
func (s *adminServiceImpl) ExpireRedeemCode(ctx context.Context, id int64) (*RedeemCode, error) {
2025-12-18 13:50:39 +08:00
code, err := s.redeemCodeRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
code.Status = StatusExpired
2025-12-18 13:50:39 +08:00
if err := s.redeemCodeRepo.Update(ctx, code); err != nil {
return nil, err
}
return code, nil
}
func (s *adminServiceImpl) TestProxy(ctx context.Context, id int64) (*ProxyTestResult, error) {
proxy, err := s.proxyRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
proxyURL := proxy.URL()
2025-12-20 11:56:11 +08:00
exitInfo, latencyMs, err := s.proxyProber.ProbeProxy(ctx, proxyURL)
2025-12-18 13:50:39 +08:00
if err != nil {
return &ProxyTestResult{
Success: false,
2025-12-20 11:56:11 +08:00
Message: err.Error(),
2025-12-18 13:50:39 +08:00
}, nil
}
return &ProxyTestResult{
Success: true,
Message: "Proxy is accessible",
LatencyMs: latencyMs,
2025-12-20 11:56:11 +08:00
IPAddress: exitInfo.IP,
City: exitInfo.City,
Region: exitInfo.Region,
Country: exitInfo.Country,
2025-12-18 13:50:39 +08:00
}, nil
}
// checkMixedChannelRisk 检查分组中是否存在混合渠道Antigravity + Anthropic
// 如果存在混合,返回错误提示用户确认
func (s *adminServiceImpl) checkMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error {
// 判断当前账号的渠道类型(基于 platform 字段,而不是 type 字段)
currentPlatform := getAccountPlatform(currentAccountPlatform)
if currentPlatform == "" {
// 不是 Antigravity 或 Anthropic无需检查
return nil
}
// 检查每个分组中的其他账号
for _, groupID := range groupIDs {
accounts, err := s.accountRepo.ListByGroup(ctx, groupID)
if err != nil {
return fmt.Errorf("get accounts in group %d: %w", groupID, err)
}
// 检查是否存在不同渠道的账号
for _, account := range accounts {
if currentAccountID > 0 && account.ID == currentAccountID {
continue // 跳过当前账号
}
otherPlatform := getAccountPlatform(account.Platform)
if otherPlatform == "" {
continue // 不是 Antigravity 或 Anthropic跳过
}
// 检测混合渠道
if currentPlatform != otherPlatform {
group, _ := s.groupRepo.GetByID(ctx, groupID)
groupName := fmt.Sprintf("Group %d", groupID)
if group != nil {
groupName = group.Name
}
return &MixedChannelError{
GroupID: groupID,
GroupName: groupName,
CurrentPlatform: currentPlatform,
OtherPlatform: otherPlatform,
}
}
}
}
return nil
}
// getAccountPlatform 根据账号 platform 判断混合渠道检查用的平台标识
func getAccountPlatform(accountPlatform string) string {
switch strings.ToLower(strings.TrimSpace(accountPlatform)) {
case PlatformAntigravity:
return "Antigravity"
case PlatformAnthropic, "claude":
return "Anthropic"
default:
return ""
}
}
// MixedChannelError 混合渠道错误
type MixedChannelError struct {
GroupID int64
GroupName string
CurrentPlatform string
OtherPlatform string
}
func (e *MixedChannelError) Error() string {
return fmt.Sprintf("mixed_channel_warning: Group '%s' contains both %s and %s accounts. Using mixed channels in the same context may cause thinking block signature validation issues, which will fallback to non-thinking mode for historical messages.",
e.GroupName, e.CurrentPlatform, e.OtherPlatform)
}