mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-18 13:54:46 +08:00
114 lines
3.3 KiB
Go
114 lines
3.3 KiB
Go
|
|
package repository
|
|||
|
|
|
|||
|
|
import (
|
|||
|
|
"context"
|
|||
|
|
"database/sql"
|
|||
|
|
"time"
|
|||
|
|
|
|||
|
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
type userGroupRateRepository struct {
|
|||
|
|
sql sqlExecutor
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// NewUserGroupRateRepository 创建用户专属分组倍率仓储
|
|||
|
|
func NewUserGroupRateRepository(sqlDB *sql.DB) service.UserGroupRateRepository {
|
|||
|
|
return &userGroupRateRepository{sql: sqlDB}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// GetByUserID 获取用户的所有专属分组倍率
|
|||
|
|
func (r *userGroupRateRepository) GetByUserID(ctx context.Context, userID int64) (map[int64]float64, error) {
|
|||
|
|
query := `SELECT group_id, rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1`
|
|||
|
|
rows, err := r.sql.QueryContext(ctx, query, userID)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, err
|
|||
|
|
}
|
|||
|
|
defer func() { _ = rows.Close() }()
|
|||
|
|
|
|||
|
|
result := make(map[int64]float64)
|
|||
|
|
for rows.Next() {
|
|||
|
|
var groupID int64
|
|||
|
|
var rate float64
|
|||
|
|
if err := rows.Scan(&groupID, &rate); err != nil {
|
|||
|
|
return nil, err
|
|||
|
|
}
|
|||
|
|
result[groupID] = rate
|
|||
|
|
}
|
|||
|
|
if err := rows.Err(); err != nil {
|
|||
|
|
return nil, err
|
|||
|
|
}
|
|||
|
|
return result, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// GetByUserAndGroup 获取用户在特定分组的专属倍率
|
|||
|
|
func (r *userGroupRateRepository) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) {
|
|||
|
|
query := `SELECT rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2`
|
|||
|
|
var rate float64
|
|||
|
|
err := scanSingleRow(ctx, r.sql, query, []any{userID, groupID}, &rate)
|
|||
|
|
if err == sql.ErrNoRows {
|
|||
|
|
return nil, nil
|
|||
|
|
}
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, err
|
|||
|
|
}
|
|||
|
|
return &rate, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// SyncUserGroupRates 同步用户的分组专属倍率
|
|||
|
|
func (r *userGroupRateRepository) SyncUserGroupRates(ctx context.Context, userID int64, rates map[int64]*float64) error {
|
|||
|
|
if len(rates) == 0 {
|
|||
|
|
// 如果传入空 map,删除该用户的所有专属倍率
|
|||
|
|
_, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE user_id = $1`, userID)
|
|||
|
|
return err
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 分离需要删除和需要 upsert 的记录
|
|||
|
|
var toDelete []int64
|
|||
|
|
toUpsert := make(map[int64]float64)
|
|||
|
|
for groupID, rate := range rates {
|
|||
|
|
if rate == nil {
|
|||
|
|
toDelete = append(toDelete, groupID)
|
|||
|
|
} else {
|
|||
|
|
toUpsert[groupID] = *rate
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 删除指定的记录
|
|||
|
|
for _, groupID := range toDelete {
|
|||
|
|
_, err := r.sql.ExecContext(ctx,
|
|||
|
|
`DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2`,
|
|||
|
|
userID, groupID)
|
|||
|
|
if err != nil {
|
|||
|
|
return err
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Upsert 记录
|
|||
|
|
now := time.Now()
|
|||
|
|
for groupID, rate := range toUpsert {
|
|||
|
|
_, err := r.sql.ExecContext(ctx, `
|
|||
|
|
INSERT INTO user_group_rate_multipliers (user_id, group_id, rate_multiplier, created_at, updated_at)
|
|||
|
|
VALUES ($1, $2, $3, $4, $4)
|
|||
|
|
ON CONFLICT (user_id, group_id) DO UPDATE SET rate_multiplier = $3, updated_at = $4
|
|||
|
|
`, userID, groupID, rate, now)
|
|||
|
|
if err != nil {
|
|||
|
|
return err
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// DeleteByGroupID 删除指定分组的所有用户专属倍率
|
|||
|
|
func (r *userGroupRateRepository) DeleteByGroupID(ctx context.Context, groupID int64) error {
|
|||
|
|
_, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE group_id = $1`, groupID)
|
|||
|
|
return err
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// DeleteByUserID 删除指定用户的所有专属倍率
|
|||
|
|
func (r *userGroupRateRepository) DeleteByUserID(ctx context.Context, userID int64) error {
|
|||
|
|
_, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE user_id = $1`, userID)
|
|||
|
|
return err
|
|||
|
|
}
|