Files
sub2api/backend/internal/service/account_service.go

398 lines
13 KiB
Go
Raw Permalink Normal View History

2025-12-18 13:50:39 +08:00
package service
import (
"context"
"fmt"
2025-12-25 17:15:01 +08:00
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
2025-12-24 21:07:21 +08:00
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
2025-12-18 13:50:39 +08:00
)
var (
2025-12-25 20:52:47 +08:00
ErrAccountNotFound = infraerrors.NotFound("ACCOUNT_NOT_FOUND", "account not found")
ErrAccountNilInput = infraerrors.BadRequest("ACCOUNT_NIL_INPUT", "account input cannot be nil")
2025-12-18 13:50:39 +08:00
)
2025-12-25 17:15:01 +08:00
type AccountRepository interface {
Create(ctx context.Context, account *Account) error
GetByID(ctx context.Context, id int64) (*Account, error)
// GetByIDs fetches accounts by IDs in a single query.
// It should return all accounts found (missing IDs are ignored).
GetByIDs(ctx context.Context, ids []int64) ([]*Account, error)
// ExistsByID 检查账号是否存在,仅返回布尔值,用于删除前的轻量级存在性检查
ExistsByID(ctx context.Context, id int64) (bool, error)
2025-12-25 17:15:01 +08:00
// GetByCRSAccountID finds an account previously synced from CRS.
// Returns (nil, nil) if not found.
GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error)
// FindByExtraField 根据 extra 字段中的键值对查找账号(限定 platform='sora'
// 用于查找通过 linked_openai_account_id 关联的 Sora 账号
FindByExtraField(ctx context.Context, key string, value any) ([]Account, error)
// ListCRSAccountIDs returns a map of crs_account_id -> local account ID
// for all accounts that have been synced from CRS.
ListCRSAccountIDs(ctx context.Context) (map[string]int64, error)
Update(ctx context.Context, account *Account) error
2025-12-25 17:15:01 +08:00
Delete(ctx context.Context, id int64) error
List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error)
ListByGroup(ctx context.Context, groupID int64) ([]Account, error)
ListActive(ctx context.Context) ([]Account, error)
ListByPlatform(ctx context.Context, platform string) ([]Account, error)
2025-12-25 17:15:01 +08:00
UpdateLastUsed(ctx context.Context, id int64) error
BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error
2025-12-25 17:15:01 +08:00
SetError(ctx context.Context, id int64, errorMsg string) error
ClearError(ctx context.Context, id int64) error
2025-12-25 17:15:01 +08:00
SetSchedulable(ctx context.Context, id int64, schedulable bool) error
2026-01-07 16:59:35 +08:00
AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error)
2025-12-25 17:15:01 +08:00
BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error
ListSchedulable(ctx context.Context) ([]Account, error)
ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]Account, error)
ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error)
ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error)
ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]Account, error)
ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error)
ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error)
ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]Account, error)
2025-12-25 17:15:01 +08:00
SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error
SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error
2025-12-25 17:15:01 +08:00
SetOverloaded(ctx context.Context, id int64, until time.Time) error
SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error
ClearTempUnschedulable(ctx context.Context, id int64) error
2025-12-25 17:15:01 +08:00
ClearRateLimit(ctx context.Context, id int64) error
2026-01-09 17:35:02 +08:00
ClearAntigravityQuotaScopes(ctx context.Context, id int64) error
ClearModelRateLimits(ctx context.Context, id int64) error
2025-12-25 17:15:01 +08:00
UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error
UpdateExtra(ctx context.Context, id int64, updates map[string]any) error
BulkUpdate(ctx context.Context, ids []int64, updates AccountBulkUpdate) (int64, error)
// IncrementQuotaUsed 原子递增 API Key 账号的配额用量(总/日/周)
IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error
// ResetQuotaUsed 重置 API Key 账号所有维度的配额用量为 0
ResetQuotaUsed(ctx context.Context, id int64) error
2025-12-25 17:15:01 +08:00
}
// AccountBulkUpdate describes the fields that can be updated in a bulk operation.
// Nil pointers mean "do not change".
type AccountBulkUpdate struct {
Name *string
ProxyID *int64
Concurrency *int
Priority *int
RateMultiplier *float64
LoadFactor *int
Status *string
Schedulable *bool
Credentials map[string]any
Extra map[string]any
2025-12-25 17:15:01 +08:00
}
2025-12-18 13:50:39 +08:00
// CreateAccountRequest 创建账号请求
type CreateAccountRequest struct {
2026-01-07 16:59:35 +08:00
Name string `json:"name"`
Notes *string `json:"notes"`
Platform string `json:"platform"`
Type string `json:"type"`
Credentials map[string]any `json:"credentials"`
Extra map[string]any `json:"extra"`
ProxyID *int64 `json:"proxy_id"`
Concurrency int `json:"concurrency"`
Priority int `json:"priority"`
GroupIDs []int64 `json:"group_ids"`
ExpiresAt *time.Time `json:"expires_at"`
AutoPauseOnExpired *bool `json:"auto_pause_on_expired"`
2025-12-18 13:50:39 +08:00
}
// UpdateAccountRequest 更新账号请求
type UpdateAccountRequest struct {
2026-01-07 16:59:35 +08:00
Name *string `json:"name"`
Notes *string `json:"notes"`
Credentials *map[string]any `json:"credentials"`
Extra *map[string]any `json:"extra"`
ProxyID *int64 `json:"proxy_id"`
Concurrency *int `json:"concurrency"`
Priority *int `json:"priority"`
Status *string `json:"status"`
GroupIDs *[]int64 `json:"group_ids"`
ExpiresAt *time.Time `json:"expires_at"`
AutoPauseOnExpired *bool `json:"auto_pause_on_expired"`
2025-12-18 13:50:39 +08:00
}
// AccountService 账号管理服务
type AccountService struct {
2025-12-25 17:15:01 +08:00
accountRepo AccountRepository
groupRepo GroupRepository
2025-12-18 13:50:39 +08:00
}
type groupExistenceBatchChecker interface {
ExistsByIDs(ctx context.Context, ids []int64) (map[int64]bool, error)
}
2025-12-18 13:50:39 +08:00
// NewAccountService 创建账号服务实例
2025-12-25 17:15:01 +08:00
func NewAccountService(accountRepo AccountRepository, groupRepo GroupRepository) *AccountService {
2025-12-18 13:50:39 +08:00
return &AccountService{
accountRepo: accountRepo,
groupRepo: groupRepo,
}
}
// Create 创建账号
func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (*Account, error) {
2025-12-18 13:50:39 +08:00
// 验证分组是否存在(如果指定了分组)
if len(req.GroupIDs) > 0 {
if err := s.validateGroupIDsExist(ctx, req.GroupIDs); err != nil {
return nil, err
2025-12-18 13:50:39 +08:00
}
}
// 创建账号
account := &Account{
2025-12-18 13:50:39 +08:00
Name: req.Name,
2026-01-05 14:07:33 +08:00
Notes: normalizeAccountNotes(req.Notes),
2025-12-18 13:50:39 +08:00
Platform: req.Platform,
Type: req.Type,
Credentials: req.Credentials,
Extra: req.Extra,
ProxyID: req.ProxyID,
Concurrency: req.Concurrency,
Priority: req.Priority,
Status: StatusActive,
2026-01-07 16:59:35 +08:00
ExpiresAt: req.ExpiresAt,
}
if req.AutoPauseOnExpired != nil {
account.AutoPauseOnExpired = *req.AutoPauseOnExpired
} else {
account.AutoPauseOnExpired = true
2025-12-18 13:50:39 +08:00
}
if err := s.accountRepo.Create(ctx, account); err != nil {
return nil, fmt.Errorf("create account: %w", err)
}
// 绑定分组
if len(req.GroupIDs) > 0 {
if err := s.accountRepo.BindGroups(ctx, account.ID, req.GroupIDs); err != nil {
return nil, fmt.Errorf("bind groups: %w", err)
}
}
return account, nil
}
// GetByID 根据ID获取账号
func (s *AccountService) GetByID(ctx context.Context, id int64) (*Account, error) {
2025-12-18 13:50:39 +08:00
account, err := s.accountRepo.GetByID(ctx, id)
if err != nil {
return nil, fmt.Errorf("get account: %w", err)
}
return account, nil
}
// List 获取账号列表
func (s *AccountService) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) {
2025-12-18 13:50:39 +08:00
accounts, pagination, err := s.accountRepo.List(ctx, params)
if err != nil {
return nil, nil, fmt.Errorf("list accounts: %w", err)
}
return accounts, pagination, nil
}
// ListByPlatform 根据平台获取账号列表
func (s *AccountService) ListByPlatform(ctx context.Context, platform string) ([]Account, error) {
2025-12-18 13:50:39 +08:00
accounts, err := s.accountRepo.ListByPlatform(ctx, platform)
if err != nil {
return nil, fmt.Errorf("list accounts by platform: %w", err)
}
return accounts, nil
}
// ListByGroup 根据分组获取账号列表
func (s *AccountService) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) {
2025-12-18 13:50:39 +08:00
accounts, err := s.accountRepo.ListByGroup(ctx, groupID)
if err != nil {
return nil, fmt.Errorf("list accounts by group: %w", err)
}
return accounts, nil
}
// Update 更新账号
func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccountRequest) (*Account, error) {
2025-12-18 13:50:39 +08:00
account, err := s.accountRepo.GetByID(ctx, id)
if err != nil {
return nil, fmt.Errorf("get account: %w", err)
}
// 更新字段
if req.Name != nil {
account.Name = *req.Name
}
2026-01-05 14:07:33 +08:00
if req.Notes != nil {
account.Notes = normalizeAccountNotes(req.Notes)
}
2025-12-18 13:50:39 +08:00
if req.Credentials != nil {
account.Credentials = *req.Credentials
}
if req.Extra != nil {
account.Extra = *req.Extra
}
if req.ProxyID != nil {
account.ProxyID = req.ProxyID
}
if req.Concurrency != nil {
account.Concurrency = *req.Concurrency
}
if req.Priority != nil {
account.Priority = *req.Priority
}
if req.Status != nil {
account.Status = *req.Status
}
2026-01-07 16:59:35 +08:00
if req.ExpiresAt != nil {
account.ExpiresAt = req.ExpiresAt
}
if req.AutoPauseOnExpired != nil {
account.AutoPauseOnExpired = *req.AutoPauseOnExpired
}
2025-12-18 13:50:39 +08:00
// 先验证分组是否存在(在任何写操作之前)
2025-12-18 13:50:39 +08:00
if req.GroupIDs != nil {
if err := s.validateGroupIDsExist(ctx, *req.GroupIDs); err != nil {
return nil, err
2025-12-18 13:50:39 +08:00
}
}
2025-12-18 13:50:39 +08:00
// 执行更新
if err := s.accountRepo.Update(ctx, account); err != nil {
return nil, fmt.Errorf("update account: %w", err)
}
// 绑定分组
if req.GroupIDs != nil {
2025-12-18 13:50:39 +08:00
if err := s.accountRepo.BindGroups(ctx, account.ID, *req.GroupIDs); err != nil {
return nil, fmt.Errorf("bind groups: %w", err)
}
}
return account, nil
}
// Delete 删除账号
// 优化:使用 ExistsByID 替代 GetByID 进行存在性检查,
// 避免加载完整账号对象及其关联数据,提升删除操作的性能
2025-12-18 13:50:39 +08:00
func (s *AccountService) Delete(ctx context.Context, id int64) error {
// 使用轻量级的存在性检查,而非加载完整账号对象
exists, err := s.accountRepo.ExistsByID(ctx, id)
2025-12-18 13:50:39 +08:00
if err != nil {
return fmt.Errorf("check account: %w", err)
}
// 明确返回账号不存在错误,便于调用方区分错误类型
if !exists {
return ErrAccountNotFound
2025-12-18 13:50:39 +08:00
}
if err := s.accountRepo.Delete(ctx, id); err != nil {
return fmt.Errorf("delete account: %w", err)
}
return nil
}
func (s *AccountService) validateGroupIDsExist(ctx context.Context, groupIDs []int64) error {
if len(groupIDs) == 0 {
return nil
}
if s.groupRepo == nil {
return fmt.Errorf("group repository not configured")
}
if batchChecker, ok := s.groupRepo.(groupExistenceBatchChecker); ok {
existsByID, err := batchChecker.ExistsByIDs(ctx, groupIDs)
if err != nil {
return fmt.Errorf("check groups exists: %w", err)
}
for _, groupID := range groupIDs {
if groupID <= 0 {
return fmt.Errorf("get group: %w", ErrGroupNotFound)
}
if !existsByID[groupID] {
return fmt.Errorf("get group: %w", ErrGroupNotFound)
}
}
return nil
}
for _, groupID := range groupIDs {
_, err := s.groupRepo.GetByID(ctx, groupID)
if err != nil {
return fmt.Errorf("get group: %w", err)
}
}
return nil
}
2025-12-18 13:50:39 +08:00
// UpdateStatus 更新账号状态
func (s *AccountService) UpdateStatus(ctx context.Context, id int64, status string, errorMessage string) error {
account, err := s.accountRepo.GetByID(ctx, id)
if err != nil {
return fmt.Errorf("get account: %w", err)
}
account.Status = status
account.ErrorMessage = errorMessage
if err := s.accountRepo.Update(ctx, account); err != nil {
return fmt.Errorf("update account: %w", err)
}
return nil
}
// UpdateLastUsed 更新最后使用时间
func (s *AccountService) UpdateLastUsed(ctx context.Context, id int64) error {
if err := s.accountRepo.UpdateLastUsed(ctx, id); err != nil {
return fmt.Errorf("update last used: %w", err)
}
return nil
}
// GetCredential 获取账号凭证(安全访问)
func (s *AccountService) GetCredential(ctx context.Context, id int64, key string) (string, error) {
account, err := s.accountRepo.GetByID(ctx, id)
if err != nil {
return "", fmt.Errorf("get account: %w", err)
}
return account.GetCredential(key), nil
}
// TestCredentials 测试账号凭证是否有效(需要实现具体平台的测试逻辑)
func (s *AccountService) TestCredentials(ctx context.Context, id int64) error {
account, err := s.accountRepo.GetByID(ctx, id)
if err != nil {
return fmt.Errorf("get account: %w", err)
}
// 根据平台执行不同的测试逻辑
switch account.Platform {
case PlatformAnthropic:
2025-12-18 13:50:39 +08:00
// TODO: 测试Anthropic API凭证
return nil
case PlatformOpenAI:
2025-12-18 13:50:39 +08:00
// TODO: 测试OpenAI API凭证
return nil
case PlatformGemini:
2025-12-18 13:50:39 +08:00
// TODO: 测试Gemini API凭证
return nil
default:
return fmt.Errorf("unsupported platform: %s", account.Platform)
}
}